In [1]:
import pandas as pd
import numpy as np
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (add_peptide_bonds,
                                             add_hydrogen_bond_interactions,
                                             add_disulfide_interactions,
                                             add_ionic_interactions,
                                             add_aromatic_interactions,
                                             add_aromatic_sulphur_interactions,
                                             add_cation_pi_interactions
                                            )
from graphein.protein.features.nodes import (
    amino_acid_one_hot,
    expasy_protein_scale,
    hydrogen_bond_acceptor,
    hydrogen_bond_donor,
    meiler_embedding
)
from graphein.protein.features.nodes.geometry import add_sidechain_vector, add_beta_carbon_vector, add_sequence_neighbour_vector
import torch
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import torch_geometric
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
import pickle
import random
import warnings
import os
import h5py
import g2papi

In [2]:
def seed_everything(seed=1234):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()
warnings.filterwarnings("ignore")

In [3]:
af_db_path = "/work/gr-fe/databases/alpha_fold/human/"
prefix = "AF-"
suffix = "-F1-model_v4.pdb"

In [4]:
# define edge and annotation functions

edge_funcs = {"edge_construction_functions": [add_peptide_bonds,
                                              add_aromatic_interactions,
                                              add_hydrogen_bond_interactions,
                                              add_disulfide_interactions,
                                              add_ionic_interactions,
                                              add_aromatic_sulphur_interactions,
                                              add_cation_pi_interactions]}


all_node_metadata = {"node_metadata_functions" : [amino_acid_one_hot,
                                                 expasy_protein_scale,
                                                 hydrogen_bond_acceptor,
                                                 hydrogen_bond_donor,
                                                 meiler_embedding]}

#all_graph_metadata = {"graph_metadata_functions": [esm_residue_embedding]}  

#config = ProteinGraphConfig(**{**edge_funcs, **all_node_metadata, **all_graph_metadata}) 
config = ProteinGraphConfig(**{**edge_funcs, **all_node_metadata}) 

In [5]:
# read ground_truth dataframe

df = pd.read_csv("../data/DN_LOF_GOF_truth.tsv", sep="\t")
df

Unnamed: 0,HGNC,DN,LOF,GOF,UniprotEntry
0,AARS1,1,0,0,P49588
1,ABCA1,1,0,0,O95477
2,ACD,1,0,0,Q96AP0
3,ACTA1,1,0,0,P68133
4,ACTB,1,1,1,P60709
...,...,...,...,...,...
1271,ZFPM2,0,1,0,Q8WW38
1272,ZIC2,0,1,0,O95409
1273,ZMYM2,0,1,0,Q9UBW7
1274,ZMYND11,0,1,0,Q15326


In [6]:
# extract labels and uniprot ids

all_uniprot_ids = [] 
all_hgnc = []
all_labels = [] 

for index, row in df.iterrows():
    all_uniprot_ids.append(row["UniprotEntry"])
    all_hgnc.append(row["HGNC"])
    all_labels.append([row["DN"], row["LOF"], row["GOF"]])

In [7]:
# read extra features from describeProt

describe_prot = pd.read_csv("../data/describePROT_clean.tsv", sep="\t")
describe_prot_featnames = describe_prot.columns[(describe_prot.columns != "seqlength") & (describe_prot.columns != "UniprotEntry") & (describe_prot.columns != "SignalP_score")] #SignalP_score doesnt match the protein length
describe_prot

Unnamed: 0,UniprotEntry,ASAquick_normscore,DFLpredScore,DRNApredDNAscore,DRNApredRNAscore,DisoDNAscore,DisoPROscore,DisoRNAscore,MMseq2_conservation_score,MoRFchibiScore,PSIPRED_helix,PSIPRED_strand,SCRIBERscore,SignalP_score,flDPnn_score,seqlength,PTMbinary
0,A0A024R1R8,"0.747,0.578,0.611,0.572,0.581,0.494,0.366,0.52...","0.026,0.021,0.022,0.020,0.018,0.019,0.023,0.02...","0.036,0.092,0.094,0.123,0.070,0.081,0.080,0.09...","0.441,0.247,0.219,0.259,0.200,0.280,0.291,0.30...","0.478,0.444,0.516,0.472,0.430,0.489,0.541,0.50...","0.582,0.621,0.632,0.638,0.652,0.622,0.609,0.61...","0.034,0.029,0.024,0.024,0.024,0.024,0.023,0.01...","3.69,2.85,2.55,2.69,2.60,2.57,2.57,2.76,2.52,2...","0.835,0.839,0.848,0.862,0.859,0.859,0.871,0.89...","0,8,28,32,41,46,59,78,73,92,78,50,60,20,61,93,...","0,1,1,2,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...","0.949,0.884,0.764,0.823,0.763,0.447,0.292,0.25...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.96,0.94,0.93,0.92,0.89,0.89,0.93,0.90,0.84,0...",64,"0,1,1,0,0,0,0,5,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
1,A0A024RBG1,"0.695,0.392,0.521,0.209,0.444,0.421,0.455,0.40...","0.024,0.032,0.040,0.038,0.043,0.055,0.059,0.07...","0.099,0.090,0.134,0.082,0.236,0.132,0.199,0.36...","0.081,0.078,0.067,0.048,0.060,0.059,0.057,0.06...","0.172,0.142,0.135,0.145,0.110,0.116,0.094,0.12...","0.505,0.495,0.473,0.483,0.483,0.486,0.484,0.49...","0.060,0.057,0.055,0.053,0.056,0.048,0.046,0.05...","3.69,3.68,2.81,2.84,2.81,3.23,3.09,3.37,2.97,2...","0.653,0.663,0.677,0.673,0.675,0.673,0.669,0.66...","0,3,2,2,1,3,7,5,4,10,4,2,2,22,49,57,78,80,67,3...","0,4,7,8,3,3,3,7,19,24,39,27,9,4,1,1,2,4,11,32,...","0.278,0.234,0.214,0.239,0.288,0.276,0.251,0.29...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.35,0.38,0.42,0.46,0.45,0.45,0.45,0.42,0.38,0...",181,"0,0,6,0,6,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8,..."
2,A0A075B6H5,"0.694,0.501,0.365,0.286,0.237,0.290,0.280,0.20...","0.068,0.082,0.084,0.107,0.138,0.158,0.191,0.22...","0.181,0.109,0.309,0.136,0.190,0.342,0.386,0.13...","0.052,0.044,0.032,0.030,0.031,0.034,0.034,0.03...","0.060,0.069,0.076,0.081,0.088,0.086,0.095,0.10...","0.481,0.501,0.490,0.473,0.473,0.476,0.468,0.42...","0.038,0.030,0.037,0.030,0.031,0.035,0.030,0.03...","3.68,2.88,2.96,2.74,2.61,2.96,2.96,2.19,3.23,2...","0.500,0.485,0.511,0.541,0.541,0.546,0.550,0.55...","0,1,1,2,3,3,2,3,4,10,16,7,1,2,3,70,77,95,94,94...","0,30,59,73,82,77,55,16,10,7,3,2,2,3,3,1,1,1,1,...","0.377,0.330,0.265,0.216,0.230,0.296,0.278,0.22...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.33,0.32,0.33,0.28,0.33,0.30,0.35,0.38,0.40,0...",130,"0,0,2,0,0,2,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
3,A0A075B6H7,"0.721,0.579,0.358,0.389,0.238,0.276,0.085,0.08...","0.047,0.046,0.043,0.040,0.038,0.043,0.040,0.04...","0.152,0.203,0.157,0.115,0.130,0.130,0.099,0.09...","0.066,0.065,0.062,0.062,0.063,0.063,0.060,0.06...","0.072,0.077,0.081,0.048,0.052,0.053,0.054,0.05...","0.430,0.428,0.410,0.434,0.466,0.439,0.434,0.43...","0.042,0.027,0.018,0.022,0.020,0.024,0.023,0.01...","3.69,1.81,1.94,3.06,2.10,3.33,1.73,2.16,2.10,2...","0.546,0.539,0.545,0.549,0.559,0.571,0.587,0.58...","0,24,67,83,94,99,99,100,100,100,99,96,83,62,3,...","0,2,1,2,1,0,0,0,0,0,0,0,1,1,0,2,3,7,11,24,44,4...","0.146,0.148,0.107,0.170,0.141,0.178,0.136,0.17...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.14,0.13,0.12,0.07,0.10,0.07,0.08,0.08,0.11,0...",116,"0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
4,A0A075B6H8,"0.649,0.510,0.364,0.481,0.218,0.241,0.259,0.37...","0.060,0.054,0.067,0.061,0.056,0.052,0.048,0.05...","0.431,0.393,0.548,0.570,0.354,0.274,0.292,0.26...","0.038,0.038,0.040,0.037,0.036,0.036,0.036,0.03...","0.174,0.172,0.171,0.169,0.168,0.105,0.107,0.10...","0.252,0.282,0.286,0.287,0.311,0.357,0.381,0.35...","0.105,0.088,0.059,0.047,0.048,0.056,0.047,0.05...","3.68,2.92,3.67,2.93,1.84,3.23,2.11,3.35,2.04,2...","0.459,0.468,0.512,0.522,0.529,0.530,0.536,0.53...","0,10,29,48,82,91,97,99,99,99,99,100,99,99,95,8...","0,1,3,3,1,1,0,0,0,0,0,0,0,0,0,1,1,4,5,12,36,71...","0.395,0.278,0.208,0.353,0.186,0.211,0.164,0.25...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.26,0.17,0.18,0.13,0.12,0.07,0.09,0.09,0.09,0...",117,"0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20449,V9GZ13,"0.752,0.627,0.587,0.419,0.434,0.395,0.326,0.47...","0.070,0.089,0.095,0.084,0.076,0.080,0.081,0.09...","0.818,0.894,0.837,0.867,0.869,0.882,0.437,0.88...","0.020,0.020,0.020,0.020,0.020,0.020,0.020,0.02...","0.054,0.059,0.064,0.069,0.074,0.080,0.081,0.09...","0.147,0.147,0.172,0.176,0.182,0.171,0.170,0.17...","0.052,0.057,0.056,0.061,0.047,0.044,0.041,0.04...","3.68,2.83,3.16,3.36,2.84,4.34,1.60,2.95,2.30,3...","0.860,0.890,0.901,0.982,0.981,0.981,0.981,0.98...","0,4,15,31,63,80,83,83,87,75,70,65,80,80,64,70,...","0,2,3,4,2,2,2,2,1,2,9,25,13,8,12,10,8,13,37,24...","0.873,0.716,0.745,0.722,0.672,0.855,0.643,0.76...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.41,0.44,0.45,0.41,0.32,0.19,0.12,0.13,0.13,0...",50,"0,6,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
20450,W5XKT8,"0.590,0.263,0.221,0.247,0.248,0.137,0.144,0.32...","0.045,0.050,0.062,0.074,0.075,0.068,0.062,0.06...","0.119,0.088,0.078,0.082,0.103,0.097,0.119,0.32...","0.066,0.051,0.049,0.049,0.051,0.050,0.050,0.05...","0.185,0.199,0.198,0.195,0.192,0.191,0.189,0.18...","0.259,0.224,0.224,0.269,0.271,0.258,0.261,0.25...","0.302,0.303,0.292,0.298,0.214,0.161,0.166,0.14...","3.68,2.56,2.34,1.76,2.37,2.60,2.24,2.60,1.59,2...","0.530,0.534,0.533,0.525,0.515,0.505,0.454,0.42...","0,83,91,95,98,96,81,54,74,64,87,98,99,99,100,9...","0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...","0.193,0.137,0.173,0.192,0.117,0.175,0.151,0.16...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.15,0.16,0.14,0.14,0.14,0.13,0.12,0.11,0.11,0...",324,"0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."
20451,W6CW81,"0.703,0.524,0.455,0.454,0.260,0.369,0.272,0.16...","0.025,0.027,0.027,0.034,0.043,0.041,0.044,0.05...","0.257,0.144,0.442,0.246,0.349,0.418,0.124,0.11...","0.058,0.044,0.050,0.046,0.049,0.050,0.043,0.04...","0.156,0.144,0.158,0.158,0.150,0.122,0.121,0.11...","0.315,0.311,0.310,0.331,0.356,0.355,0.340,0.34...","0.022,0.026,0.017,0.018,0.019,0.021,0.018,0.01...","3.69,2.83,2.82,2.57,3.44,2.26,2.86,2.41,2.16,2...","0.548,0.555,0.586,0.602,0.602,0.611,0.618,0.62...","0,37,57,79,83,88,90,91,79,78,76,59,39,51,31,25...","0,0,1,2,2,2,3,7,14,7,13,17,12,5,3,2,1,0,0,0,0,...","0.247,0.219,0.167,0.181,0.205,0.161,0.161,0.15...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.41,0.36,0.32,0.20,0.26,0.26,0.19,0.19,0.15,0...",113,"0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,0,..."
20452,X5D2U9,"0.570,0.295,0.150,0.142,0.357,0.109,0.425,0.51...","0.077,0.092,0.100,0.075,0.077,0.086,0.081,0.09...","0.077,0.072,0.071,0.066,0.143,0.080,0.097,0.12...","0.147,0.116,0.110,0.073,0.093,0.077,0.114,0.09...","0.301,0.305,0.295,0.286,0.275,0.253,0.247,0.23...","0.386,0.373,0.355,0.304,0.291,0.305,0.308,0.31...","0.048,0.050,0.058,0.068,0.048,0.055,0.042,0.05...","3.67,2.52,3.68,2.26,3.79,2.37,3.17,2.13,2.58,2...","0.535,0.551,0.561,0.566,0.566,0.560,0.523,0.44...","0,1,1,2,1,1,2,4,19,25,31,21,29,20,18,28,24,52,...","0,52,75,84,83,42,8,2,4,9,27,41,49,67,65,65,69,...","0.119,0.120,0.112,0.128,0.118,0.142,0.214,0.15...","0.0000,0.0000,0.0000,0.0000,0.0000,0.0000,0.00...","0.34,0.33,0.33,0.25,0.24,0.21,0.22,0.17,0.11,0...",266,"0,0,8,0,6,0,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,..."


In [8]:
uniprot = pd.read_csv("../data/uniprot_all_human_proteins_annotated.txt", sep="\t")
uniprot_features = ['ACT_SITE', 'BINDING', 'DNA_BIND',
                   'TOPO_DOM', 'TRANSMEM',
                   'DISULFID',  'PROPEP', 'SIGNAL', 'TRANSIT',
                   'STRAND', 'HELIX',
                   'COILED', 'COMPBIAS', 'DOMAIN', 'REGION', 'REPEAT', 'ZN_FING']

# Initialize a new DataFrame to store the output
uniprot_clean = pd.DataFrame()
uniprot_clean['Entry'] = uniprot['Entry']  # Copy 'Entry' column Length
uniprot_clean['Length'] = uniprot['Length'] 

# Iterate over each feature and assign the correct value to the output
for feature in uniprot_features:
    actual_or_pred_col = f'actual_or_pred_{feature}'
    actual_annotation_col = f'annotation_actual_{feature}'
    pred_annotation_col = f'annotation_pred_{feature}'
    
    # Use `np.where` to conditionally select the 'actual' or 'pred' annotation
    uniprot_clean[feature] = uniprot.apply(
        lambda row: row[actual_annotation_col] if row[actual_or_pred_col] == 'actual' else row[pred_annotation_col], 
        axis=1
    )

# Display the resulting DataFrame
uniprot_clean

Unnamed: 0,Entry,Length,ACT_SITE,BINDING,DNA_BIND,TOPO_DOM,TRANSMEM,DISULFID,PROPEP,SIGNAL,TRANSIT,STRAND,HELIX,COILED,COMPBIAS,DOMAIN,REGION,REPEAT,ZN_FING
0,A0A087X1C5,515,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1100000000000000000000011111111111111111111111...,0011111111111111111111100000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111111000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000001...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000011...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
1,A0A0B4J2F0,54,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111000000000000000000000111111111111111111111...,0000111111111111111111111000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111110000000000000000000000000...,1100000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000001111111111111111111111111111111111111000...,0000000000000000000000000011111111111111111111...,0000000000000000000000000000000000000000000011...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000011111110000000000...,0000000000000001000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
2,A0A0B4J2F2,783,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000011111111100000...,0000000000000000000000000000000000000000000000...,1111111111111111111111111111111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000011111111100011111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000011111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
3,A0A0C5B5G6,16,0000000000000000,0000000000000000,0000000000000000,1111111111111111,0000000000000000,0000000000000000,0000000000000000,0000000000000000,0000000000000000,0000000011100000,0000000000000000,0000000000000000,0000000000000000,0000000000000000,0000000000000000,0000000000000000,0000000000000000
4,A0A0K2S4Q6,201,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000001111111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000001000...,0000000000000000000000000000000000000000000000...,1111111111111111111111110000000000000000000000...,1111111111111110000010000000000000000000000000...,0000000000000000000000000000111111000011111110...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000001111111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
20429,Q9UI54,55,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000001111111111...,0000000000000011111111111111111111100000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111111111111111111100000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
20430,Q9UI72,69,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111111111111111111111111111111...,1111111111000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111111100000000000000000000000...,1000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000111111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000111111111111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
20431,Q9Y3F1,56,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111000000000000000000000001111111111111111111...,0011001111111111111111111110000000000000000000...,0000000000000000000000000000100010000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111100000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...
20432,Q9Y6C7,94,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0110000000000000000000000000000000000000000000...,0000000000110011110011111111100111111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1111111111111111111111111110000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000010111111111111111111000000111111111111...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...


In [9]:
# conservation
phylop = pd.read_csv("../data/proteins_phylop_perresidue.tsv.gz", sep="\t")
phylop_features = ['phyloP100way_vertebrate', 'phyloP30way_mammalian']
phylop

Unnamed: 0,Entry,phyloP100way_vertebrate,phyloP30way_mammalian
0,A0A075B759,"1.6,0.1,1.2,1.4,-0.1,3.1,1.5,2.3,0.7,-0.4,-1.7...","0.3,0.3,-0.5,-1.5,-0.8,-0.3,-0.1,0,-0.1,-0.1,-..."
1,A0A075B767,"1.8,1.8,1.8,1.8,1.8,1.8,1.8,1.8,1.8,1.8,1.8,1....","0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3,0...."
2,A0A087WTH1,"3.1,2.4,2.8,1.8,4.1,2.7,2.2,1.8,2.3,0.3,1.9,1....","1.2,0.6,1.2,0.9,1.2,1.3,1.1,0.7,0.9,-0.1,1.1,0..."
3,A0A087WUV0,"0.6,0.8,-0.4,-0.9,0.5,0.6,2.5,-0.1,0.5,0.9,1.3...","0.9,0.7,0.8,-0.2,0.5,1,0.6,0,0.4,0.1,0.1,0.9,0..."
4,A0A087WV53,"3,2.3,3.1,0.9,1.7,0.6,0.1,2.1,2.1,0.2,-0.4,1,0...","1,0.9,0.7,0.9,0.9,0.9,0.1,0.8,0.7,0.5,0.3,0.5,..."
...,...,...,...
18174,Q9Y6Z7,"4.4,4.4,4.4,4.4,4.4,4.4,4.4,4.4,4.4,4.4,4.4,4....","1.2,1.2,1.2,1.2,1.2,1.2,1.2,1.2,1.2,1.2,1.2,1...."
18175,U3KPV4,"3.7,2.2,2.9,0.6,2.7,2.7,0.1,-0.4,7.4,2.8,0.4,2...","0.3,0.8,1,0.7,1.1,1.1,-0.1,0.4,1.2,1.1,-0.1,0...."
18176,W5XKT8,"1.1,1.2,-0.2,-0.2,-0.6,-0.9,-1.9,-0.9,-0.2,0.3...","1.2,0.6,-0.2,-0.7,-0.3,-0.7,0.2,-0.4,-0.7,-0.8..."
18177,X6R8D5,"-0.9,0.5,0.5,-0.3,1.2,-1.6,0.3,0.2,0.2,0.2,0,0...","0.3,-0.4,0.3,-0.2,0.8,-0.1,1.1,0.1,0.3,0.1,0.2..."


In [None]:
# read protein (graph) features
archs_df = pd.read_csv("../data/archs_protein_embeds.tsv.gz", sep="\t")
go_df = pd.read_csv("../data/go_protein_embeds.tsv.gz", sep="\t")
string_df = pd.read_csv("../data/string_protein_embeds.tsv.gz", sep="\t")
gnomad_df = pd.read_csv("../data/gnomadv4_constraints.tsv.gz", sep="\t")

In [None]:

# creat PyG objects
all_pyg_graphs = []
all_graphein_features = ["amino_acid_one_hot", "expasy", 'hbond_acceptors', 'hbond_donors', 'meiler', 'sidechain_vector', 'c_beta_vector', 'sequence_neighbour_vector_n_to_c']  # AAonehot 20, expasy = 61, hbond_acceptors=1, hbond_donors=1, meiler=7, sidechain_vector=3, c_beta_vector=3, sequence_neighbour_vector_n_to_c=3
#all_g2p_features = ["Accessible surface area (Å²)*", "Phi angle (degrees)*", "Psi angle (degrees)*", "ss_B", "ss_C", "ss_H"]

counter = 1
for uniprot_id, hgnc, label in zip(all_uniprot_ids, all_hgnc, all_labels):

    # remove those that are not in the describe_prot database
    if uniprot_id not in describe_prot["UniprotEntry"].values:
        continue
        
    # construct a networkx graph from AlphaFold predictions
    g = construct_graph(config=config, path=(af_db_path + prefix + uniprot_id + suffix))

    add_sidechain_vector(g)
    add_beta_carbon_vector(g)
    add_sequence_neighbour_vector(g)
    
    # convert to PyG object
    g2 = from_networkx(g)

    '''
    # ignore proteins with problematic length
    if len(g2.residue_name) != describe_prot[describe_prot["UniprotEntry"] == uniprot_id]["seqlength"].values[0]:
        continue

    if len(g2.residue_name) != uniprot_clean[uniprot_clean["Entry"] == uniprot_id]["Length"].values[0]:
        continue
    '''
    
    # add graphein features
    g2.x = g2[all_graphein_features[0]] 
    for feature in all_graphein_features[1:]:
        g2.x = torch.cat((g2.x, g2[feature]), dim=1)

    # add describe_prot features
    for feature in describe_prot_featnames:
        try:
            temp_feature = describe_prot[describe_prot["UniprotEntry"] == uniprot_id][feature].values[0]
            temp_num_list = [float(num) for num in temp_feature.split(',')]
            temp_num_tensor = torch.tensor(temp_num_list).reshape(-1, 1)
            g2.x = torch.cat((g2.x, temp_num_tensor), dim=1)
        except:
            g2.x = torch.cat((g2.x, torch.full((len(g2.residue_name), 1), float('nan'))), dim=1)

    # add uniprot features
    for feature in uniprot_features:
        try:
            temp_feature = uniprot_clean[uniprot_clean["Entry"] == uniprot_id][feature].values[0]
            temp_num_list = [int(char) for char in temp_feature]
            temp_num_tensor = torch.tensor(temp_num_list).reshape(-1, 1)
            g2.x = torch.cat((g2.x, temp_num_tensor), dim=1)
        except:
            g2.x = torch.cat((g2.x, torch.full((len(g2.residue_name), 1), float('nan'))), dim=1)

    # add phylop features
    for feature in phylop_features:
        try:
            temp_feature = phylop[phylop["Entry"] == uniprot_id][feature].values[0]
            temp_num_list = [float(num) for num in temp_feature.split(',')]
            temp_num_tensor = torch.tensor(temp_num_list).reshape(-1, 1)
            g2.x = torch.cat((g2.x, temp_num_tensor), dim=1)
        except:
            g2.x = torch.cat((g2.x, torch.full((len(g2.residue_name), 1), float('nan'))), dim=1)
    
    # add protT5 embeddings
    with h5py.File('../data/protT5_per_residue.h5', 'r') as f:
        embeds = f[uniprot_id][:]  

    '''
    if embeds.shape[0] != g2.x.shape[0]:
        continue
    '''
    
    try:
        g2.x = torch.cat((g2.x, torch.from_numpy(embeds)), dim=1)
    except:
        g2.x = torch.cat((g2.x, torch.full((len(g2.residue_name), 1024), float('nan'))), dim=1)
    
    
    '''
    # add G2P features
    g2p_features = g2papi.get_protein_features(geneName=hgnc, uniprotId=uniprot_id)
    g2p_features['ss_B'] = g2p_features['Secondary structure (DSSP 3-state)*'].apply(lambda x: 1 if x == 'B (strand)' else 0)
    g2p_features['ss_C'] = g2p_features['Secondary structure (DSSP 3-state)*'].apply(lambda x: 1 if x == 'C (loop/coil)' else 0)
    g2p_features['ss_H'] = g2p_features['Secondary structure (DSSP 3-state)*'].apply(lambda x: 1 if x == 'H (helix)' else 0)
    g2p_features2 = g2p_features[all_g2p_features]
    g2.x = torch.cat((g2.x, torch.tensor(g2p_features2.values, dtype=torch.float32)), dim=1) 
    '''

    # add protein (graph) features
    try:
        archs_embed = torch.tensor(archs_df[archs_df["uniprot_ids"] == uniprot_id].values[0][1:].tolist())
    except:
        archs_embed = torch.full((256,), float('nan'))

    try:
        go_embed = torch.tensor(go_df[go_df["uniprot_ids"] == uniprot_id].values[0][1:].tolist())
    except:
        go_embed = torch.full((256,), float('nan'))

    try:
        string_embed = torch.tensor(string_df[string_df["uniprot_ids"] == uniprot_id].values[0][1:].tolist())
    except:
        string_embed = torch.full((256,), float('nan'))

    try:
        gnomad_constraints = torch.tensor(gnomad_df[gnomad_df["uniprot_ids"] == uniprot_id].values[0][1:].tolist())
    except:
        gnomad_constraints = torch.full((6,), float('nan'))

    g2.u = torch.tensor([])
    g2.u = torch.cat((g2.u, archs_embed, go_embed, string_embed, gnomad_constraints))

    # add label
    g2.y = label

    all_pyg_graphs.append(g2)
    print(f"finished {counter} proteins")
    counter += 1


In [None]:
# create a simple version of all pyg graphs

all_pyg_graphs_simple = []

for pyg in all_pyg_graphs:
    
    simple_pyg = Data(
        x=pyg.x,               # Node features
        y=pyg.y,               # Labels
        batch=pyg.batch,       # Batch information
        edge_index=pyg.edge_index,  # Edge connections
        name = pyg.name,
        u=pyg.u,
        coords=pyg.coords
    )

    all_pyg_graphs_simple.append(simple_pyg)


In [None]:
# save all_pyg_graphs and all_pyg_graphs_simple
'''
with open('../res/pyg_graphs/all_pyg_graphs_simple.pkl', 'wb') as f:
    pickle.dump(all_pyg_graphs_simple, f)

with open('../res/pyg_graphs/all_pyg_graphs.pkl', 'wb') as f:
    pickle.dump(all_pyg_graphs, f)
'''

In [10]:
# load all_pyg_graphs_simple

with open('../res/pyg_graphs/all_pyg_graphs_simple.pkl', 'rb') as file:
    all_pyg_graphs_simple = pickle.load(file)

### Train test split + Normalization + Imputation

In [11]:
# squeeze graph freatures from 1*1*L to L
for graph in all_pyg_graphs_simple:
    graph.u = torch.squeeze(graph.u)

In [12]:
from sklearn.model_selection import train_test_split

def stratified_split(graphs, train_size=0.8, val_size=0.1, test_size=0.1, random_state=42):
    # Extract the labels (y) as tuples for stratification
    labels = [tuple(graph.y) for graph in graphs]

    # Split the data into training and temp (to split further into val and test)
    train_graphs, temp_graphs, train_labels, temp_labels = train_test_split(
        graphs, labels, test_size=(1 - train_size), stratify=labels, random_state=random_state
    )

    # Determine the proportion of the temp set to split into validation and test
    val_test_ratio = val_size / (val_size + test_size)

    # Split the remaining data (temp) into validation and test sets
    val_graphs, test_graphs, val_labels, test_labels = train_test_split(
        temp_graphs, temp_labels, test_size=(1 - val_test_ratio), stratify=temp_labels, random_state=random_state
    )

    return train_graphs, val_graphs, test_graphs

# Example usage:
train_list, val_list, test_list = stratified_split(all_pyg_graphs_simple)


In [15]:
from collections import defaultdict

def count_label_occurrences(graphs):
    # Create a dictionary to store the count of each unique label tuple
    label_count = defaultdict(int)
    
    # Iterate over the list of graphs
    for graph in graphs:
        # Convert the label list to a tuple (as tuples are hashable and can be used as dict keys)
        label_tuple = tuple(graph.y)
        # Increment the count of this label
        label_count[label_tuple] += 1
    
    return dict(label_count)

print(f"train:\n {count_label_occurrences(train_list)}\n")
print(f"test:\n {count_label_occurrences(test_list)}\n")
print(f"val:\n {count_label_occurrences(val_list)}\n")

train:
 {(0, 0, 1): 200, (0, 1, 0): 297, (1, 0, 1): 92, (1, 0, 0): 198, (1, 1, 1): 63, (1, 1, 0): 90, (0, 1, 1): 71}

test:
 {(1, 1, 1): 7, (0, 1, 0): 37, (0, 0, 1): 25, (1, 0, 0): 25, (1, 1, 0): 12, (1, 0, 1): 12, (0, 1, 1): 9}

val:
 {(1, 1, 0): 11, (1, 1, 1): 8, (1, 0, 0): 25, (0, 0, 1): 25, (0, 1, 0): 37, (1, 0, 1): 11, (0, 1, 1): 9}



In [13]:
def impute_and_normalize(graph_list):
    for graph in graph_list:
        # Get node features (x)
        x = graph.x

        # Step 1: Imputation (replace NaNs in non-one-hot columns using median)
        col_median = torch.nanmedian(x, dim=0).values  # Calculate column median, ignoring NaNs

        # Identify one-hot encoded columns (assuming values are exactly 0 or 1)
        onehot_mask = torch.all((x == 0) | (x == 1), dim=0)

        # Impute only non-one-hot columns
        non_onehot_mask = ~onehot_mask
        x[:, non_onehot_mask] = torch.where(torch.isnan(x[:, non_onehot_mask]), col_median[non_onehot_mask], x[:, non_onehot_mask])

        # Step 2: Normalization (only normalize non-one-hot columns)
        max_values = torch.max(x[:, non_onehot_mask], dim=0).values  # Find max of each non-one-hot column
        max_values[max_values == 0] = 1  # Avoid division by zero
        x[:, non_onehot_mask] = x[:, non_onehot_mask] / max_values  # Normalize non-one-hot columns

        # Assign the updated node features back to the graph
        graph.x = x
    
    return graph_list


train_list2= impute_and_normalize(train_list)
val_list2 = impute_and_normalize(val_list)
test_list2 = impute_and_normalize(test_list)

In [None]:
'''
import torch
from torch_geometric.data import DataLoader

def stratified_split(dataset, train_ratio=0.9, val_ratio=0.05, test_ratio=0.05):

    # Calculate label frequencies
    label_counts = {}
    for data in dataset:
        label = tuple(data.y)
        label_counts[label] = label_counts.get(label, 0) + 1

    # Determine split sizes
    train_size = int(len(dataset) * train_ratio)
    val_size = int(len(dataset) * val_ratio)
    test_size = len(dataset) - train_size - val_size

    # Create stratified splits
    train_idx = []
    val_idx = []
    test_idx = []
    for label, count in label_counts.items():
        indices = torch.where(dataset.data.y == label)[0]
        torch.manual_seed(42)  # Set a seed for reproducibility
        indices = indices[torch.randperm(count)]
        train_idx.extend(indices[:train_size * count // len(dataset)])
        val_idx.extend(indices[train_size * count // len(dataset):train_size * count // len(dataset) + val_size * count // len(dataset)])
        test_idx.extend(indices[train_size * count // len(dataset) + val_size * count // len(dataset):])

    return dataset[train_idx], dataset[val_idx], dataset[test_idx]

train_list, val_list, test_list = stratified_split(all_pyg_graphs_simple)
'''

In [18]:
def compute_global_min_max_median(graphs):
    ### Node features
    # Concatenate features of all graphs along the node dimension
    all_features = torch.cat([data.x for data in graphs], dim=0)
    
    # Mask NaN values by creating a boolean mask
    nan_mask = torch.isnan(all_features)
    
    # Replace NaN values with a very large number for min and very small for max
    all_features_min = all_features.clone()
    all_features_max = all_features.clone()
    
    all_features_min[nan_mask] = float('inf')  # Use infinity for min
    all_features_max[nan_mask] = -float('inf')  # Use negative infinity for max
    
    # Compute global min, max and median ignoring NaNs
    global_min_node = torch.min(all_features_min, dim=0, keepdim=True)[0]
    global_max_node = torch.max(all_features_max, dim=0, keepdim=True)[0]
    
    # For median, use nan-to-num to replace NaNs with a large number temporarily, then calculate the median
    all_features_median = torch.nan_to_num(all_features, nan=0)
    global_median_node = torch.median(all_features_median, dim=0, keepdim=True)[0]

    ### graph features
    # Concatenate features of all graphs along the node dimension
    all_features = torch.stack([torch.squeeze(data.u) for data in graphs])
    
    # Mask NaN values by creating a boolean mask
    nan_mask = torch.isnan(all_features)
    
    # Replace NaN values with a very large number for min and very small for max
    all_features_min = all_features.clone()
    all_features_max = all_features.clone()
    
    all_features_min[nan_mask] = float('inf')  # Use infinity for min
    all_features_max[nan_mask] = -float('inf')  # Use negative infinity for max
    
    # Compute global min, max and median ignoring NaNs
    global_min_graph = torch.min(all_features_min, dim=0, keepdim=True)[0]
    global_max_graph = torch.max(all_features_max, dim=0, keepdim=True)[0]
    
    # For median, use nan-to-num to replace NaNs with a large number temporarily, then calculate the median
    all_features_median = torch.nan_to_num(all_features, nan=0)
    global_median_graph = torch.median(all_features_median, dim=0, keepdim=True)[0]
    return global_min_node, global_max_node, global_median_node, global_min_graph, global_max_graph, global_median_graph 


def replace_nan_with_median(graphs, global_median_node, global_mdeian_graph):
    # Replace NaN values with the median for each graph
    for data in graphs:
        # Replace NaN values in node feature matrix with the global median
        data.x = torch.where(torch.isnan(data.x), global_median_node, data.x)

        data.u = torch.where(torch.isnan(data.u), global_median_graph, data.u)
        
    return graphs


def min_max_normalize_features_global(graphs, global_min_node, global_max_node, global_min_graph, global_max_graph):

    global_range_node = global_max_node - global_min_node + 1e-9  # To avoid division by zero
    global_range_graph = global_max_graph - global_min_graph + 1e-9  # To avoid division by zero
    
    # Normalize each graph's features using the global min and max
    for data in graphs:
        data.x = (data.x - global_min_node) / global_range_node
        data.u = (data.u - global_min_graph) / global_range_graph
        
    return graphs

# Compute global min and max based on train list
global_min_node, global_max_node, global_median_node, global_min_graph, global_max_graph, global_median_graph = compute_global_min_max_median(train_list2)

# replace NaN
train_list_norm = replace_nan_with_median(train_list2, global_median_node, global_median_graph)
test_list_norm = replace_nan_with_median(test_list2, global_median_node, global_median_graph)
val_list_norm = replace_nan_with_median(val_list2, global_median_node, global_median_graph)

'''
# normalize
train_list_norm = min_max_normalize_features_global(train_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)
test_list_norm = min_max_normalize_features_global(test_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)
val_list_norm = min_max_normalize_features_global(val_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)
'''

'\n# normalize\ntrain_list_norm = min_max_normalize_features_global(train_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)\ntest_list_norm = min_max_normalize_features_global(test_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)\nval_list_norm = min_max_normalize_features_global(val_list_impute, global_min_node, global_max_node, global_min_graph, global_max_graph)\n'

In [None]:
'''
def compute_global_mean_std_median(graphs):
    ### Node features
    # Concatenate features of all graphs along the node dimension
    all_features = torch.cat([data.x for data in graphs], dim=0)
    
    # Mask NaN values by creating a boolean mask
    nan_mask = torch.isnan(all_features)
    
    # Replace NaN values with zeros temporarily for mean and std computation
    all_features_for_mean_std = torch.nan_to_num(all_features, nan=0)
    
    # Compute the global mean and standard deviation ignoring NaNs
    global_mean_node = torch.sum(all_features_for_mean_std, dim=0, keepdim=True) / (~nan_mask).sum(dim=0, keepdim=True).float()
    global_std_node = torch.sqrt(torch.sum((all_features_for_mean_std - global_mean_node)**2, dim=0, keepdim=True) / (~nan_mask).sum(dim=0, keepdim=True).float())
    
    # For median, use nan-to-num to replace NaNs with a large number temporarily, then calculate the median
    global_median_node = torch.median(torch.nan_to_num(all_features, nan=0), dim=0, keepdim=True)[0]

    ### Graph features (data.u)
    # Stack all graph-level features (u)
    all_graph_features = torch.stack([torch.squeeze(data.u) for data in graphs])
    
    # Mask NaN values by creating a boolean mask
    nan_mask_graph = torch.isnan(all_graph_features)
    
    # Replace NaN values with zeros temporarily for mean and std computation
    all_graph_features_for_mean_std = torch.nan_to_num(all_graph_features, nan=0)
    
    # Compute the global mean and standard deviation ignoring NaNs
    global_mean_graph = torch.sum(all_graph_features_for_mean_std, dim=0, keepdim=True) / (~nan_mask_graph).sum(dim=0, keepdim=True).float()
    global_std_graph = torch.sqrt(torch.sum((all_graph_features_for_mean_std - global_mean_graph)**2, dim=0, keepdim=True) / (~nan_mask_graph).sum(dim=0, keepdim=True).float())
    
    # For median, calculate it similarly
    global_median_graph = torch.median(torch.nan_to_num(all_graph_features, nan=0), dim=0, keepdim=True)[0]
    
    return global_mean_node, global_std_node, global_median_node, global_mean_graph, global_std_graph, global_median_graph


def zscore_normalize_features_global(graphs, global_mean_node, global_std_node, global_mean_graph, global_std_graph):
    
    # Normalize each graph's node and graph features using the global mean and std
    for data in graphs:
        data.x = (data.x - global_mean_node) / (global_std_node + 1e-9)  # Add epsilon to avoid division by zero
        data.u = (data.u - global_mean_graph) / (global_std_graph + 1e-9)
        
    return graphs


# Compute global mean, std, and median based on the train list
global_mean_node, global_std_node, global_median_node, global_mean_graph, global_std_graph, global_median_graph = compute_global_mean_std_median(train_list2)

# Replace NaN values with median (same as before)
train_list_norm = replace_nan_with_median(train_list2, global_median_node, global_median_graph)
test_list_norm = replace_nan_with_median(test_list2, global_median_node, global_median_graph)
val_list_norm = replace_nan_with_median(val_list2, global_median_node, global_median_graph)


# Normalize using Z-score normalization
train_list_norm = zscore_normalize_features_global(train_list_norm, global_mean_node, global_std_node, global_mean_graph, global_std_graph)
test_list_norm = zscore_normalize_features_global(test_list_norm, global_mean_node, global_std_node, global_mean_graph, global_std_graph)
val_list_norm = zscore_normalize_features_global(val_list_norm, global_mean_node, global_std_node, global_mean_graph, global_std_graph)
'''

### feature selection

In [22]:
# keep names for features 

temp_g = construct_graph(config=config, path=(af_db_path + prefix + "P49588" + suffix))

# AA onehot 20
all_feature_names = ["aa"+str(i) for i in range(20)]

# expasy 61
for n, d in temp_g.nodes(data=True):
    all_feature_names += d["expasy"].index.values.tolist()
    break

# hbond acceptor and donor
all_feature_names += ["hbond_acc", "hbond_donor"]

# meiler
all_feature_names += ["meiler"+str(i) for i in range(7)]

# sidechain_vector
all_feature_names += ["sidechain_vector"+str(i) for i in range(3)]

# c_beta_vector
all_feature_names += ["c_beta_vector"+str(i) for i in range(3)]

# sequence_neighbour_vector_n_to_c
all_feature_names += ["sequence_neighbour_vector_n_to_c"+str(i) for i in range(3)]

# describeProt
all_feature_names += describe_prot_featnames.tolist()

# UniProt
all_feature_names += uniprot_features

# Phylop
all_feature_names += phylop_features

# protLM embeddings 
all_feature_names += ["protLM"+str(i) for i in range(1024)]

all_feature_names

Output()

['aa0',
 'aa1',
 'aa2',
 'aa3',
 'aa4',
 'aa5',
 'aa6',
 'aa7',
 'aa8',
 'aa9',
 'aa10',
 'aa11',
 'aa12',
 'aa13',
 'aa14',
 'aa15',
 'aa16',
 'aa17',
 'aa18',
 'aa19',
 'pka_cooh_alpha',
 'pka_nh3',
 'pka_rgroup',
 'isoelectric_points',
 'molecularweight',
 'numbercodons',
 'bulkiness',
 'polarityzimmerman',
 'polaritygrantham',
 'refractivity',
 'recognitionfactors',
 'hphob_eisenberg',
 'hphob_sweet',
 'hphob_woods',
 'hphob_doolittle',
 'hphob_manavalan',
 'hphob_leo',
 'hphob_black',
 'hphob_breese',
 'hphob_fauchere',
 'hphob_guy',
 'hphob_janin',
 'hphob_miyazawa',
 'hphob_argos',
 'hphob_roseman',
 'hphob_tanford',
 'hphob_wolfenden',
 'hphob_welling',
 'hphob_wilson',
 'hphob_parker',
 'hphob_ph3_4',
 'hphob_ph7_5',
 'hphob_mobility',
 'hplchfba',
 'hplctfa',
 'transmembranetendency',
 'hplc2_1',
 'hplc7_4',
 'buriedresidues',
 'accessibleresidues',
 'hphob_chothia',
 'hphob_rose',
 'ratioside',
 'averageburied',
 'averageflexibility',
 'alpha_helixfasman',
 'beta_sheetfasman

In [23]:
import torch
import pandas as pd
import numpy as np
from sklearn.feature_selection import VarianceThreshold

# Step 1: Concatenate all graph features
# Assuming `all_feature_names` corresponds to columns in the node feature matrix `x`
def concatenate_graph_features(graphs):
    feature_list = []
    for graph in graphs:
        feature_list.append(graph.x)
    return torch.cat(feature_list, dim=0)

# Assuming `graphs` is the list of Data objects
concatenated_features = concatenate_graph_features(train_list_norm)

# sample 10k amino acids
concatenated_features = concatenated_features[torch.randperm(concatenated_features.size(0))[:10_000]]

# Convert to pandas DataFrame for easier manipulation
df_features = pd.DataFrame(concatenated_features.numpy(), columns=all_feature_names)

# Step 2: Remove features with low variance (except for binary features)
# Identify binary features (one-hot encoded)
binary_columns = df_features.columns[df_features.nunique() == 2]
non_binary_columns = df_features.columns.difference(binary_columns)

# Apply variance thresholding on non-binary features
selector = VarianceThreshold(threshold=0.1)
filtered_features = selector.fit_transform(df_features[non_binary_columns])

# Keep the selected features
selected_features = non_binary_columns[selector.get_support()]

# Combine back with binary columns that were not filtered
df_filtered = pd.concat([df_features[binary_columns], df_features[selected_features]], axis=1)

# Step 3: Remove highly correlated features
# Compute correlation matrix
corr_matrix = df_filtered.corr().abs()

# Identify highly correlated features (correlation > 0.8)
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))

# Select columns to drop based on the correlation threshold
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > 0.8)]

# Drop the highly correlated features
df_final = df_filtered.drop(columns=to_drop)

# Step 4: Update all_feature_names
all_feature_names_filtered = df_final.columns.tolist()

all_feature_names_filtered  # Updated feature names


['aa0',
 'aa1',
 'aa2',
 'aa3',
 'aa4',
 'aa5',
 'aa6',
 'aa7',
 'aa8',
 'aa9',
 'aa10',
 'aa11',
 'aa12',
 'aa13',
 'aa14',
 'aa15',
 'aa16',
 'aa17',
 'aa18',
 'aa19',
 'ACT_SITE',
 'BINDING',
 'DNA_BIND',
 'TOPO_DOM',
 'DISULFID',
 'PROPEP',
 'SIGNAL',
 'TRANSIT',
 'STRAND',
 'HELIX',
 'COILED',
 'COMPBIAS',
 'DOMAIN',
 'REGION',
 'REPEAT',
 'ZN_FING',
 'PSIPRED_helix',
 'buriedresidues',
 'c_beta_vector0',
 'c_beta_vector1',
 'c_beta_vector2',
 'hbond_acc',
 'hphob_argos',
 'hphob_welling',
 'protLM100',
 'protLM1008',
 'protLM1009',
 'protLM1010',
 'protLM1011',
 'protLM1015',
 'protLM1017',
 'protLM1018',
 'protLM1020',
 'protLM109',
 'protLM110',
 'protLM117',
 'protLM120',
 'protLM121',
 'protLM126',
 'protLM133',
 'protLM138',
 'protLM141',
 'protLM144',
 'protLM148',
 'protLM15',
 'protLM154',
 'protLM158',
 'protLM159',
 'protLM161',
 'protLM162',
 'protLM164',
 'protLM167',
 'protLM181',
 'protLM182',
 'protLM188',
 'protLM189',
 'protLM196',
 'protLM197',
 'protLM20',
 'pr

In [25]:
selected_feat_indexes = [all_feature_names.index(item) for item in all_feature_names_filtered]
selected_feat_indexes

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 113,
 114,
 115,
 116,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 108,
 58,
 93,
 94,
 95,
 81,
 43,
 47,
 232,
 1140,
 1141,
 1142,
 1143,
 1147,
 1149,
 1150,
 1152,
 241,
 242,
 249,
 252,
 253,
 258,
 265,
 270,
 273,
 276,
 280,
 147,
 286,
 290,
 291,
 293,
 294,
 296,
 299,
 313,
 314,
 320,
 321,
 328,
 329,
 152,
 332,
 335,
 337,
 342,
 345,
 349,
 351,
 154,
 357,
 155,
 364,
 367,
 374,
 375,
 378,
 379,
 380,
 383,
 392,
 394,
 398,
 159,
 403,
 405,
 406,
 409,
 411,
 414,
 426,
 162,
 444,
 450,
 164,
 453,
 458,
 464,
 465,
 467,
 468,
 470,
 166,
 484,
 485,
 487,
 488,
 489,
 498,
 502,
 504,
 505,
 521,
 523,
 526,
 530,
 531,
 172,
 533,
 534,
 537,
 539,
 173,
 543,
 546,
 553,
 560,
 564,
 568,
 577,
 581,
 177,
 585,
 593,
 602,
 180,
 616,
 619,
 631,
 182,
 632,
 633,
 635,
 636,
 639,
 641,
 643,
 645,
 649,
 651,
 653,
 655,
 657,
 669,
 67

In [26]:
for i in range(len(train_list_norm)):
    train_list_norm[i].x = train_list_norm[i].x[:, selected_feat_indexes]

for i in range(len(test_list_norm)):
    test_list_norm[i].x = test_list_norm[i].x[:, selected_feat_indexes]

for i in range(len(val_list_norm)):
    val_list_norm[i].x = val_list_norm[i].x[:, selected_feat_indexes]


### PCA

In [30]:
import torch
from sklearn.decomposition import PCA

# Concatenate node features from all graphs in train_list
all_train_features = torch.cat([graph.x for graph in train_list_norm], dim=0)

# Fit PCA on the concatenated features
#pca = PCA(n_components=64)  # Specify the number of components
pca = PCA(n_components=0.8) 
pca.fit(all_train_features.numpy())  # Convert to numpy for PCA


In [31]:
import copy

# Function to apply PCA to the node features of a list of graphs
def apply_pca(graph_list, pca_model):
    for graph in graph_list:
        graph.x = torch.tensor(pca_model.transform(graph.x.numpy()), dtype=torch.float)

#make_copy 
train_list_norm_pca = copy.deepcopy(train_list_norm)
val_list_norm_pca = copy.deepcopy(val_list_norm)
test_list_norm_pca = copy.deepcopy(test_list_norm)

# Apply PCA to train, val, and test lists
apply_pca(train_list_norm_pca, pca)
apply_pca(val_list_norm_pca, pca)
apply_pca(test_list_norm_pca, pca)

### Define a GNN

In [32]:
batch_size=32
train_loader_nopca = DataLoader(train_list_norm, batch_size=batch_size, shuffle=True)
test_loader_nopca = DataLoader(test_list_norm, batch_size=batch_size, shuffle=True)
val_loader_nopca = DataLoader(val_list_norm, batch_size=batch_size, shuffle=True)


train_loader_pca = DataLoader(train_list_norm_pca, batch_size=batch_size, shuffle=True)
test_loader_pca = DataLoader(test_list_norm_pca, batch_size=batch_size, shuffle=True)
val_loader_pca = DataLoader(val_list_norm_pca, batch_size=batch_size, shuffle=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [33]:
import torch
import torch.nn as nn

class FbetaLoss(nn.Module):
    def __init__(self, beta=0.5, epsilon=1e-8, reduction='mean'):
        """
        F-beta loss (for custom balancing precision and recall).
        
        Args:
            beta: Weighting between precision and recall (beta < 1 favors precision).
            epsilon: Small constant to avoid division by zero.
            reduction: How to reduce the loss ('mean', 'sum', or 'none').
        """
        super(FbetaLoss, self).__init__()
        self.beta = beta
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, logits, target):
        """
        Forward pass for F-beta loss.
        
        Args:
            logits: Raw output (logits) from the model of shape [batch_size, num_classes].
            target: Ground truth binary labels of shape [batch_size, num_classes].
        
        Returns:
            loss: The negative F-beta score (to maximize it).
        """
        # Apply sigmoid to logits to get predicted probabilities
        probs = torch.sigmoid(logits)

        # Use soft probabilities instead of hard binary predictions
        true_positives = (probs * target).sum(dim=0)
        false_positives = ((1 - target) * probs).sum(dim=0)
        false_negatives = (target * (1 - probs)).sum(dim=0)

        # Precision: TP / (TP + FP)
        precision = true_positives / (true_positives + false_positives + self.epsilon)
        
        # Recall: TP / (TP + FN)
        recall = true_positives / (true_positives + false_negatives + self.epsilon)
        
        # F-beta score: (1 + beta^2) * (precision * recall) / (beta^2 * precision + recall)
        beta_squared = self.beta ** 2
        f_beta = (1 + beta_squared) * (precision * recall) / (beta_squared * precision + recall + self.epsilon)

        # Loss is 1 - F-beta score (we want to maximize F-beta, so minimize 1 - F-beta)
        loss = 1 - f_beta

        # Reduce the loss over all classes
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'product':
            return torch.prod(loss)
        else:
            return loss  # No reduction, return per-class loss


In [139]:
import torch
import torch.nn as nn

class MCCLoss(nn.Module):
    def __init__(self, epsilon=1e-8, reduction='mean'):
        """
        Custom loss to maximize the Matthews Correlation Coefficient (MCC).
        
        Args:
            epsilon: Small constant to avoid division by zero.
            reduction: How to reduce the loss ('mean', 'sum', or 'none').
        """
        super(MCCLoss, self).__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, logits, target):
        """
        Forward pass for MCC loss.
        
        Args:
            logits: The raw output (logits) from the model of shape [batch_size, num_classes].
            target: Ground truth binary labels of shape [batch_size, num_classes].
        
        Returns:
            loss: Negative MCC (we minimize the negative to maximize MCC).
        """
        # Apply sigmoid to logits to get predicted probabilities
        probs = torch.sigmoid(logits)

        # Compute predicted soft labels (use probabilities instead of hard labels)
        pred_pos = probs  # Predicted positives (soft)
        pred_neg = 1 - probs  # Predicted negatives (soft)

        # Compute actual positives and negatives
        true_pos = target  # Ground truth positives
        true_neg = 1 - target  # Ground truth negatives

        # True Positives, False Positives, True Negatives, False Negatives
        TP = (pred_pos * true_pos).sum(dim=0)
        TN = (pred_neg * true_neg).sum(dim=0)
        FP = (pred_pos * true_neg).sum(dim=0)
        FN = (pred_neg * true_pos).sum(dim=0)

        # MCC numerator: TP * TN - FP * FN
        numerator = TP * TN - FP * FN

        # MCC denominator: sqrt((TP + FP)(TP + FN)(TN + FP)(TN + FN))
        denominator = torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + self.epsilon)

        # MCC: numerator / denominator
        mcc = numerator / (denominator + self.epsilon)

        # Loss is the negative MCC (to maximize it)
        loss = 1 - mcc  # Maximize MCC by minimizing (1 - MCC)

        # Reduction: 'mean' or 'sum' over all classes
        if self.reduction == 'mean':
            return loss.mean()  # Mean loss across classes
        elif self.reduction == 'sum':
            return loss.sum()   # Sum loss across classes
        elif self.reduction == 'product':
            return torch.prod(loss)
        else:
            return loss  # No reduction, return per-class loss


In [35]:
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_geometric.nn import GINConv, GATv2Conv, GCNConv, global_add_pool
import torch
import torch.nn.functional as F


class GNN(torch.nn.Module):
    """GNN model for multi-label classification"""
    def __init__(self, arch, dim_in, dim_h, dim_out, dim_graph_feature, n_layer=2, heads=2, use_graph_feat=False, dropout_p=0.5):
        super(GNN, self).__init__()
        # Store the initialization arguments
        self.args = (arch, dim_in, dim_h, dim_out, dim_graph_feature, n_layer, heads, use_graph_feat, dropout_p)
        self.use_graph_feat = use_graph_feat
        self.dropout_p = dropout_p
        self.n_layer = n_layer
        self.heads = heads

        hidden_dims = [dim_h]  # Initialize with the input hidden dimension
        for i in range(1, n_layer):
            hidden_dims.append(hidden_dims[-1] // 2)  # Divide by 2 progressively for each layer

        if arch == "GCN":
            self.convs = torch.nn.ModuleList([GCNConv(dim_in if i == 0 else hidden_dims[i-1], hidden_dims[i]) for i in range(n_layer)])

        elif arch == "GAT":
            self.convs = torch.nn.ModuleList([
                GATv2Conv(dim_in if i == 0 else hidden_dims[i-1] * heads, hidden_dims[i], heads=heads, concat=True)
                for i in range(n_layer)
            ])

        elif arch == "GIN":
            self.convs = torch.nn.ModuleList([
                GINConv(
                    Sequential(Linear(dim_in if i == 0 else hidden_dims[i-1], hidden_dims[i]),
                               BatchNorm1d(hidden_dims[i]), ReLU(),
                               Linear(hidden_dims[i], hidden_dims[i]), ReLU())
                )
                for i in range(n_layer)
            ])

        # Calculate total input dimension for lin1 based on the number of layers and heads (if applicable)
        if arch == "GAT":
            total_dim = sum([dim * heads for dim in hidden_dims])
        else:
            total_dim = sum(hidden_dims)

        if self.use_graph_feat:
            total_dim += dim_graph_feature  # Include graph features if applicable

        self.lin1 = Linear(total_dim, dim_h * n_layer)  # Adjusted lin1 to match the total dimension
        self.lin2 = Linear(dim_h * n_layer, dim_out)    # Output layer for multi-label prediction

    def forward(self, x, edge_index, batch, graph_features):
        #x, edge_index, batch, graph_features = data.x, data.edge_index, data.batch, data.u

        h_list = []
        h = x
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.dropout(h, p=self.dropout_p, training=self.training)
            h_list.append(global_add_pool(h, batch))

        # Concatenate graph embeddings from each layer
        h = torch.cat(h_list, dim=1)

        if self.use_graph_feat:
            # Concatenate graph-level features to the pooled graph embeddings
            graph_features = graph_features.reshape(h.shape[0], -1)
            h = torch.cat((h, graph_features), dim=1)

        # Fully connected layers for classification
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=self.dropout_p, training=self.training)  # Dropout for regularization
        h = self.lin2(h)

        return h  # Returning raw logits for BCEWithLogitsLoss


In [36]:
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    matthews_corrcoef,
    precision_score,
    recall_score,
    roc_auc_score
)

def calculate_metrics(logits, target, threshold=0.5):
    # Convert logits to probabilities using sigmoid
    probs = torch.sigmoid(logits)
    
    # Convert probabilities to binary predictions (threshold 0.5)
    pred = (probs > threshold).float()  # Binary predictions
    pred_np = pred.cpu().numpy()
    target_np = target.cpu().numpy()

    # F1 Score
    f1_per_class = f1_score(target_np, pred_np, average=None, zero_division=0) # (Per class)
    f1_micro = f1_score(target_np, pred_np, average='micro', zero_division=0)
    f1_macro = f1_score(target_np, pred_np, average='macro', zero_division=0) 
    
    # Precision and Recall
    precision_per_class = precision_score(target_np, pred_np, average=None, zero_division=0)
    precision_micro = precision_score(target_np, pred_np, average='micro', zero_division=0)
    precision_macro = precision_score(target_np, pred_np, average='macro', zero_division=0) 
    
    recall_per_class = recall_score(target_np, pred_np, average=None, zero_division=0)
    recall_micro = recall_score(target_np, pred_np, average='micro', zero_division=0)
    recall_macro = recall_score(target_np, pred_np, average='macro', zero_division=0) 

    # Matthews Correlation Coefficient (MCC) per class
    class_mccs = []
    for i in range(target_np.shape[1]):  # Iterate over each class (column in multi-label)
        class_mcc = matthews_corrcoef(target_np[:, i], pred_np[:, i])
        class_mccs.append(class_mcc)
        
    # MCC for the entire dataset (flattened)
    mcc_overall = matthews_corrcoef(target_np.ravel(), pred_np.ravel())
    
    return {
        'f1_per_class': f1_per_class,
        'f1_micro': f1_micro,
        'f1_macro': f1_macro,
        'precision_per_class': precision_per_class,
        'precision_micro': precision_micro,
        'precision_macro': precision_macro,
        'recall_per_class': recall_per_class,
        'recall_micro': recall_micro,
        'recall_macro': recall_macro,
        'mcc_per_class': class_mccs,
        'mcc_overall': mcc_overall
    }

In [37]:
num_classes = 3  # Number of classes in the multi-label setting

# Initialize counters for positive and negative labels
positive_counts = torch.zeros(num_classes)
negative_counts = torch.zeros(num_classes)

# Loop through each graph and count positives and negatives
for graph in train_list_norm:
    label = torch.tensor(graph.y)  # Assume graph.y is a tensor of shape [3] for the label [1, 1, 0]
    
    positive_counts += (label == 1).float()  # Count positives (1s)
    negative_counts += (label == 0).float()  # Count negatives (0s)

# Calculate the pos_weight for each class
pos_weight = negative_counts / positive_counts
pos_weight = pos_weight.to(torch.float32)
pos_weight

tensor([1.2822, 0.9405, 1.3732])

In [42]:
# Set up training parameters
input_dim_nopca = train_list_norm[0].x.shape[1]  # Adjust this based on your feature size
input_dim_pca = train_list_norm_pca[0].x.shape[1]  # Adjust this based on your feature size
hidden_dim = 128
output_dim = 3  # Number of labels
n_layer = 2
graph_features_dim = train_list_norm[0].u.shape[0]

# define potential models
gcn_pca   = GNN("GCN", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, use_graph_feat=False).to(device)
gcn_nopca = GNN("GCN", input_dim_nopca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, use_graph_feat=False).to(device)

gat_pca   = GNN("GAT", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, heads=1, use_graph_feat=False).to(device)
gat_nopca = GNN("GAT", input_dim_nopca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, heads=1, use_graph_feat=False).to(device)

gin_pca   = GNN("GIN", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, use_graph_feat=False).to(device)
gin_nopca = GNN("GIN", input_dim_nopca, hidden_dim, output_dim, graph_features_dim, n_layer=n_layer, use_graph_feat=False).to(device)

LR = 5e-4
WD = 5e-4
# define optimizers
opt_gcn_pca = torch.optim.AdamW(gcn_pca.parameters(), lr=LR, weight_decay=WD)
opt_gcn_nopca = torch.optim.AdamW(gcn_nopca.parameters(), lr=LR, weight_decay=WD)

opt_gat_pca = torch.optim.AdamW(gat_pca.parameters(), lr=LR, weight_decay=WD)
opt_gat_nopca = torch.optim.AdamW(gat_nopca.parameters(), lr=LR, weight_decay=WD)

opt_gin_pca = torch.optim.AdamW(gin_pca.parameters(), lr=LR, weight_decay=WD)
opt_gin_nopca = torch.optim.AdamW(gin_nopca.parameters(), lr=LR, weight_decay=WD)

# define schedulers
scheduler_gcn_pca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gcn_pca, mode='min', factor=0.1, patience=3)
scheduler_gcn_nopca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gcn_nopca, mode='min', factor=0.1, patience=3)

scheduler_gat_pca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gat_pca, mode='min', factor=0.1, patience=3)
scheduler_gat_nopca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gat_nopca, mode='min', factor=0.1, patience=3)

scheduler_gin_pca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gin_pca, mode='min', factor=0.1, patience=3)
scheduler_gin_nopca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_gin_nopca, mode='min', factor=0.1, patience=3)

# define loss
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
#criterion = MCCLoss(reduction="product").to(device)

In [132]:
'''def train(model, optimizer, train_loader, val_loader, model_arch, scheduler=None, epochs=100, patience=7):
    best_val_loss = float('inf')  # Initialize best validation loss as infinity
    best_model_wts = None  # Variable to store the best model weights
    patience_counter = 0  # Counter for early stopping
    
    for epoch in range(epochs):
        # Train phase
        model.train()
        total_train_loss = 0
        all_train_preds = []
        all_train_targets = []
        
        for data in train_loader:
            data = data.to(device)
            
            # Ensure input features are in float32
            data.x = data.x.float()
            
            # If data.y is a list, convert it to a tensor, then cast to float32
            if isinstance(data.y, list):
                data.y = torch.tensor(data.y, dtype=torch.float32).to(device)
            else:
                data.y = data.y.float()  # Otherwise, cast directly
            
            optimizer.zero_grad()
            if model_arch == "EGNN":
                out = model(data.x, data.edge_index, data.coords.float(), data.batch)
            else:
                out = model(data.x, data.edge_index, data.batch, data.u)
            loss = criterion(out, data.y)  # Multi-label target
            loss.backward()
            total_train_loss += loss.item()
            optimizer.step()
            
            all_train_preds.append(out.detach())
            all_train_targets.append(data.y)
        
        all_train_preds = torch.cat(all_train_preds, dim=0)
        all_train_targets = torch.cat(all_train_targets, dim=0)
        
        # Calculate metrics for train
        train_metrics = calculate_metrics(all_train_preds, all_train_targets)
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Evaluate on validation set
        val_loss, val_metrics = evaluate(model, val_loader, model_arch)

        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.3f}, Train metrics: {train_metrics}')
        print(f'            Val Loss:   {val_loss:.3f}, Val F1:   {val_metrics}')
        print('*****')
        
        # Check if validation loss has improved
        if val_loss < best_val_loss:
            print(f'Validation loss decreased ({best_val_loss:.4f} -> {val_loss:.4f}). Saving model...')
            best_val_loss = val_loss
            best_model_wts = model.state_dict()  # Save the best model weights
            patience_counter = 0  # Reset patience counter when improvement occurs
        else:
            patience_counter += 1
            print(f'No improvement for {patience_counter} epochs.')

        # Early stopping: stop if no improvement for 'patience' number of epochs
        if patience_counter >= patience:
            print(f'Early stopping triggered after {patience_counter} epochs without improvement.')
            break

        # Step the learning rate scheduler (if provided)
        if scheduler:
            scheduler.step(val_loss)
            print(f"Learning rate adjusted: {scheduler.get_last_lr()}")
    
    # Load the best model weights at the end of training
    if best_model_wts is not None:
        # Step 1: Create a new instance of the same model class (assume the class is called `model.__class__`)
        copied_model = model.__class__(*model.args)  # Make sure to pass the same initialization arguments
        
        # Step 2: Load the state_dict of the original model into the copied model
        copied_model.load_state_dict(best_model_wts)

        # Step 3: Move the copied model to the same device as the original model (if you're using GPU)
        copied_model = copied_model.to(next(model.parameters()).device)
    
    return copied_model
    '''

def train(model, optimizer, train_loader, val_loader, model_arch, scheduler=None, epochs=100, patience=5):
    best_val_loss = float('inf')  # Initialize best validation loss as infinity
    best_model_wts = None  # Variable to store the best model weights
    patience_counter = 0  # Counter for early stopping
    
    for epoch in range(epochs):
        # Train phase
        model.train()
        total_train_loss = 0
        all_train_preds = []
        all_train_targets = []
        
        for data in train_loader:
            data = data.to(device)
            
            # Ensure input features are in float32
            data.x = data.x.float()
            
            # If data.y is a list, convert it to a tensor, then cast to float32
            if isinstance(data.y, list):
                data.y = torch.tensor(data.y, dtype=torch.float32).to(device)
            else:
                data.y = data.y.float()  # Otherwise, cast directly
            
            optimizer.zero_grad()
            if model_arch == "EGNN":
                out = model(data.x, data.edge_index, data.coords.float(), data.batch)
            else:
                out = model(data.x, data.edge_index, data.batch, data.u)
            loss = criterion(out, data.y)  # Multi-label target
            loss.backward()
            total_train_loss += loss.item()
            optimizer.step()
            
            all_train_preds.append(out.detach())
            all_train_targets.append(data.y)
        
        all_train_preds = torch.cat(all_train_preds, dim=0)
        all_train_targets = torch.cat(all_train_targets, dim=0)
        
        # Calculate metrics for train
        train_metrics = calculate_metrics(all_train_preds, all_train_targets)
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Evaluate on validation set
        val_loss, val_metrics = evaluate(model, val_loader, model_arch)

        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.3f}, Train metrics: {train_metrics}')
        print(f'            Val Loss:   {val_loss:.3f}, Val F1:   {val_metrics}')
        print('*****')
        
        # Check if validation loss has improved
        if val_loss < best_val_loss:
            print(f'Validation loss decreased ({best_val_loss:.4f} -> {val_loss:.4f}). Saving model...')
            best_val_loss = val_loss
            # Save both model state and arguments used to create it
            best_model_wts = {
                'state_dict': model.state_dict(),
                'args': model.args  # Save initialization arguments
            }
            patience_counter = 0  # Reset patience counter when improvement occurs
        else:
            patience_counter += 1
            print(f'No improvement for {patience_counter} epochs.')

        # Early stopping: stop if no improvement for 'patience' number of epochs
        if patience_counter >= patience:
            print(f'Early stopping triggered after {patience_counter} epochs without improvement.')
            break

        # Step the learning rate scheduler (if provided)
        if scheduler:
            scheduler.step(val_loss)
            print(f"Learning rate adjusted: {scheduler.get_last_lr()}")
    
    # Load the best model weights at the end of training
    if best_model_wts is not None:
        # Step 1: Create a new instance of the same model class with saved args
        copied_model = model.__class__(*best_model_wts['args'])  # Use saved args
        
        # Step 2: Load the state_dict of the original model into the copied model
        copied_model.load_state_dict(best_model_wts['state_dict'])

        # Step 3: Move the copied model to the same device as the original model (if you're using GPU)
        copied_model = copied_model.to(next(model.parameters()).device)
    
    return copied_model


# Evaluation loop remains the same
def evaluate(model, loader, model_arch, threshold=0.5):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            
            # Ensure input features are in float32
            data.x = data.x.float()
            
            # If data.y is a list, convert it to a tensor, then cast to float32
            if isinstance(data.y, list):
                data.y = torch.tensor(data.y, dtype=torch.float32).to(device)
            else:
                data.y = data.y.float()  # Otherwise, cast directly

            if model_arch == "EGNN":
                out = model(data.x, data.edge_index, data.coords.float(), data.batch)
            else:
                out = model(data.x, data.edge_index, data.batch, data.u)
            loss = criterion(out, data.y)
            total_loss += loss.item()
            
            all_preds.append(out)
            all_targets.append(data.y)
    
    all_preds = torch.cat(all_preds, dim=0)
    all_targets = torch.cat(all_targets, dim=0)
    
    # Calculate metrics for validation
    metrics = calculate_metrics(all_preds, all_targets, threshold)
    
    return total_loss / len(loader), metrics


In [44]:
# GCN PCA
trained_gcn_pca = train(gcn_pca, opt_gcn_pca, train_loader_pca, val_loader_pca, model_arch="GCN", scheduler=scheduler_gcn_pca, epochs=100)
print("\nBest model:")
trained_gcn_pca.eval()
val_loss, val_metrics = evaluate(trained_gcn_pca, test_loader_pca, model_arch="GCN")
val_loss, val_metrics

Epoch 1, Train Loss: 803.724, Train metrics: {'f1_per_class': array([0.5075594 , 0.57459506, 0.52291105]), 'f1_micro': 0.537359900373599, 'f1_macro': 0.5350218339582626, 'precision_per_class': array([0.48654244, 0.51687117, 0.42358079]), 'precision_micro': 0.47365532381997805, 'precision_macro': 0.47566479824485164, 'recall_per_class': array([0.53047404, 0.64683301, 0.68309859]), 'recall_micro': 0.620863309352518, 'recall_macro': 0.6201352152056835, 'mcc_per_class': [0.09322658644708566, 0.0041521986071913625, 0.006534389430713968], 'mcc_overall': 0.0378207498513887}
            Val Loss:   106.778, Val F1:   {'f1_per_class': array([0.58974359, 0.65945946, 0.59217877]), 'f1_micro': 0.6153846153846154, 'f1_macro': 0.6137939400509232, 'precision_per_class': array([0.45544554, 0.50833333, 0.42063492]), 'precision_micro': 0.4610951008645533, 'precision_macro': 0.46147126617423645, 'recall_per_class': array([0.83636364, 0.93846154, 1.        ]), 'recall_micro': 0.9248554913294798, 'recall_m

(326.05749130249023,
 {'f1_per_class': array([0.06779661, 0.23913043, 0.54545455]),
  'f1_micro': 0.35374149659863946,
  'f1_macro': 0.2841271968022152,
  'precision_per_class': array([0.66666667, 0.40740741, 0.43333333]),
  'precision_micro': 0.43333333333333335,
  'precision_macro': 0.5024691358024691,
  'recall_per_class': array([0.03571429, 0.16923077, 0.73584906]),
  'recall_micro': 0.2988505747126437,
  'recall_macro': 0.3135980371829428,
  'mcc_per_class': [0.07071373773001521,
   -0.10852970474001519,
   0.05063784071198787],
  'mcc_overall': -0.03179805667945898})

In [45]:
# GCN no PCA
trained_gcn_nopca = train(gcn_nopca, opt_gcn_nopca, train_loader_nopca, val_loader_nopca, model_arch="GCN", epochs=100)
print("\nBest model:")
val_loss, val_metrics = evaluate(trained_gcn_nopca, test_loader_nopca, model_arch="GCN")
val_loss, val_metrics

Epoch 1, Train Loss: 699.970, Train metrics: {'f1_per_class': array([0.45708155, 0.50579151, 0.39206534]), 'f1_micro': 0.4552212389380531, 'f1_macro': 0.4516461316933069, 'precision_per_class': array([0.43558282, 0.50873786, 0.38979118]), 'precision_micro': 0.4480836236933798, 'precision_macro': 0.44470395648607425, 'recall_per_class': array([0.48081264, 0.50287908, 0.3943662 ]), 'recall_micro': 0.46258992805755395, 'recall_macro': 0.45935263898714584, 'mcc_per_class': [-0.0050663980947809484, -0.013443503017492616, -0.05512136207557225], 'mcc_overall': -0.019415365383042005}
            Val Loss:   161.274, Val F1:   {'f1_per_class': array([0.53781513, 0.43971631, 0.55757576]), 'f1_micro': 0.5129411764705882, 'f1_macro': 0.5117023985609718, 'precision_per_class': array([0.5       , 0.40789474, 0.41071429]), 'precision_micro': 0.43253968253968256, 'precision_macro': 0.4395363408521303, 'recall_per_class': array([0.58181818, 0.47692308, 0.86792453]), 'recall_micro': 0.630057803468208, '

(92.01296520233154,
 {'f1_per_class': array([0.41904762, 0.59210526, 0.48333333]),
  'f1_micro': 0.5092838196286472,
  'f1_macro': 0.4981620718462824,
  'precision_per_class': array([0.44897959, 0.51724138, 0.43283582]),
  'precision_micro': 0.4729064039408867,
  'precision_macro': 0.4663522640142006,
  'recall_per_class': array([0.39285714, 0.69230769, 0.54716981]),
  'recall_micro': 0.5517241379310345,
  'recall_macro': 0.5441115488285299,
  'mcc_per_class': [0.01282630240723671,
   0.01602171880926478,
   0.03324352819429573],
  'mcc_overall': 0.034759996137153616})

In [46]:
# GAT PCA
trained_gat_pca = train(gat_pca, opt_gat_pca, train_loader_pca, val_loader_pca, model_arch="GAT", scheduler=scheduler_gat_pca, epochs=100)
print("\nBest model:")
trained_gat_pca.eval()
val_loss, val_metrics = evaluate(trained_gat_pca, test_loader_pca, model_arch="GAT")
val_loss, val_metrics

Epoch 1, Train Loss: 600.572, Train metrics: {'f1_per_class': array([0.44281217, 0.52486188, 0.49230769]), 'f1_micro': 0.4884702825592725, 'f1_macro': 0.48666058094962455, 'precision_per_class': array([0.41372549, 0.50442478, 0.41693811]), 'precision_micro': 0.44523386619301364, 'precision_macro': 0.4450294599021087, 'recall_per_class': array([0.47629797, 0.54702495, 0.60093897]), 'recall_micro': 0.5410071942446043, 'recall_macro': 0.5414206291829322, 'mcc_per_class': [-0.04972796221297122, -0.02456289160605297, -0.011149478562935787], 'mcc_overall': -0.029379643716940292}
            Val Loss:   239.678, Val F1:   {'f1_per_class': array([0.50406504, 0.36893204, 0.44210526]), 'f1_micro': 0.4423676012461059, 'f1_macro': 0.43836744754775087, 'precision_per_class': array([0.45588235, 0.5       , 0.5       ]), 'precision_micro': 0.4797297297297297, 'precision_macro': 0.4852941176470588, 'recall_per_class': array([0.56363636, 0.29230769, 0.39622642]), 'recall_micro': 0.41040462427745666, 'r

(326.95761489868164,
 {'f1_per_class': array([0.31707317, 0.19277108, 0.34615385]),
  'f1_micro': 0.2899628252788104,
  'f1_macro': 0.2853327004076343,
  'precision_per_class': array([0.5       , 0.44444444, 0.35294118]),
  'precision_micro': 0.4105263157894737,
  'precision_macro': 0.43246187363834426,
  'recall_per_class': array([0.23214286, 0.12307692, 0.33962264]),
  'recall_micro': 0.22413793103448276,
  'recall_macro': 0.23161414057640473,
  'mcc_per_class': [0.06034816410142874,
   -0.054766967746567065,
   -0.10695236098677958],
  'mcc_overall': -0.053416057880103555})

In [47]:
# GAT no PCA
trained_gat_nopca = train(gat_nopca, opt_gat_nopca, train_loader_nopca, val_loader_nopca, model_arch="GAT", epochs=100)
print("\nBest model:")
val_loss, val_metrics = evaluate(trained_gat_nopca, test_loader_nopca, model_arch="GAT")
val_loss, val_metrics

Epoch 1, Train Loss: 507.109, Train metrics: {'f1_per_class': array([0.47276688, 0.50898204, 0.45022624]), 'f1_micro': 0.47860199714693297, 'f1_macro': 0.47732505493454186, 'precision_per_class': array([0.45684211, 0.53014553, 0.43449782]), 'precision_micro': 0.47454031117397455, 'precision_macro': 0.47382848400085814, 'recall_per_class': array([0.48984199, 0.48944338, 0.46713615]), 'recall_micro': 0.4827338129496403, 'recall_macro': 0.48214050493657523, 'mcc_per_class': [0.0354078624077418, 0.02823880597868139, 0.02420454739836681], 'mcc_overall': 0.030475618527708043}
            Val Loss:   67.757, Val F1:   {'f1_per_class': array([0.32608696, 0.51666667, 0.58108108]), 'f1_micro': 0.49444444444444446, 'f1_macro': 0.474611568089829, 'precision_per_class': array([0.40540541, 0.56363636, 0.45263158]), 'precision_micro': 0.47593582887700536, 'precision_macro': 0.4738911159963792, 'recall_per_class': array([0.27272727, 0.47692308, 0.81132075]), 'recall_micro': 0.5144508670520231, 'recall

(215.9801540374756,
 {'f1_per_class': array([0.50406504, 0.57142857, 0.54961832]),
  'f1_micro': 0.5431472081218274,
  'f1_macro': 0.5417039775632216,
  'precision_per_class': array([0.46268657, 0.53333333, 0.46153846]),
  'precision_micro': 0.4863636363636364,
  'precision_macro': 0.48585278734532467,
  'recall_per_class': array([0.55357143, 0.61538462, 0.67924528]),
  'recall_micro': 0.6149425287356322,
  'recall_macro': 0.6160671089916373,
  'mcc_per_class': [0.04627383716579194,
   0.051709344744503984,
   0.11312926733198664],
  'mcc_overall': 0.0696292488302312})

In [48]:

# GIN PCA
trained_gin_pca = train(gin_pca, opt_gin_pca, train_loader_pca, val_loader_pca, model_arch="GIN", scheduler=scheduler_gin_pca, epochs=100)
print("\nBest model:")
trained_gin_pca.eval()
val_loss, val_metrics = evaluate(trained_gin_pca, test_loader_pca, model_arch="GIN")
val_loss, val_metrics


Epoch 1, Train Loss: 13.200, Train metrics: {'f1_per_class': array([0.45073375, 0.49245283, 0.45474138]), 'f1_micro': 0.46702923181509176, 'f1_macro': 0.4659759873731897, 'precision_per_class': array([0.42074364, 0.48423006, 0.42031873]), 'precision_micro': 0.4426546391752577, 'precision_macro': 0.4417641402266503, 'recall_per_class': array([0.48532731, 0.50095969, 0.49530516]), 'recall_micro': 0.49424460431654677, 'recall_macro': 0.4938640569957577, 'mcc_per_class': [-0.035526865704130446, -0.06650222435516744, -0.002104267941422873], 'mcc_overall': -0.03212782667556666}
            Val Loss:   1.194, Val F1:   {'f1_per_class': array([0.03508772, 0.12820513, 0.16129032]), 'f1_micro': 0.1116751269035533, 'f1_macro': 0.10819439002800631, 'precision_per_class': array([0.5       , 0.38461538, 0.55555556]), 'precision_micro': 0.4583333333333333, 'precision_macro': 0.48005698005698005, 'recall_per_class': array([0.01818182, 0.07692308, 0.09433962]), 'recall_micro': 0.06358381502890173, 'rec

(0.747761681675911,
 {'f1_per_class': array([0.        , 0.63157895, 0.27586207]),
  'f1_micro': 0.391304347826087,
  'f1_macro': 0.3024803387779794,
  'precision_per_class': array([0.        , 0.61764706, 0.35294118]),
  'precision_micro': 0.5294117647058824,
  'precision_macro': 0.3235294117647059,
  'recall_per_class': array([0.        , 0.64615385, 0.22641509]),
  'recall_micro': 0.3103448275862069,
  'recall_macro': 0.290856313497823,
  'mcc_per_class': [0.0, 0.2273072095172018, -0.07894235325030699],
  'mcc_overall': 0.0882693997132293})

In [49]:
# GIN no PCA
trained_gin_nopca = train(gin_nopca, opt_gin_nopca, train_loader_nopca, val_loader_nopca, model_arch="GIN", epochs=100)
print("\nBest model:")
val_loss, val_metrics = evaluate(trained_gin_nopca, test_loader_nopca, model_arch="GIN")
val_loss, val_metrics

Epoch 1, Train Loss: 12.525, Train metrics: {'f1_per_class': array([0.41939121, 0.51509434, 0.46069869]), 'f1_micro': 0.46804051694027243, 'f1_macro': 0.4650614119641298, 'precision_per_class': array([0.41891892, 0.50649351, 0.43061224]), 'precision_micro': 0.4548540393754243, 'precision_macro': 0.45200822343679486, 'recall_per_class': array([0.41986456, 0.52399232, 0.49530516]), 'recall_micro': 0.48201438848920863, 'recall_macro': 0.4797206821984919, 'mcc_per_class': [-0.03435232894293502, -0.01889749306670323, 0.018161873963814593], 'mcc_overall': -0.006705039817868339}
            Val Loss:   1.511, Val F1:   {'f1_per_class': array([0.49315068, 0.35294118, 0.13559322]), 'f1_micro': 0.3793103448275862, 'f1_macro': 0.32722836058035937, 'precision_per_class': array([0.3956044 , 0.75      , 0.66666667]), 'precision_micro': 0.4700854700854701, 'precision_macro': 0.604090354090354, 'recall_per_class': array([0.65454545, 0.23076923, 0.0754717 ]), 'recall_micro': 0.3179190751445087, 'recall

(0.7496405690908432,
 {'f1_per_class': array([0.39316239, 0.6013986 , 0.18918919]),
  'f1_micro': 0.437125748502994,
  'f1_macro': 0.3945833945833946,
  'precision_per_class': array([0.37704918, 0.55128205, 0.33333333]),
  'precision_micro': 0.45625,
  'precision_macro': 0.42055485498108447,
  'recall_per_class': array([0.41071429, 0.66153846, 0.13207547]),
  'recall_micro': 0.41954022988505746,
  'recall_macro': 0.4014427396502868,
  'mcc_per_class': [-0.12372148548351541,
   0.09962742211111307,
   -0.07581089846549775],
  'mcc_overall': -0.0007565677749389509})

In [None]:
# save models
torch.save(trained_gcn_pca.state_dict(), '../res/trained_models/trained_gcn_pca.pth')
torch.save(trained_gat_pca.state_dict(), '../res/trained_models/trained_gat_pca.pth')
torch.save(trained_gin_pca.state_dict(), '../res/trained_models/trained_gin_pca.pth')


### EGNN

In [50]:
from torch import nn
import torch


class E_GCL(nn.Module):
    """
    E(n) Equivariant Convolutional Layer
    re
    """

    def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False):
        super(E_GCL, self).__init__()
        input_edge = input_nf * 2
        self.residual = residual
        self.attention = attention
        self.normalize = normalize
        self.coords_agg = coords_agg
        self.tanh = tanh
        self.epsilon = 1e-8
        edge_coords_nf = 1

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

        coord_mlp = []
        coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
        coord_mlp.append(act_fn)
        coord_mlp.append(layer)
        if self.tanh:
            coord_mlp.append(nn.Tanh())
        self.coord_mlp = nn.Sequential(*coord_mlp)

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, radial, edge_attr):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target, radial], dim=1)
        else:
            out = torch.cat([source, target, radial, edge_attr], dim=1)
        out = self.edge_mlp(out)
        if self.attention:
            att_val = self.att_mlp(out)
            out = out * att_val
        return out

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = self.node_mlp(agg)
        if self.residual:
            out = x + out
        return out, agg

    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, col = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        if self.coords_agg == 'sum':
            agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
        elif self.coords_agg == 'mean':
            agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
        else:
            raise Exception('Wrong coords_agg parameter' % self.coords_agg)
        coord += agg
        return coord

    def coord2radial(self, edge_index, coord):
        row, col = edge_index
        coord_diff = coord[row] - coord[col]
        radial = torch.sum(coord_diff**2, 1).unsqueeze(1)

        if self.normalize:
            norm = torch.sqrt(radial).detach() + self.epsilon
            coord_diff = coord_diff / norm

        return radial, coord_diff

    def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
        row, col = edge_index
        radial, coord_diff = self.coord2radial(edge_index, coord)

        edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
        coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)

        return h, coord, edge_attr


class EGNN(nn.Module):
    def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4, residual=True, attention=False, normalize=False, tanh=False):
        '''

        :param in_node_nf: Number of features for 'h' at the input
        :param hidden_nf: Number of hidden features
        :param out_node_nf: Number of features for 'h' at the output
        :param in_edge_nf: Number of features for the edge features
        :param device: Device (e.g. 'cpu', 'cuda:0',...)
        :param act_fn: Non-linearity
        :param n_layers: Number of layer for the EGNN
        :param residual: Use residual connections, we recommend not changing this one
        :param attention: Whether using attention or not
        :param normalize: Normalizes the coordinates messages such that:
                    instead of: x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)
                    we get:     x^{l+1}_i = x^{l}_i + Σ(x_i - x_j)phi_x(m_ij)/||x_i - x_j||
                    We noticed it may help in the stability or generalization in some future works.
                    We didn't use it in our paper.
        :param tanh: Sets a tanh activation function at the output of phi_x(m_ij). I.e. it bounds the output of
                        phi_x(m_ij) which definitely improves in stability but it may decrease in accuracy.
                        We didn't use it in our paper.
        '''

        super(EGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
                                                act_fn=act_fn, residual=residual, attention=attention,
                                                normalize=normalize, tanh=tanh))
        self.to(self.device)

    def forward(self, h, x, edges, edge_attr):
        h = self.embedding_in(h)
        for i in range(0, self.n_layers):
            h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr)
        h = self.embedding_out(h)
        return h, x


def unsorted_segment_sum(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    return result


def unsorted_segment_mean(data, segment_ids, num_segments):
    result_shape = (num_segments, data.size(1))
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    count = data.new_full(result_shape, 0)
    result.scatter_add_(0, segment_ids, data)
    count.scatter_add_(0, segment_ids, torch.ones_like(data))
    return result / count.clamp(min=1)


def get_edges(n_nodes):
    rows, cols = [], []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j:
                rows.append(i)
                cols.append(j)

    edges = [rows, cols]
    return edges


def get_edges_batch(n_nodes, batch_size):
    edges = get_edges(n_nodes)
    edge_attr = torch.ones(len(edges[0]) * batch_size, 1)
    edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])]
    if batch_size == 1:
        return edges, edge_attr
    elif batch_size > 1:
        rows, cols = [], []
        for i in range(batch_size):
            rows.append(edges[0] + n_nodes * i)
            cols.append(edges[1] + n_nodes * i)
        edges = [torch.cat(rows), torch.cat(cols)]
    return edges, edge_attr

In [117]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool

class EGNN_Graph_Classifier(torch.nn.Module):
    """EGNN model for graph classification"""
    def __init__(self, dim_in, dim_h, dim_out, dim_edge_attr, n_layers=4, n_layers_egnn=1, use_graph_feat=False, dim_graph_feature=0, dropout_p=0.5):
        super(EGNN_Graph_Classifier, self).__init__()
        self.args = (dim_in, dim_h, dim_out, dim_edge_attr, n_layers, n_layers_egnn, use_graph_feat, dim_graph_feature, dropout_p)
        self.use_graph_feat = use_graph_feat
        self.dropout_p = dropout_p
        self.n_layers = n_layers
        
        # EGNN layers
        self.egnn_layers = torch.nn.ModuleList([EGNN(in_node_nf=dim_in if i == 0 else dim_h // (2**i),
                                                     hidden_nf=dim_h // (2**i),
                                                     out_node_nf=dim_h // (2**(i+1)),
                                                     in_edge_nf=dim_edge_attr, 
                                                     n_layers=n_layers_egnn)
                                                for i in range(n_layers)])
        
        # Calculate the total dimension after pooling all EGNN layers
        total_dim = sum([dim_h // (2**i) for i in range(n_layers)])  # Sum of pooled dimensions

        if use_graph_feat:
            total_dim += dim_graph_feature  # Adjust for graph-level features

        # Linear layers for classification
        self.lin1 = torch.nn.Linear(total_dim//2, dim_h // 4)
        self.lin2 = torch.nn.Linear(dim_h // 4, dim_out)

    def forward(self, h, edge_index, x, batch, graph_features=None):
        # h: node features
        # x: coordinates
        h_list = []
        
        # Apply EGNN layers
        for layer in self.egnn_layers:
            h, x = layer(h, x, edge_index, edge_attr=None)
            h_add = global_add_pool(h, batch)
            h_list.append(h_add)
        
        # Concatenate pooled representations from all layers
        h = torch.cat(h_list, dim=1)

        # Concatenate graph-level features if applicable
        if self.use_graph_feat and graph_features is not None:
            graph_features = graph_features.view(h.shape[0], -1)
            h = torch.cat([h, graph_features], dim=1)

        # Fully connected layers for classification
        h = self.lin1(h)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_p, training=self.training)
        h = self.lin2(h)

        return h  # Return logits


In [77]:
'''import torch
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool, global_max_pool

class EGNN_Graph_Classifier(torch.nn.Module):
    """EGNN model for graph classification"""
    def __init__(self, dim_in, dim_h, dim_out, dim_edge_attr, n_layers=4, use_graph_feat=False, dropout_p=0.5):
        super(EGNN_Graph_Classifier, self).__init__()
        self.args = (dim_in, dim_h, dim_out, dim_edge_attr, n_layers, use_graph_feat, dropout_p)
        self.use_graph_feat = use_graph_feat
        self.dropout_p = dropout_p

        # EGNN layer
        self.egnn = EGNN(in_node_nf=dim_in, hidden_nf=dim_h, out_node_nf=dim_h//2, in_edge_nf=dim_edge_attr, n_layers=n_layers)
        #self.egnn2 = EGNN(in_node_nf=dim_h//2, hidden_nf=dim_h//4, out_node_nf=dim_h//8, in_edge_nf=dim_edge_attr, n_layers=n_layers)
        
        # Linear layers for classification
        self.lin1 = torch.nn.Linear(dim_h , dim_h//2)
        self.lin2 = torch.nn.Linear(dim_h//2, dim_out)

    def forward(self, h, edge_index, x, batch):
        #h, edge_index, x, batch = data.x, data.edge_index, data.coords.float(), data.batch

        h1, x1 = self.egnn(h, x, edge_index, edge_attr=None)
        h_add = global_add_pool(h1, batch)
        h_max = global_max_pool(h1, batch)
        h = torch.cat([h_add, h_max], dim=1)


        # Fully connected layers for classification
        h = self.lin1(h)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_p, training=self.training)
        h = self.lin2(h)

        return h  # Return logits
'''

'import torch\nimport torch.nn.functional as F\nfrom torch_geometric.nn import global_add_pool, global_max_pool\n\nclass EGNN_Graph_Classifier(torch.nn.Module):\n    """EGNN model for graph classification"""\n    def __init__(self, dim_in, dim_h, dim_out, dim_edge_attr, n_layers=4, use_graph_feat=False, dropout_p=0.5):\n        super(EGNN_Graph_Classifier, self).__init__()\n        self.args = (dim_in, dim_h, dim_out, dim_edge_attr, n_layers, use_graph_feat, dropout_p)\n        self.use_graph_feat = use_graph_feat\n        self.dropout_p = dropout_p\n\n        # EGNN layer\n        self.egnn = EGNN(in_node_nf=dim_in, hidden_nf=dim_h, out_node_nf=dim_h//2, in_edge_nf=dim_edge_attr, n_layers=n_layers)\n        #self.egnn2 = EGNN(in_node_nf=dim_h//2, hidden_nf=dim_h//4, out_node_nf=dim_h//8, in_edge_nf=dim_edge_attr, n_layers=n_layers)\n        \n        # Linear layers for classification\n        self.lin1 = torch.nn.Linear(dim_h , dim_h//2)\n        self.lin2 = torch.nn.Linear(dim_h//2,

In [226]:
class FocalLoss(torch.nn.Module):
    def __init__(self, pos_weight=torch.tensor([1,1,1]), alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce_loss = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
        #self.mccloss = MCCLoss(reduction="none").to(device)
        
    def forward(self, inputs, targets):
        loss = self.bce_loss(inputs, targets) #+ self.mccloss(inputs, targets)
        pt = torch.exp(-loss)  # pt is the probability of the correct class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * loss
        return focal_loss.mean()


In [242]:
# Set up training parameters
input_dim_nopca = train_list_norm[0].x.shape[1]  # Adjust this based on your feature size
input_dim_pca = train_list_norm_pca[0].x.shape[1]  # Adjust this based on your feature size
hidden_dim = 128
n_layers = 1
n_layers_egnn = 1
output_dim = 3  # Number of labels
graph_features_dim = train_list_norm[0].u.shape[0]

# define potential models
egnn_pca   = EGNN_Graph_Classifier(input_dim_pca, hidden_dim, output_dim, dim_edge_attr=0, n_layers=n_layers, n_layers_egnn=n_layers_egnn, use_graph_feat=False, dropout_p=0.5).to(device) 
egnn_nopca   = EGNN_Graph_Classifier(input_dim_nopca, hidden_dim, output_dim, dim_edge_attr=0, n_layers=n_layers, n_layers_egnn=n_layers_egnn, use_graph_feat=False, dropout_p=0.5).to(device)


LR = 5e-4
WD = 5e-4
# define optimizers
opt_egnn_pca = torch.optim.AdamW(egnn_pca.parameters(), lr=LR, weight_decay=WD)
opt_egnn_nopca = torch.optim.AdamW(egnn_nopca.parameters(), lr=LR, weight_decay=WD)

# define schedulers
scheduler_egnn_pca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_egnn_pca, mode='min', factor=0.1, patience=1)
scheduler_egnn_nopca = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_egnn_nopca, mode='min', factor=0.1, patience=1)

# define loss
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
#criterion = FocalLoss(pos_weight=pos_weight, alpha=0.5, gamma=2)
#criterion = MCCLoss(reduction="product").to(device)

In [243]:
'''# EGNN PCA
trained_egnn_pca = train(egnn_pca, opt_egnn_pca, train_loader_pca, val_loader_pca, model_arch="EGNN", scheduler=scheduler_egnn_pca, epochs=100)
print("\nBest model:")
trained_egnn_pca.eval()
val_loss, val_metrics = evaluate(trained_egnn_pca, test_loader_pca, model_arch="EGNN")
val_loss, val_metrics
'''

'# EGNN PCA\ntrained_egnn_pca = train(egnn_pca, opt_egnn_pca, train_loader_pca, val_loader_pca, model_arch="EGNN", scheduler=scheduler_egnn_pca, epochs=100)\nprint("\nBest model:")\ntrained_egnn_pca.eval()\nval_loss, val_metrics = evaluate(trained_egnn_pca, test_loader_pca, model_arch="EGNN")\nval_loss, val_metrics\n'

In [244]:
# EGNN no PCA
trained_egnn_nopca = train(egnn_nopca, opt_egnn_nopca, train_loader_nopca, val_loader_nopca, model_arch="EGNN", scheduler=scheduler_egnn_nopca, epochs=100, patience=3)
print("\nBest model:")
trained_egnn_nopca.eval()
val_loss, val_metrics = evaluate(trained_egnn_nopca, test_loader_nopca, model_arch="EGNN")
val_loss, val_metrics

Epoch 1, Train Loss: 155.228, Train metrics: {'f1_per_class': array([0.44191344, 0.55026455, 0.49038462]), 'f1_micro': 0.4986893840104849, 'f1_macro': 0.4941875350949003, 'precision_per_class': array([0.44597701, 0.50897227, 0.41530945]), 'precision_micro': 0.45788206979542717, 'precision_macro': 0.4567529084283431, 'recall_per_class': array([0.43792325, 0.59884837, 0.59859155]), 'recall_micro': 0.5474820143884892, 'recall_macro': 0.5451210561273938, 'mcc_per_class': [0.013656379719610914, -0.015791295478481546, -0.015251415545151246], 'mcc_overall': -0.0009061081937617749}
            Val Loss:   14.293, Val F1:   {'f1_per_class': array([0.42718447, 0.58503401, 0.58333333]), 'f1_micro': 0.5454545454545454, 'f1_macro': 0.5318506043193977, 'precision_per_class': array([0.45833333, 0.52439024, 0.42608696]), 'precision_micro': 0.46530612244897956, 'precision_macro': 0.4696035112525038, 'recall_per_class': array([0.4       , 0.66153846, 0.9245283 ]), 'recall_micro': 0.6589595375722543, 're

(10.925203949213028,
 {'f1_per_class': array([0.12698413, 0.65921788, 0.58426966]),
  'f1_micro': 0.5476190476190477,
  'f1_macro': 0.4568238890001491,
  'precision_per_class': array([0.57142857, 0.51754386, 0.416     ]),
  'precision_micro': 0.46747967479674796,
  'precision_macro': 0.501657477025898,
  'recall_per_class': array([0.07142857, 0.90769231, 0.98113208]),
  'recall_micro': 0.6609195402298851,
  'recall_macro': 0.6534176515308591,
  'mcc_per_class': [0.06347389635014633,
   0.033962642018834095,
   -0.02120779397396682],
  'mcc_overall': 0.029231891693933854})

In [245]:
trained_egnn_nopca.eval()
evaluate(trained_egnn_nopca, test_loader_nopca, model_arch="EGNN", threshold=0.55)

(10.952672839164734,
 {'f1_per_class': array([0.12698413, 0.08571429, 0.2       ]),
  'f1_micro': 0.13793103448275862,
  'f1_macro': 0.13756613756613756,
  'precision_per_class': array([0.57142857, 0.6       , 0.41176471]),
  'precision_micro': 0.4827586206896552,
  'precision_macro': 0.5277310924369747,
  'recall_per_class': array([0.07142857, 0.04615385, 0.13207547]),
  'recall_micro': 0.08045977011494253,
  'recall_macro': 0.0832192964268436,
  'mcc_per_class': [0.06347389635014633,
   0.03571663917955086,
   -0.004431049974208991],
  'mcc_overall': 0.01501973810980609})

In [118]:
import torch

# Define hyperparameter ranges
n_layers_list = [1, 2, 3]  # Example: different numbers of layers
n_layers_egnn_list = [1, 2, 3]  # Example: different numbers of EGNN layers
hidden_dim_list = [32, 64, 128, 256]  # Example: different hidden dimensions

# Example training function (assuming train and evaluate are already defined)
def train_and_evaluate_models(input_dim, output_dim, dim_edge_attr, device, train_loader, val_loader, test_loader, epochs=100):
    results = []
    
    for n_layers in n_layers_list:
        for n_layers_egnn in n_layers_egnn_list:
            for hidden_dim in hidden_dim_list:
                print(f"\nTraining model with n_layers={n_layers}, n_layers_egnn={n_layers_egnn}, hidden_dim={hidden_dim}")
                
                # Instantiate the model
                model = EGNN_Graph_Classifier(input_dim, hidden_dim, output_dim, dim_edge_attr=dim_edge_attr, 
                                              n_layers=n_layers, n_layers_egnn=n_layers_egnn, 
                                              use_graph_feat=False, dropout_p=0.5).to(device)
                
               # define optimizers
                optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
                
                # define schedulers
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
                
                # Train the model
                trained_model = train(model, optimizer, train_loader, val_loader, model_arch="EGNN", scheduler=scheduler, epochs=epochs)
                
                # Evaluate the model on the validation set
                trained_model.eval()
                val_loss, val_metrics = evaluate(trained_model, test_loader, model_arch="EGNN")
                
                # Log results
                results.append({
                    'n_layers': n_layers,
                    'n_layers_egnn': n_layers_egnn,
                    'hidden_dim': hidden_dim,
                    'val_loss': val_loss,
                    'val_metrics': val_metrics
                })

                print(results)
    
    # Return the results for all combinations
    return results

# Example call to the function
results = train_and_evaluate_models(input_dim_nopca, output_dim, dim_edge_attr=0, device=device, 
                                    train_loader=train_loader_nopca, val_loader=val_loader_nopca, 
                                    test_loader=test_loader_nopca, epochs=100)

# Print results
for result in results:
    print(f"Model with n_layers={result['n_layers']}, n_layers_egnn={result['n_layers_egnn']}, hidden_dim={result['hidden_dim']}:")
    print(f"Validation Loss: {result['val_loss']}, Validation Metrics: {result['val_metrics']}\n")



Training model with n_layers=1, n_layers_egnn=1, hidden_dim=32
Epoch 1, Train Loss: 57.232, Train metrics: {'f1_per_class': array([0.42012579, 0.61742984, 0.34993614]), 'f1_micro': 0.4924965893587995, 'f1_macro': 0.46249725557385907, 'precision_per_class': array([0.47443182, 0.50180072, 0.3837535 ]), 'precision_micro': 0.4682230869001297, 'precision_macro': 0.45332867995683124, 'recall_per_class': array([0.37697517, 0.80230326, 0.32159624]), 'recall_micro': 0.5194244604316547, 'recall_macro': 0.5002915587958451, 'mcc_per_class': [0.053399022271381026, -0.0585686111988711, -0.05627741276482471], 'mcc_overall': 0.02026941093045445}
            Val Loss:   3.920, Val F1:   {'f1_per_class': array([0.21621622, 0.68085106, 0.25      ]), 'f1_micro': 0.4742857142857143, 'f1_macro': 0.3823557600153345, 'precision_per_class': array([0.42105263, 0.5203252 , 0.31428571]), 'precision_micro': 0.4689265536723164, 'precision_macro': 0.41855451637223134, 'recall_per_class': array([0.14545455, 0.984615

KeyboardInterrupt: 

In [None]:
n_layers=1, n_layers_egnn=2, hidden_dim=128

In [None]:
q = next(iter(val_loader_nopca)).to(device)

In [225]:
idx = 8
trained_egnn_nopca.eval()
trained_egnn_nopca(train_list_norm[idx].x.float(), train_list_norm[idx].edge_index, train_list_norm[idx].coords.float(), train_list_norm[idx].batch)#, train_list_norm[idx].u)


tensor([[-648.9427,  412.3538,  634.8348]], grad_fn=<AddmmBackward0>)

In [None]:
train_list_norm[idx].x.float()

In [None]:
out = model(data.x, data.edge_index, data.coords.float(), data.batch)

In [None]:
torch.save(trained_egnn_pca.state_dict(), '../res/trained_models/trained_egnn_pca.pth')
torch.save(trained_egnn_nopca.state_dict(), '../res/trained_models/trained_egnn_nopca.pth')

### test

In [None]:
# read models

trained_gcn_pca = GNN("GCN", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=3, use_graph_feat=False).to(device)
trained_gcn_pca.load_state_dict(torch.load('../res/trained_models/trained_gcn_pca.pth'))
trained_gcn_pca.eval()

trained_gat_pca = GNN("GAT", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=3, heads=2, use_graph_feat=False).to(device)
trained_gat_pca.load_state_dict(torch.load('../res/trained_models/trained_gat_pca.pth'))
trained_gat_pca.eval()

trained_gin_pca = GNN("GIN", input_dim_pca, hidden_dim, output_dim, graph_features_dim, n_layer=3, use_graph_feat=False).to(device)
trained_gin_pca.load_state_dict(torch.load('../res/trained_models/trained_gin_pca.pth'))
trained_gin_pca.eval()

trained_egnn_pca = EGNN_Graph_Classifier(input_dim_pca, hidden_dim, output_dim, dim_edge_attr=0, n_layers=2, use_graph_feat=False, dropout_p=0.5).to(device) 
trained_egnn_pca.load_state_dict(torch.load('../res/trained_models/trained_egnn_pca.pth'))
trained_egnn_pca.eval()

In [None]:
# GCN

gcn_loss, gcn_test_metrics = evaluate(trained_gcn_pca, test_loader_pca)
gcn_loss, gcn_test_metrics

In [None]:
# GAT
gat_loss, gat_test_metrics = evaluate(trained_gat_pca, test_loader_pca)
gat_loss, gat_test_metrics

In [None]:
# GIN
gin_loss, gin_test_metrics = evaluate(trained_gin_pca, test_loader_pca)
gin_loss, gin_test_metrics

In [None]:
# EGNN

egnn_loss, egnn_test_metrics = evaluate(trained_egnn_pca, test_loader_pca)
egnn_loss, egnn_test_metrics

### Explain

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool

class EGNN_Graph_Classifier(torch.nn.Module):
    """EGNN model for graph classification"""
    def __init__(self, dim_in, dim_h, dim_out, dim_edge_attr, n_layers=4, use_graph_feat=False, dropout_p=0.5):
        super(EGNN_Graph_Classifier, self).__init__()
        self.args = (dim_in, dim_h, dim_out, dim_edge_attr, n_layers, use_graph_feat, dropout_p)
        self.use_graph_feat = use_graph_feat
        self.dropout_p = dropout_p

        # EGNN layer
        self.egnn = EGNN(in_node_nf=dim_in, hidden_nf=dim_h, out_node_nf=dim_h, in_edge_nf=dim_edge_attr, n_layers=n_layers)

        # Linear layers for classification
        self.lin1 = torch.nn.Linear(dim_h, dim_h)
        self.lin2 = torch.nn.Linear(dim_h, dim_out)

    def forward(self, x, edge_index, batch, h): # this line changed due to GNNExplainer requirements

        '''
        # Pass through the EGNN model
        h1, x1 = self.egnn(h, x, edge_index, edge_attr=None)
        h1 = F.dropout(h1, p=self.dropout_p, training=self.training)
        
        h2, x2 = self.egnn(h1, x1, edge_index, edge_attr=None)
        h2 = F.dropout(h2, p=self.dropout_p, training=self.training)
        
        h3, x3 = self.egnn(h2, x2, edge_index, edge_attr=None)
        h3 = F.dropout(h3, p=self.dropout_p, training=self.training)

        # Global pooling (sum/mean) over the nodes for graph classification
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate pooled embeddings
        h = torch.cat([h1, h2, h3], dim=1)
         '''
        h1, x1 = self.egnn(h, x, edge_index, edge_attr=None)
        h = global_add_pool(h1, batch)

        # Fully connected layers for classification
        h = self.lin1(h)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_p, training=self.training)
        h = self.lin2(h)

        return h  # Return logits


In [None]:
trained_egnn_pca = EGNN_Graph_Classifier(input_dim_pca, hidden_dim, output_dim, dim_edge_attr=0, n_layers=2, use_graph_feat=False, dropout_p=0.5).to(device) 
trained_egnn_pca.load_state_dict(torch.load('../res/trained_models/trained_egnn_pca.pth'))
trained_egnn_pca.eval()

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
            model=trained_egnn_pca,
            algorithm=GNNExplainer(epochs=250),
            explanation_type='model',
            node_mask_type='object',
            edge_mask_type=None,
            model_config=dict(
                mode='multiclass_classification',
                task_level='graph',
                return_type='raw',
            ),
        )

In [None]:
idx = 10
explanation = explainer(x=test_list_norm_pca[idx].coords.float(), edge_index=test_list_norm_pca[idx].edge_index, batch=test_list_norm_pca[idx].batch, h=test_list_norm_pca[idx].x)


In [None]:
explanation.node_mask[1:5]