In [1]:
import scanpy as sc
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import mean_squared_error, r2_score, root_mean_squared_error

from torch.utils.tensorboard import SummaryWriter

In [2]:

class RegressionMLP(nn.Module):
    def __init__(self,):
        super().__init__()
        # Example architecture: 1280 -> 512 -> 256 -> 128 -> 1
        self.model = nn.Sequential(
            nn.Linear(1280, 512),
            nn.LeakyReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Dropout(0.3),

            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.3),

            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        return self.model(x)

In [3]:
target_variables_short = {
    "pearson_correlation_min_max": "pear_corr_min_max",
    "pearson_correlation_sigmoid": "pear_corr_sigmoid",
    "cosine_similarity_min_max": "cos_sim_min_max",
    "cosine_similarity_sigmoid": "cos_sim_sigmoid"
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

adata = sc.read("../../../data/anndata/train_val_adata.h5ad")

adata_train, adata_val = adata[adata.obs["split"] == "train"], adata[adata.obs["split"] == "validation"]

predictions = {}
for target_var in target_variables_short.keys():
    print(f"Fitting model on {target_var}...")
    writer = SummaryWriter(log_dir=f'runs/run_{target_var}')

    target_var_short = target_variables_short[target_var]
    
    X_train, y_train = torch.tensor(adata_train.X, dtype=torch.float32), torch.tensor(adata_train.obs[target_var].values, dtype=torch.float32).unsqueeze(1)
    X_val, y_val = torch.tensor(adata_val.X, dtype=torch.float32), torch.tensor(adata_val.obs[target_var].values, dtype=torch.float32).unsqueeze(1)

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = RegressionMLP()
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
    criterion = nn.MSELoss()

    epochs = 100
    train_loss = []

    for epoch in range(epochs):
        model.train()
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())

            if i % 10 == 9:
                writer.add_scalar("Training loss", sum(train_loss) / len(train_loss), epoch * len(train_loader) + i)
                train_loss = []

                # Track validation loss
                model.eval()
                val_loss = []
                with torch.no_grad():
                    for val_data in val_loader:
                        val_inputs, val_labels = val_data
                        val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                        val_outputs = model(val_inputs)
                        val_loss.append(criterion(val_outputs, val_labels).item())
                writer.add_scalar("Validation loss", sum(val_loss) / len(val_loss), epoch)
                model.train()
    
    # Final evaluation
    model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            all_predictions.extend(outputs.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    all_predictions = np.array(all_predictions).squeeze()
    all_targets = np.array(all_targets).squeeze()

    # Store predictions for bootstrapping
    predictions[f"{target_var_short}-true"] = all_targets
    predictions[f"{target_var_short}-pred"] = all_predictions

    mse = mean_squared_error(all_targets, all_predictions)
    rmse = root_mean_squared_error(all_targets, all_predictions)
    r2 = r2_score(all_targets, all_predictions)

    # Print results
    print(f"Validation {target_var_short} MSE: {mse:.4f}")
    print(f"Validation {target_var_short} RMSE: {rmse:.4f}")
    print(f"Validation {target_var_short} R2: {r2:.4f}\n")

    torch.save(model, f"../../../data/models/mlp/{target_var_short}_mlp.pt")

Fitting model on pearson_correlation_min_max...
Validation pear_corr_min_max MSE: 0.0015
Validation pear_corr_min_max RMSE: 0.0381
Validation pear_corr_min_max R2: 0.7387

Fitting model on pearson_correlation_sigmoid...
Validation pear_corr_sigmoid MSE: 0.0014
Validation pear_corr_sigmoid RMSE: 0.0375
Validation pear_corr_sigmoid R2: 0.7662

Fitting model on cosine_similarity_min_max...
Validation cos_sim_min_max MSE: 0.0028
Validation cos_sim_min_max RMSE: 0.0530
Validation cos_sim_min_max R2: 0.6850

Fitting model on cosine_similarity_sigmoid...
Validation cos_sim_sigmoid MSE: 0.0004
Validation cos_sim_sigmoid RMSE: 0.0191
Validation cos_sim_sigmoid R2: 0.9601

