In [None]:
# Transformer VAE + Regressor for Species Abundance

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# --- Feature Embedding ---
class FeatureEmbedding(nn.Module):
    def __init__(self, num_features, d_model):
        super().__init__()
        self.embedding = nn.Embedding(num_features, d_model)

    def forward(self, x):
        batch_size, num_features = x.shape
        indices = torch.arange(num_features).unsqueeze(0).repeat(batch_size, 1).to(x.device)
        return self.embedding(indices)

# --- Transformer Encoder ---
class TransformerEncoderVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, d_model=128, n_heads=4, num_layers=2):
        super().__init__()
        self.input_proj = nn.Linear(1, d_model)
        self.feature_embed = FeatureEmbedding(input_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(input_dim * d_model, latent_dim)
        self.fc_logvar = nn.Linear(input_dim * d_model, latent_dim)

    def forward(self, x):
        x = x.unsqueeze(-1)  # (B, F) -> (B, F, 1)
        x = self.input_proj(x) + self.feature_embed(x.squeeze(-1))
        h = self.transformer_encoder(x)
        h_flat = self.flatten(h)
        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        return mu, logvar

# --- Decoder ---
class TransformerDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dim=128):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, z):
        return self.decoder(z)

# --- Full VAE ---
class TransformerVAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = TransformerEncoderVAE(input_dim, latent_dim)
        self.decoder = TransformerDecoder(latent_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar, z

# --- MLP Regressor ---
class LatentToSpeciesRegressor(nn.Module):
    def __init__(self, latent_dim, output_dim, hidden_dim=128):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, z):
        return self.model(z)

# --- Loss Function ---
def vae_loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kld, recon_loss, kld

# --- Example Training Loop Placeholder ---
# To use: Load your data into `X`, create dataloader, and loop over epochs.

# Example usage:
# vae = TransformerVAE(input_dim=X.shape[1], latent_dim=32)
# regressor = LatentToSpeciesRegressor(latent_dim=32, output_dim=X.shape[1])
# optimizer = optim.Adam(list(vae.parameters()) + list(regressor.parameters()), lr=1e-3)
# for epoch in range(epochs):
#     for batch in dataloader:
#         x = batch[0]
#         recon_x, mu, logvar, z = vae(x)
#         y_pred = regressor(z)
#         loss_vae, recon_loss, kld = vae_loss_function(recon_x, x, mu, logvar)
#         loss_pred = F.mse_loss(y_pred, x)
#         loss = loss_vae + loss_pred
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch}: VAE Loss={loss_vae.item():.4f}, Pred Loss={loss_pred.item():.4f}")
