<a href="https://colab.research.google.com/github/Zebreu/DeorphaNN/blob/main/DeorphaNN_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#location to save results
save_to = '/content/'

In [None]:
#@title Install Dependencies
%%capture
!pip uninstall torch -y
!pip install torch==2.4.0
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install torch-geometric
!pip install optuna
!pip install numpy-indexed

import glob
import warnings
import numpy as np
import numpy_indexed as npi
import pandas as pd
import scipy
from collections import defaultdict

from sklearn.metrics import roc_auc_score, confusion_matrix, average_precision_score

import torch
from torch.nn import Linear
import torch.nn.functional as F

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.nn import aggr
from torch_geometric.nn.norm import GraphNorm

from sklearn import preprocessing

import optuna

torch.manual_seed(111)
from huggingface_hub import hf_hub_download, list_repo_files
import h5py
import random


In [None]:
#@title Import files from HuggingFace repo
%%capture
repo_id = "lariferg/DeorphaNN"
all_files = list_repo_files(repo_id, repo_type="dataset")

pdbs_paths = sorted(
    hf_hub_download(repo_id, f, repo_type="dataset")
    for f in all_files
    if f.startswith("DeorphaNN_training/nov7relaxed") and f.endswith(".parquet")
)

labels = hf_hub_download(
    repo_id,
    next(f for f in all_files if f.startswith("DeorphaNN_training/") and f.endswith("Dataset_Labels_full - Sheet1.csv")),
    repo_type="dataset"
)
beetsdata = pd.read_csv(labels)


min_dis = hf_hub_download(
    repo_id,
    next(f for f in all_files if f.startswith("DeorphaNN_training/") and f.endswith("mindistance_active_bias.csv")),
    repo_type="dataset"
)
outsidepocket = pd.read_csv(min_dis)


new_contacts_paths = sorted(
    hf_hub_download(repo_id, f, repo_type="dataset")
    for f in all_files
    if f.startswith("DeorphaNN_training/nov7relaxed") and f.endswith("_arpeggio_contacts3.parquet")
)

hdfs_int = sorted(
    hf_hub_download(repo_id, f, repo_type="dataset")
    for f in all_files
    if f.startswith("pair_representations/") and f.endswith("_interaction.h5")
)


hdfs_t = sorted(
    hf_hub_download(repo_id, f, repo_type="dataset")
    for f in all_files
    if f.startswith("pair_representations/") and f.endswith("T.h5")
)


###Preparing data

In [None]:
gpcrs = []
peptides = []
plddts = []
paths = []
plddt_peptides = []
plddt_gpcrs = []
plddt_atoms = []
pdb_frames = dict()
for pdbs in pdbs_paths:
    print(pdbs)
    pdbs = pd.read_parquet(pdbs)
    for key, st in pdbs.groupby('path'):
        if 'amber_r_' in key:
            original_key = key
            key = key.replace('amber_r_', '')
        gpcrs.append(key.split('/')[-1].split('_')[0])
        peptides.append(key.split('/')[-1].split('_')[1])
        plddt_peptides.append(st[st['chain_id'] == 'B'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_gpcrs.append(st[st['chain_id'] == 'A'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_atoms.append(st[st['chain_id'] == 'B']['b_factor'].mean())
        paths.append(original_key)
        pdb_frames[original_key] = st
    del pdbs

In [None]:
st = pd.DataFrame({'path': paths, 'gpcr': gpcrs, 'peptide': peptides, 'plddt_peptides': plddt_peptides, 'plddt_gpcrs': plddt_gpcrs, 'plddt_atoms': plddt_atoms})
merged = pd.merge(beetsdata, st, how='left', left_on=['GPCR name', 'Peptide'], right_on=['gpcr', 'peptide'])
merged['gpcr_family'] = merged['GPCR name'].str[:-2]
merged['y'] = merged['binds'].apply(lambda x: 1 if x == True else 0)

In [None]:
gpcr_hits = merged[merged['gpcr'].isna() == False]
gpcrs_lens = []
peps_lens = []
for index, st in gpcr_hits.iterrows():
    pdb = pdb_frames[st['path']]
    rec = pdb[pdb['chain_id'] == 'A']
    pep = pdb[pdb['chain_id'] == 'B']
    gpcrs_lens.append(rec['residue_seq_id'].max())
    peps_lens.append(pep['residue_seq_id'].max())
gpcr_hits['gpcr_len'] = gpcrs_lens
gpcr_hits['pep_len'] = peps_lens

In [None]:
outsidepocket['pair'] = outsidepocket['GPCR name']+'_'+outsidepocket['Peptide']

In [None]:
gpcr_hits = gpcr_hits[-gpcr_hits['pair'].isin(set(outsidepocket['pair'].values))]

In [None]:
allcontacts_new = []
for cpath_new in new_contacts_paths:
    allcontacts_new.append(pd.read_parquet(cpath_new))
allcontacts_new = pd.concat(allcontacts_new)

In [None]:
len(allcontacts_new)

In [None]:
allcontacts_new

In [None]:
total_interactions = allcontacts_new.groupby('gpcr_peptide')['contacts'].apply(lambda x: sum(len(c) for c in x)).reset_index()
total_interactions.columns = ['gpcr_peptide', 'total_interactions']


In [None]:
total_interactions

In [None]:
gpcr_hits

In [None]:
gpcr_hits_bonds = pd.merge(
    gpcr_hits,
    total_interactions,
    how='left',
    left_on='pair',
    right_on='gpcr_peptide'
)

# Optionally drop the redundant 'gpcr_peptide' column
gpcr_hits_bonds = gpcr_hits_bonds.drop(columns=['gpcr_peptide'])


In [None]:
gpcr_hits_bonds

In [None]:
interactions = dict()
for key, st in allcontacts_new.groupby('gpcr_peptide'):
    interactions[key] = st

In [None]:
gpcr_hits_interaction_edges = dict()
for index, g in gpcr_hits.iterrows():
    pdb = pdb_frames[g['path']].copy()
    if g['path'] not in interactions:
        continue
    bonds = interactions[g['path']]

    gpcr_len = g['gpcr_len']

    # they're 1-indexed so -1
    bonds['source'] = bonds['bgn'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds['target'] = bonds['end'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds = bonds.groupby(['source', 'target'])['contact'].agg(lambda x: {bondtype for array in x for bondtype in array}).reset_index()
    sources = bonds['source'].values
    targets = bonds['target'].values

    h_edge_index = np.vstack([sources,targets])
    key = g['gpcr']+'_'+g['peptide']
    gpcr_hits_interaction_edges[key] = h_edge_index

In [None]:
gpcr_hits_interaction_edges_new = dict()
for _, row in allcontacts_new.iterrows():
    key = row['gpcr_peptide']
    contacts = row['contacts']
    # Check if contacts is None, NaN, or empty
    if contacts is None or len(contacts) == 0:
        # create empty 2x0 array
        #gpcr_hits_interaction_edges_new[key] = np.empty((2,0), dtype=int)
        continue
    # Stack the pairs vertically and transpose
    arr = np.vstack(contacts).T  # shape: 2 x N_pairs
    arr -= 1 #convert 1-indexed to 0-indexed
    gpcr_hits_interaction_edges_new[key] = arr

In [None]:
len(gpcr_hits_interaction_edges_new)

In [None]:
len(interactions)

In [None]:
gpcr_hits_interaction_edges_new['DMSR-5-1_FLP-1-7']

In [None]:
lens = gpcr_hits.groupby(['GPCR name'])['gpcr_len'].first()

In [None]:
len(hdfs_int)

In [None]:
emb_map_interaction = dict()
emb_map_interaction_gpcrindex = dict()
pairmissed = []
for hdf in hdfs_int:
    with h5py.File(hdf, "r") as f:
        keys = list(f.keys())
        print(keys)
        gpcr = keys[0].split('_')[0]
        for k in keys:
            try:
                array = np.nan_to_num(f[k][()],0)
                peptide = k.split('_')[1]
                mapkey = gpcr+'_'+peptide
                indices_to_keep = set()
                maximum = lens[gpcr]
                indices_to_keep.update(set(gpcr_hits_interaction_edges_new[mapkey][0]))
                indices_to_keep.update(set(gpcr_hits_interaction_edges_new[mapkey][1]))
                indices_to_keep = sorted([i for i in indices_to_keep if i < maximum])
                emb_map_interaction[mapkey] = array[:,indices_to_keep,:]
                emb_map_interaction_gpcrindex[mapkey] = np.array(indices_to_keep)
                print("success "+k)
            except:
                pairmissed.append(k)
                print("missed "+k)
                continue

In [None]:
len(emb_map_interaction)

In [None]:
len(pairmissed)

In [None]:
all_peptide_arrays = []
peptide_keys = []
all_gpcr_arrays = []
gpcr_keys = []

for hdf in hdfs_t:
    with h5py.File(hdf, "r") as f:
        arrays = []
        keys = list(f.keys())

        for k in keys:
            arrays.append(f[k][()])
        if "_pep_T" in hdf:
            all_peptide_arrays.append(arrays)
            peptide_keys.append(keys)
        if "_gpcr_T" in hdf:
            all_gpcr_arrays.append(arrays)
            gpcr_keys.append(keys)

In [None]:
emb_map_gpcr = dict()
for keys, arrays in zip(gpcr_keys, all_gpcr_arrays):
    try:
        gpcr = keys[0].split('_')[0]

        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_gpcr[gpcr+'_'+peptide] = array
    except:
        print('oops')
        continue

In [None]:
emb_map_peptide = dict()
for keys, arrays in zip(peptide_keys, all_peptide_arrays):
    try:
        gpcr = keys[0].split('_')[0]

        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_peptide[gpcr+'_'+peptide] = array

    except:
        print('oops')
        continue

In [None]:
embst = pd.DataFrame({'gpcr_keys': [kk for k in gpcr_keys for kk in k ], 'gpcr_embedding': [aa.mean(axis=0) for a in all_gpcr_arrays for aa in a ], 'peptide_keys': [kk for k in peptide_keys for kk in k], 'peptide_embedding': [aa.mean(axis=0) for a in all_peptide_arrays for aa in a]})

In [None]:
embst['peptide'] = embst['peptide_keys'].apply(lambda x: x.split('_')[1])
embst['gpcr'] = embst['gpcr_keys'].apply(lambda x: x.split('_')[0])

In [None]:
gpcrweight = 1/gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

In [None]:
gpcrweight

In [None]:
gpcr_hits

In [None]:
subgraphing = True
subgraph_hops = 1
with_edge_weights = True
missed = []
all_graphs = []
for index, g in gpcr_hits.iterrows():
    mapkey = g['gpcr']+'_'+g['peptide']
    if mapkey not in gpcr_hits_interaction_edges_new:
        missed.append((mapkey, g['y']))
        continue
    gpcr_len = g['gpcr_len']
    h_edge_index = gpcr_hits_interaction_edges_new[g['gpcr']+'_'+g['peptide']]
    xg = emb_map_gpcr[g['gpcr']+'_'+g['peptide']]
    xp = emb_map_peptide[g['gpcr']+'_'+g['peptide']]
    x = np.concatenate([xg, xp])
    x = torch.from_numpy(x).type(torch.float32)
    pep_edge_index = np.vstack([np.array(range(g['gpcr_len'], len(x)-1)), np.array(range(g['gpcr_len']+1, len(x)))])
    edge_index = torch.cat([torch.from_numpy(h_edge_index), torch.from_numpy(pep_edge_index)], dim=1)
    if with_edge_weights:
        if mapkey not in emb_map_interaction:
            missed.append((mapkey, g['y']))
            continue
        edgefeatures = emb_map_interaction[mapkey]
        edgeindices = emb_map_interaction_gpcrindex[mapkey]
        sources = npi.remap(h_edge_index[0], edgeindices, np.arange(len(edgeindices)))
        targets = npi.remap(h_edge_index[1], edgeindices, np.arange(len(edgeindices)))
        sourcewherever = np.where(sources >= gpcr_len)[0]
        targetwherever = np.where(targets < gpcr_len)[0]
        newsources = np.array(sources)
        newtargets = np.array(targets)
        newsources[sourcewherever] = targets[sourcewherever]
        newtargets[targetwherever] = sources[targetwherever]
        newtargets -= gpcr_len
        edge_attrs = edgefeatures[newtargets, newsources, :]
        pep_edge_attrs = np.ones(shape=(len(pep_edge_index[0]),128))*edge_attrs.mean(axis=0)
        edge_attrs = torch.from_numpy(edge_attrs).type(torch.float32)
        edge_attrs = torch.cat([edge_attrs, torch.from_numpy(pep_edge_attrs)], dim=0)
    if with_edge_weights:
        # convert to undirected first, so hops are symmetric
        edge_index, edge_attrs = torch_geometric.utils.to_undirected(edge_index, edge_attrs, reduce='mean')
    if subgraphing:
        to_keep = torch.tensor([i for i in range(gpcr_len, len(x))]) #hopping from peptide nodes
        # to_keep = torch.unique(torch.from_numpy(h_edge_index[0])) #hopping from gpcr nodes

        nodes, edges, _, _ = torch_geometric.utils.k_hop_subgraph(to_keep, subgraph_hops, edge_index, relabel_nodes=True, num_nodes=len(x))
        # mask = (nodes >= gpcr_len) | (torch.isin(nodes, to_keep))
        # nodes = nodes[mask]
        if with_edge_weights:
            edges, new_edge_attrs = torch_geometric.utils.subgraph(nodes, edge_index, edge_attrs, relabel_nodes=True)
            # edges, new_edge_attrs = torch_geometric.utils.to_undirected(edges, new_edge_attrs, reduce='mean')
            graph = Data(x=x[nodes], edge_index=edges, edge_attr=new_edge_attrs, y=torch.tensor(g['y']))
        else:
            graph = Data(x=x[nodes], edge_index=edges, y=torch.tensor(g['y']))
    else:
        graph = Data(x=x, edge_index=edge_index, y=torch.tensor(g['y']))
    graph.peptide = g['peptide']
    graph.gpcr = g['gpcr']
    graph.gpcr_family = g['gpcr_family']
    #graph.zscore = g['modifiedzscore']
    gpcrw = gpcrweight.loc[g['gpcr']].iloc[0]
    all_graphs.append({'graph': graph, 'peptide':g['peptide'], 'gpcr':g['gpcr'], 'gpcr_family': g['gpcr_family'], 'y': g['y'], 'gpcrweight': gpcrw})


In [None]:
len(all_graphs)

In [None]:
len(missed)

In [None]:
print(missed)

#Train

In [None]:
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score

def train(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch) # edge_attr
        loss = criterion(logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

def train_weighted(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch) # edge_attr
        loss = criterion(logits, data.y)
        loss = (loss*data.gpcrweight).mean()
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_roc(model, criterion, loader):
     model.eval()
     aucs = 0
     total = len(loader.dataset)
     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.edge_attr, data.batch) # edge_attr
         aucs += roc_auc_score(data.y.detach().cpu(), torch.softmax(out.detach(),dim=1).cpu()[:, 1])*(len(out)/total)
     return aucs

@torch.no_grad()
def test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch) # data.edge_attr
        all_logits.append(logits.cpu().detach()[:,1])
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def nope_test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch) # data.edge_attr
        all_logits.append(torch.sigmoid(logits.squeeze()).cpu().detach())
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def test(model, criterion, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch) # data.edge_attr
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(test_loader.dataset)

In [None]:
def move_to_cuda(g):
    g.x = g.x.cuda()
    g.edge_index = g.edge_index.cuda()
    g.edge_attr = g.edge_attr.cuda().type(torch.float32)
    g.y = g.y.cuda()
    return g

In [None]:
from torch_geometric.nn.norm import LayerNorm, BatchNorm
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import aggr
import sklearn
class PeptideGNN(torch.nn.Module):
    def __init__(self, hidden_channels, input_channels=4, gatheads=10, gatdropout=0.5, finaldropout=0.5):
        super(PeptideGNN, self).__init__()
        self.finaldropout = finaldropout
        torch.manual_seed(111)
        self.norm = BatchNorm(input_channels)
        self.conv1 = GATv2Conv(input_channels, hidden_channels, dropout=gatdropout, heads=gatheads, concat=False, edge_dim=128)
        self.pooling = global_mean_pool
        self.lin = Linear(hidden_channels, 2)
    def forward(self, x, edge_index, edge_attr, batch, hidden=False):
        x = self.norm(x)
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        if hidden:
            return x
        x = self.pooling(x, batch)
        x = F.dropout(x, p=self.finaldropout, training=self.training)
        x = self.lin(x)
        return x

In [None]:
gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

In [None]:
for g in all_graphs:
    g['graph'].gpcrweight = torch.tensor(g['gpcrweight']).cuda()

In [None]:
# original splits (gpcr)
# validation_peptides = [['NPR-43', 'CKR-1', 'NPR-39', 'AEX-2', 'DMSR-2', 'NPR-41'],
# ['NPR-11', 'SPRR-2', 'SPRR-1', 'NPR-10', 'DMSR-3', 'GNRR-6'],
# ['NPR-5', 'DMSR-8', 'NPR-2', 'FRPR-9', 'NPR-42', 'NPR-32'],
# ['FRPR-8', 'NPR-40', 'FRPR-16', 'NPR-1', 'FRPR-6', 'FRPR-4'],
# ['NPR-6', 'NMUR-2', 'FRPR-7', 'NPR-13', 'FRPR-19', 'TRHR-1'],
# ['GNRR-1', 'FRPR-18', 'NPR-37', 'PDFR-1', 'FRPR-3'],
# ['NPR-22', 'EGL-6', 'CKR-2', 'NMUR-1', 'NPR-4', 'FRPR-15'],
# ['NPR-24', 'SEB-3', 'DMSR-6', 'NPR-12', 'DMSR-7'],
# ['GNRR-3', 'NPR-35', 'TKR-2', 'NTR-1', 'DMSR-5'],
# ['NPR-8', 'DMSR-1', 'NPR-3', 'TKR-1']]

#phylogenetic splits (gpcr)
validation_peptides = [['FRPR-16', 'FRPR-18', 'FRPR-4', 'FRPR-6', 'NPR-22', 'NMUR-2'],
['AEX-2', 'DMSR-5', 'DMSR-6', 'DMSR-7', 'DMSR-8','NPR-32'],
['FRPR-7', 'FRPR-7', 'NPR-6', 'GNRR-3', 'EGL-6'],
['TKR-1', 'TKR-2', 'DMSR-1', 'DMSR-2', 'NPR-40', 'GNRR-6'],
['NPR-42', 'FRPR-9', 'FRPR-15', 'FRPR-19', 'NMUR-1'],
['NPR-8', 'NPR-24', 'NPR-37', 'NPR-43', 'FRPR-3'],
['FRPR-8', 'GNRR-1', 'SPRR-1', 'SPRR-2', 'NPR-11', 'NPR-12'],
['NPR-41', 'NPR-1', 'NPR-2', 'NPR-3', 'PDFR-1', 'NTR-1'],
['TRHR-1', 'NPR-35', 'NPR-13', 'NPR-5', 'SEB-3'],
['DMSR-3', 'NPR-10', 'NPR-4', 'NPR-39', 'CKR-1', 'CKR-2']]




In [None]:
hpt_peptides = validation_peptides[-2:]+validation_peptides[0:-2]

In [None]:
all_graphs

In [None]:
test_logits = []
test_labels = []
candidatesgnn = []
hitmapgnn = []
subsampling_factor = 4
average_precisions = dict()
hpt_results = []

for validation_peptide,hpt_peptide in zip(validation_peptides, hpt_peptides):
    training_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['y'] == 1]
    hpt_training_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['gpcr_family'] not in hpt_peptide and g['y'] == 1]
    hpt_to_shuffle = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['gpcr_family'] not in hpt_peptide and g['y'] == 0]
    to_shuffle = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['y'] == 0]
    random.Random(111).shuffle(to_shuffle)
    random.Random(111).shuffle(hpt_to_shuffle)
    trainings = []
    hpt_trainings = []
    for i in range(20):
        random.Random(i).shuffle(to_shuffle)
        trainings.append(training_graphs + to_shuffle[0:len(training_graphs)*subsampling_factor])
        trainings = [list(map(move_to_cuda, h)) for h in trainings]
        random.Random(i).shuffle(hpt_to_shuffle)
        hpt_trainings.append(hpt_training_graphs + hpt_to_shuffle[0:len(hpt_training_graphs)*subsampling_factor])
        hpt_trainings = [list(map(move_to_cuda, h)) for h in hpt_trainings]
    validation_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] in validation_peptide]
    hpt_validation_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] in hpt_peptide]
    validation_graphs = list(map(move_to_cuda, validation_graphs))
    hpt_validation_graphs = list(map(move_to_cuda, hpt_validation_graphs))
    test_loader = DataLoader(validation_graphs, batch_size=256, shuffle=False)
    hpt_test_loader = DataLoader(hpt_validation_graphs, batch_size=256, shuffle=False)
    hpt_maps = []
    hpt_logits = []
    hpt_labels = []
    def objective(trial):
        hidden_channels = trial.suggest_int('hidden_units', 50, 100)
        batch_size = trial.suggest_int('batch_size', 50, 200)
        lr = 0.0005
        model = PeptideGNN(hidden_channels, input_channels=128).cuda()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.5, reduction='none')
        for epoch in range(1, 30):
            training_graphs = random.Random(epoch+111).choice(hpt_trainings)
            train_loader = DataLoader(training_graphs, batch_size=batch_size, shuffle=True)
            loss = train_weighted(model, criterion, optimizer, train_loader)
        model.eval()
        all_logits = []
        atrues = []
        with torch.no_grad():
            for data in hpt_test_loader:
                logits = model(data.x, data.edge_index, data.edge_attr, data.batch)

                if torch.isnan(data.x).any():
                    print(f"NaN in node features for {hpt_peptide}")
                if torch.isnan(data.edge_attr).any():
                    print(f"NaN in edge attributes for {hpt_peptide}")
                if torch.isinf(data.x).any() or torch.isinf(data.edge_attr).any():
                    print(f"Infinite values in data for {hpt_peptide}")

                # ðŸ§© Check for NaNs in model output
                if torch.isnan(logits).any():
                    print(f"NaN detected in model output during Optuna eval for {hpt_peptide}")
                    continue  # skip this batch safely

                all_logits.append(logits.cpu().detach()[:,1])
                atrues.append(data.y.cpu())
        hpt_logits.append(np.concatenate(all_logits))
        hpt_labels.append(np.concatenate(atrues))
        hpt_average_precisions = []
        val_gpcr = [g.gpcr for g in hpt_validation_graphs]
        val_peptide = [g.peptide for g in hpt_validation_graphs]
        result = pd.DataFrame(zip(hpt_labels[-1], hpt_logits[-1], val_gpcr, val_peptide))
        for gpcr, r in result.groupby(2):
            if r[0].sum() > 0:
                if np.isnan(r[0]).any():
                    print(f"NaNs in TRUE labels for group: {gpcr}")
                if np.isnan(r[1]).any():
                    print(f"NaNs in PREDICTIONS for group: {gpcr}")
                hpt_average_precisions.append(sklearn.metrics.average_precision_score(r[0], r[1]))
            else:
                print('what')
        hpt_maps.append((np.mean(hpt_average_precisions), (hidden_channels, batch_size, lr)))
        return np.mean(hpt_average_precisions)
    sampler = optuna.samplers.RandomSampler(seed=111)
    study = optuna.create_study(direction='maximize', sampler=sampler)
    study.optimize(objective, n_trials=40)
    hpt_results.append(hpt_maps)
    _, params = sorted(hpt_maps)[-1]
    model = PeptideGNN(params[0], input_channels=128).cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params[2])
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.5, reduction='none')
    print(validation_peptide)
    for epoch in range(1, 30):
        training_graphs = random.Random(epoch+111).choice(trainings)
        train_loader = DataLoader(training_graphs, batch_size=params[1], shuffle=True)
        loss = train_weighted(model, criterion, optimizer, train_loader)
        if epoch % 14 == 0:
            test_acc = test_without_crash(model, criterion, test_loader)
            print(f'Epoch: {epoch:02d}, Train Acc: {test_without_crash(model, criterion, train_loader):.4f}, Test AUC: {test_acc:.4f}')

    torch.save(model.state_dict(), f'{save_to}pretrained_{validation_peptide[0]}.pth')

    model.eval()
    all_logits = []
    atrues = []
    with torch.no_grad():
        for data in test_loader:
            logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
            all_logits.append(logits.cpu().detach()[:,1])
            atrues.append(data.y.cpu())

    test_logits.append(np.concatenate(all_logits))
    test_labels.append(np.concatenate(atrues))

    val_gpcr = [g.gpcr for g in validation_graphs]
    val_peptide = [g.peptide for g in validation_graphs]
    result = pd.DataFrame(zip(test_labels[-1], test_logits[-1], val_gpcr, val_peptide))
    for gpcr, r in result.groupby(2):
        hitmapgnn.append((gpcr, r.sort_values(by=1).iloc[-17:][0].sum()))
        candidatesgnn.append((gpcr,r.sort_values(by=1)))
        if r[0].sum() > 0:
            average_precisions[gpcr] = sklearn.metrics.average_precision_score(r[0], r[1])
            print(gpcr, average_precisions[gpcr])

print(roc_auc_score(np.concatenate(test_labels), np.concatenate(test_logits)))

In [None]:
np.mean(list(average_precisions.values()))

In [None]:
map_value = int(round(np.mean(list(average_precisions.values())),3)*1000)
st = pd.concat([st for g,st in candidatesgnn])
st.to_csv(f'/content/average_precision_values.csv', index=False)