In [1]:
import pickle
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings
from sklearn.impute import KNNImputer
import matplotlib.pyplot as plt


In [2]:
warnings.filterwarnings('ignore')

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
df_embeds = pd.read_csv('../../classical_ML/data/Orthrus_data/orthrus_features_clean.tsv.gz', sep='\t')
df_embeds.dropna(subset = ['embed0'], inplace=True)
embed_cols = [f'embed{i}' for i in range(512)]
df_embeds = df_embeds[['Transcript ID']+embed_cols]
df_embeds.rename(columns={"Transcript ID": "Transcript_ID"}, inplace=True)
df_embeds

Unnamed: 0,Transcript_ID,embed0,embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,...,embed502,embed503,embed504,embed505,embed506,embed507,embed508,embed509,embed510,embed511
0,ENST00000263100,0.122282,0.177534,-0.048803,-0.165734,-0.034072,-0.264644,-0.057117,-0.227631,-0.529456,...,-0.041493,0.312705,0.183835,0.282550,0.097991,0.247955,-0.156811,0.038840,-0.064718,-0.006802
1,ENST00000373997,0.342975,0.134082,0.039255,-0.224912,-0.052249,-0.148360,-0.062987,-0.092132,-0.442369,...,-0.327766,0.412198,0.237658,0.315500,0.118175,0.254118,-0.152559,0.185892,-0.129714,0.020999
2,ENST00000318602,0.331549,0.211808,-0.040499,-0.144805,-0.050687,-0.184935,0.127531,-0.119810,-0.621050,...,-0.248826,0.385974,0.509238,0.150244,0.153647,0.098654,-0.023546,0.113986,-0.099528,0.030393
3,ENST00000299698,0.337860,0.260349,-0.019156,-0.144985,-0.045211,-0.194701,0.104279,-0.138256,-0.562558,...,-0.233801,0.381406,0.526234,0.110568,0.128577,0.106606,-0.047536,0.093703,-0.121776,-0.045927
4,ENST00000442999,0.120457,0.181139,0.107973,-0.104078,-0.039713,-0.133477,0.049908,-0.193681,-0.174398,...,-0.172032,0.451827,0.141332,0.319514,0.230903,0.197436,-0.121235,0.083669,-0.072020,-0.016692
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18479,ENST00000371528,0.189914,0.113874,0.061029,-0.105412,-0.031373,-0.201013,-0.110164,-0.084314,-0.437392,...,-0.185700,0.380479,0.177239,0.300294,0.138531,0.178986,-0.119901,0.136907,-0.033310,-0.126740
18480,ENST00000294353,0.223297,0.086196,0.052600,-0.128977,-0.027397,-0.177775,-0.117537,-0.079050,-0.434602,...,-0.188737,0.404469,0.187927,0.321878,0.120758,0.228710,-0.140578,0.149540,0.003069,-0.073322
18481,ENST00000322764,0.254627,0.138448,0.156764,-0.182271,0.062424,-0.127684,-0.138878,-0.132653,-0.478736,...,-0.293639,0.455482,0.232366,0.347795,0.366158,0.279273,0.026741,0.107467,0.051708,-0.081295
18482,ENST00000381638,0.442461,0.263657,-0.102592,-0.080843,0.024760,-0.054935,0.484756,-0.382065,-0.816725,...,-0.153948,0.554390,1.176475,-0.155350,0.156531,0.152959,0.140080,-0.022837,-0.029495,-0.261842


In [4]:
df_unq = pd.read_csv('../data/tcga_annotated_clean4ML.tsv.gz', sep='\t')
df_unq = df_unq.merge(df_embeds, how='left')
df_unq

Unnamed: 0,CHROM,POS,REF,ALT,Transcript_ID,HGVSc,KIM_PTC_to_start_codon,KIM_upstream_exon_count,KIM_downstream_exon_count,KIM_last_exon,...,embed502,embed503,embed504,embed505,embed506,embed507,embed508,embed509,embed510,embed511
0,1,944753,C,T,ENST00000327044,ENST00000327044:c.2191C>T,2193,18,0,1,...,-0.266964,0.572462,0.137608,0.364898,0.278375,0.192540,-0.055143,0.177699,-0.033967,-0.119414
1,1,952113,G,A,ENST00000327044,ENST00000327044:c.1218G>A,1218,10,8,0,...,-0.266964,0.572462,0.137608,0.364898,0.278375,0.192540,-0.055143,0.177699,-0.033967,-0.119414
2,1,1255304,G,T,ENST00000349431,ENST00000349431:c.679G>T,681,6,0,1,...,-0.255506,0.340774,0.216285,0.367062,0.385427,0.249874,-0.126066,0.061456,-0.044280,0.038091
3,1,1338573,G,T,ENST00000378888,ENST00000378888:c.1288G>T,1290,11,3,0,...,-0.230311,0.540800,0.241533,0.349425,0.209922,0.237617,-0.086651,0.132743,-0.047608,-0.005701
4,1,1387314,C,T,ENST00000400809,ENST00000400809:c.1480C>T,1482,10,0,1,...,-0.111227,0.485855,0.080609,0.412174,0.167230,0.145487,-0.071526,0.118111,-0.018258,0.048301
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4013,X,152920717,C,T,ENST00000370268,ENST00000370268:c.622C>T,624,8,14,0,...,,,,,,,,,,
4014,X,153650171,G,T,ENST00000342782,ENST00000342782:c.1021G>T,1023,3,0,1,...,-0.196150,0.426428,0.199716,0.310246,0.163403,0.159778,-0.084737,0.111034,-0.035579,-0.101326
4015,X,154030948,C,T,ENST00000303391,ENST00000303391:c.880C>T,882,3,0,1,...,-0.123834,0.466796,0.185562,0.355276,0.198612,0.238267,-0.030854,0.173761,0.059287,-0.104817
4016,X,154354015,C,A,ENST00000369850,ENST00000369850:c.5586C>A,5586,34,13,0,...,-0.152846,0.604602,1.504313,-0.309990,0.027346,0.059884,0.179337,-0.128790,0.015879,-0.266402


In [5]:
prot_feat_df = pd.read_csv('../data/protein_AA_features.tsv.gz', sep='\t')
prot_feat_df.rename(columns={'MANE-Select':'Transcript_ID'}, inplace=True)
prot_feat_df.drop(columns=['Unnamed: 0', 'Entry'], axis=1, inplace=True)
prot_feat_df

Unnamed: 0,Transcript_ID,actual_or_pred_ACT_SITE,annotation_actual_ACT_SITE,annotation_pred_ACT_SITE,actual_or_pred_BINDING,annotation_actual_BINDING,annotation_pred_BINDING,actual_or_pred_COILED,annotation_actual_COILED,annotation_pred_COILED,...,annotation_pred_TRANSIT,actual_or_pred_TRANSMEM,annotation_actual_TRANSMEM,annotation_pred_TRANSMEM,actual_or_pred_TURN,annotation_actual_TURN,annotation_pred_TURN,actual_or_pred_ZN_FING,annotation_actual_ZN_FING,annotation_pred_ZN_FING
0,ENST00000436697,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000001100100000...,...,1111000000000000000000000000000000000000000000...,actual,0000111111111111111111111000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
1,ENST00000709217,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
2,ENST00000374922,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
3,ENST00000637218,actual,0000000000000000000000000000000000000000000000...,,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111000000000000000000000000...,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
4,ENST00000469902,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,actual,0000000000000000000000000000000000001111111111...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18412,ENST00000295896,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111100000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18413,ENST00000295898,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111011001000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18414,ENST00000306862,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18415,ENST00000450660,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111100000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...


In [6]:
# Calculate the proportion of lost features upon stop-gain
df_unq = df_unq.merge(prot_feat_df, how='left')

prot_feat_list = ['ACT_SITE', 'BINDING', 'COILED', 'COMPBIAS', 'DISULFID', 'BINDING', 'DOMAIN', 
                 'HELIX', 'MOD_RES', 'MOTIF', 'PROPEP', 'REGION', 'REPEAT', 'SIGNAL', 'STRAND',
                 'TOPO_DOM', 'TRANSIT', 'TRANSMEM', 'TURN', 'ZN_FING']

for prot_feat in prot_feat_list:
    df_unq[f'prop_lost_{prot_feat}'] = None

for idx, row in df_unq.iterrows():
    
    for prot_feat_name in prot_feat_list:
        temp_pred_or_actual = row[f'actual_or_pred_{prot_feat_name}']
        
        if not pd.isna(temp_pred_or_actual):
            
            prot_feat_binary = row[f'annotation_{temp_pred_or_actual}_{prot_feat_name}']
            
            if row['Protein_position'] <= len(prot_feat_binary):
                prot_feat_binary_lost = prot_feat_binary[row['Protein_position']:]
                df_unq.loc[idx, f'prop_lost_{prot_feat_name}'] = prot_feat_binary_lost.count('1') / len(prot_feat_binary)

prot_lost_feat_list = [f'prop_lost_{x}' for x in prot_feat_list]

In [7]:
def apply_baseline_rules(df, exon_len_threshold=407, penultimate_threshold=55, start_threshold=100):

    # long exon rule
    df['long_exon'] = 0
    df.loc[df['current_exon_len'] > exon_len_threshold, 'long_exon'] = 1

    # penultimate exon rule
    df['penultimate_flag'] = 0
    df.loc[
        (df['DIST_FROM_LAST_EXON'] < penultimate_threshold) &
        (df['DIST_FROM_LAST_EXON'] >= 0), 
        'penultimate_flag'
    ] = 1

    # close to start rule
    df['close_to_start'] = 0
    df.loc[df['CDS_position'] < start_threshold, 'close_to_start'] = 1

    return df


In [8]:
# last exon rule
df_unq['last_exon'] = 0
df_unq.loc[df_unq['current_exon_number']==df_unq['total_exon_numbers'], 'last_exon'] = 1

df_unq = apply_baseline_rules(df_unq, exon_len_threshold=355, penultimate_threshold=49, start_threshold=120)
df_unq

Unnamed: 0,CHROM,POS,REF,ALT,Transcript_ID,HGVSc,KIM_PTC_to_start_codon,KIM_upstream_exon_count,KIM_downstream_exon_count,KIM_last_exon,...,prop_lost_STRAND,prop_lost_TOPO_DOM,prop_lost_TRANSIT,prop_lost_TRANSMEM,prop_lost_TURN,prop_lost_ZN_FING,last_exon,long_exon,penultimate_flag,close_to_start
0,1,944753,C,T,ENST00000327044,ENST00000327044:c.2191C>T,2193,18,0,1,...,0.0,0.024032,0.0,0.0,0.0,0.0,1,1,0,0
1,1,952113,G,A,ENST00000327044,ENST00000327044:c.1218G>A,1218,10,8,0,...,0.0,0.457944,0.0,0.05474,0.0,0.0,0,0,0,0
2,1,1255304,G,T,ENST00000349431,ENST00000349431:c.679G>T,681,6,0,1,...,0.0,0.046332,0.0,0.07722,0.0,0.0,1,1,0,0
3,1,1338573,G,T,ENST00000378888,ENST00000378888:c.1288G>T,1290,11,3,0,...,0.0,0.381295,0.0,0.020144,0.0,0.0,0,0,0,0
4,1,1387314,C,T,ENST00000400809,ENST00000400809:c.1480C>T,1482,10,0,1,...,0.0,0.05,0.0,0.0,0.0,0.0,1,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4013,X,152920717,C,T,ENST00000370268,ENST00000370268:c.622C>T,624,8,14,0,...,,,,,,,0,0,0,0
4014,X,153650171,G,T,ENST00000342782,ENST00000342782:c.1021G>T,1023,3,0,1,...,0.0,0.111979,0.0,0.0,0.0,0.0,1,1,0,0
4015,X,154030948,C,T,ENST00000303391,ENST00000303391:c.880C>T,882,3,0,1,...,,,,,,,1,1,0,0
4016,X,154354015,C,A,ENST00000369850,ENST00000369850:c.5586C>A,5586,34,13,0,...,0.144314,0.296562,0.0,0.0,0.005667,0.0,0,0,0,0


In [9]:
KIM_features = ['KIM_PTC_to_start_codon', 'KIM_upstream_exon_count',
       'KIM_downstream_exon_count', 'KIM_last_exon', 'KIM_50nt_to_last_EJ',
       'KIM_PTC_exon_length', 'KIM_PTC_to_intron', 'KIM_dist_to_stop_codon',
       'KIM_mRNA_half_life', 'KIM_AF', 'KIM_LOEUF', 'KIM_5UTR_length',
       'KIM_3UTR_length', 'KIM_Transcript_length']

baseline_features = ['last_exon', 'long_exon', 'penultimate_flag', 'close_to_start'] 

my_features = ['last_exon', 'long_exon', 'penultimate_flag', 'close_to_start', 'hl',
       'mrl', 'cDNA_position', 'PERCENTILE', 'GERP_DIST', 'BP_DIST',
       'utr5_len', 'current_exon_len', 'VEST4_score', 'CADD_phred',
       'phyloP100way_vertebrate', 'dn_ds', 'abundance', 'shet',
       'lof.oe_ci.upper', 'lof.pRec', 'UTR5_GC', 'connectedness',
       'prop_lost_DOMAIN', 'prop_lost_HELIX', 'prop_lost_REGION',
       'prop_lost_TOPO_DOM', 'LoF_HC', 'NearestExonJB_dist', 'TF', 'tau',
       'phyloP17way_primate', 'phyloP470way_mammalian',
       'phastCons100way_vertebrate', 'phastCons17way_primate',
       'phastCons470way_mammalian', 'mis.z_score', 'syn.z_score', 'lof.pNull',
       'exp_var', 'utr3_len', 'total_exons_len', 'fathmm-XF_coding_score',
       'Nucleus', 'Cytosol', 'Cytoplasm', 'Ribosome', 'Membrane',
       'Endoplasmic_reticulum', 'embed6', 'embed8', 'embed90', 'embed145',
       'embed205', 'embed219', 'embed230', 'embed240', 'embed254', 'embed309',
       'embed356', 'embed430'] # selected based on ML_feature_selection notebook

#NearestExonJB_dist, Endoplasmic_reticulum, exp_var, LoF_HC, Nucleus, prop_lost_HELIX, phyloP17way_primate, GERP_DIST, total_exons_len, connectedness,
# syn.z_score, Cytosol, prop_lost_REGION
# utr3_len, lof.pNull
# prop_lost_TOPO_DOM
# embed6
# phyloP470way_mammalian

### Splits

In [10]:
test_var_ids = df_unq[df_unq['CHROM'].isin(['20', '21', '22'])]['HGVSc'].values.tolist()
test_var_ids = list(set(test_var_ids))
test_var_ids[0:5], len(test_var_ids)

(['ENST00000314328:c.2410C>T',
  'ENST00000322927:c.2098C>T',
  'ENST00000252934:c.1036G>T',
  'ENST00000399151:c.3241C>T',
  'ENST00000291688:c.5803C>T'],
 224)

In [11]:
val_var_ids = df_unq[df_unq['CHROM']=='19']['HGVSc'].values.tolist()
val_var_ids = list(set(val_var_ids))
val_var_ids[0:5], len(val_var_ids)

(['ENST00000263377:c.1267C>T',
  'ENST00000221480:c.190C>T',
  'ENST00000282286:c.853G>T',
  'ENST00000359866:c.181C>T',
  'ENST00000253193:c.487C>T'],
 204)

In [12]:
train_var_ids = list(set(df_unq.HGVSc.values.tolist()) - set(val_var_ids) - set(test_var_ids))
train_var_ids[0:5], len(train_var_ids)

(['ENST00000283943:c.4123G>T',
  'ENST00000255082:c.456C>A',
  'ENST00000330636:c.202C>T',
  'ENST00000301785:c.1015C>T',
  'ENST00000283875:c.1129C>T'],
 3590)

### loaders

In [13]:
X_train_base = df_unq[df_unq['HGVSc'].isin(train_var_ids)][baseline_features]
X_train_kim = df_unq[df_unq['HGVSc'].isin(train_var_ids)][KIM_features]
X_train_mine = df_unq[df_unq['HGVSc'].isin(train_var_ids)][my_features]
y_train = df_unq[df_unq['HGVSc'].isin(train_var_ids)]['NMD_efficiency']

X_test_base = df_unq[df_unq['HGVSc'].isin(test_var_ids)][baseline_features]
X_test_kim = df_unq[df_unq['HGVSc'].isin(test_var_ids)][KIM_features]
X_test_mine = df_unq[df_unq['HGVSc'].isin(test_var_ids)][my_features]
y_test = df_unq[df_unq['HGVSc'].isin(test_var_ids)]['NMD_efficiency']

X_val_base = df_unq[df_unq['HGVSc'].isin(val_var_ids)][baseline_features]
X_val_kim = df_unq[df_unq['HGVSc'].isin(val_var_ids)][KIM_features]
X_val_mine = df_unq[df_unq['HGVSc'].isin(val_var_ids)][my_features]
y_val = df_unq[df_unq['HGVSc'].isin(val_var_ids)]['NMD_efficiency']

In [14]:
def correct_dtype(X):
    X = X.apply(pd.to_numeric, errors="coerce")
    return X

X_train_mine = correct_dtype(X_train_mine)
X_test_mine = correct_dtype(X_test_mine)
X_val_mine = correct_dtype(X_val_mine)

In [15]:
train_feature_means_base = X_train_base.mean()
train_feature_means_kim = X_train_kim.mean()
train_feature_means_mine = X_train_mine.mean()

train_feature_stds_base = X_train_base.std()
train_feature_stds_kim = X_train_kim.std()
train_feature_stds_mine = X_train_mine.std()

In [16]:
train_mean_std_df = pd.DataFrame({'mean':train_feature_means_mine,
                                 'std':train_feature_stds_mine})
train_mean_std_df.to_csv('../data/train_mean_std_df.csv')

In [17]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, feature_means=None, feature_stds=None):
        """
        Args:
            data (pd.DataFrame): Feature matrix with possible NaNs.
            targets (pd.Series or torch.Tensor): Target values.
            feature_means (pd.Series): Feature means calculated from the training dataset.
            feature_stds (pd.Series): Feature standard deviations from the training dataset.
        """
        
        if feature_means is not None:
            # Fill missing values using the training feature means
            data = data.fillna(feature_means)
        
        if feature_means is not None and feature_stds is not None:
            # Normalize using training statistics
            data = (data - feature_means) / (feature_stds + 1e-8)  # Add epsilon to prevent division by zero
        
        self.data = data.values
        self.targets = targets.values if isinstance(targets, pd.Series) else targets
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx], dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        return x, y


In [18]:
# Create datasets
train_dataset_base = CustomDataset(X_train_base, y_train, train_feature_means_base, train_feature_stds_base)
test_dataset_base = CustomDataset(X_test_base, y_test, train_feature_means_base, train_feature_stds_base)
val_dataset_base = CustomDataset(X_val_base, y_val, train_feature_means_base, train_feature_stds_base)

train_dataset_kim = CustomDataset(X_train_kim, y_train, train_feature_means_kim, train_feature_stds_kim)
test_dataset_kim = CustomDataset(X_test_kim, y_test, train_feature_means_kim, train_feature_stds_kim)
val_dataset_kim = CustomDataset(X_val_kim, y_val, train_feature_means_kim, train_feature_stds_kim)

train_dataset_mine = CustomDataset(X_train_mine, y_train, train_feature_means_mine, train_feature_stds_mine)
test_dataset_mine = CustomDataset(X_test_mine, y_test, train_feature_means_mine, train_feature_stds_mine)
val_dataset_mine = CustomDataset(X_val_mine, y_val, train_feature_means_mine, train_feature_stds_mine)

# Create dataloaders
batch_size = 32
train_loader_base = DataLoader(train_dataset_base, batch_size=batch_size, shuffle=True)
test_loader_base = DataLoader(test_dataset_base, batch_size=batch_size, shuffle=False)
val_loader_base = DataLoader(val_dataset_base, batch_size=batch_size, shuffle=False)

train_loader_kim = DataLoader(train_dataset_kim, batch_size=batch_size, shuffle=True)
test_loader_kim = DataLoader(test_dataset_kim, batch_size=batch_size, shuffle=False)
val_loader_kim = DataLoader(val_dataset_kim, batch_size=batch_size, shuffle=False)

train_loader_mine = DataLoader(train_dataset_mine, batch_size=batch_size, shuffle=True)
test_loader_mine = DataLoader(test_dataset_mine, batch_size=batch_size, shuffle=False)
val_loader_mine = DataLoader(val_dataset_mine, batch_size=batch_size, shuffle=False)

### train MLP

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import pandas as pd
import numpy as np
import torch.nn.functional as F

# Hyperparameters
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 5e-4
HIDDEN_DIMS = [8, 8]
DROPOUT = 0.25
N_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
TRANSFORMATION = 'none'  # Options: 'z-score', 'min-max', 'log', 'none'

# Evaluation Metrics
def evaluate_regression_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy().flatten()
    y_pred = y_pred.cpu().numpy().flatten()
    loss = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(loss)
    r2 = r2_score(y_true, y_pred)
    
    if np.std(y_true) == 0 or np.std(y_pred) == 0:
        spearman_corr = np.nan
        pearson_corr = np.nan
    else:
        try:
            spearman_corr, _ = spearmanr(y_true, y_pred)
            pearson_corr, _ = pearsonr(y_true, y_pred)
        except ValueError:
            spearman_corr = np.nan
            pearson_corr = np.nan
    
    return {
        'loss': loss,
        'spearman_corr': spearman_corr,
        'pearson_corr': pearson_corr,
        'mae': mae,
        'rmse': rmse,
        'r2': r2
    }

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim, dtype=torch.float32))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, 1, dtype=torch.float32))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

# Transformation Functions
def apply_transformation(y, transformation, stats):
    if transformation == 'log':
        return torch.log1p(torch.clamp(y - stats['min'] + 1, min=1e-8))
    elif transformation == 'z-score':
        return (y - stats['mean']) / (stats['std'] + 1e-8)
    elif transformation == 'min-max':
        return (y - stats['min']) / (stats['max'] - stats['min'] + 1e-8)
    return y

def inverse_transformation(y, stats, transformation):
    if transformation == 'log':
        return torch.expm1(y) + stats['min'] - 1
    elif transformation == 'z-score':
        return y * stats['std'] + stats['mean']
    elif transformation == 'min-max':
        return y * (stats['max'] - stats['min']) + stats['min']
    return y

# Calculate Transformation Statistics
def calculate_transformation_stats(dataset, transformation):
    if transformation == 'none':
        return {}
    
    y = torch.cat([targets for _, targets in dataset], dim=0)
    stats = {}
    if transformation == 'log':
        stats['min'] = y.min()
    elif transformation == 'z-score':
        stats['mean'] = y.mean()
        stats['std'] = y.std()
    elif transformation == 'min-max':
        stats['min'] = y.min()
        stats['max'] = y.max()
    return stats

# Training Function
def train_mlp(train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = next(iter(train_loader))[0].shape[1]
    model = MLP(input_dim, HIDDEN_DIMS, DROPOUT).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Calculate global transformation stats
    transformation_stats = calculate_transformation_stats(train_loader, TRANSFORMATION)
    
    metrics = {'epoch': [], 'phase': [], 'loss': [], 'spearman_corr': [], 'pearson_corr':[], 'mae': [], 'rmse': [], 'r2': []}
    best_val_loss = float('inf')
    patience_counter = 0
    
    best_model_state = None
    
    for epoch in range(1, N_EPOCHS + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader
            
            all_preds, all_targets = [], []
            running_loss = 0.0
            
            with torch.set_grad_enabled(phase == 'train'):
                for inputs, targets in loader:
                    inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
                    original_targets = targets.clone()
                    
                    if TRANSFORMATION != 'none':
                        targets = apply_transformation(targets, TRANSFORMATION, transformation_stats)
                    
                    optimizer.zero_grad()
                    outputs = model(inputs).squeeze()
                    loss = criterion(outputs, targets)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    running_loss += loss.item() * inputs.size(0)
                    all_preds.append(outputs.detach())
                    all_targets.append(original_targets.detach())
            
            epoch_loss = running_loss / len(loader.dataset)
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            
            if TRANSFORMATION != 'none':
                all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
            
            epoch_metrics = evaluate_regression_metrics(all_targets, all_preds)
            epoch_metrics['loss'] = epoch_loss
            
            # Log metrics
            metrics['epoch'].append(epoch)
            metrics['phase'].append(phase)
            for key, value in epoch_metrics.items():
                metrics[key].append(value)
            
            # Early stopping logic
            if phase == 'val':
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    patience_counter = 0
                    best_model_state = model.state_dict()
                else:
                    patience_counter += 1
                    if patience_counter >= EARLY_STOPPING_PATIENCE:
                        print(f"Early stopping at epoch {epoch} with best validation loss: {best_val_loss:.4f}")
                        model.load_state_dict(best_model_state)
                        metrics_df = pd.DataFrame(metrics)
                        return metrics_df, evaluate_regression_metrics(all_targets, all_preds), model
            
            # Print metrics every 5 epochs for both phases
            if epoch % 5 == 0:
                print(
                    f"[Epoch {epoch:03d}] Phase: {phase:5s} | "
                    f"Loss: {epoch_metrics['loss']:.4f} | "
                    f"Spearman: {epoch_metrics['spearman_corr']:.4f} | "
                    f"Pearson: {epoch_metrics['pearson_corr']:.4f} | "
                    f"MAE: {epoch_metrics['mae']:.4f} | "
                    f"RMSE: {epoch_metrics['rmse']:.4f} | "
                    f"R²: {epoch_metrics['r2']:.4f}"
                )
    
    # Load the best model before final testing
    model.load_state_dict(best_model_state)
    
    # Final Test Evaluation
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
            outputs = model(inputs).squeeze()
            all_preds.append(outputs.detach())
            all_targets.append(targets.detach())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    if TRANSFORMATION != 'none':
        all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
    
    test_metrics = evaluate_regression_metrics(all_targets, all_preds)
    print(
        "\n📊 Final Test Metrics | "
        f"Loss: {test_metrics['loss']:.4f} | "
        f"Spearman: {test_metrics['spearman_corr']:.4f} | "
        f"Pearson: {epoch_metrics['pearson_corr']:.4f} | "
        f"MAE: {test_metrics['mae']:.4f} | "
        f"RMSE: {test_metrics['rmse']:.4f} | "
        f"R²: {test_metrics['r2']:.4f}"
    )
    
    metrics_df = pd.DataFrame(metrics)
    return metrics_df, test_metrics, model


In [19]:
# train and evaluate
print("START TRAINING")
metrics_df_base, test_metrics_base, model_base = train_mlp(train_loader_base, val_loader_base, test_loader_base)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.2764 | Spearman: 0.5012 | Pearson: 0.4630 | MAE: 0.8757 | RMSE: 1.1298 | R²: 0.1906
[Epoch 005] Phase: val   | Loss: 1.0151 | Spearman: 0.7014 | Pearson: 0.6495 | MAE: 0.7653 | RMSE: 1.0075 | R²: 0.3783
[Epoch 010] Phase: train | Loss: 1.1785 | Spearman: 0.5542 | Pearson: 0.5159 | MAE: 0.8360 | RMSE: 1.0856 | R²: 0.2527
[Epoch 010] Phase: val   | Loss: 0.9662 | Spearman: 0.7064 | Pearson: 0.6585 | MAE: 0.7460 | RMSE: 0.9830 | R²: 0.4083
[Epoch 015] Phase: train | Loss: 1.1807 | Spearman: 0.5532 | Pearson: 0.5120 | MAE: 0.8270 | RMSE: 1.0866 | R²: 0.2513
[Epoch 015] Phase: val   | Loss: 0.9625 | Spearman: 0.7107 | Pearson: 0.6606 | MAE: 0.7485 | RMSE: 0.9811 | R²: 0.4105
[Epoch 020] Phase: train | Loss: 1.1372 | Spearman: 0.5626 | Pearson: 0.5339 | MAE: 0.8145 | RMSE: 1.0664 | R²: 0.2789
[Epoch 020] Phase: val   | Loss: 0.9580 | Spearman: 0.7107 | Pearson: 0.6604 | MAE: 0.7511 | RMSE: 0.9788 | R²: 0.4133
[Epoch 025] Phase: train | Loss: 

In [20]:
# train and evaluate
print("START TRAINING")
metrics_df_kim, test_metrics_kim, model_kim = train_mlp(train_loader_kim, val_loader_kim, test_loader_kim)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.3374 | Spearman: 0.4596 | Pearson: 0.4055 | MAE: 0.8975 | RMSE: 1.1565 | R²: 0.1519
[Epoch 005] Phase: val   | Loss: 1.1315 | Spearman: 0.5833 | Pearson: 0.5833 | MAE: 0.8137 | RMSE: 1.0637 | R²: 0.3070
[Epoch 010] Phase: train | Loss: 1.1550 | Spearman: 0.5475 | Pearson: 0.5209 | MAE: 0.8214 | RMSE: 1.0747 | R²: 0.2676
[Epoch 010] Phase: val   | Loss: 1.0475 | Spearman: 0.6603 | Pearson: 0.6073 | MAE: 0.7818 | RMSE: 1.0235 | R²: 0.3585
[Epoch 015] Phase: train | Loss: 1.1300 | Spearman: 0.5521 | Pearson: 0.5351 | MAE: 0.8166 | RMSE: 1.0630 | R²: 0.2835
[Epoch 015] Phase: val   | Loss: 1.0491 | Spearman: 0.6559 | Pearson: 0.6070 | MAE: 0.7792 | RMSE: 1.0242 | R²: 0.3575
[Epoch 020] Phase: train | Loss: 1.1051 | Spearman: 0.5625 | Pearson: 0.5492 | MAE: 0.8078 | RMSE: 1.0512 | R²: 0.2993
[Epoch 020] Phase: val   | Loss: 1.0418 | Spearman: 0.6552 | Pearson: 0.6069 | MAE: 0.7783 | RMSE: 1.0207 | R²: 0.3620
[Epoch 025] Phase: train | Loss: 

In [21]:
print("START TRAINING")
metrics_df_mine, test_metrics_mine, model_mine = train_mlp(train_loader_mine, val_loader_mine, test_loader_mine)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.2480 | Spearman: 0.5237 | Pearson: 0.4757 | MAE: 0.8695 | RMSE: 1.1171 | R²: 0.2086
[Epoch 005] Phase: val   | Loss: 0.8864 | Spearman: 0.7451 | Pearson: 0.7173 | MAE: 0.7178 | RMSE: 0.9415 | R²: 0.4572
[Epoch 010] Phase: train | Loss: 1.0527 | Spearman: 0.6229 | Pearson: 0.5828 | MAE: 0.7841 | RMSE: 1.0260 | R²: 0.3325
[Epoch 010] Phase: val   | Loss: 0.7929 | Spearman: 0.7632 | Pearson: 0.7331 | MAE: 0.6648 | RMSE: 0.8905 | R²: 0.5144
[Epoch 015] Phase: train | Loss: 0.9921 | Spearman: 0.6445 | Pearson: 0.6127 | MAE: 0.7570 | RMSE: 0.9961 | R²: 0.3709
[Epoch 015] Phase: val   | Loss: 0.7956 | Spearman: 0.7597 | Pearson: 0.7297 | MAE: 0.6684 | RMSE: 0.8920 | R²: 0.5127
[Epoch 020] Phase: train | Loss: 0.9889 | Spearman: 0.6374 | Pearson: 0.6134 | MAE: 0.7577 | RMSE: 0.9945 | R²: 0.3729
[Epoch 020] Phase: val   | Loss: 0.7883 | Spearman: 0.7651 | Pearson: 0.7313 | MAE: 0.6656 | RMSE: 0.8878 | R²: 0.5173
[Epoch 025] Phase: train | Loss: 

In [22]:
metrics_df_base.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_base_optim.csv', index=False)

test_metrics_df_base = pd.DataFrame({'base_optim_test_metrics':test_metrics_base})
test_metrics_df_base.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_base_optim.csv', index=False)
test_metrics_df_base

Unnamed: 0,base_optim_test_metrics
loss,0.961744
spearman_corr,0.706688
pearson_corr,0.6602
mae,0.754939
rmse,0.980686
r2,0.411026


In [23]:
metrics_df_kim.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_kim.csv', index=False)

test_metrics_df_kim = pd.DataFrame({'kim_test_metrics':test_metrics_kim})
test_metrics_df_kim.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_kim.csv', index=False)
test_metrics_df_kim

Unnamed: 0,kim_test_metrics
loss,1.128748
spearman_corr,0.628702
pearson_corr,0.559957
mae,0.779216
rmse,1.062425
r2,0.300161


In [24]:
metrics_df_mine.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_mine.csv', index=False)

test_metrics_df_mine = pd.DataFrame({'mine_test_metrics':test_metrics_mine})
test_metrics_df_mine.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_mine.csv', index=False)
test_metrics_df_mine

Unnamed: 0,mine_test_metrics
loss,0.795399
spearman_corr,0.759828
pearson_corr,0.727155
mae,0.671177
rmse,0.891851
r2,0.512896


In [25]:
torch.save(model_kim.state_dict(), '../res/models/model_base_optim.pth')
torch.save(model_kim.state_dict(), '../res/models/model_kim.pth')
torch.save(model_mine.state_dict(), '../res/models/model_mine.pth')

### Interpretability

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_mine = MLP(next(iter(train_loader_mine))[0].shape[1], HIDDEN_DIMS, DROPOUT).to(device)
model_mine.load_state_dict(torch.load('../res/models/model_mine.pth', map_location=device))
model_mine.eval()  

MLP(
  (model): Sequential(
    (0): Linear(in_features=60, out_features=8, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.25, inplace=False)
    (6): Linear(in_features=8, out_features=1, bias=True)
  )
)

In [25]:
import shap
def model_predict(x):
    model_mine.eval()
    with torch.no_grad():
        x_tensor = torch.tensor(x, dtype=torch.float32).to(device)
        return model_mine(x_tensor).cpu().numpy()
        
X = test_dataset_mine.data

explainer = shap.Explainer(model_predict, X)
shap_values = explainer(X)
shap_values.feature_names = X_test_mine.columns.tolist()

In [27]:
plt.figure(figsize=(10, 20))
shap.summary_plot(shap_values, X, feature_names=X_test_mine.columns,  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_summary.png", dpi=300, bbox_inches='tight')  
plt.close()

In [29]:
plt.figure(figsize=(10, 20))
shap.plots.beeswarm(shap_values,  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_beeswarm.png", dpi=300, bbox_inches='tight')  
plt.close()

In [30]:
plt.figure(figsize=(10, 20))
shap.plots.bar(shap_values,  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_bar.png", dpi=300, bbox_inches='tight')  
plt.close()

In [31]:
plt.figure(figsize=(10, 20))
shap.plots.violin(shap_values,  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_violoin.png", dpi=300, bbox_inches='tight')  
plt.close()

In [32]:
plt.figure(figsize=(20, 10))
shap.plots.heatmap(shap_values,  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_heatmap.png", dpi=300, bbox_inches='tight')  
plt.close()

In [33]:
# Visualize SHAP values for a sample
plt.figure(figsize=(10, 20))
shap.plots.waterfall(shap_values[3],  max_display=15, show=False)
plt.savefig("../res/plots/shap/shap_waterfall.png", dpi=300, bbox_inches='tight')  
plt.close()