In [1]:
import sys
import pandas as pd
import ast
import pickle
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
import scanpy as sc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, precision_recall_curve
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

sys.path.append('../src/null-effect-net')
import utils
import models
import dataset
import train_utils

In [2]:
class SimpleMLPAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=1024, latent_dim=256, dropout_rate=0.3):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, latent_dim),
            nn.ReLU(),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim),
            # No activation here; assume reconstruction loss (e.g., MSE) will be used
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, latent):
        return self.decoder(latent)

    def forward(self, x):
        latent = self.encode(x)
        reconstructed = self.decode(latent)
        return reconstructed

    def compute_loss(self, reconstructed, original):
        return F.mse_loss(reconstructed, original)


In [3]:
# New simple dataset for autoencoder
class EmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, node_features_df, device):
        self.embeddings = torch.tensor(np.stack(node_features_df['Concat Embedding'].values), dtype=torch.float32, device=device)
    
    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx]


In [4]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        reconstructed = model(batch)
        loss = model.compute_loss(reconstructed, batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.size(0)

    avg_loss = total_loss / len(loader.dataset)
    return {'loss': avg_loss}

def evaluate(model, loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            reconstructed = model(batch)
            loss = model.compute_loss(reconstructed, batch)

            total_loss += loss.item() * batch.size(0)

    avg_loss = total_loss / len(loader.dataset)
    return {'loss': avg_loss}


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_utils.set_seed(42)

with open('../data/embeddings.pkl', 'rb') as f:
    node_features_df = pickle.load(f)

node_features_df['Concat Embedding'] = node_features_df['ESM Embedding'] + node_features_df['SubCell Embedding'] + node_features_df['PINNACLE Embedding']

train_df = pd.read_csv('../data/train.csv')

active_nodes_df = pd.read_csv('../data/expression_reference/expression_reference.csv', index_col=0)

train_dataset = EmbeddingDataset(node_features_df.iloc[:int(0.9*len(node_features_df))], device=device)
val_dataset = EmbeddingDataset(node_features_df.iloc[int(0.9*len(node_features_df)):], device=device)

input_dim = len(node_features_df['Concat Embedding'][0])

model = SimpleMLPAutoencoder(input_dim, hidden_dim=1024, latent_dim=256, dropout_rate=0.3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


In [7]:
num_epochs = 20
best_val_loss = float('inf')
patience = 5
counter = 0

for epoch in range(num_epochs):
    print(f"\n==== Epoch {epoch+1}/{num_epochs} ====")
    train_metrics = train_one_epoch(model, train_loader, optimizer, device)
    val_metrics = evaluate(model, val_loader, device)

    print(f"Train loss: {train_metrics['loss']:.4f}")
    print(f"Val loss: {val_metrics['loss']:.4f}")

    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        train_utils.save_model(model, "../models/simple_mlp_autoencoder_best.pt")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered!")
            break



==== Epoch 1/20 ====
Train loss: 0.0006
Val loss: 0.0005

==== Epoch 2/20 ====
Train loss: 0.0006
Val loss: 0.0005

==== Epoch 3/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 4/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 5/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 6/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 7/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 8/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 9/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 10/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 11/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 12/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 13/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 14/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 15/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 16/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== Epoch 17/20 ====
Train loss: 0.0005
Val loss: 0.0005

==== 