In [1]:
!pip uninstall torch -y
!pip install torch==2.4.0
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install torch-geometric
!pip install optuna
!pip install numpy-indexed

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Collecting torch==2.4.0
  Downloading torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata 

In [2]:
import glob
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy

import sklearn

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.nn import aggr
from torch_geometric.nn.norm import GraphNorm, LayerNorm, BatchNorm
import torch_geometric

import seaborn
import optuna
import h5py
import numpy_indexed as npi
import random

from collections import defaultdict

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
validation_peptides = [['NPR-43', 'CKR-1', 'NPR-39', 'AEX-2', 'DMSR-2', 'NPR-41'],
['NPR-11', 'SPRR-2', 'SPRR-1', 'NPR-10', 'DMSR-3', 'GNRR-6'],
['NPR-5', 'DMSR-8', 'NPR-2', 'FRPR-9', 'NPR-42', 'NPR-32'],
['FRPR-8', 'NPR-40', 'FRPR-16', 'NPR-1', 'FRPR-6', 'FRPR-4'],
['NPR-6', 'NMUR-2', 'FRPR-7', 'NPR-13', 'FRPR-19', 'TRHR-1'],
['GNRR-1', 'FRPR-18', 'NPR-37', 'PDFR-1', 'FRPR-3'],
['NPR-22', 'EGL-6', 'CKR-2', 'NMUR-1', 'NPR-4', 'FRPR-15'],
['NPR-24', 'SEB-3', 'DMSR-6', 'NPR-12', 'DMSR-7'],
['GNRR-3', 'NPR-35', 'TKR-2', 'NTR-1', 'DMSR-5'],
['NPR-8', 'DMSR-1', 'NPR-3', 'TKR-1']]

In [7]:
pdbs_paths = sorted(glob.glob('/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxed*.parquet'))

In [8]:
gpcrs = []
peptides = []
plddts = []
paths = []
plddt_peptides = []
plddt_gpcrs = []
plddt_atoms = []

# Dictionary to store PDB data
pdb_frames = dict()
for pdbs in pdbs_paths:
    print(pdbs)
    pdbs = pd.read_parquet(pdbs)
    for key, st in pdbs.groupby('path'):
        if 'amber_r_' in key:
            original_key = key
            key = key.replace('amber_r_', '')
        gpcrs.append(key.split('/')[-1].split('_')[0])
        peptides.append(key.split('/')[-1].split('_')[1])
        plddt_peptides.append(st[st['chain_id'] == 'B'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_gpcrs.append(st[st['chain_id'] == 'A'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_atoms.append(st[st['chain_id'] == 'B']['b_factor'].mean())
        paths.append(original_key)
        pdb_frames[original_key] = st
    del pdbs

/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_0.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_1.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_2.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_3.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_4.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_5.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_6.parquet
/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov7relaxedpdbs_7.parquet


In [9]:
beetsdata = pd.read_csv('/content/drive/MyDrive/gpcrpeptidedesign/Dataset_Labels_full - Sheet1.csv')

In [10]:
st = pd.DataFrame({'path': paths, 'gpcr': gpcrs, 'peptide': peptides, 'plddt_peptides': plddt_peptides, 'plddt_gpcrs': plddt_gpcrs, 'plddt_atoms': plddt_atoms})
merged = pd.merge(beetsdata, st, how='left', left_on=['GPCR name', 'Peptide'], right_on=['gpcr', 'peptide'])
merged['gpcr_family'] = merged['GPCR name'].str[:-2]
merged['y'] = merged['binds'].apply(lambda x: 1 if x == True else 0)

In [11]:
namesare = []
hard = []
candidates = []
hitmap = []
for validation_peptide in validation_peptides:
    validation_graphs = merged[merged['gpcr_family'].isin(validation_peptide)]

    all_probs = []
    all_trues = []
    for key, df in validation_graphs.groupby('gpcr'):
        all_probs.append(scipy.stats.percentileofscore(df['plddt_atoms'],df['plddt_atoms'])/100.0)
        all_trues.append(df['y'].values)
        result = pd.DataFrame(zip(all_trues[-1], all_probs[-1], df['peptide'].values))
        hitmap.append((key, result.sort_values(by=1).iloc[-17:][0].sum()))
        candidates.append((key,result.sort_values(by=1).iloc[-30:]))

    aprobs = np.concatenate(all_probs)
    atrues = np.concatenate(all_trues)
    aucs = roc_auc_score(atrues, aprobs)
    namesare.append(aprobs)
    hard.append(atrues)
    print(validation_peptide, aucs)
print(roc_auc_score(np.concatenate(hard), np.concatenate(namesare)))

['NPR-43', 'CKR-1', 'NPR-39', 'AEX-2', 'DMSR-2', 'NPR-41'] 0.8159581766139143
['NPR-11', 'SPRR-2', 'SPRR-1', 'NPR-10', 'DMSR-3', 'GNRR-6'] 0.859034258815834
['NPR-5', 'DMSR-8', 'NPR-2', 'FRPR-9', 'NPR-42', 'NPR-32'] 0.806256206554121
['FRPR-8', 'NPR-40', 'FRPR-16', 'NPR-1', 'FRPR-6', 'FRPR-4'] 0.9017338712924711
['NPR-6', 'NMUR-2', 'FRPR-7', 'NPR-13', 'FRPR-19', 'TRHR-1'] 0.9126146630802123
['GNRR-1', 'FRPR-18', 'NPR-37', 'PDFR-1', 'FRPR-3'] 0.9406022171177469
['NPR-22', 'EGL-6', 'CKR-2', 'NMUR-1', 'NPR-4', 'FRPR-15'] 0.960018604514891
['NPR-24', 'SEB-3', 'DMSR-6', 'NPR-12', 'DMSR-7'] 0.6889119942493326
['GNRR-3', 'NPR-35', 'TKR-2', 'NTR-1', 'DMSR-5'] 0.6433549524947373
['NPR-8', 'DMSR-1', 'NPR-3', 'TKR-1'] 0.8554254683140693
0.8380926010019525


In [12]:
probs = []
gpcrs = []
peptides = []
for key, df in merged.groupby('gpcr'):
        median = df['plddt_peptides'].median()
        deviation = np.abs(df['plddt_peptides'] - median)
        mad = deviation.median()
        modifiedzscores = 1.48*(df['plddt_peptides'] - median) / mad
        probs.append(modifiedzscores)
        gpcrs.append([key]*len(df))
        peptides.append(df['peptide'].values)

In [13]:
probs_st = pd.DataFrame({'gpcr': np.concatenate(gpcrs), 'peptide': np.concatenate(peptides), 'modifiedzscore': np.concatenate(probs)})
merged = pd.merge(merged, probs_st, how='left', on=['gpcr', 'peptide'])

In [14]:
gpcr_hits = merged[merged['gpcr'].isna() == False]
gpcrs_lens = []
peps_lens = []
for index, st in gpcr_hits.iterrows():
    pdb = pdb_frames[st['path']]
    rec = pdb[pdb['chain_id'] == 'A']
    pep = pdb[pdb['chain_id'] == 'B']
    gpcrs_lens.append(rec['residue_seq_id'].max())
    peps_lens.append(pep['residue_seq_id'].max())
gpcr_hits['gpcr_len'] = gpcrs_lens
gpcr_hits['pep_len'] = peps_lens

In [15]:
# removing outside pocket
outsidepocket = pd.read_csv('/content/drive/MyDrive/gpcrpeptidedesign/mindistance_active_bias.csv')
outsidepocket['pair'] = outsidepocket['GPCR name']+'_'+outsidepocket['Peptide']
gpcr_hits = gpcr_hits[-gpcr_hits['pair'].isin(set(outsidepocket['pair'].values))]

In [16]:
contact_paths = glob.glob('/content/drive/MyDrive/gpcrpeptidedesign/october24data/nov9contacts_*.parquet')
allcontacts = []
for cpath in contact_paths:
    allcontacts.append(pd.read_parquet(cpath))
allcontacts = pd.concat(allcontacts)
total_interactions = allcontacts.groupby('path')['path'].count()
total_interactions = pd.DataFrame(total_interactions)
total_interactions.columns = ['total_interactions']
gpcr_hits_bonds = pd.merge(gpcr_hits, total_interactions, how='left', on='path')

In [17]:
interactions = dict()
for key, st in allcontacts.groupby('path'):
    interactions[key] = st

In [18]:
bond_type_array = []
for index, st in gpcr_hits.iterrows():
    pdb = pdb_frames[st['path']]
    bond_type_mapping = defaultdict(int)
    if st['path'] in interactions:
        bond = interactions[st['path']]
        for other_index, b in bond.iterrows():
            for bond_type in b['contact']:
                bond_type_mapping[bond_type] += 1

    bond_type_mapping['path'] = st['path']
    bond_type_array.append(bond_type_mapping)

In [19]:
bond_st = pd.DataFrame(bond_type_array)
important_bonds = ['hydrophobic', 'polar', 'weak_polar', 'hbond', 'weak_hbond', 'ionic', 'aromatic', 'CARBONPI', 'vdw', 'vdw_clash', 'AMIDERING', 'AMIDEAMIDE', 'CATIONPI', 'METSULPHURPI']
bond_st = bond_st[important_bonds+['path']].fillna(0)

In [20]:
gpcr_hits_bonds = pd.merge(gpcr_hits_bonds, bond_st, how='left', on='path')

In [21]:
gpcr_hits = gpcr_hits_bonds.copy()

In [22]:
lens = gpcr_hits.groupby(['GPCR name'])['gpcr_len'].first()

In [23]:
gpcr_hits_interaction_edges = dict()
for index, g in gpcr_hits.iterrows():
    pdb = pdb_frames[g['path']].copy()
    if g['path'] not in interactions:
        continue
    bonds = interactions[g['path']]

    gpcr_len = g['gpcr_len']

    bonds['source'] = bonds['bgn'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds['target'] = bonds['end'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds = bonds.groupby(['source', 'target'])['contact'].agg(lambda x: {bondtype for array in x for bondtype in array}).reset_index()
    sources = bonds['source'].values
    targets = bonds['target'].values

    h_edge_index = np.vstack([sources,targets])
    key = g['gpcr']+'_'+g['peptide']
    gpcr_hits_interaction_edges[key] = h_edge_index

In [24]:
hdfs = sorted(glob.glob('/content/drive/MyDrive/peptide/ReP-Pair/AF2/multistate_embeddings/worm_dataset/activebias_pair_representations/average_of_5_models/interaction_region_avg/*.h5'))

In [25]:
emb_map_interaction = dict()
emb_map_interaction_gpcrindex = dict()
pairmissed = []
for hdf in hdfs:
    with h5py.File(hdf, "r") as f:
        keys = list(f.keys())
        gpcr = keys[0].split('_')[0]
        for k in keys:
            try:
                array = np.nan_to_num(f[k][()],0)
                peptide = k.split('_')[1]
                mapkey = gpcr+'_'+peptide
                indices_to_keep = set()
                maximum = lens[gpcr]
                indices_to_keep.update(set(gpcr_hits_interaction_edges[mapkey][0]))
                indices_to_keep.update(set(gpcr_hits_interaction_edges[mapkey][1]))
                indices_to_keep = sorted([i for i in indices_to_keep if i < maximum])
                emb_map_interaction[mapkey] = array[:,indices_to_keep,:]
                emb_map_interaction_gpcrindex[mapkey] = np.array(indices_to_keep)
            except:
                pairmissed.append(k)
                continue


In [40]:
hdfs2 = glob.glob('/content/drive/MyDrive/peptide/ReP-Pair/AF2/multistate_embeddings/worm_dataset/activebias_pair_representations/average_of_5_models/2D_t-average/*.h5')

In [41]:
all_peptide_arrays = []
peptide_keys = []
all_gpcr_arrays = []
gpcr_keys = []
for hdf in hdfs2:
    with h5py.File(hdf, "r") as f:
        arrays = []
        keys = list(f.keys())
        for k in keys:
            arrays.append(f[k][()])
        if "_pep_T" in hdf:
            all_peptide_arrays.append(arrays)
            peptide_keys.append(keys)
        if "_gpcr_T" in hdf:
            all_gpcr_arrays.append(arrays)
            gpcr_keys.append(keys)

In [42]:
emb_map_gpcr = dict()
for keys, arrays in zip(gpcr_keys, all_gpcr_arrays):
    try:
        gpcr = keys[0].split('_')[0]
        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_gpcr[gpcr+'_'+peptide] = array
    except:
        print('GPCR embedding error')
        continue

In [43]:
emb_map_peptide = dict()
for keys, arrays in zip(peptide_keys, all_peptide_arrays):
    try:
        gpcr = keys[0].split('_')[0]
        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_peptide[gpcr+'_'+peptide] = array
    except:
        print('peptide embedding error')
        continue

In [44]:
embst = pd.DataFrame({'gpcr_keys': [kk for k in gpcr_keys for kk in k ], 'gpcr_embedding': [aa.mean(axis=0) for a in all_gpcr_arrays for aa in a ], 'peptide_keys': [kk for k in peptide_keys for kk in k], 'peptide_embedding': [aa.mean(axis=0) for a in all_peptide_arrays for aa in a]})

In [45]:
embst['peptide'] = embst['peptide_keys'].apply(lambda x: x.split('_')[1])
embst['gpcr'] = embst['gpcr_keys'].apply(lambda x: x.split('_')[0])

In [46]:
for_classification = pd.merge(left=gpcr_hits, right=embst, how='inner', left_on=['gpcr', 'peptide'], right_on=['gpcr', 'peptide'])

In [47]:
gpcrweight = 1/gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

In [48]:
subgraphing = True
subgraph_hops = 1

with_edge_weights = True

missed = []
all_graphs = []
for index, g in gpcr_hits.iterrows():
    pdb = pdb_frames[g['path']].copy()
    if g['path'] not in interactions:
        missed.append((mapkey, g['y']))
        continue

    mapkey = g['gpcr']+'_'+g['peptide']
    gpcr_len = g['gpcr_len']

    h_edge_index = gpcr_hits_interaction_edges[g['gpcr']+'_'+g['peptide']]

    xg = emb_map_gpcr[g['gpcr']+'_'+g['peptide']]
    xp = emb_map_peptide[g['gpcr']+'_'+g['peptide']]

    x = np.concatenate([xg, xp])

    x = torch.from_numpy(x).type(torch.float32)

    pep_edge_index = np.vstack([np.array(range(g['gpcr_len'], len(x)-1)), np.array(range(g['gpcr_len']+1, len(x)))])

    edge_index = torch.cat([torch.from_numpy(h_edge_index), torch.from_numpy(pep_edge_index)], dim=1)

    if with_edge_weights:
        if mapkey not in emb_map_interaction:
            missed.append((mapkey, g['y']))
            continue
        edgefeatures = emb_map_interaction[mapkey]
        edgeindices = emb_map_interaction_gpcrindex[mapkey]

        sources = npi.remap(h_edge_index[0], edgeindices, np.arange(len(edgeindices)))
        targets = npi.remap(h_edge_index[1], edgeindices, np.arange(len(edgeindices)))
        sourcewherever = np.where(sources >= gpcr_len)[0]
        targetwherever = np.where(targets < gpcr_len)[0]
        newsources = np.array(sources)
        newtargets = np.array(targets)
        newsources[sourcewherever] = targets[sourcewherever]
        newtargets[targetwherever] = sources[targetwherever]
        newtargets -= gpcr_len
        edge_attrs = edgefeatures[newtargets, newsources, :]

        pep_edge_attrs = np.ones(shape=(len(pep_edge_index[0]),128))*edge_attrs.mean(axis=0)

        edge_attrs = torch.from_numpy(edge_attrs).type(torch.float32)
        edge_attrs = torch.cat([edge_attrs, torch.from_numpy(pep_edge_attrs)], dim=0)

    if subgraphing:
        to_keep = torch.tensor([i for i in range(gpcr_len, len(x))])
        nodes, edges, _, _ = torch_geometric.utils.k_hop_subgraph(to_keep, subgraph_hops, edge_index, relabel_nodes=True, num_nodes=len(x))
        if with_edge_weights:
            edges, new_edge_attrs = torch_geometric.utils.subgraph(nodes, edge_index, edge_attrs, relabel_nodes=True)
            graph = Data(x=x[nodes], edge_index=edges, edge_attr=new_edge_attrs, y=torch.tensor(g['y']))
        else:
            graph = Data(x=x[nodes], edge_index=edges, y=torch.tensor(g['y']))
    else:
        graph = Data(x=x, edge_index=edge_index, y=torch.tensor(g['y']))

    graph.peptide = g['peptide']
    graph.gpcr = g['gpcr']
    graph.gpcr_family = g['gpcr_family']
    gpcrw = gpcrweight.loc[g['gpcr']].iloc[0]
    all_graphs.append({'graph': graph, 'peptide':g['peptide'], 'gpcr':g['gpcr'], 'gpcr_family': g['gpcr_family'], 'y': g['y'], 'gpcrweight': gpcrw})






In [49]:
def train(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

def train_weighted(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(logits, data.y)
        loss = (loss*data.gpcrweight).mean()
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_roc(model, criterion, loader):
     model.eval()
     aucs = 0
     total = len(loader.dataset)
     correct = 0
     for data in loader:
         out = model(data.x, data.edge_index, data.edge_attr, data.batch)
         aucs += roc_auc_score(data.y.detach().cpu(), torch.softmax(out.detach(),dim=1).cpu()[:, 1])*(len(out)/total)
     return aucs

@torch.no_grad()
def test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        all_logits.append(logits.cpu().detach()[:,1])
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def nope_test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        all_logits.append(torch.sigmoid(logits.squeeze()).cpu().detach())
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def test(model, criterion, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(test_loader.dataset)

In [50]:
def move_to_cuda(g):
    g.x = g.x.cuda()
    g.edge_index = g.edge_index.cuda()
    g.edge_attr = g.edge_attr.cuda().type(torch.float32)
    g.y = g.y.cuda()
    return g

In [51]:
class DeorphaNN(torch.nn.Module):
    def __init__(self, hidden_channels, input_channels=128, gatheads=10, gatdropout=0.5, finaldropout=0.5):
        super(DeorphaNN, self).__init__()
        self.finaldropout = finaldropout
        torch.manual_seed(111)
        self.norm = BatchNorm(input_channels)
        self.conv1 = GATv2Conv(input_channels, hidden_channels, dropout=gatdropout, heads=gatheads, concat=False, edge_dim=128)
        self.pooling = global_mean_pool
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, edge_attr, batch, hidden=False):
        x = self.norm(x)
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()

        if hidden:
            return x
        x = self.pooling(x, batch)
        x = F.dropout(x, p=self.finaldropout, training=self.training)
        x = self.lin(x)
        return x

In [52]:
gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

Unnamed: 0_level_0,y
gpcr,Unnamed: 1_level_1
AEX-2-1,1
FRPR-15-1,1
GNRR-1-1,1
NPR-2-2,1
NPR-13-1,1
...,...
NPR-10-1,19
FRPR-8-1,41
DMSR-1-1,42
DMSR-1-2,45


In [53]:
for g in all_graphs:
    g['graph'].gpcrweight = torch.tensor(g['gpcrweight']).cuda()

In [54]:
logo_validation_peptides = [[vv] for v in validation_peptides for vv in v]

In [55]:
hpt_peptides = validation_peptides[-2:]+validation_peptides[0:-2]

In [56]:
test_logits = []
test_labels = []

candidatesgnn = []
hitmapgnn = []

subsampling_factor = 4

average_precisions = dict()

hpt_results = []

for validation_peptide,hpt_peptide in zip(validation_peptides, hpt_peptides):

    training_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['y'] == 1]
    hpt_training_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['gpcr_family'] not in hpt_peptide and g['y'] == 1]

    hpt_to_shuffle = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['gpcr_family'] not in hpt_peptide and g['y'] == 0]
    to_shuffle = [g['graph'] for g in all_graphs if g['gpcr_family'] not in validation_peptide and g['y'] == 0]

    random.Random(111).shuffle(to_shuffle)
    random.Random(111).shuffle(hpt_to_shuffle)

    trainings = []
    hpt_trainings = []
    for i in range(20):
        random.Random(i).shuffle(to_shuffle)
        trainings.append(training_graphs + to_shuffle[0:len(training_graphs)*subsampling_factor])
        trainings = [list(map(move_to_cuda, h)) for h in trainings]
        random.Random(i).shuffle(hpt_to_shuffle)
        hpt_trainings.append(hpt_training_graphs + hpt_to_shuffle[0:len(hpt_training_graphs)*subsampling_factor])
        hpt_trainings = [list(map(move_to_cuda, h)) for h in hpt_trainings]

    validation_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] in validation_peptide]
    hpt_validation_graphs = [g['graph'] for g in all_graphs if g['gpcr_family'] in hpt_peptide]

    validation_graphs = list(map(move_to_cuda, validation_graphs))
    hpt_validation_graphs = list(map(move_to_cuda, hpt_validation_graphs))

    test_loader = DataLoader(validation_graphs, batch_size=256, shuffle=False)
    hpt_test_loader = DataLoader(hpt_validation_graphs, batch_size=256, shuffle=False)

    hpt_maps = []
    hpt_logits = []
    hpt_labels = []

    def objective(trial):
        hidden_channels = trial.suggest_int('hidden_units', 50, 100)
        batch_size = trial.suggest_int('batch_size', 50, 200)

        lr = 0.0005

        model = DeorphaNN(hidden_channels, input_channels=128).cuda()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.5, reduction='none')
        for epoch in range(1, 30):
            training_graphs = random.Random(epoch+111).choice(hpt_trainings)
            train_loader = DataLoader(training_graphs, batch_size=batch_size, shuffle=True)
            loss = train_weighted(model, criterion, optimizer, train_loader)

        model.eval()
        all_logits = []
        atrues = []
        with torch.no_grad():
            for data in hpt_test_loader:
                logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
                all_logits.append(logits.cpu().detach()[:,1])
                atrues.append(data.y.cpu())

        hpt_logits.append(np.concatenate(all_logits))
        hpt_labels.append(np.concatenate(atrues))

        hpt_average_precisions = []

        val_gpcr = [g.gpcr for g in hpt_validation_graphs]
        val_peptide = [g.peptide for g in hpt_validation_graphs]
        result = pd.DataFrame(zip(hpt_labels[-1], hpt_logits[-1], val_gpcr, val_peptide))
        for gpcr, r in result.groupby(2):
            if r[0].sum() > 0:
                hpt_average_precisions.append(sklearn.metrics.average_precision_score(r[0], r[1]))
            else:
                print('what')

        hpt_maps.append((np.mean(hpt_average_precisions), (hidden_channels, batch_size, lr)))

        return np.mean(hpt_average_precisions)

    sampler = optuna.samplers.RandomSampler(seed=111)
    study = optuna.create_study(direction='maximize', sampler=sampler)
    study.optimize(objective, n_trials=40)

    hpt_results.append(hpt_maps)
    _, params = sorted(hpt_maps)[-1]

    model = DeorphaNN(params[0], input_channels=128).cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params[2])
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.5, reduction='none')
    print(validation_peptide)
    for epoch in range(1, 30):
        training_graphs = random.Random(epoch+111).choice(trainings)
        train_loader = DataLoader(training_graphs, batch_size=params[1], shuffle=True)
        loss = train_weighted(model, criterion, optimizer, train_loader)
        if epoch % 14 == 0:
            test_acc = test_without_crash(model, criterion, test_loader)
            print(f'Epoch: {epoch:02d}, Train Acc: {test_without_crash(model, criterion, train_loader):.4f}, Test AUC: {test_acc:.4f}')

    torch.save(model.state_dict(), f'/content/drive/MyDrive/gpcrpeptidedesign/pretrainedmodels/pretrained_testing_{validation_peptide[0]}.pth')

    model.eval()
    all_logits = []
    atrues = []
    with torch.no_grad():
        for data in test_loader:
            logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
            all_logits.append(logits.cpu().detach()[:,1])
            atrues.append(data.y.cpu())

    test_logits.append(np.concatenate(all_logits))
    test_labels.append(np.concatenate(atrues))

    val_gpcr = [g.gpcr for g in validation_graphs]
    val_peptide = [g.peptide for g in validation_graphs]
    result = pd.DataFrame(zip(test_labels[-1], test_logits[-1], val_gpcr, val_peptide))
    for gpcr, r in result.groupby(2):
        hitmapgnn.append((gpcr, r.sort_values(by=1).iloc[-17:][0].sum()))
        candidatesgnn.append((gpcr,r.sort_values(by=1)))
        if r[0].sum() > 0:
            average_precisions[gpcr] = sklearn.metrics.average_precision_score(r[0], r[1])
            print(gpcr, average_precisions[gpcr])

print(roc_auc_score(np.concatenate(test_labels), np.concatenate(test_logits)))

[I 2025-03-23 13:31:28,672] A new study created in memory with name: no-name-3d5678c7-fc77-4e9b-bfb8-9ff31b95fb20
[I 2025-03-23 13:31:41,723] Trial 0 finished with value: 0.14414462890638094 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.14414462890638094.
[I 2025-03-23 13:31:50,849] Trial 1 finished with value: 0.21034323642617117 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 1 with value: 0.21034323642617117.
[I 2025-03-23 13:32:01,529] Trial 2 finished with value: 0.18008333898567933 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 1 with value: 0.21034323642617117.
[I 2025-03-23 13:32:10,404] Trial 3 finished with value: 0.11682723294457338 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 1 with value: 0.21034323642617117.
[I 2025-03-23 13:32:20,126] Trial 4 finished with value: 0.11238989394170842 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 1 with v

['NPR-43', 'CKR-1', 'NPR-39', 'AEX-2', 'DMSR-2', 'NPR-41']
Epoch: 14, Train Acc: 0.8129, Test AUC: 0.6947
Epoch: 28, Train Acc: 0.8753, Test AUC: 0.6817
AEX-2-1 0.007042253521126761
CKR-1-1 1.0
DMSR-2-1 0.7554761904761904
NPR-39-1 0.02261178861788618
NPR-41-1 0.0171780639123962
NPR-43-1 0.10416666666666666


[I 2025-03-23 13:38:45,255] A new study created in memory with name: no-name-6032f0cd-a07a-453d-bcfe-1cbbee5c81c3
[I 2025-03-23 13:38:54,056] Trial 0 finished with value: 0.553449786419408 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.553449786419408.
[I 2025-03-23 13:39:01,237] Trial 1 finished with value: 0.5258116282332193 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 0 with value: 0.553449786419408.
[I 2025-03-23 13:39:09,570] Trial 2 finished with value: 0.4520480061740724 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 0 with value: 0.553449786419408.
[I 2025-03-23 13:39:16,565] Trial 3 finished with value: 0.4891701993102231 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 0 with value: 0.553449786419408.
[I 2025-03-23 13:39:23,967] Trial 4 finished with value: 0.47369659022140964 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 0 with value: 0.55344

['NPR-11', 'SPRR-2', 'SPRR-1', 'NPR-10', 'DMSR-3', 'GNRR-6']
Epoch: 14, Train Acc: 0.8973, Test AUC: 0.6612
Epoch: 28, Train Acc: 0.9379, Test AUC: 0.5910
DMSR-3-1 0.05555555555555555
GNRR-6-1 0.015968655005597538
NPR-10-1 0.05176164837434381
NPR-11-1 0.5708292311311167
SPRR-1-1 0.6428571428571428
SPRR-2-1 0.6106516290726817


[I 2025-03-23 13:44:28,411] A new study created in memory with name: no-name-cb79816d-f1fc-4797-9ecc-5786ef33d859
[I 2025-03-23 13:44:39,603] Trial 0 finished with value: 0.33046190583231466 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.33046190583231466.
[I 2025-03-23 13:44:49,023] Trial 1 finished with value: 0.4017340023390387 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 1 with value: 0.4017340023390387.
[I 2025-03-23 13:44:59,864] Trial 2 finished with value: 0.3171113256521589 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 1 with value: 0.4017340023390387.
[I 2025-03-23 13:45:08,515] Trial 3 finished with value: 0.3216135789249492 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 1 with value: 0.4017340023390387.
[I 2025-03-23 13:45:18,178] Trial 4 finished with value: 0.2953758617290736 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 1 with value: 0

['NPR-5', 'DMSR-8', 'NPR-2', 'FRPR-9', 'NPR-42', 'NPR-32']
Epoch: 14, Train Acc: 0.8105, Test AUC: 0.8803
Epoch: 28, Train Acc: 0.8634, Test AUC: 0.9097
DMSR-8-2 1.0
FRPR-9-1 0.13287401574803148
FRPR-9-2 0.3055555555555556
NPR-2-1 1.0
NPR-2-2 0.2
NPR-32-2 0.004166666666666667
NPR-42-1 0.06666666666666667
NPR-5-1 0.2778413033565296
NPR-5-3 0.1645407142590252


[I 2025-03-23 13:51:38,190] A new study created in memory with name: no-name-d9545fa1-bc60-47ef-b922-c3ade1368446
[I 2025-03-23 13:51:48,083] Trial 0 finished with value: 0.280774722882617 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.280774722882617.
[I 2025-03-23 13:51:56,077] Trial 1 finished with value: 0.38689733070345506 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 1 with value: 0.38689733070345506.
[I 2025-03-23 13:52:05,459] Trial 2 finished with value: 0.2869723469617687 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 1 with value: 0.38689733070345506.
[I 2025-03-23 13:52:13,117] Trial 3 finished with value: 0.38445427756867523 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 1 with value: 0.38689733070345506.
[I 2025-03-23 13:52:21,365] Trial 4 finished with value: 0.2707526025390216 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 1 with value: 

['FRPR-8', 'NPR-40', 'FRPR-16', 'NPR-1', 'FRPR-6', 'FRPR-4']
Epoch: 14, Train Acc: 0.8217, Test AUC: 0.8724
Epoch: 28, Train Acc: 0.8699, Test AUC: 0.8745
FRPR-16-1 0.42715049656226123
FRPR-4-1 0.31988513006750086
FRPR-6-1 0.23333333333333334
FRPR-8-1 0.5199691742176636
NPR-1-1 0.3333333333333333
NPR-40-1 0.021953405017921146


[I 2025-03-23 13:57:57,176] A new study created in memory with name: no-name-f9188a42-2880-4ef0-b72f-c4c5aee5c8f1
[I 2025-03-23 13:58:08,410] Trial 0 finished with value: 0.33831229485345743 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.33831229485345743.
[I 2025-03-23 13:58:17,658] Trial 1 finished with value: 0.24813234622817623 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 0 with value: 0.33831229485345743.
[I 2025-03-23 13:58:28,257] Trial 2 finished with value: 0.35365238690170203 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 2 with value: 0.35365238690170203.
[I 2025-03-23 13:58:37,082] Trial 3 finished with value: 0.30708617713275216 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 2 with value: 0.35365238690170203.
[I 2025-03-23 13:58:46,504] Trial 4 finished with value: 0.32866406029882966 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 2 with v

['NPR-6', 'NMUR-2', 'FRPR-7', 'NPR-13', 'FRPR-19', 'TRHR-1']
Epoch: 14, Train Acc: 0.9121, Test AUC: 0.8912
Epoch: 28, Train Acc: 0.9480, Test AUC: 0.8535
FRPR-19-1 0.26785714285714285
FRPR-19-2 0.6666666666666666
FRPR-7-1 0.6816287878787879
FRPR-7-2 0.7102678571428572
NMUR-2-1 0.019776409559807858
NPR-13-1 0.021739130434782608
NPR-6-1 0.25
TRHR-1-1 0.04190981432360743


[I 2025-03-23 14:05:10,625] A new study created in memory with name: no-name-d260ce30-f38c-4c8f-8f25-83294e72e30e
[I 2025-03-23 14:05:21,184] Trial 0 finished with value: 0.3358917289928412 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.3358917289928412.
[I 2025-03-23 14:05:29,729] Trial 1 finished with value: 0.2839771179243941 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 0 with value: 0.3358917289928412.
[I 2025-03-23 14:05:39,805] Trial 2 finished with value: 0.3254740757985191 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 0 with value: 0.3358917289928412.
[I 2025-03-23 14:05:47,952] Trial 3 finished with value: 0.2216404529660663 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 0 with value: 0.3358917289928412.
[I 2025-03-23 14:05:56,896] Trial 4 finished with value: 0.3342934857900996 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 0 with value: 0.3

['GNRR-1', 'FRPR-18', 'NPR-37', 'PDFR-1', 'FRPR-3']
Epoch: 14, Train Acc: 0.8782, Test AUC: 0.9592
Epoch: 28, Train Acc: 0.9183, Test AUC: 0.9828
FRPR-18-1 0.75
FRPR-18-2 1.0
FRPR-3-1 0.2265151515151515
GNRR-1-1 0.017857142857142856
NPR-37-1 0.7916666666666666
NPR-37-2 1.0
PDFR-1-1 0.9166666666666665


[I 2025-03-23 14:11:55,480] A new study created in memory with name: no-name-3880441f-c7c1-4c43-86e2-1dc9a5cdb390
[I 2025-03-23 14:12:05,954] Trial 0 finished with value: 0.4158387115992711 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.4158387115992711.
[I 2025-03-23 14:12:14,539] Trial 1 finished with value: 0.4271912542777274 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 1 with value: 0.4271912542777274.
[I 2025-03-23 14:12:24,331] Trial 2 finished with value: 0.3530384478370541 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 1 with value: 0.4271912542777274.
[I 2025-03-23 14:12:32,598] Trial 3 finished with value: 0.3943169736400593 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 1 with value: 0.4271912542777274.
[I 2025-03-23 14:12:41,496] Trial 4 finished with value: 0.4509253931991662 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 4 with value: 0.4

['NPR-22', 'EGL-6', 'CKR-2', 'NMUR-1', 'NPR-4', 'FRPR-15']
Epoch: 14, Train Acc: 0.8693, Test AUC: 0.9159
Epoch: 28, Train Acc: 0.9163, Test AUC: 0.9532
CKR-2-1 1.0
EGL-6-1 0.3219316508837573
EGL-6-2 0.46201141822044556
FRPR-15-1 1.0
NMUR-1-1 0.022607022607022607
NPR-22-1 0.4991060025542784
NPR-4-1 0.36746732281340755


[I 2025-03-23 14:18:37,871] A new study created in memory with name: no-name-b8c10a30-88b7-44cc-a626-f5979162b2e2
[I 2025-03-23 14:18:48,313] Trial 0 finished with value: 0.5282317994950477 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.5282317994950477.
[I 2025-03-23 14:18:56,788] Trial 1 finished with value: 0.5250752598728726 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 0 with value: 0.5282317994950477.
[I 2025-03-23 14:19:06,522] Trial 2 finished with value: 0.5345547156640504 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 2 with value: 0.5345547156640504.
[I 2025-03-23 14:19:14,640] Trial 3 finished with value: 0.4213129028294404 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 2 with value: 0.5345547156640504.
[I 2025-03-23 14:19:23,424] Trial 4 finished with value: 0.4888112138112138 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 2 with value: 0.5

['NPR-24', 'SEB-3', 'DMSR-6', 'NPR-12', 'DMSR-7']
Epoch: 14, Train Acc: 0.8955, Test AUC: 0.7876
Epoch: 28, Train Acc: 0.9498, Test AUC: 0.8848
DMSR-6-1 0.14890922397361903
DMSR-7-1 0.4814313071963441
NPR-12-1 0.5
NPR-24-1 0.5
SEB-3-1 1.0


[I 2025-03-23 14:25:15,742] A new study created in memory with name: no-name-ef841eb9-f4d7-425d-aa96-cd56633cd027
[I 2025-03-23 14:25:26,460] Trial 0 finished with value: 0.43217625330655013 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.43217625330655013.
[I 2025-03-23 14:25:35,048] Trial 1 finished with value: 0.5926439315931099 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 1 with value: 0.5926439315931099.
[I 2025-03-23 14:25:45,149] Trial 2 finished with value: 0.5774893012737693 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 1 with value: 0.5926439315931099.
[I 2025-03-23 14:25:53,441] Trial 3 finished with value: 0.5355139298590816 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 1 with value: 0.5926439315931099.
[I 2025-03-23 14:26:02,350] Trial 4 finished with value: 0.5813578175948483 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 1 with value: 0

['GNRR-3', 'NPR-35', 'TKR-2', 'NTR-1', 'DMSR-5']
Epoch: 14, Train Acc: 0.8549, Test AUC: 0.6248
Epoch: 28, Train Acc: 0.8885, Test AUC: 0.6672
DMSR-5-1 0.07197210301579396
GNRR-3-1 0.008024257334236102
NPR-35-1 0.8409090909090909
NTR-1-1 0.24285714285714285
TKR-2-1 0.009128508540812881


[I 2025-03-23 14:31:57,462] A new study created in memory with name: no-name-94a87fcc-a52f-40f3-bf37-2113b13727ec
[I 2025-03-23 14:32:05,343] Trial 0 finished with value: 0.5786758262777576 and parameters: {'hidden_units': 81, 'batch_size': 75}. Best is trial 0 with value: 0.5786758262777576.
[I 2025-03-23 14:32:11,854] Trial 1 finished with value: 0.4404905487549752 and parameters: {'hidden_units': 72, 'batch_size': 166}. Best is trial 0 with value: 0.5786758262777576.
[I 2025-03-23 14:32:19,325] Trial 2 finished with value: 0.5034339497840914 and parameters: {'hidden_units': 65, 'batch_size': 72}. Best is trial 0 with value: 0.5786758262777576.
[I 2025-03-23 14:32:25,551] Trial 3 finished with value: 0.43711633407969935 and parameters: {'hidden_units': 51, 'batch_size': 113}. Best is trial 0 with value: 0.5786758262777576.
[I 2025-03-23 14:32:32,491] Trial 4 finished with value: 0.5133370743568217 and parameters: {'hidden_units': 62, 'batch_size': 100}. Best is trial 0 with value: 0.

['NPR-8', 'DMSR-1', 'NPR-3', 'TKR-1']
Epoch: 14, Train Acc: 0.9015, Test AUC: 0.8646
Epoch: 28, Train Acc: 0.9338, Test AUC: 0.8788
DMSR-1-1 0.4141280351475994
DMSR-1-2 0.4095538530499172
NPR-3-1 1.0
NPR-8-1 0.8500000000000001
NPR-8-2 0.75
TKR-1-1 0.011228095234375552
0.7315620394854854
