In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        # Encoder
        self.encoder_fc1 = nn.Linear(input_dim, 128)
        self.encoder_fc2 = nn.Linear(128, 64)
        self.z_mean = nn.Linear(64, latent_dim)
        self.z_log_var = nn.Linear(64, latent_dim)
        
        # Signature matrix (latent_dim x input_dim) with enforced positivity
        self.signature_matrix = nn.Parameter(torch.abs(torch.randn(latent_dim, input_dim)))
        
    def encode(self, x):
        h = torch.relu(self.encoder_fc1(x))
        h = torch.relu(self.encoder_fc2(h))
        z_mean = self.z_mean(h)
        z_log_var = self.z_log_var(h)
        return z_mean, z_log_var
    
    def reparameterize(self, z_mean, z_log_var):
        epsilon = torch.randn_like(z_mean)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon
    
    def decode(self, exposures):
        # Multiply exposures (batch_size x latent_dim) with signature matrix (latent_dim x input_dim)
        # Ensure signature matrix remains positive
        positive_signatures = torch.abs(self.signature_matrix)
        return torch.matmul(exposures, positive_signatures)
    
    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed_x = self.decode(z)
        return reconstructed_x, z_mean, z_log_var

# Loss function
def vae_loss(reconstructed_x, x, z_mean, z_log_var, exposures):
    reconstruction_loss = nn.functional.mse_loss(reconstructed_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    # Add penalty for negative exposures
    negative_exposure_penalty = torch.sum(torch.relu(-exposures))
    return reconstruction_loss + kl_loss + negative_exposure_penalty

# Parameters
input_dim = 96  # Number of mutation contexts
latent_dim = 10  # Number of mutation signatures
batch_size = 16
epochs = 50
learning_rate = 0.001

# Example Data (replace with your mutation count data)
torch.manual_seed(42)
data = torch.poisson(torch.full((100, input_dim), 5.0))  # Simulated mutation data
data = data / data.sum(dim=1, keepdim=True)  # Normalize to probabilities

# DataLoader
dataset = TensorDataset(data)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, optimizer
model = VAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in data_loader:
        x = batch[0]
        optimizer.zero_grad()
        reconstructed_x, z_mean, z_log_var = model(x)
        exposures, _ = model.encode(x)
        loss = vae_loss(reconstructed_x, x, z_mean, z_log_var, exposures)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")

# Extract latent exposures and learned signatures
model.eval()
with torch.no_grad():
    exposures, _ = model.encode(data)  # Exposures: z_mean
    signatures = torch.abs(model.signature_matrix)  # Enforced positive signatures

# Reconstruct the original matrix
reconstructed_matrix = torch.matmul(exposures, signatures)

print("Original Matrix:", data)
print("Reconstructed Matrix:", reconstructed_matrix)
