In [1]:
import pickle
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings

In [2]:
warnings.filterwarnings('ignore')

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
baseline_features = ['Last exon', 'Last 50nt penultimate exon', 'Exon length longer than 407nt', 'Less than 150nt to start']

In [4]:
df = pd.read_csv('../data/tcga_processed.tsv.gz', sep='\t')

### some variants are repeated multiple times
df = df.drop(columns=['Cancer_type', 'Cancer_type_count', 'NMF_cluster'])

agg_columns = [
    'NMD_efficiency', 't_vaf', 't_depth', 'n_vaf', 'n_depth',
    'VAF_RNA', 'depth_RNA', 'VAF_DNA_RNA_ratio', 'tpm_unstranded_x'
]

# Group by all other columns except the ones to aggregate
group_by_columns = [col for col in df.columns if col not in agg_columns]

# Perform grouping and aggregation without dropping NaNs
df_unq = df.groupby(
    group_by_columns, dropna=False, as_index=False
).agg({col: 'mean' for col in agg_columns})

df_unq

Unnamed: 0,build,chromosome,start,end,Hugo_Symbol,Transcript_ID,HGVSc,HGVSp,PTC_pos_codon,PTC_to_start_codon,...,var_token_idx,NMD_efficiency,t_vaf,t_depth,n_vaf,n_depth,VAF_RNA,depth_RNA,VAF_DNA_RNA_ratio,tpm_unstranded_x
0,GRCh38,chr1,944753,944753,NOC2L,ENST00000327044,c.2191C>T,p.Gln731Ter,731,2193,...,2240,-0.508648,0.311111,45.0,0.0,43.0,0.442623,244.0,1.422717,76.8018
1,GRCh38,chr1,952113,952113,NOC2L,ENST00000327044,c.1218G>A,p.Trp406Ter,406,1218,...,1267,1.629509,0.390244,41.0,0.0,14.0,0.126126,111.0,0.323198,49.1731
2,GRCh38,chr1,1255304,1255304,UBE2J2,ENST00000349431,c.679G>T,p.Gly227Ter,227,681,...,898,0.047557,0.480000,50.0,0.0,32.0,0.464435,239.0,0.967573,47.4110
3,GRCh38,chr1,1338573,1338573,DVL1,ENST00000378888,c.1288G>T,p.Glu430Ter,430,1290,...,1572,1.505167,0.413043,46.0,0.0,63.0,0.145511,323.0,0.352289,50.0535
4,GRCh38,chr1,1387314,1387314,CCNL2,ENST00000400809,c.1480C>T,p.Arg494Ter,494,1482,...,1485,-0.424007,0.407692,130.0,0.0,152.0,0.546980,298.0,1.341649,27.1883
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4088,GRCh38,chrX,152920717,152920717,ZNF185,ENST00000370268,c.622C>T,p.Gln208Ter,208,624,...,658,1.785660,0.328358,67.0,0.0,65.0,0.095238,42.0,0.290043,52.3474
4089,GRCh38,chrX,153650171,153650171,DUSP9,ENST00000342782,c.1021G>T,p.Glu341Ter,341,1023,...,1285,-1.360998,0.389313,131.0,0.0,185.0,1.000000,35.0,2.568627,16.9024
4090,GRCh38,chrX,154030948,154030948,MECP2,ENST00000303391,c.880C>T,p.Arg294Ter,294,882,...,1129,-1.166585,0.406977,86.0,0.0,70.0,0.913580,81.0,2.244797,11.5544
4091,GRCh38,chrX,154354015,154354015,FLNA,ENST00000369850,c.5586C>A,p.Tyr1862Ter,1862,5586,...,5834,0.332048,0.904762,189.0,0.0,134.0,0.718750,32.0,0.794408,63.2541


### Splits

In [5]:
test_var_ids = df_unq[df_unq['chromosome'].isin(['chr20', 'chr21', 'chr22'])]['var_id'].values.tolist()
test_var_ids = list(set(test_var_ids))
test_var_ids[0:5], len(test_var_ids)

(['ENST00000377813:c.481C>T',
  'ENST00000338074:c.2596G>T',
  'ENST00000380393:c.1566C>G',
  'ENST00000371242:c.781C>T',
  'ENST00000370873:c.106C>T'],
 232)

In [6]:
val_var_ids = df_unq[df_unq['chromosome']=='chr19']['var_id'].values.tolist()
val_var_ids = list(set(val_var_ids))
val_var_ids[0:5], len(val_var_ids)

(['ENST00000309311:c.193G>T',
  'ENST00000301454:c.310C>T',
  'ENST00000541714:c.2695C>T',
  'ENST00000309061:c.1594G>T',
  'ENST00000250896:c.1105C>T'],
 210)

In [7]:
train_var_ids = list(set(df_unq.var_id.values.tolist()) - set(val_var_ids) - set(test_var_ids))
train_var_ids[0:5], len(train_var_ids)

(['ENST00000538716:c.865C>T',
  'ENST00000380773:c.1171C>T',
  'ENST00000301030:c.5227C>T',
  'ENST00000261866:c.6856C>T',
  'ENST00000293831:c.931C>T'],
 3651)

### loaders

In [8]:
X_train = df_unq[df_unq['var_id'].isin(train_var_ids)][baseline_features]
X_train = X_train.replace({'Yes': 1, 'No': 0})
y_train = df_unq[df_unq['var_id'].isin(train_var_ids)]['NMD_efficiency']

X_test = df_unq[df_unq['var_id'].isin(test_var_ids)][baseline_features]
X_test = X_test.replace({'Yes': 1, 'No': 0})
y_test = df_unq[df_unq['var_id'].isin(test_var_ids)]['NMD_efficiency']

X_val = df_unq[df_unq['var_id'].isin(val_var_ids)][baseline_features]
X_val = X_val.replace({'Yes': 1, 'No': 0})
y_val = df_unq[df_unq['var_id'].isin(val_var_ids)]['NMD_efficiency']

In [9]:
# Define custom dataset
class CustomDataset(Dataset):
    def __init__(self, features, targets):
        self.features = torch.tensor(features.values, dtype=torch.float32)
        self.targets = torch.tensor(targets.values, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

In [10]:
# Create datasets
train_dataset = CustomDataset(X_train, y_train)
test_dataset = CustomDataset(X_test, y_test)
val_dataset = CustomDataset(X_val, y_val)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

### train MLP

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import pandas as pd
import numpy as np
import torch.nn.functional as F

# Hyperparameters
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 5e-4
HIDDEN_DIMS = [8, 8]
DROPOUT = 0.25
N_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
TRANSFORMATION = 'none'  # Options: 'z-score', 'min-max', 'log', 'none'

# Evaluation Metrics
def evaluate_regression_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy().flatten()
    y_pred = y_pred.cpu().numpy().flatten()
    loss = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(loss)
    r2 = r2_score(y_true, y_pred)
    
    if np.std(y_true) == 0 or np.std(y_pred) == 0:
        spearman_corr = np.nan
        pearson_corr = np.nan
    else:
        try:
            spearman_corr, _ = spearmanr(y_true, y_pred)
            pearson_corr, _ = pearsonr(y_true, y_pred)
        except ValueError:
            spearman_corr = np.nan
            pearson_corr = np.nan
    
    return {
        'loss': loss,
        'spearman_corr': spearman_corr,
        'pearson_corr': pearson_corr,
        'mae': mae,
        'rmse': rmse,
        'r2': r2
    }

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim, dtype=torch.float32))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, 1, dtype=torch.float32))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Transformation Functions
def apply_transformation(y, transformation, stats):
    if transformation == 'log':
        return torch.log1p(torch.clamp(y - stats['min'] + 1, min=1e-8))
    elif transformation == 'z-score':
        return (y - stats['mean']) / (stats['std'] + 1e-8)
    elif transformation == 'min-max':
        return (y - stats['min']) / (stats['max'] - stats['min'] + 1e-8)
    return y

def inverse_transformation(y, stats, transformation):
    if transformation == 'log':
        return torch.expm1(y) + stats['min'] - 1
    elif transformation == 'z-score':
        return y * stats['std'] + stats['mean']
    elif transformation == 'min-max':
        return y * (stats['max'] - stats['min']) + stats['min']
    return y

# Calculate Transformation Statistics
def calculate_transformation_stats(dataset, transformation):
    if transformation == 'none':
        return {}
    
    y = torch.cat([targets for _, targets in dataset], dim=0)
    stats = {}
    if transformation == 'log':
        stats['min'] = y.min()
    elif transformation == 'z-score':
        stats['mean'] = y.mean()
        stats['std'] = y.std()
    elif transformation == 'min-max':
        stats['min'] = y.min()
        stats['max'] = y.max()
    return stats

# Training Function
def train_mlp(train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = next(iter(train_loader))[0].shape[1]
    model = MLP(input_dim, HIDDEN_DIMS, DROPOUT).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Calculate global transformation stats
    transformation_stats = calculate_transformation_stats(train_loader, TRANSFORMATION)
    
    metrics = {'epoch': [], 'phase': [], 'loss': [], 'spearman_corr': [], 'pearson_corr':[], 'mae': [], 'rmse': [], 'r2': []}
    best_val_loss = float('inf')
    patience_counter = 0
    
    best_model_state = None
    
    for epoch in range(1, N_EPOCHS + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader
            
            all_preds, all_targets = [], []
            running_loss = 0.0
            
            with torch.set_grad_enabled(phase == 'train'):
                for inputs, targets in loader:
                    inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
                    original_targets = targets.clone()
                    
                    if TRANSFORMATION != 'none':
                        targets = apply_transformation(targets, TRANSFORMATION, transformation_stats)
                    
                    optimizer.zero_grad()
                    outputs = model(inputs).squeeze()
                    loss = criterion(outputs, targets)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    running_loss += loss.item() * inputs.size(0)
                    all_preds.append(outputs.detach())
                    all_targets.append(original_targets.detach())
            
            epoch_loss = running_loss / len(loader.dataset)
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            
            if TRANSFORMATION != 'none':
                all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
            
            epoch_metrics = evaluate_regression_metrics(all_targets, all_preds)
            epoch_metrics['loss'] = epoch_loss
            
            # Log metrics
            metrics['epoch'].append(epoch)
            metrics['phase'].append(phase)
            for key, value in epoch_metrics.items():
                metrics[key].append(value)
            
            # Early stopping logic
            if phase == 'val':
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    patience_counter = 0
                    best_model_state = model.state_dict()
                else:
                    patience_counter += 1
                    if patience_counter >= EARLY_STOPPING_PATIENCE:
                        print(f"Early stopping at epoch {epoch} with best validation loss: {best_val_loss:.4f}")
                        model.load_state_dict(best_model_state)
                        metrics_df = pd.DataFrame(metrics)
                        return metrics_df, evaluate_regression_metrics(all_targets, all_preds), model
            
            # Print metrics every 5 epochs for both phases
            if epoch % 5 == 0:
                print(
                    f"[Epoch {epoch:03d}] Phase: {phase:5s} | "
                    f"Loss: {epoch_metrics['loss']:.4f} | "
                    f"Spearman: {epoch_metrics['spearman_corr']:.4f} | "
                    f"Pearson: {epoch_metrics['pearson_corr']:.4f} | "
                    f"MAE: {epoch_metrics['mae']:.4f} | "
                    f"RMSE: {epoch_metrics['rmse']:.4f} | "
                    f"R²: {epoch_metrics['r2']:.4f}"
                )
    
    # Load the best model before final testing
    model.load_state_dict(best_model_state)
    
    # Final Test Evaluation
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
            outputs = model(inputs).squeeze()
            all_preds.append(outputs.detach())
            all_targets.append(targets.detach())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    if TRANSFORMATION != 'none':
        all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
    
    test_metrics = evaluate_regression_metrics(all_targets, all_preds)
    print(
        "\n📊 Final Test Metrics | "
        f"Loss: {test_metrics['loss']:.4f} | "
        f"Spearman: {test_metrics['spearman_corr']:.4f} | "
        f"MAE: {test_metrics['mae']:.4f} | "
        f"RMSE: {test_metrics['rmse']:.4f} | "
        f"R²: {test_metrics['r2']:.4f}"
    )
    
    metrics_df = pd.DataFrame(metrics)
    return metrics_df, test_metrics, model


In [13]:
# train and evaluate
print("START TRAINING")
metrics_df, test_metrics, model = train_mlp(train_loader, val_loader, test_loader)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.4660 | Spearman: 0.3431 | Pearson: 0.3268 | MAE: 0.9494 | RMSE: 1.2108 | R²: 0.0737
[Epoch 005] Phase: val   | Loss: 1.3598 | Spearman: 0.6445 | Pearson: 0.5639 | MAE: 0.8726 | RMSE: 1.1661 | R²: 0.2249
[Epoch 010] Phase: train | Loss: 1.2483 | Spearman: 0.4972 | Pearson: 0.4701 | MAE: 0.8688 | RMSE: 1.1173 | R²: 0.2112
[Epoch 010] Phase: val   | Loss: 1.1841 | Spearman: 0.6518 | Pearson: 0.5906 | MAE: 0.8210 | RMSE: 1.0882 | R²: 0.3251
[Epoch 015] Phase: train | Loss: 1.1860 | Spearman: 0.5422 | Pearson: 0.5044 | MAE: 0.8440 | RMSE: 1.0890 | R²: 0.2506
[Epoch 015] Phase: val   | Loss: 1.1632 | Spearman: 0.6518 | Pearson: 0.5986 | MAE: 0.8145 | RMSE: 1.0785 | R²: 0.3370
[Epoch 020] Phase: train | Loss: 1.1757 | Spearman: 0.5511 | Pearson: 0.5111 | MAE: 0.8383 | RMSE: 1.0843 | R²: 0.2571
[Epoch 020] Phase: val   | Loss: 1.1429 | Spearman: 0.6658 | Pearson: 0.6081 | MAE: 0.8065 | RMSE: 1.0691 | R²: 0.3486
[Epoch 025] Phase: train | Loss: 

In [14]:
metrics_df.to_csv('../res/metrics/baseline/per_epoch_metric.csv', index=False)

test_metrics_df = pd.DataFrame({'baseline_test_metrics':test_metrics})
test_metrics_df.to_csv('../res/metrics/baseline/test_metrics.csv', index=False)
test_metrics_df

Unnamed: 0,baseline_test_metrics
loss,1.136854
mae,0.804565
pearson_corr,0.604776
r2,0.352013
rmse,1.066234
spearman_corr,0.674507
