In [1]:
import pickle
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings
from sklearn.impute import KNNImputer

In [2]:
warnings.filterwarnings('ignore')

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(1234)

In [3]:
df_unq = pd.read_csv('../data/tcga_annotated_clean4ML.tsv.gz', sep='\t')
df_unq

Unnamed: 0,CHROM,POS,REF,ALT,Transcript_ID,HGVSc,KIM_PTC_to_start_codon,KIM_upstream_exon_count,KIM_downstream_exon_count,KIM_last_exon,...,phyloP100way_vertebrate,phyloP17way_primate,phyloP470way_mammalian,phastCons100way_vertebrate,phastCons17way_primate,phastCons470way_mammalian,gnomad41_genome_AF,gnomad41_exome_AF,LoF_HC,single_exon
0,1,944753,C,T,ENST00000327044,ENST00000327044:c.2191C>T,2193,18,0,1,...,1.489,0.654,,0.095,0.022,0.993,0.000000,6.906000e-07,1,0
1,1,952113,G,A,ENST00000327044,ENST00000327044:c.1218G>A,1218,10,8,0,...,7.149,0.549,,1.000,0.961,1.000,0.000000,6.843000e-07,1,0
2,1,1255304,G,T,ENST00000349431,ENST00000349431:c.679G>T,681,6,0,1,...,7.448,0.599,7.591,1.000,0.996,1.000,0.000000,0.000000e+00,1,0
3,1,1338573,G,T,ENST00000378888,ENST00000378888:c.1288G>T,1290,11,3,0,...,7.850,0.596,7.673,1.000,0.061,1.000,0.000000,0.000000e+00,1,0
4,1,1387314,C,T,ENST00000400809,ENST00000400809:c.1480C>T,1482,10,0,1,...,0.468,0.656,4.959,0.998,0.372,1.000,0.000007,4.789000e-06,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4013,X,152920717,C,T,ENST00000370268,ENST00000370268:c.622C>T,624,8,14,0,...,1.169,0.599,,0.365,0.530,1.000,0.000000,0.000000e+00,1,0
4014,X,153650171,G,T,ENST00000342782,ENST00000342782:c.1021G>T,1023,3,0,1,...,6.742,0.672,,1.000,0.869,1.000,0.000000,0.000000e+00,0,0
4015,X,154030948,C,T,ENST00000303391,ENST00000303391:c.880C>T,882,3,0,1,...,3.156,0.658,5.360,1.000,0.982,1.000,0.000000,0.000000e+00,1,0
4016,X,154354015,C,A,ENST00000369850,ENST00000369850:c.5586C>A,5586,34,13,0,...,-0.708,-3.409,-10.746,0.000,0.584,0.000,0.000000,0.000000e+00,1,0


In [4]:
prot_feat_df = pd.read_csv('../data/protein_AA_features.tsv.gz', sep='\t')
prot_feat_df.rename(columns={'MANE-Select':'Transcript_ID'}, inplace=True)
prot_feat_df.drop(columns=['Unnamed: 0', 'Entry'], axis=1, inplace=True)
prot_feat_df

Unnamed: 0,Transcript_ID,actual_or_pred_ACT_SITE,annotation_actual_ACT_SITE,annotation_pred_ACT_SITE,actual_or_pred_BINDING,annotation_actual_BINDING,annotation_pred_BINDING,actual_or_pred_COILED,annotation_actual_COILED,annotation_pred_COILED,...,annotation_pred_TRANSIT,actual_or_pred_TRANSMEM,annotation_actual_TRANSMEM,annotation_pred_TRANSMEM,actual_or_pred_TURN,annotation_actual_TURN,annotation_pred_TURN,actual_or_pred_ZN_FING,annotation_actual_ZN_FING,annotation_pred_ZN_FING
0,ENST00000436697,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000001100100000...,...,1111000000000000000000000000000000000000000000...,actual,0000111111111111111111111000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
1,ENST00000709217,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
2,ENST00000374922,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
3,ENST00000637218,actual,0000000000000000000000000000000000000000000000...,,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111000000000000000000000000...,actual,0000000000000000000000000000000000000000000000...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
4,ENST00000469902,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,actual,0000000000000000000000000000000000001111111111...,,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18412,ENST00000295896,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111100000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18413,ENST00000295898,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111011001000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18414,ENST00000306862,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...
18415,ENST00000450660,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,...,1111111111111111111111100000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...,pred,,0000000000000000000000000000000000000000000000...


In [5]:
# Calculate the proportion of lost features upon stop-gain
df_unq = df_unq.merge(prot_feat_df, how='left')

prot_feat_list = ['ACT_SITE', 'BINDING', 'COILED', 'COMPBIAS', 'DISULFID', 'BINDING', 'DOMAIN', 
                 'HELIX', 'MOD_RES', 'MOTIF', 'PROPEP', 'REGION', 'REPEAT', 'SIGNAL', 'STRAND',
                 'TOPO_DOM', 'TRANSIT', 'TRANSMEM', 'TURN', 'ZN_FING']

for prot_feat in prot_feat_list:
    df_unq[f'prop_lost_{prot_feat}'] = None

for idx, row in df_unq.iterrows():
    
    for prot_feat_name in prot_feat_list:
        temp_pred_or_actual = row[f'actual_or_pred_{prot_feat_name}']
        
        if not pd.isna(temp_pred_or_actual):
            
            prot_feat_binary = row[f'annotation_{temp_pred_or_actual}_{prot_feat_name}']
            
            if row['Protein_position'] <= len(prot_feat_binary):
                prot_feat_binary_lost = prot_feat_binary[row['Protein_position']:]
                df_unq.loc[idx, f'prop_lost_{prot_feat_name}'] = prot_feat_binary_lost.count('1') / len(prot_feat_binary)

prot_lost_feat_list = [f'prop_lost_{x}' for x in prot_feat_list]

In [6]:
def apply_baseline_rules(df, exon_len_threshold=407, penultimate_threshold=55, start_threshold=100):

    # long exon rule
    df['long_exon'] = 0
    df.loc[df['current_exon_len'] > exon_len_threshold, 'long_exon'] = 1

    # penultimate exon rule
    df['penultimate_flag'] = 0
    df.loc[
        (df['DIST_FROM_LAST_EXON'] < penultimate_threshold) &
        (df['DIST_FROM_LAST_EXON'] >= 0), 
        'penultimate_flag'
    ] = 1

    # close to start rule
    df['close_to_start'] = 0
    df.loc[df['CDS_position'] < start_threshold, 'close_to_start'] = 1

    return df


In [7]:
# last exon rule
df_unq['last_exon'] = 0
df_unq.loc[df_unq['current_exon_number']==df_unq['total_exon_numbers'], 'last_exon'] = 1

df_unq = apply_baseline_rules(df_unq, exon_len_threshold=355, penultimate_threshold=49, start_threshold=120)
df_unq

Unnamed: 0,CHROM,POS,REF,ALT,Transcript_ID,HGVSc,KIM_PTC_to_start_codon,KIM_upstream_exon_count,KIM_downstream_exon_count,KIM_last_exon,...,prop_lost_STRAND,prop_lost_TOPO_DOM,prop_lost_TRANSIT,prop_lost_TRANSMEM,prop_lost_TURN,prop_lost_ZN_FING,last_exon,long_exon,penultimate_flag,close_to_start
0,1,944753,C,T,ENST00000327044,ENST00000327044:c.2191C>T,2193,18,0,1,...,0.0,0.024032,0.0,0.0,0.0,0.0,1,1,0,0
1,1,952113,G,A,ENST00000327044,ENST00000327044:c.1218G>A,1218,10,8,0,...,0.0,0.457944,0.0,0.05474,0.0,0.0,0,0,0,0
2,1,1255304,G,T,ENST00000349431,ENST00000349431:c.679G>T,681,6,0,1,...,0.0,0.046332,0.0,0.07722,0.0,0.0,1,1,0,0
3,1,1338573,G,T,ENST00000378888,ENST00000378888:c.1288G>T,1290,11,3,0,...,0.0,0.381295,0.0,0.020144,0.0,0.0,0,0,0,0
4,1,1387314,C,T,ENST00000400809,ENST00000400809:c.1480C>T,1482,10,0,1,...,0.0,0.05,0.0,0.0,0.0,0.0,1,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4013,X,152920717,C,T,ENST00000370268,ENST00000370268:c.622C>T,624,8,14,0,...,,,,,,,0,0,0,0
4014,X,153650171,G,T,ENST00000342782,ENST00000342782:c.1021G>T,1023,3,0,1,...,0.0,0.111979,0.0,0.0,0.0,0.0,1,1,0,0
4015,X,154030948,C,T,ENST00000303391,ENST00000303391:c.880C>T,882,3,0,1,...,,,,,,,1,1,0,0
4016,X,154354015,C,A,ENST00000369850,ENST00000369850:c.5586C>A,5586,34,13,0,...,0.144314,0.296562,0.0,0.0,0.005667,0.0,0,0,0,0


In [8]:
KIM_features = ['KIM_PTC_to_start_codon', 'KIM_upstream_exon_count',
       'KIM_downstream_exon_count', 'KIM_last_exon', 'KIM_50nt_to_last_EJ',
       'KIM_PTC_exon_length', 'KIM_PTC_to_intron', 'KIM_dist_to_stop_codon',
       'KIM_mRNA_half_life', 'KIM_AF', 'KIM_LOEUF', 'KIM_5UTR_length',
       'KIM_3UTR_length', 'KIM_Transcript_length']

baseline_features = ['last_exon', 'long_exon', 'penultimate_flag', 'close_to_start'] 

my_features = ['last_exon', 'long_exon', 'penultimate_flag', 'close_to_start', 'hl',
       'mrl', 'cDNA_position', 'PERCENTILE', 'GERP_DIST', 'BP_DIST',
       'utr5_len', 'current_exon_len', 'VEST4_score', 'CADD_phred',
       'phyloP100way_vertebrate', 'dn_ds', 'abundance', 'shet',
       'lof.oe_ci.upper', 'lof.pRec', 'UTR5_GC', 'connectedness',
       'prop_lost_DOMAIN', 'prop_lost_HELIX', 'prop_lost_REGION',
       'prop_lost_TOPO_DOM', 'LoF_HC', 'NearestExonJB_dist', 'TF', 'tau',
       'phyloP17way_primate', 'phyloP470way_mammalian',
       'phastCons100way_vertebrate', 'phastCons17way_primate',
       'phastCons470way_mammalian', 'mis.z_score', 'syn.z_score', 'lof.pNull',
       'exp_var', 'utr3_len', 'total_exons_len', 'fathmm-XF_coding_score',
       'Nucleus', 'Cytosol', 'Cytoplasm', 'Ribosome', 'Membrane',
       'Endoplasmic_reticulum'] # selected based on ML_feature_selection notebook


### Splits

In [9]:
'''test_var_ids = df_unq[df_unq['CHROM'].isin(['20', '21', '22'])]['HGVSc'].values.tolist()
test_var_ids = list(set(test_var_ids))
test_var_ids[0:5], len(test_var_ids)'''

"test_var_ids = df_unq[df_unq['CHROM'].isin(['20', '21', '22'])]['HGVSc'].values.tolist()\ntest_var_ids = list(set(test_var_ids))\ntest_var_ids[0:5], len(test_var_ids)"

In [10]:
'''val_var_ids = df_unq[df_unq['CHROM']=='19']['HGVSc'].values.tolist()
val_var_ids = list(set(val_var_ids))
val_var_ids[0:5], len(val_var_ids)'''

"val_var_ids = df_unq[df_unq['CHROM']=='19']['HGVSc'].values.tolist()\nval_var_ids = list(set(val_var_ids))\nval_var_ids[0:5], len(val_var_ids)"

In [11]:
'''train_var_ids = list(set(df_unq.HGVSc.values.tolist()) - set(val_var_ids) - set(test_var_ids))
train_var_ids[0:5], len(train_var_ids)'''

'train_var_ids = list(set(df_unq.HGVSc.values.tolist()) - set(val_var_ids) - set(test_var_ids))\ntrain_var_ids[0:5], len(train_var_ids)'

In [12]:
from sklearn.model_selection import train_test_split

# Step 1: Group rows by Transcript_ID
groups = df_unq.groupby('Transcript_ID')

# Step 2: Aggregate a representative NMD_efficiency per Transcript_ID (e.g., mean)
group_data = groups['NMD_efficiency'].mean().reset_index()

# Step 3: Stratify by binning NMD_efficiency into quantiles
group_data['strata'] = pd.qcut(group_data['NMD_efficiency'], q=5, labels=False, duplicates='drop')

# Step 4: Split into train, temp (test+val) using stratification
train_ids, temp_ids = train_test_split(
    group_data['Transcript_ID'], test_size=0.1, stratify=group_data['strata'], random_state=42
)

# Step 5: Split temp into test and validation
test_ids, val_ids = train_test_split(
    temp_ids, test_size=0.5, stratify=group_data.loc[group_data['Transcript_ID'].isin(temp_ids), 'strata'], random_state=42
)

# Step 6: Assign the corresponding rows to each split
train_df = df_unq[df_unq['Transcript_ID'].isin(train_ids)]
test_df = df_unq[df_unq['Transcript_ID'].isin(test_ids)]
val_df = df_unq[df_unq['Transcript_ID'].isin(val_ids)]

### loaders

In [13]:
'''X_train_base = df_unq[df_unq['HGVSc'].isin(train_var_ids)][baseline_features]
X_train_kim = df_unq[df_unq['HGVSc'].isin(train_var_ids)][KIM_features]
X_train_mine = df_unq[df_unq['HGVSc'].isin(train_var_ids)][my_features]
y_train = df_unq[df_unq['HGVSc'].isin(train_var_ids)]['NMD_efficiency']

X_test_base = df_unq[df_unq['HGVSc'].isin(test_var_ids)][baseline_features]
X_test_kim = df_unq[df_unq['HGVSc'].isin(test_var_ids)][KIM_features]
X_test_mine = df_unq[df_unq['HGVSc'].isin(test_var_ids)][my_features]
y_test = df_unq[df_unq['HGVSc'].isin(test_var_ids)]['NMD_efficiency']

X_val_base = df_unq[df_unq['HGVSc'].isin(val_var_ids)][baseline_features]
X_val_kim = df_unq[df_unq['HGVSc'].isin(val_var_ids)][KIM_features]
X_val_mine = df_unq[df_unq['HGVSc'].isin(val_var_ids)][my_features]
y_val = df_unq[df_unq['HGVSc'].isin(val_var_ids)]['NMD_efficiency']'''

"X_train_base = df_unq[df_unq['HGVSc'].isin(train_var_ids)][baseline_features]\nX_train_kim = df_unq[df_unq['HGVSc'].isin(train_var_ids)][KIM_features]\nX_train_mine = df_unq[df_unq['HGVSc'].isin(train_var_ids)][my_features]\ny_train = df_unq[df_unq['HGVSc'].isin(train_var_ids)]['NMD_efficiency']\n\nX_test_base = df_unq[df_unq['HGVSc'].isin(test_var_ids)][baseline_features]\nX_test_kim = df_unq[df_unq['HGVSc'].isin(test_var_ids)][KIM_features]\nX_test_mine = df_unq[df_unq['HGVSc'].isin(test_var_ids)][my_features]\ny_test = df_unq[df_unq['HGVSc'].isin(test_var_ids)]['NMD_efficiency']\n\nX_val_base = df_unq[df_unq['HGVSc'].isin(val_var_ids)][baseline_features]\nX_val_kim = df_unq[df_unq['HGVSc'].isin(val_var_ids)][KIM_features]\nX_val_mine = df_unq[df_unq['HGVSc'].isin(val_var_ids)][my_features]\ny_val = df_unq[df_unq['HGVSc'].isin(val_var_ids)]['NMD_efficiency']"

In [14]:
X_train_base = train_df[baseline_features]
X_train_kim = train_df[KIM_features]
X_train_mine = train_df[my_features]
y_train = train_df['NMD_efficiency']

X_test_base = test_df[baseline_features]
X_test_kim = test_df[KIM_features]
X_test_mine = test_df[my_features]
y_test = test_df['NMD_efficiency']

X_val_base = val_df[baseline_features]
X_val_kim = val_df[KIM_features]
X_val_mine = val_df[my_features]
y_val = val_df['NMD_efficiency']

In [15]:
def correct_dtype(X):
    X = X.apply(pd.to_numeric, errors="coerce")
    return X

X_train_mine = correct_dtype(X_train_mine)
X_test_mine = correct_dtype(X_test_mine)
X_val_mine = correct_dtype(X_val_mine)

In [16]:
train_feature_means_base = X_train_base.mean()
train_feature_means_kim = X_train_kim.mean()
train_feature_means_mine = X_train_mine.mean()

train_feature_stds_base = X_train_base.std()
train_feature_stds_kim = X_train_kim.std()
train_feature_stds_mine = X_train_mine.std()

In [17]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, feature_means=None, feature_stds=None):
        """
        Args:
            data (pd.DataFrame): Feature matrix with possible NaNs.
            targets (pd.Series or torch.Tensor): Target values.
            feature_means (pd.Series): Feature means calculated from the training dataset.
            feature_stds (pd.Series): Feature standard deviations from the training dataset.
        """
        
        if feature_means is not None:
            # Fill missing values using the training feature means
            data = data.fillna(feature_means)
        
        if feature_means is not None and feature_stds is not None:
            # Normalize using training statistics
            data = (data - feature_means) / (feature_stds + 1e-8)  # Add epsilon to prevent division by zero
        
        self.data = data.values
        self.targets = targets.values if isinstance(targets, pd.Series) else targets
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx], dtype=torch.float32)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        return x, y


In [18]:
# Create datasets
train_dataset_base = CustomDataset(X_train_base, y_train, train_feature_means_base, train_feature_stds_base)
test_dataset_base = CustomDataset(X_test_base, y_test, train_feature_means_base, train_feature_stds_base)
val_dataset_base = CustomDataset(X_val_base, y_val, train_feature_means_base, train_feature_stds_base)

train_dataset_kim = CustomDataset(X_train_kim, y_train, train_feature_means_kim, train_feature_stds_kim)
test_dataset_kim = CustomDataset(X_test_kim, y_test, train_feature_means_kim, train_feature_stds_kim)
val_dataset_kim = CustomDataset(X_val_kim, y_val, train_feature_means_kim, train_feature_stds_kim)

train_dataset_mine = CustomDataset(X_train_mine, y_train, train_feature_means_mine, train_feature_stds_mine)
test_dataset_mine = CustomDataset(X_test_mine, y_test, train_feature_means_mine, train_feature_stds_mine)
val_dataset_mine = CustomDataset(X_val_mine, y_val, train_feature_means_mine, train_feature_stds_mine)

# Create dataloaders
batch_size = 32
train_loader_base = DataLoader(train_dataset_base, batch_size=batch_size, shuffle=True)
test_loader_base = DataLoader(test_dataset_base, batch_size=batch_size, shuffle=False)
val_loader_base = DataLoader(val_dataset_base, batch_size=batch_size, shuffle=False)

train_loader_kim = DataLoader(train_dataset_kim, batch_size=batch_size, shuffle=True)
test_loader_kim = DataLoader(test_dataset_kim, batch_size=batch_size, shuffle=False)
val_loader_kim = DataLoader(val_dataset_kim, batch_size=batch_size, shuffle=False)

train_loader_mine = DataLoader(train_dataset_mine, batch_size=batch_size, shuffle=True)
test_loader_mine = DataLoader(test_dataset_mine, batch_size=batch_size, shuffle=False)
val_loader_mine = DataLoader(val_dataset_mine, batch_size=batch_size, shuffle=False)

### train MLP

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import spearmanr
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import pandas as pd
import numpy as np
import torch.nn.functional as F

# Hyperparameters
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 5e-4
HIDDEN_DIMS = [8, 8]
DROPOUT = 0.2 #0.25
N_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
TRANSFORMATION = 'none'  # Options: 'z-score', 'min-max', 'log', 'none'

# Evaluation Metrics
def evaluate_regression_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy().flatten()
    y_pred = y_pred.cpu().numpy().flatten()
    loss = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(loss)
    r2 = r2_score(y_true, y_pred)
    
    if np.std(y_true) == 0 or np.std(y_pred) == 0:
        spearman_corr = np.nan
    else:
        try:
            spearman_corr, _ = spearmanr(y_true, y_pred)
        except ValueError:
            spearman_corr = np.nan
    
    return {
        'loss': loss,
        'spearman_corr': spearman_corr,
        'mae': mae,
        'rmse': rmse,
        'r2': r2
    }



class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, dropout):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim, dtype=torch.float32))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, 1, dtype=torch.float32))
        self.model = nn.Sequential(*layers)
        self.dropout = dropout  # Store dropout value

    def forward(self, x):
        for layer in self.model[:-1]:  # Apply dropout before the last layer
            x = layer(x)
            if self.training:  # Dropout only during training
                x = F.dropout(x, p=self.dropout, training=True)
        return self.model[-1](x)  # Last layer without dropout


# Transformation Functions
def apply_transformation(y, transformation, stats):
    if transformation == 'log':
        return torch.log1p(torch.clamp(y - stats['min'] + 1, min=1e-8))
    elif transformation == 'z-score':
        return (y - stats['mean']) / (stats['std'] + 1e-8)
    elif transformation == 'min-max':
        return (y - stats['min']) / (stats['max'] - stats['min'] + 1e-8)
    return y

def inverse_transformation(y, stats, transformation):
    if transformation == 'log':
        return torch.expm1(y) + stats['min'] - 1
    elif transformation == 'z-score':
        return y * stats['std'] + stats['mean']
    elif transformation == 'min-max':
        return y * (stats['max'] - stats['min']) + stats['min']
    return y

# Calculate Transformation Statistics
def calculate_transformation_stats(dataset, transformation):
    if transformation == 'none':
        return {}
    
    y = torch.cat([targets for _, targets in dataset], dim=0)
    stats = {}
    if transformation == 'log':
        stats['min'] = y.min()
    elif transformation == 'z-score':
        stats['mean'] = y.mean()
        stats['std'] = y.std()
    elif transformation == 'min-max':
        stats['min'] = y.min()
        stats['max'] = y.max()
    return stats

# Training Function
def train_mlp(train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_dim = next(iter(train_loader))[0].shape[1]
    model = MLP(input_dim, HIDDEN_DIMS, DROPOUT).to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Calculate global transformation stats
    transformation_stats = calculate_transformation_stats(train_loader, TRANSFORMATION)
    
    metrics = {'epoch': [], 'phase': [], 'loss': [], 'spearman_corr': [], 'mae': [], 'rmse': [], 'r2': []}
    best_val_loss = float('inf')
    patience_counter = 0
    
    best_model_state = None
    
    for epoch in range(1, N_EPOCHS + 1):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader
            
            all_preds, all_targets = [], []
            running_loss = 0.0
            
            with torch.set_grad_enabled(phase == 'train'):
                for inputs, targets in loader:
                    inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
                    original_targets = targets.clone()
                    
                    if TRANSFORMATION != 'none':
                        targets = apply_transformation(targets, TRANSFORMATION, transformation_stats)
                    
                    optimizer.zero_grad()
                    outputs = model(inputs).squeeze()
                    loss = criterion(outputs, targets)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    running_loss += loss.item() * inputs.size(0)
                    all_preds.append(outputs.detach())
                    all_targets.append(original_targets.detach())
            
            epoch_loss = running_loss / len(loader.dataset)
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            
            if TRANSFORMATION != 'none':
                all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
            
            epoch_metrics = evaluate_regression_metrics(all_targets, all_preds)
            epoch_metrics['loss'] = epoch_loss
            
            # Log metrics
            metrics['epoch'].append(epoch)
            metrics['phase'].append(phase)
            for key, value in epoch_metrics.items():
                metrics[key].append(value)
            
            # Early stopping logic
            if phase == 'val':
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    patience_counter = 0
                    best_model_state = model.state_dict()
                else:
                    patience_counter += 1
                    if patience_counter >= EARLY_STOPPING_PATIENCE:
                        print(f"Early stopping at epoch {epoch} with best validation loss: {best_val_loss:.4f}")
                        model.load_state_dict(best_model_state)
                        metrics_df = pd.DataFrame(metrics)
                        return metrics_df, evaluate_regression_metrics(all_targets, all_preds), model
            
            # Print metrics every 5 epochs for both phases
            if epoch % 5 == 0:
                print(
                    f"[Epoch {epoch:03d}] Phase: {phase:5s} | "
                    f"Loss: {epoch_metrics['loss']:.4f} | "
                    f"Spearman: {epoch_metrics['spearman_corr']:.4f} | "
                    f"MAE: {epoch_metrics['mae']:.4f} | "
                    f"RMSE: {epoch_metrics['rmse']:.4f} | "
                    f"R²: {epoch_metrics['r2']:.4f}"
                )
    
    # Load the best model before final testing
    model.load_state_dict(best_model_state)
    
    # Final Test Evaluation
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, dtype=torch.float32), targets.to(device, dtype=torch.float32)
            outputs = model(inputs).squeeze()
            all_preds.append(outputs.detach())
            all_targets.append(targets.detach())
    
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)
    
    if TRANSFORMATION != 'none':
        all_preds = inverse_transformation(all_preds, transformation_stats, TRANSFORMATION)
    
    test_metrics = evaluate_regression_metrics(all_targets, all_preds)
    print(
        "\n📊 Final Test Metrics | "
        f"Loss: {test_metrics['loss']:.4f} | "
        f"Spearman: {test_metrics['spearman_corr']:.4f} | "
        f"MAE: {test_metrics['mae']:.4f} | "
        f"RMSE: {test_metrics['rmse']:.4f} | "
        f"R²: {test_metrics['r2']:.4f}"
    )
    
    metrics_df = pd.DataFrame(metrics)
    return metrics_df, test_metrics, model


In [20]:
# train and evaluate
print("START TRAINING")
metrics_df_base, test_metrics_base, model_base = train_mlp(train_loader_base, val_loader_base, test_loader_base)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.5092 | Spearman: 0.3592 | MAE: 0.9589 | RMSE: 1.2285 | R²: 0.0451
[Epoch 005] Phase: val   | Loss: 1.2794 | Spearman: 0.6560 | MAE: 0.8708 | RMSE: 1.1311 | R²: 0.2454
[Epoch 010] Phase: train | Loss: 1.3439 | Spearman: 0.4954 | MAE: 0.8799 | RMSE: 1.1593 | R²: 0.1497
[Epoch 010] Phase: val   | Loss: 1.1603 | Spearman: 0.6570 | MAE: 0.8087 | RMSE: 1.0772 | R²: 0.3156
[Epoch 015] Phase: train | Loss: 1.2704 | Spearman: 0.5204 | MAE: 0.8595 | RMSE: 1.1271 | R²: 0.1962
[Epoch 015] Phase: val   | Loss: 1.1365 | Spearman: 0.6568 | MAE: 0.8011 | RMSE: 1.0661 | R²: 0.3297
[Epoch 020] Phase: train | Loss: 1.2393 | Spearman: 0.5279 | MAE: 0.8556 | RMSE: 1.1132 | R²: 0.2159
[Epoch 020] Phase: val   | Loss: 1.1306 | Spearman: 0.6609 | MAE: 0.8049 | RMSE: 1.0633 | R²: 0.3331
[Epoch 025] Phase: train | Loss: 1.1870 | Spearman: 0.5388 | MAE: 0.8468 | RMSE: 1.0895 | R²: 0.2490
[Epoch 025] Phase: val   | Loss: 1.1204 | Spearman: 0.6610 | MAE: 0.8075 | R

In [21]:
# train and evaluate
print("START TRAINING")
metrics_df_kim, test_metrics_kim, model_kim = train_mlp(train_loader_kim, val_loader_kim, test_loader_kim)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.5703 | Spearman: 0.3770 | MAE: 0.9694 | RMSE: 1.2531 | R²: 0.0065
[Epoch 005] Phase: val   | Loss: 1.5014 | Spearman: 0.5114 | MAE: 0.9446 | RMSE: 1.2253 | R²: 0.1144
[Epoch 010] Phase: train | Loss: 1.2938 | Spearman: 0.4724 | MAE: 0.8833 | RMSE: 1.1374 | R²: 0.1814
[Epoch 010] Phase: val   | Loss: 1.1652 | Spearman: 0.6272 | MAE: 0.8265 | RMSE: 1.0795 | R²: 0.3127
[Epoch 015] Phase: train | Loss: 1.2594 | Spearman: 0.4869 | MAE: 0.8723 | RMSE: 1.1223 | R²: 0.2031
[Epoch 015] Phase: val   | Loss: 1.1217 | Spearman: 0.6676 | MAE: 0.8140 | RMSE: 1.0591 | R²: 0.3384
[Epoch 020] Phase: train | Loss: 1.2362 | Spearman: 0.4936 | MAE: 0.8671 | RMSE: 1.1119 | R²: 0.2178
[Epoch 020] Phase: val   | Loss: 1.1086 | Spearman: 0.6749 | MAE: 0.8135 | RMSE: 1.0529 | R²: 0.3461
[Epoch 025] Phase: train | Loss: 1.2262 | Spearman: 0.5052 | MAE: 0.8628 | RMSE: 1.1074 | R²: 0.2241
[Epoch 025] Phase: val   | Loss: 1.1006 | Spearman: 0.6723 | MAE: 0.8106 | R

In [22]:
print("START TRAINING")
metrics_df_mine, test_metrics_mine, model_mine = train_mlp(train_loader_mine, val_loader_mine, test_loader_mine)

START TRAINING
[Epoch 005] Phase: train | Loss: 1.4047 | Spearman: 0.4535 | MAE: 0.9101 | RMSE: 1.1852 | R²: 0.1112
[Epoch 005] Phase: val   | Loss: 1.2140 | Spearman: 0.6040 | MAE: 0.8664 | RMSE: 1.1018 | R²: 0.2839
[Epoch 010] Phase: train | Loss: 1.2029 | Spearman: 0.5602 | MAE: 0.8417 | RMSE: 1.0968 | R²: 0.2389
[Epoch 010] Phase: val   | Loss: 1.0745 | Spearman: 0.6569 | MAE: 0.8194 | RMSE: 1.0366 | R²: 0.3662
[Epoch 015] Phase: train | Loss: 1.1556 | Spearman: 0.5880 | MAE: 0.8249 | RMSE: 1.0750 | R²: 0.2688
[Epoch 015] Phase: val   | Loss: 1.0375 | Spearman: 0.6727 | MAE: 0.8052 | RMSE: 1.0186 | R²: 0.3881
[Epoch 020] Phase: train | Loss: 1.1380 | Spearman: 0.6012 | MAE: 0.8237 | RMSE: 1.0668 | R²: 0.2800
[Epoch 020] Phase: val   | Loss: 1.0447 | Spearman: 0.6753 | MAE: 0.8055 | RMSE: 1.0221 | R²: 0.3838
[Epoch 025] Phase: train | Loss: 1.1205 | Spearman: 0.6054 | MAE: 0.8198 | RMSE: 1.0585 | R²: 0.2911
[Epoch 025] Phase: val   | Loss: 1.0155 | Spearman: 0.6822 | MAE: 0.8005 | R

In [23]:
metrics_df_base.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_base_optim.csv', index=False)

test_metrics_df_base = pd.DataFrame({'base_optim_test_metrics':test_metrics_base})
test_metrics_df_base.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_base_optim.csv', index=False)
test_metrics_df_base

Unnamed: 0,base_optim_test_metrics
loss,1.12526
spearman_corr,0.66164
mae,0.817617
rmse,1.060783
r2,0.336292


In [24]:
metrics_df_kim.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_kim.csv', index=False)

test_metrics_df_kim = pd.DataFrame({'kim_test_metrics':test_metrics_kim})
test_metrics_df_kim.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_kim.csv', index=False)
test_metrics_df_kim

Unnamed: 0,kim_test_metrics
loss,0.981379
spearman_corr,0.63489
mae,0.794323
rmse,0.990646
r2,0.347105


In [25]:
metrics_df_mine.to_csv('../res/metrics/base_and_kim_and_mine/per_epoch_metric_mine.csv', index=False)

test_metrics_df_mine = pd.DataFrame({'mine_test_metrics':test_metrics_mine})
test_metrics_df_mine.to_csv('../res/metrics/base_and_kim_and_mine/test_metrics_mine.csv', index=False)
test_metrics_df_mine

Unnamed: 0,mine_test_metrics
loss,1.023792
spearman_corr,0.684003
mae,0.803887
rmse,1.011826
r2,0.396141


In [26]:
torch.save(model_kim.state_dict(), '../res/models/model_base_optim.pth')
torch.save(model_kim.state_dict(), '../res/models/model_kim.pth')
torch.save(model_mine.state_dict(), '../res/models/model_mine.pth')

In [None]:
from xgboost import XGBRegressor
from sklearn.pipeline import Pipeline
from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import RandomizedSearchCV, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.stats import spearmanr, pearsonr
import numpy as np

def train_xgb_with_cv(X_train, X_test, y_train, y_test, scale_y=True):
    # Scale the target variable
    if scale_y:
        y_scaler = StandardScaler()
        y_train_scaled = y_scaler.fit_transform(y_train.reshape(-1, 1)).flatten()
        y_test_scaled = y_scaler.transform(y_test.reshape(-1, 1)).flatten()
    else:
        y_train_scaled = y_train
        y_test_scaled = y_test
    
    # Define the pipeline
    pipeline = Pipeline([
        ('imputer', KNNImputer(n_neighbors=5)),
        ('scaler', StandardScaler()),
        ('regressor', XGBRegressor(objective='reg:squarederror', random_state=42))
    ])
    
    # Define the parameter distribution
    param_dist = {
        'regressor__n_estimators': np.arange(50, 501, 50),
        'regressor__learning_rate': [0.01, 0.05, 0.1, 0.2],
        'regressor__max_depth': [1, 2, 3, 5, 7, 10, 20],
        'regressor__min_child_weight': [1, 3, 5, 10, 15, 20],
        'regressor__subsample': [0.6, 0.8, 1.0],
        'regressor__colsample_bytree': [0.6, 0.8, 1.0],
        'regressor__gamma': [0, 0.1, 0.2, 0.3],
        'regressor__reg_alpha': [0, 0.01, 0.1, 1, 5, 10],  # L1 regularization
        'regressor__reg_lambda': [0.1, 1, 5, 10]       # L2 regularization
    }
    
    # Perform randomized search with cross-validation
    cv = KFold(n_splits=5, shuffle=True, random_state=42)
    random_search = RandomizedSearchCV(pipeline, param_dist, n_iter=40, cv=cv, scoring='neg_mean_squared_error', n_jobs=-1, verbose=2, random_state=42)
    random_search.fit(X_train, y_train_scaled)
    
    # Evaluate on train data
    y_train_pred_scaled = random_search.best_estimator_.predict(X_train)
    if scale_y:
        y_train_pred = y_scaler.inverse_transform(y_train_pred_scaled.reshape(-1, 1)).flatten()
    else:
        y_train_pred = y_train_pred_scaled
    train_mse = mean_squared_error(y_train, y_train_pred)
    train_rmse = train_mse ** 0.5
    train_mae = mean_absolute_error(y_train, y_train_pred)
    train_r2 = r2_score(y_train, y_train_pred)
    train_spearman, _ = spearmanr(y_train, y_train_pred)
    train_pearson, _ = pearsonr(y_train, y_train_pred)
    
    # Evaluate on test data
    y_test_pred_scaled = random_search.best_estimator_.predict(X_test)
    if scale_y:
        y_test_pred = y_scaler.inverse_transform(y_test_pred_scaled.reshape(-1, 1)).flatten()
    else:
        y_test_pred = y_test_pred_scaled
    test_mse = mean_squared_error(y_test, y_test_pred)
    test_rmse = test_mse ** 0.5
    test_mae = mean_absolute_error(y_test, y_test_pred)
    test_r2 = r2_score(y_test, y_test_pred)
    test_spearman, _ = spearmanr(y_test, y_test_pred)
    test_pearson, _ = pearsonr(y_test, y_test_pred)
    
    print(f"Best Parameters: {random_search.best_params_}")
    print(f"Train Metrics: R2={train_r2:.4f}, RMSE={train_rmse:.4f}, MAE={train_mae:.4f}, Spearman={train_spearman:.4f}, Pearson={train_pearson:.4f}")
    print(f"Test Metrics: R2={test_r2:.4f}, RMSE={test_rmse:.4f}, MAE={test_mae:.4f}, Spearman={test_spearman:.4f}, Pearson={train_pearson:0.4f}")
    
    return random_search.best_estimator_, test_mse

    
train_xgb_with_cv(X_train_mine, X_test_mine, y_train.values, y_test.values, scale_y=False)