In [1]:
import pickle
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings

In [2]:
warnings.filterwarnings('ignore')

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
df = pd.read_csv('../data/tcga_processed.tsv.gz', sep='\t')

### some variants are repeated multiple times
df = df.drop(columns=['Cancer_type', 'Cancer_type_count', 'NMF_cluster'])

agg_columns = [
    'NMD_efficiency', 't_vaf', 't_depth', 'n_vaf', 'n_depth',
    'VAF_RNA', 'depth_RNA', 'VAF_DNA_RNA_ratio', 'tpm_unstranded_x'
]

# Group by all other columns except the ones to aggregate
group_by_columns = [col for col in df.columns if col not in agg_columns]

# Perform grouping and aggregation without dropping NaNs
df_unq = df.groupby(
    group_by_columns, dropna=False, as_index=False
).agg({col: 'mean' for col in agg_columns})

df_unq

Unnamed: 0,build,chromosome,start,end,Hugo_Symbol,Transcript_ID,HGVSc,HGVSp,PTC_pos_codon,PTC_to_start_codon,...,var_token_idx,NMD_efficiency,t_vaf,t_depth,n_vaf,n_depth,VAF_RNA,depth_RNA,VAF_DNA_RNA_ratio,tpm_unstranded_x
0,GRCh38,chr1,944753,944753,NOC2L,ENST00000327044,c.2191C>T,p.Gln731Ter,731,2193,...,2240,-0.508648,0.311111,45.0,0.0,43.0,0.442623,244.0,1.422717,76.8018
1,GRCh38,chr1,952113,952113,NOC2L,ENST00000327044,c.1218G>A,p.Trp406Ter,406,1218,...,1267,1.629509,0.390244,41.0,0.0,14.0,0.126126,111.0,0.323198,49.1731
2,GRCh38,chr1,1255304,1255304,UBE2J2,ENST00000349431,c.679G>T,p.Gly227Ter,227,681,...,898,0.047557,0.480000,50.0,0.0,32.0,0.464435,239.0,0.967573,47.4110
3,GRCh38,chr1,1338573,1338573,DVL1,ENST00000378888,c.1288G>T,p.Glu430Ter,430,1290,...,1572,1.505167,0.413043,46.0,0.0,63.0,0.145511,323.0,0.352289,50.0535
4,GRCh38,chr1,1387314,1387314,CCNL2,ENST00000400809,c.1480C>T,p.Arg494Ter,494,1482,...,1485,-0.424007,0.407692,130.0,0.0,152.0,0.546980,298.0,1.341649,27.1883
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4088,GRCh38,chrX,152920717,152920717,ZNF185,ENST00000370268,c.622C>T,p.Gln208Ter,208,624,...,658,1.785660,0.328358,67.0,0.0,65.0,0.095238,42.0,0.290043,52.3474
4089,GRCh38,chrX,153650171,153650171,DUSP9,ENST00000342782,c.1021G>T,p.Glu341Ter,341,1023,...,1285,-1.360998,0.389313,131.0,0.0,185.0,1.000000,35.0,2.568627,16.9024
4090,GRCh38,chrX,154030948,154030948,MECP2,ENST00000303391,c.880C>T,p.Arg294Ter,294,882,...,1129,-1.166585,0.406977,86.0,0.0,70.0,0.913580,81.0,2.244797,11.5544
4091,GRCh38,chrX,154354015,154354015,FLNA,ENST00000369850,c.5586C>A,p.Tyr1862Ter,1862,5586,...,5834,0.332048,0.904762,189.0,0.0,134.0,0.718750,32.0,0.794408,63.2541


### Splits

In [4]:
test_var_ids = df_unq[df_unq['chromosome'].isin(['chr20', 'chr21', 'chr22'])]['var_id'].values.tolist()
test_var_ids = list(set(test_var_ids))
test_var_ids[0:5], len(test_var_ids)

(['ENST00000397527:c.3463C>T',
  'ENST00000375687:c.2077C>T',
  'ENST00000314328:c.2410C>T',
  'ENST00000449058:c.1366C>T',
  'ENST00000217244:c.916C>T'],
 232)

In [5]:
val_var_ids = df_unq[df_unq['chromosome']=='chr19']['var_id'].values.tolist()
val_var_ids = list(set(val_var_ids))
val_var_ids[0:5], len(val_var_ids)

(['ENST00000344099:c.1099C>T',
  'ENST00000420124:c.2992C>T',
  'ENST00000313434:c.1195C>T',
  'ENST00000300853:c.130C>T',
  'ENST00000317683:c.62C>A'],
 210)

In [6]:
train_var_ids = list(set(df_unq.var_id.values.tolist()) - set(val_var_ids) - set(test_var_ids))
train_var_ids[0:5], len(train_var_ids)

(['ENST00000360663:c.1228C>T',
  'ENST00000275235:c.355A>T',
  'ENST00000249776:c.169G>T',
  'ENST00000458591:c.1600G>T',
  'ENST00000521381:c.1740C>A'],
 3651)

In [7]:
# create dictionaries with NMD_efficiency as values
test_dict = {}
val_dict = {}
train_dict = {}
var_pos_idx_dict = {}

for index, row in df_unq.iterrows():
    if row['var_id'] in test_var_ids:
        test_dict[row['var_id']] = row['NMD_efficiency']
    elif row['var_id'] in val_var_ids:
        val_dict[row['var_id']] = row['NMD_efficiency']
    else:
        train_dict[row['var_id']] = row['NMD_efficiency']

    var_pos_idx_dict[row['var_id']] = row['var_token_idx']

### loaders

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader

def aggregate_tensors(embed_dict, value_dict, aggregation, var_pos_idx_dict=None):
    """
    Aggregates tensors based on the specified aggregation method.
    
    Args:
        embed_dict (dict): Dictionary where keys are IDs and values are 2D tensors.
        value_dict (dict): Dictionary where keys are IDs and values are single numbers.
        aggregation (str): Aggregation method ('mean', 'max', 'sum', 'product', 'token').
        var_pos_idx_dict (dict, optional): Dictionary where keys are IDs and values are token indices for aggregation.
        
    Returns:
        dict: Dictionary with aggregated values.
    """
    if aggregation not in ['mean', 'max', 'sum', 'product', 'token']:
        raise ValueError("Invalid aggregation method. Choose from 'mean', 'max', 'sum', 'product', 'token'.")
    
    if aggregation == 'token' and var_pos_idx_dict is None:
        raise ValueError("'var_pos_idx_dict' must be provided when aggregation is 'token'.")
    
    aggregated_dict = {}
    for key, tensor in embed_dict.items():
        if key not in value_dict:
            continue  # Ensure matching keys
        
        if aggregation == 'mean':
            aggregated_dict[key] = torch.mean(tensor, dim=0)
        elif aggregation == 'max':
            aggregated_dict[key], _ = torch.max(tensor, dim=0)
        elif aggregation == 'sum':
            aggregated_dict[key] = torch.sum(tensor, dim=0)
        elif aggregation == 'product':
            aggregated_dict[key] = torch.prod(tensor, dim=0)
        elif aggregation == 'token':
            if key not in var_pos_idx_dict:
                raise KeyError(f"'var_pos_idx_dict' must contain a token index for key '{key}'")
            
            var_pos_idx = var_pos_idx_dict[key]
            if var_pos_idx >= tensor.shape[0]:
                raise IndexError(f"Index {var_pos_idx} out of range for tensor with shape {tensor.shape}")
            aggregated_dict[key] = tensor[var_pos_idx]
    
    return aggregated_dict


class AggregatedDataset(Dataset):
    def __init__(self, aggregated_tensors, value_dict):
        self.data = []
        for key, tensor in aggregated_tensors.items():
            if key in value_dict:
                self.data.append((tensor, value_dict[key]))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


def create_data_loader(embed_dict, value_dict, aggregation, var_pos_idx_dict=None, shuffle=False, batch_size=32):
    """
    Creates a DataLoader for the aggregated dataset.
    
    Args:
        embed_dict (dict): Dictionary of embedding tensors.
        value_dict (dict): Dictionary of values associated with embeddings.
        aggregation (str): Aggregation method.
        var_pos_idx_dict (dict, optional): Token indices for 'token' aggregation.
        shuffle (bool): Whether to shuffle the dataset.
        batch_size (int): Batch size for DataLoader.
        
    Returns:
        DataLoader: PyTorch DataLoader instance.
    """
    if aggregation == 'token':
        agg_tensor = aggregate_tensors(embed_dict, value_dict, aggregation, var_pos_idx_dict=var_pos_idx_dict)
    else:
        agg_tensor = aggregate_tensors(embed_dict, value_dict, aggregation)
    
    agg_dataset = AggregatedDataset(agg_tensor, value_dict)
    dataloader = DataLoader(agg_dataset, batch_size=batch_size, shuffle=shuffle)
    
    return dataloader


### ref (mean, max, token)

In [9]:
with open('../data/tcga_ref_embeds.pkl', 'rb') as file:
    ref_embeds = pickle.load(file)

In [10]:
# Mean aggregation
train_mean_loader_ref = create_data_loader(ref_embeds, train_dict, 'mean', shuffle=True)
val_mean_loader_ref = create_data_loader(ref_embeds, val_dict, 'mean', shuffle=False)
test_mean_loader_ref = create_data_loader(ref_embeds, test_dict, 'mean', shuffle=False)

# Max aggregation
train_max_loader_ref = create_data_loader(ref_embeds, train_dict, 'max', shuffle=True)
val_max_loader_ref = create_data_loader(ref_embeds, val_dict, 'max', shuffle=False)
test_max_loader_ref = create_data_loader(ref_embeds, test_dict, 'max', shuffle=False)

# Token aggregation
train_token_loader_ref = create_data_loader(ref_embeds, train_dict, 'token', shuffle=True, var_pos_idx_dict=var_pos_idx_dict)
val_token_loader_ref = create_data_loader(ref_embeds, val_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)
test_token_loader_ref = create_data_loader(ref_embeds, test_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)

# Sum aggregation
train_sum_loader_ref = create_data_loader(ref_embeds, train_dict, 'sum', shuffle=True)
val_sum_loader_ref = create_data_loader(ref_embeds, val_dict, 'sum', shuffle=False)
test_sum_loader_ref = create_data_loader(ref_embeds, test_dict, 'sum', shuffle=False)

# Product aggregation
train_product_loader_ref = create_data_loader(ref_embeds, train_dict, 'product', shuffle=True)
val_product_loader_ref = create_data_loader(ref_embeds, val_dict, 'product', shuffle=False)
test_product_loader_ref = create_data_loader(ref_embeds, test_dict, 'product', shuffle=False)

### alt (mean, max, token)

In [11]:
with open('../data/tcga_alt_embeds.pkl', 'rb') as file:
    alt_embeds = pickle.load(file)

In [12]:
# Mean aggregation
train_mean_loader_alt = create_data_loader(alt_embeds, train_dict, 'mean', shuffle=True)
val_mean_loader_alt = create_data_loader(alt_embeds, val_dict, 'mean', shuffle=False)
test_mean_loader_alt = create_data_loader(alt_embeds, test_dict, 'mean', shuffle=False)

# Max aggregation
train_max_loader_alt = create_data_loader(alt_embeds, train_dict, 'max', shuffle=True)
val_max_loader_alt = create_data_loader(alt_embeds, val_dict, 'max', shuffle=False)
test_max_loader_alt = create_data_loader(alt_embeds, test_dict, 'max', shuffle=False)

# Token aggregation
train_token_loader_alt = create_data_loader(alt_embeds, train_dict, 'token', shuffle=True, var_pos_idx_dict=var_pos_idx_dict)
val_token_loader_alt = create_data_loader(alt_embeds, val_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)
test_token_loader_alt = create_data_loader(alt_embeds, test_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)

# Sum aggregation
train_sum_loader_alt = create_data_loader(alt_embeds, train_dict, 'sum', shuffle=True)
val_sum_loader_alt = create_data_loader(alt_embeds, val_dict, 'sum', shuffle=False)
test_sum_loader_alt = create_data_loader(alt_embeds, test_dict, 'sum', shuffle=False)

# Product aggregation
train_product_loader_alt = create_data_loader(alt_embeds, train_dict, 'product', shuffle=True)
val_product_loader_alt = create_data_loader(alt_embeds, val_dict, 'product', shuffle=False)
test_product_loader_alt = create_data_loader(alt_embeds, test_dict, 'product', shuffle=False)

### alt - ref (mean, max, token)

In [13]:
# create alt - ref embeds
altref_embeds = {}

for var_id, alt_embed in alt_embeds.items():
    altref_embeds[var_id] = alt_embed - ref_embeds[var_id]

In [14]:
# Mean aggregation
train_mean_loader_altref = create_data_loader(altref_embeds, train_dict, 'mean', shuffle=True)
val_mean_loader_altref = create_data_loader(altref_embeds, val_dict, 'mean', shuffle=False)
test_mean_loader_altref = create_data_loader(altref_embeds, test_dict, 'mean', shuffle=False)

# Max aggregation
train_max_loader_altref = create_data_loader(altref_embeds, train_dict, 'max', shuffle=True)
val_max_loader_altref = create_data_loader(altref_embeds, val_dict, 'max', shuffle=False)
test_max_loader_altref = create_data_loader(altref_embeds, test_dict, 'max', shuffle=False)

# Token aggregation
train_token_loader_altref = create_data_loader(altref_embeds, train_dict, 'token', shuffle=True, var_pos_idx_dict=var_pos_idx_dict)
val_token_loader_altref = create_data_loader(altref_embeds, val_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)
test_token_loader_altref = create_data_loader(altref_embeds, test_dict, 'token', shuffle=False, var_pos_idx_dict=var_pos_idx_dict)

# Sum aggregation
train_sum_loader_altref = create_data_loader(altref_embeds, train_dict, 'sum', shuffle=True)
val_sum_loader_altref = create_data_loader(altref_embeds, val_dict, 'sum', shuffle=False)
test_sum_loader_altref = create_data_loader(altref_embeds, test_dict, 'sum', shuffle=False)

# Product aggregation
train_product_loader_altref = create_data_loader(altref_embeds, train_dict, 'product', shuffle=True)
val_product_loader_altref = create_data_loader(altref_embeds, val_dict, 'product', shuffle=False)
test_product_loader_altref = create_data_loader(altref_embeds, test_dict, 'product', shuffle=False)

### Train and evaluate

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import pandas as pd
import numpy as np

# Hyperparameters
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 5e-4
HIDDEN_DIMS = [8, 8]
DROPOUT = 0.25
N_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
TRANSFORMATION = 'none'  # Options: 'z-score', 'min-max', 'log', 'none'

# Evaluation Metrics
def evaluate_regression_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy().flatten()
    y_pred = y_pred.cpu().numpy().flatten()
    loss = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(loss)
    r2 = r2_score(y_true, y_pred)
    
    try:
        spearman_corr, _ = spearmanr(y_true, y_pred)
    except ValueError:
        spearman_corr = np.nan

    try:
        pearson_corr, _ = pearsonr(y_true, y_pred)
    except ValueError:
        pearson_corr = np.nan
    
    return {
        'loss': loss,
        'spearman_corr': spearman_corr,
        'pearson_corr': pearson_corr,
        'mae': mae,
        'rmse': rmse,
        'r2': r2
    }

# MLP Model
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim, dtype=torch.float32))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, 1, dtype=torch.float32))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Transformation Functions
def apply_transformation(y, transformation, stats):
    if transformation == 'log':
        return torch.log1p(torch.clamp(y - stats['min'] + 1, min=1e-8))
    elif transformation == 'z-score':
        return (y - stats['mean']) / (stats['std'] + 1e-8)
    elif transformation == 'min-max':
        return (y - stats['min']) / (stats['max'] - stats['min'] + 1e-8)
    return y

def inverse_transformation(y, stats, transformation):
    if transformation == 'log':
        return torch.expm1(y) + stats['min'] - 1
    elif transformation == 'z-score':
        return y * stats['std'] + stats['mean']
    elif transformation == 'min-max':
        return y * (stats['max'] - stats['min']) + stats['min']
    return y

# Calculate Transformation Statistics
def calculate_transformation_stats(dataset, transformation):
    if transformation == 'none':
        return {}
    
    y = torch.cat([targets for _, targets in dataset], dim=0)
    stats = {}
    if transformation == 'log':
        stats['min'] = y.min()
    elif transformation == 'z-score':
        stats['mean'] = y.mean()
        stats['std'] = y.std()
    elif transformation == 'min-max':
        stats['min'] = y.min()
        stats['max'] = y.max()
    return stats

# Training Function
def train_mlp(train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = next(iter(train_loader))[0].shape[1]
    model = MLP(input_dim, HIDDEN_DIMS, DROPOUT).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Calculate global transformation stats
    transformation_stats = calculate_transformation_stats(train_loader, TRANSFORMATION)
    
    metrics = {'epoch': [], 'phase': [], 'loss': [], 'spearman_corr': [], 'pearson_corr':[], 'mae': [], 'rmse': [], 'r2': []}
    best_val_loss = float('inf')
    patience_counter = 0
    
    best_model_state = None
    
    for epoch in range(1, N_EPOCHS + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader
            
            all_preds, all_targets = [], []
            running_loss = 0.0
            
            with torch.set_grad_enabled(phase == 'train'):
                for inputs, targets in loader:
                    inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
                    original_targets = targets.clone()
                    
                    if TRANSFORMATION != 'none':
                        targets = apply_transformation(targets, TRANSFORMATION, transformation_stats)
                    
                    optimizer.zero_grad()
                    outputs = model(inputs).squeeze()
                    loss = criterion(outputs, targets)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    running_loss += loss.item() * inputs.size(0)
                    all_preds.append(outputs.detach())
                    all_targets.append(original_targets.detach())
            
            epoch_loss = running_loss / len(loader.dataset)
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            
            if TRANSFORMATION != 'none':
                all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
            
            epoch_metrics = evaluate_regression_metrics(all_targets, all_preds)
            epoch_metrics['loss'] = epoch_loss
            
            # Log metrics
            metrics['epoch'].append(epoch)
            metrics['phase'].append(phase)
            for key, value in epoch_metrics.items():
                metrics[key].append(value)
            
            # Early stopping logic
            if phase == 'val':
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    patience_counter = 0
                    best_model_state = model.state_dict()
                else:
                    patience_counter += 1
                    if patience_counter >= EARLY_STOPPING_PATIENCE:
                        print(f"Early stopping at epoch {epoch} with best validation loss: {best_val_loss:.4f}")
                        model.load_state_dict(best_model_state)
                        metrics_df = pd.DataFrame(metrics)
                        return metrics_df, evaluate_regression_metrics(all_targets, all_preds)
            
            # Print metrics every 5 epochs for both phases
            if epoch % 5 == 0:
                print(
                    f"[Epoch {epoch:03d}] Phase: {phase:5s} | "
                    f"Loss: {epoch_metrics['loss']:.4f} | "
                    f"Spearman: {epoch_metrics['spearman_corr']:.4f} | "
                    f"Pearson: {epoch_metrics['pearson_corr']:.4f} | "
                    f"MAE: {epoch_metrics['mae']:.4f} | "
                    f"RMSE: {epoch_metrics['rmse']:.4f} | "
                    f"R²: {epoch_metrics['r2']:.4f}"
                )
    
    # Load the best model before final testing
    model.load_state_dict(best_model_state)
    
    # Final Test Evaluation
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
            outputs = model(inputs).squeeze()
            all_preds.append(outputs.detach())
            all_targets.append(targets.detach())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    if TRANSFORMATION != 'none':
        all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
    
    test_metrics = evaluate_regression_metrics(all_targets, all_preds)
    print(
        "\n📊 Final Test Metrics | "
        f"Loss: {test_metrics['loss']:.4f} | "
        f"Spearman: {test_metrics['spearman_corr']:.4f} | "
        f"Pearson: {epoch_metrics['pearson_corr']:.4f} | "
        f"MAE: {test_metrics['mae']:.4f} | "
        f"RMSE: {test_metrics['rmse']:.4f} | "
        f"R²: {test_metrics['r2']:.4f}"
    )
    
    metrics_df = pd.DataFrame(metrics)
    return metrics_df, test_metrics


In [16]:
### Ref embeds

# ref mean
print("START TRAINING REF MEAN")
ref_mean_metrics_df, ref_mean_test_metrics = train_mlp(train_mean_loader_ref, val_mean_loader_ref, test_mean_loader_ref)
ref_mean_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/ref_mean_history_df.csv', index=False)

# ref max
print("START TRAINING REF MAX")
ref_max_metrics_df, ref_max_test_metrics = train_mlp(train_max_loader_ref, val_max_loader_ref, test_max_loader_ref)
ref_max_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/ref_max_metrics_df.csv', index=False)

# ref token
print("START TRAINING REF TOKEN")
ref_token_metrics_df, ref_token_test_metrics = train_mlp(train_token_loader_ref, val_token_loader_ref, test_token_loader_ref)
ref_token_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/ref_token_metrics_df.csv', index=False)

# ref sum
print("START TRAINING REF SUM")
ref_sum_metrics_df, ref_sum_test_metrics = train_mlp(train_sum_loader_ref, val_sum_loader_ref, test_sum_loader_ref)
ref_sum_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/ref_sum_metrics_df.csv', index=False)

START TRAINING REF MEAN
[Epoch 005] Phase: train | Loss: 1.6557 | Spearman: 0.0864 | Pearson: 0.0670 | MAE: 1.0480 | RMSE: 1.2867 | R²: -0.0462
[Epoch 005] Phase: val   | Loss: 1.6853 | Spearman: 0.4560 | Pearson: 0.4350 | MAE: 1.0419 | RMSE: 1.2982 | R²: 0.0394
[Epoch 010] Phase: train | Loss: 1.5316 | Spearman: 0.1867 | Pearson: 0.1972 | MAE: 1.0138 | RMSE: 1.2376 | R²: 0.0322
[Epoch 010] Phase: val   | Loss: 1.5830 | Spearman: 0.4541 | Pearson: 0.4374 | MAE: 1.0123 | RMSE: 1.2582 | R²: 0.0977
[Epoch 015] Phase: train | Loss: 1.4852 | Spearman: 0.2429 | Pearson: 0.2517 | MAE: 1.0012 | RMSE: 1.2187 | R²: 0.0615
[Epoch 015] Phase: val   | Loss: 1.5307 | Spearman: 0.4449 | Pearson: 0.4310 | MAE: 1.0032 | RMSE: 1.2372 | R²: 0.1275
[Epoch 020] Phase: train | Loss: 1.4801 | Spearman: 0.2718 | Pearson: 0.2579 | MAE: 1.0008 | RMSE: 1.2166 | R²: 0.0647
[Epoch 020] Phase: val   | Loss: 1.5198 | Spearman: 0.4614 | Pearson: 0.4409 | MAE: 0.9995 | RMSE: 1.2328 | R²: 0.1338
[Epoch 025] Phase: trai

In [17]:
### Alt embeds

# alt mean
print("START TRAINING ALT MEAN")
alt_mean_metrics_df, alt_mean_test_metrics = train_mlp(train_mean_loader_alt, val_mean_loader_alt, test_mean_loader_alt)
alt_mean_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/alt_mean_metrics_df.csv', index=False)

# alt max
print("START TRAINING ALT MAX")
alt_max_metrics_df, alt_max_test_metrics = train_mlp(train_max_loader_alt, val_max_loader_alt, test_max_loader_alt)
alt_max_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/alt_max_metrics_df.csv', index=False)

# alt token
print("START TRAINING ALT TOKEN")
alt_token_metrics_df, alt_token_test_metrics = train_mlp(train_token_loader_alt, val_token_loader_alt, test_token_loader_alt)
alt_token_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/alt_token_metrics_df.csv', index=False)

# alt sum
print("START TRAINING ALT SUM")
alt_sum_metrics_df, alt_sum_test_metrics = train_mlp(train_sum_loader_alt, val_sum_loader_alt, test_sum_loader_alt)
alt_sum_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/alt_sum_metrics_df.csv', index=False)

START TRAINING ALT MEAN
[Epoch 005] Phase: train | Loss: 1.5815 | Spearman: 0.1402 | Pearson: 0.1415 | MAE: 1.0288 | RMSE: 1.2576 | R²: 0.0007
[Epoch 005] Phase: val   | Loss: 1.6117 | Spearman: 0.4397 | Pearson: 0.4295 | MAE: 1.0143 | RMSE: 1.2695 | R²: 0.0814
[Epoch 010] Phase: train | Loss: 1.4400 | Spearman: 0.3103 | Pearson: 0.3040 | MAE: 0.9832 | RMSE: 1.2000 | R²: 0.0901
[Epoch 010] Phase: val   | Loss: 1.4863 | Spearman: 0.4491 | Pearson: 0.4363 | MAE: 0.9732 | RMSE: 1.2191 | R²: 0.1528
[Epoch 015] Phase: train | Loss: 1.4027 | Spearman: 0.3668 | Pearson: 0.3387 | MAE: 0.9635 | RMSE: 1.1843 | R²: 0.1137
[Epoch 015] Phase: val   | Loss: 1.4399 | Spearman: 0.4638 | Pearson: 0.4518 | MAE: 0.9674 | RMSE: 1.2000 | R²: 0.1793
[Epoch 020] Phase: train | Loss: 1.3761 | Spearman: 0.3699 | Pearson: 0.3619 | MAE: 0.9573 | RMSE: 1.1731 | R²: 0.1305
[Epoch 020] Phase: val   | Loss: 1.4277 | Spearman: 0.4674 | Pearson: 0.4497 | MAE: 0.9573 | RMSE: 1.1949 | R²: 0.1862
[Epoch 025] Phase: train

In [18]:
### Alt - Ref embeds

# altref mean
print("START TRAINING ALTREF MEAN")
altref_mean_metrics_df, altref_mean_test_metrics = train_mlp(train_mean_loader_altref, val_mean_loader_altref, test_mean_loader_altref)
altref_mean_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/altref_mean_metrics_df.csv', index=False)

# altref max
print("START TRAINING ALTREF MAX")
altref_max_metrics_df, altref_max_test_metrics = train_mlp(train_max_loader_altref, val_max_loader_altref, test_max_loader_altref)
altref_max_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/altref_max_metrics_df.csv', index=False)

# altref token
print("START TRAINING ALTREF TOKEN")
altref_token_metrics_df, altref_token_test_metrics = train_mlp(train_token_loader_altref, val_token_loader_altref, test_token_loader_altref)
altref_token_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/altref_token_metrics_df.csv', index=False)

# altref sum
print("START TRAINING ALTREF SUM")
altref_sum_metrics_df, altref_sum_test_metrics = train_mlp(train_sum_loader_altref, val_sum_loader_altref, test_sum_loader_altref)
altref_sum_metrics_df.to_csv('../res/metrics/per_epoch_agg1st/altref_sum_metrics_df.csv', index=False)

START TRAINING ALTREF MEAN
[Epoch 005] Phase: train | Loss: 1.5977 | Spearman: 0.0552 | Pearson: 0.0610 | MAE: 1.0480 | RMSE: 1.2640 | R²: -0.0095
[Epoch 005] Phase: val   | Loss: 1.7273 | Spearman: 0.2779 | Pearson: 0.1955 | MAE: 1.0686 | RMSE: 1.3143 | R²: 0.0155
[Epoch 010] Phase: train | Loss: 1.5154 | Spearman: 0.2123 | Pearson: 0.2101 | MAE: 1.0186 | RMSE: 1.2310 | R²: 0.0424
[Epoch 010] Phase: val   | Loss: 1.6991 | Spearman: 0.3062 | Pearson: 0.2250 | MAE: 1.0602 | RMSE: 1.3035 | R²: 0.0316
[Epoch 015] Phase: train | Loss: 1.4869 | Spearman: 0.2601 | Pearson: 0.2471 | MAE: 1.0078 | RMSE: 1.2194 | R²: 0.0604
[Epoch 015] Phase: val   | Loss: 1.6652 | Spearman: 0.3279 | Pearson: 0.2489 | MAE: 1.0401 | RMSE: 1.2904 | R²: 0.0509
[Epoch 020] Phase: train | Loss: 1.4261 | Spearman: 0.3431 | Pearson: 0.3165 | MAE: 0.9791 | RMSE: 1.1942 | R²: 0.0989
[Epoch 020] Phase: val   | Loss: 1.6243 | Spearman: 0.3660 | Pearson: 0.2876 | MAE: 1.0163 | RMSE: 1.2745 | R²: 0.0742
[Epoch 025] Phase: t

In [19]:
mean_metrics = pd.DataFrame({ 'ref':ref_mean_test_metrics,
                'alt': alt_mean_test_metrics, 
                'altref': altref_mean_test_metrics})

mean_metrics.to_csv('../res/metrics/test_agg1st/mean_metrics.csv', index=False)
mean_metrics

Unnamed: 0,ref,alt,altref
loss,1.365319,1.377834,1.491862
spearman_corr,0.46414,0.450427,0.391467
pearson_corr,0.444606,0.414521,0.320416
mae,0.89675,0.902667,0.935173
rmse,1.168469,1.173812,1.221418
r2,0.160706,0.153013,0.082917


In [20]:
max_metrics = pd.DataFrame({ 'ref':ref_max_test_metrics,
                'alt': alt_max_test_metrics, 
                'altref': altref_max_test_metrics})

max_metrics.to_csv('../res/metrics/test_agg1st/max_metrics.csv', index=False)
max_metrics

Unnamed: 0,ref,alt,altref
loss,1.758156,1.757412,1.708832
spearman_corr,,,0.241421
pearson_corr,,,0.198627
mae,1.087674,1.08636,1.061846
rmse,1.325955,1.325674,1.307223
r2,-0.002118,-0.001694,0.025995


In [21]:
token_metrics = pd.DataFrame({ 'ref':ref_token_test_metrics,
                'alt': alt_token_test_metrics, 
                'altref': altref_token_test_metrics})

token_metrics.to_csv('../res/metrics/test_agg1st/token_metrics.csv', index=False)
token_metrics

Unnamed: 0,ref,alt,altref
loss,1.669656,1.609213,1.720941
spearman_corr,0.256755,0.267639,0.200032
pearson_corr,0.225077,0.201883,0.153836
mae,1.048005,0.987705,1.06809
rmse,1.292152,1.268547,1.311847
r2,0.048325,0.010779,0.019093


In [22]:
sum_metrics = pd.DataFrame({ 'ref':ref_sum_test_metrics,
                'alt': alt_sum_test_metrics, 
                'altref': altref_sum_test_metrics})

sum_metrics.to_csv('../res/metrics/test_agg1st/sum_metrics.csv', index=False)
sum_metrics

Unnamed: 0,ref,alt,altref
loss,1.692386,1.759012,1.503381
spearman_corr,0.261613,,0.445995
pearson_corr,0.212503,,0.378416
mae,1.053869,1.089069,0.949296
rmse,1.300917,1.326277,1.226124
r2,0.03537,-0.002606,0.143099
