In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from wj_autoencoders import Autoencoder, VariationalAutoencoder

In [15]:
seed = 37

np.random.seed(seed)
torch.manual_seed(seed)

# Set device (with seed)
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(seed)
elif torch.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
print(f"Using device {device} with manual seed {seed}.")

Using device mps with manual seed 37.


In [16]:
df = pd.read_csv('/Users/wjohns/Downloads/lish_moa_annotated.csv', low_memory=False)

# Separate treated samples from control samples
treated_df = df[df['cp_type'] == 'trt_cp'].reset_index(drop=True)
control_df = df[df['cp_type'] == 'ctrl_vehicle'].reset_index(drop=True)

# Extract gene expression and cell viability features
gene_cols = [col for col in df.columns if col.startswith('g-')]
cell_cols = [col for col in df.columns if col.startswith('c-')]
feature_cols = gene_cols + cell_cols
print(f"Total features: {len(feature_cols)} (Genes: {len(gene_cols)}, Cells: {len(cell_cols)})")

# Extract MoA columns (excluding metadata and feature columns)
metadata_cols = ['sig_id', 'drug_id', 'training', 'cp_type', 'cp_time', 'cp_dose']
moa_cols = [col for col in df.columns 
            if not col.startswith('g-') and 
               not col.startswith('c-') and 
               col not in metadata_cols]
print(f"Total MoA labels: {len(moa_cols)}")

Total features: 872 (Genes: 772, Cells: 100)
Total MoA labels: 608


In [17]:
# Extract features and labels from treated samples
X = treated_df[feature_cols].values
y = treated_df[moa_cols].values

# Split the data
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Normalize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train_scaled)
X_val_tensor = torch.FloatTensor(X_val_scaled)
y_train_tensor = torch.FloatTensor(y_train)
y_val_tensor = torch.FloatTensor(y_val)

In [18]:
# Create DataLoaders
batch_size = 1024 * 8
num_workers = 10 # Set to number of performance cores on your apple silicon

train_dataset = TensorDataset(X_train_tensor, X_train_tensor)  # Autoencoder input = output
val_dataset = TensorDataset(X_val_tensor, X_val_tensor)
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    persistent_workers=True, 
    shuffle=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    persistent_workers=True
)

In [21]:
# Set model parameters
input_dim = len(feature_cols)
hidden_dims = [512, 256, 128]  # Decreasing hidden layer sizes
latent_dim = 64  # Bottleneck dimension size

# Choose variational or standard autoencoder
use_vae = True  # Set to False for standard autoencoder

if use_vae:
    model = VariationalAutoencoder(input_dim, hidden_dims, latent_dim)
    print("Using Variational Autoencoder")
else:
    model = Autoencoder(input_dim, hidden_dims, latent_dim)
    print("Using Standard Autoencoder")

model.to(device)

Using Variational Autoencoder


VariationalAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=872, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=256, out_features=128, bias=True)
    (9): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.2, inplace=False)
  )
  (fc_mu): Linear(in_features=128, out_features=64, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=64, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
 

In [22]:
# Define the loss function
def vae_loss(recon_x, x, mu, log_var):
    # Reconstruction loss
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    
    # KL divergence
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    return MSE + 0.1 * KLD  # Beta parameter to control KL divergence weight

# Define optimizer and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, factor=0.5)

In [23]:
# Training function
def train_epoch(model, train_loader, optimizer, device, use_vae):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        if use_vae:
            recon_batch, mu, log_var = model(data)
            loss = vae_loss(recon_batch, data, mu, log_var)
        else:
            recon_batch = model(data)
            loss = F.mse_loss(recon_batch, data)
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    return train_loss / len(train_loader.dataset)

# Validation function
def validate(model, val_loader, device, use_vae):
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.to(device)
            
            if use_vae:
                recon_batch, mu, log_var = model(data)
                loss = vae_loss(recon_batch, data, mu, log_var)
            else:
                recon_batch = model(data)
                loss = F.mse_loss(recon_batch, data)
                
            val_loss += loss.item()
    
    return val_loss / len(val_loader.dataset)

In [24]:
# Train the model
num_epochs = 500
train_losses = []
val_losses = []
best_val_loss = float('inf')

print("Starting training...")
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, device, use_vae)
    val_loss = validate(model, val_loader, device, use_vae)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_autoencoder_model.pt')
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}')

Starting training...
Epoch 1/500, Train Loss: 1013.355367, Val Loss: 866.510753
Epoch 2/500, Train Loss: 878.618541, Val Loss: 773.051417
Epoch 3/500, Train Loss: 790.684973, Val Loss: 661.460313
Epoch 4/500, Train Loss: 720.795889, Val Loss: 748.529130
Epoch 5/500, Train Loss: 692.331867, Val Loss: 876.924242
Epoch 6/500, Train Loss: 675.782006, Val Loss: 995.968426
Epoch 7/500, Train Loss: 665.276715, Val Loss: 1021.887488
Epoch 8/500, Train Loss: 657.726475, Val Loss: 1002.379961
Epoch 9/500, Train Loss: 654.300093, Val Loss: 915.030890
Epoch 10/500, Train Loss: 649.027594, Val Loss: 797.385973
Epoch 11/500, Train Loss: 644.688872, Val Loss: 707.491642
Epoch 12/500, Train Loss: 641.361307, Val Loss: 662.263734
Epoch 13/500, Train Loss: 637.266596, Val Loss: 634.925709
Epoch 14/500, Train Loss: 632.314733, Val Loss: 614.062805
Epoch 15/500, Train Loss: 626.177005, Val Loss: 596.513587
Epoch 16/500, Train Loss: 620.314990, Val Loss: 586.933871
Epoch 17/500, Train Loss: 614.583284, Val

In [1]:
### Plot loss curve using Matplotlib
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train', marker='o')
plt.plot(val_losses, label='Validation', marker='x')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'plt' is not defined

In [26]:
# Load the best model
model.load_state_dict(torch.load('best_autoencoder_model.pt'))

# Generate embeddings using the trained autoencoder
def get_embeddings(model, data_loader, device, use_vae):
    model.eval()
    embeddings = []
    
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.to(device)
            
            if use_vae:
                mu, _ = model.encode(data)
                embeddings.append(mu.cpu().numpy())
            else:
                encoded = model.get_latent(data)
                embeddings.append(encoded.cpu().numpy())
    
    return np.vstack(embeddings)

# Get embeddings for train and validation sets
train_embeddings = get_embeddings(model, train_loader, device, use_vae)
val_embeddings = get_embeddings(model, val_loader, device, use_vae)

print(f"Train embeddings shape: {train_embeddings.shape}")
print(f"Validation embeddings shape: {val_embeddings.shape}")

Train embeddings shape: (20457, 64)
Validation embeddings shape: (5115, 64)
