In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# ---------------------------------------------
# 1) LOAD & PREPROCESS DATA
# ---------------------------------------------
def load_mutation_data(csv_path):
    # Load CSV file
    df = pd.read_csv(csv_path)
    # Keep only the Patient column and the mutation columns (adjust column indices as needed)
    df = df.iloc[:, [0] + list(range(263, len(df.columns) - 2))]
    
    # Extract patient IDs
    patient_ids = df['Patient'].values
    # Get binary mutation features; convert to float32 for PyTorch
    X = df.drop(columns='Patient').values.astype(np.float32)
    return X, patient_ids

# ---------------------------------------------
# 2) COMPUTE WEIGHTED LOSS FACTOR (POS_WEIGHT)
# ---------------------------------------------
def compute_pos_weight(X):
    """
    Compute per-feature positive weight as (# negatives) / (# positives).
    To avoid division by zero, add a small epsilon.
    """
    eps = 1e-6
    # X is (num_samples, num_features)
    num_positives = np.sum(X, axis=0)
    num_negatives = X.shape[0] - num_positives
    pos_weight = num_negatives / (num_positives + eps)
    return torch.tensor(pos_weight, dtype=torch.float32)

# ---------------------------------------------
# 3) DEFINE THE VAE MODEL
# ---------------------------------------------
class MutationVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=64, dropout_p=0.5):
        super(MutationVAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        # ENCODER
        # reduce overfitting by adding dropout layers
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_p)
        )
        # Two separate linear layers to output mean and log variance.
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # DECODER
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(512, input_dim)
            # No Sigmoid here because BCEWithLogitsLoss applies sigmoid internally.
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        logits = self.decoder(z)
        return logits

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        logits = self.decode(z)
        return logits, mu, logvar

# ---------------------------------------------
# 4) DEFINE A CUSTOM DATASET
# ---------------------------------------------
class MutationDataset(Dataset):
    def __init__(self, X):
        self.X = torch.from_numpy(X)  # shape: [num_samples, num_features]

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        # Autoencoder: input is also target.
        return self.X[idx], self.X[idx]

def make_dataloaders(X, train_ratio=0.8, batch_size=32):
    dataset = MutationDataset(X)
    n = len(dataset)
    n_train = int(train_ratio * n)
    n_val = n - n_train
    train_ds, val_ds = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

# ---------------------------------------------
# 5) DEFINE THE VAE LOSS FUNCTION
# ---------------------------------------------
def vae_loss_function(logits, target, mu, logvar, pos_weight, kl_weight=0.001):
    # Reconstruction loss: weighted BCEWithLogitsLoss (summed over batch and features)
    bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean')
    recon_loss = bce_loss(logits, target)
    # KL divergence loss: sum over latent dimensions
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_weight * kl_loss


def train_vae(model, train_loader, val_loader, pos_weight, 
                      num_epochs=50, lr=1e-3, device='cpu', kl_weight=0.001):
    model.to(device)
    # Use BCEWithLogitsLoss with the computed pos_weight.
    # pos_weight should be of shape (input_dim,)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for batch_x, batch_target in train_loader:
            batch_x = batch_x.to(device)
            batch_target = batch_target.to(device)
            # Forward pass: note that model returns logits.
            logits, mu, logvar = model(batch_x)
            
            loss = vae_loss_function(logits, batch_target, mu, logvar, pos_weight.to(device), kl_weight)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * len(batch_x)
        avg_train_loss = train_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_target in val_loader:
                batch_x = batch_x.to(device)
                batch_target = batch_target.to(device)
                logits, mu, logvar = model(batch_x)
                loss = vae_loss_function(logits, batch_target, mu, logvar, pos_weight.to(device), kl_weight)
                val_loss += loss.item() * len(batch_x)
        avg_val_loss = val_loss / len(val_loader.dataset)

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    print("Training complete.")

# ---------------------------------------------
# 6) VISUALIZE LATENT REPRESENTATIONS WITH t-SNE
# ---------------------------------------------
def visualize_latent(model, dataset, device='cpu'):
    model.eval()
    latent_list = []
    with torch.no_grad():
        loader = DataLoader(dataset, batch_size=32, shuffle=False)
        for batch_x, _ in loader:
            batch_x = batch_x.to(device)
            mu, _ = model.encode(batch_x)  # use mu as the latent representation
            latent_list.append(mu.cpu().numpy())
    latent_all = np.concatenate(latent_list, axis=0)
    print("Latent shape:", latent_all.shape)
    
    # Use t-SNE to reduce latent dimensions to 2D for visualization
    tsne = TSNE(n_components=2, random_state=42)
    latent_2d = tsne.fit_transform(latent_all)
    
    plt.figure(figsize=(8,6))
    plt.scatter(latent_2d[:, 0], latent_2d[:, 1], s=10, alpha=0.7)
    plt.title("t-SNE Visualization of Latent Representations")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    # save plot
    plt.savefig("../fig/latent_tsne.png")
    plt.show()

# ---------------------------------------------
# 7) MAIN: TRAIN MODEL, SAVE LATENT, VISUALIZE
# ---------------------------------------------
if __name__ == "__main__":
    # Load data
    csv_path = "../data/msk_2024_mutations_final.csv"
    X, patient_ids = load_mutation_data(csv_path)
    print("Data shape:", X.shape)

    # Compute the positive weight per feature (for weighted loss)
    pos_weight = compute_pos_weight(X)
    print("Computed pos_weight shape:", pos_weight.shape)

    # Create dataloaders
    train_loader, val_loader = make_dataloaders(X, train_ratio=0.8, batch_size=32)

    # Initialize the autoencoder model
    input_dim = X.shape[1]
    latent_dim = 128  # Adjust latent dimension as needed
    model = MutationVAE(input_dim=input_dim, latent_dim=latent_dim, dropout_p=0.5)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_vae(model, train_loader, val_loader, pos_weight, num_epochs=50, lr=1e-3, device=device, kl_weight=0.001)

    # Use full dataset for latent extraction (no noise)
    full_dataset = MutationDataset(X, noise_level=0.0)
    model.eval()
    all_latents = []
    with torch.no_grad():
        full_loader = DataLoader(full_dataset, batch_size=32, shuffle=False)
        for batch_x, _ in full_loader:
            batch_x = batch_x.to(device)
            mu, _ = model.encode(batch_x)
            all_latents.append(mu.cpu().numpy())
    all_latents = np.concatenate(all_latents, axis=0)
    
    # Save latent representations along with patient IDs to CSV.
    latent_cols = [f"latent_{i}" for i in range(latent_dim)]
    latent_df = pd.DataFrame(all_latents, columns=latent_cols)
    latent_df.insert(0, "Patient", patient_ids)
    latent_df.to_csv("latent_representations.csv", index=False)
    print("Saved latent representations to 'gene_vae.csv'.")

    # Visualize the latent space using t-SNE.
    visualize_latent(model, full_dataset, device=device)


Data shape: (23544, 2792)
Computed pos_weight shape: torch.Size([2792])
Epoch [1/50] Train Loss: 2.6873 | Val Loss: 1.6549
Epoch [2/50] Train Loss: 1.5861 | Val Loss: 1.6592
Epoch [3/50] Train Loss: 1.3913 | Val Loss: 1.7426
Epoch [4/50] Train Loss: 1.3299 | Val Loss: 1.5649
Epoch [5/50] Train Loss: 1.2963 | Val Loss: 1.4908
Epoch [6/50] Train Loss: 1.2690 | Val Loss: 1.6211
Epoch [7/50] Train Loss: 1.2434 | Val Loss: 1.5015
Epoch [8/50] Train Loss: 1.2401 | Val Loss: 1.5999
Epoch [9/50] Train Loss: 1.2280 | Val Loss: 1.5353
Epoch [10/50] Train Loss: 1.2297 | Val Loss: 1.6253
Epoch [11/50] Train Loss: 1.2090 | Val Loss: 1.5282
Epoch [12/50] Train Loss: 1.2114 | Val Loss: 1.5246
Epoch [13/50] Train Loss: 1.2132 | Val Loss: 1.5964
Epoch [14/50] Train Loss: 1.1923 | Val Loss: 1.6386
Epoch [15/50] Train Loss: 1.2007 | Val Loss: 1.6030
Epoch [16/50] Train Loss: 1.1927 | Val Loss: 1.6249
Epoch [17/50] Train Loss: 1.1827 | Val Loss: 1.5376
Epoch [18/50] Train Loss: 1.1770 | Val Loss: 1.6569
E

TypeError: MutationDataset.__init__() got an unexpected keyword argument 'noise_level'