In [None]:
!pip uninstall torch -y
!pip install torch==2.4.0
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install torch-geometric
!pip install optuna
!pip install numpy-indexed

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Collecting torch==2.4.0
  Downloading torch-2.4.0-cp311-cp311-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.4.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.4.0)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.4.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata 

In [None]:
import glob
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy

import sklearn

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.nn import aggr
from torch_geometric.nn.norm import GraphNorm, LayerNorm, BatchNorm
import torch_geometric

import seaborn
import optuna
import h5py
import numpy_indexed as npi
import random

from collections import defaultdict

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
pdbs_paths = sorted(glob.glob('/content/drive/MyDrive/gpcrpeptidedesign/orphanpreprocessing/jan29relaxed*.parquet'))

In [None]:
gpcrs = []
peptides = []
plddts = []
paths = []
plddt_peptides = []
plddt_gpcrs = []
plddt_atoms = []

# Dictionary to store PDB data
pdb_frames = dict()
for pdbs in pdbs_paths:
    print(pdbs)
    pdbs = pd.read_parquet(pdbs)
    for key, st in pdbs.groupby('path'):
        if 'amber_r_' in key:
            original_key = key
            key = key.replace('amber_r_', '')
        gpcrs.append(key.split('/')[-1].split('_')[0])
        peptides.append(key.split('/')[-1].split('_')[1])
        plddt_peptides.append(st[st['chain_id'] == 'B'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_gpcrs.append(st[st['chain_id'] == 'A'].groupby('residue_seq_id')['b_factor'].first().mean())
        plddt_atoms.append(st[st['chain_id'] == 'B']['b_factor'].mean())
        paths.append(original_key)
        pdb_frames[original_key] = st
    del pdbs

/content/drive/MyDrive/gpcrpeptidedesign/orphanpreprocessing/jan29relaxedpdbs_0.parquet


In [None]:
orphandatalabels = pd.read_csv('/content/drive/MyDrive/gpcrpeptidedesign/non-dataset_worm_LABELS.csv')

In [None]:
orphandata = pd.read_csv('/content/drive/MyDrive/gpcrpeptidedesign/non-dataset_worm_plddt_iptm_activebias.csv')

In [None]:
st = pd.DataFrame({'path': paths, 'gpcr': gpcrs, 'peptide': peptides, 'plddt_peptides': plddt_peptides, 'plddt_gpcrs': plddt_gpcrs, 'plddt_atoms': plddt_atoms})
merged = pd.merge(orphandata, st, how='left', left_on=['GPCR', 'Peptide'], right_on=['gpcr', 'peptide'])
merged['gpcr_family'] = merged['GPCR'].str[:-2]

In [None]:
merged = pd.merge(merged, orphandatalabels, how='left', left_on=['GPCR', 'Peptide'], right_on=['GPCR', 'Peptide'])
merged['y'] = merged['binds'].fillna(False).astype(int)

  merged['y'] = merged['binds'].fillna(False).astype(int)


In [None]:
gpcr_hits = merged[merged['gpcr'].isna() == False]
gpcrs_lens = []
peps_lens = []
for index, st in gpcr_hits.iterrows():
    pdb = pdb_frames[st['path']]
    rec = pdb[pdb['chain_id'] == 'A']
    pep = pdb[pdb['chain_id'] == 'B']
    gpcrs_lens.append(rec['residue_seq_id'].max())
    peps_lens.append(pep['residue_seq_id'].max())
gpcr_hits['gpcr_len'] = gpcrs_lens
gpcr_hits['pep_len'] = peps_lens

In [None]:
contact_paths = glob.glob('/content/drive/MyDrive/gpcrpeptidedesign/november28/nov28contacts_*.parquet')
allcontacts = []
for cpath in contact_paths:
    allcontacts.append(pd.read_parquet(cpath))
allcontacts = pd.concat(allcontacts)
total_interactions = allcontacts.groupby('path')['path'].count()
total_interactions = pd.DataFrame(total_interactions)
total_interactions.columns = ['total_interactions']
gpcr_hits_bonds = pd.merge(gpcr_hits, total_interactions, how='left', on='path')

In [None]:
interactions = dict()
for key, st in allcontacts.groupby('path'):
    interactions[key] = st

In [None]:
bond_type_array = []
for index, st in gpcr_hits.iterrows():
    pdb = pdb_frames[st['path']]
    bond_type_mapping = defaultdict(int)
    if st['path'] in interactions:
        bond = interactions[st['path']]
        for other_index, b in bond.iterrows():
            for bond_type in b['contact']:
                bond_type_mapping[bond_type] += 1

    bond_type_mapping['path'] = st['path']
    bond_type_array.append(bond_type_mapping)

In [None]:
bond_st = pd.DataFrame(bond_type_array)
important_bonds = ['hydrophobic', 'polar', 'weak_polar', 'hbond', 'weak_hbond', 'ionic', 'aromatic', 'CARBONPI', 'vdw', 'vdw_clash', 'AMIDERING', 'AMIDEAMIDE', 'CATIONPI', 'METSULPHURPI']
bond_st = bond_st[important_bonds+['path']].fillna(0)

In [None]:
gpcr_hits_bonds = pd.merge(gpcr_hits_bonds, bond_st, how='left', on='path')

In [None]:
gpcr_hits = gpcr_hits_bonds.copy()

In [None]:
gpcr_hits['GPCR name'] = gpcr_hits['GPCR']

In [None]:
lens = gpcr_hits.groupby(['GPCR name'])['gpcr_len'].first()

In [None]:
gpcr_hits_interaction_edges = dict()
for index, g in gpcr_hits.iterrows():
    pdb = pdb_frames[g['path']].copy()
    if g['path'] not in interactions:
        continue
    bonds = interactions[g['path']]

    gpcr_len = g['gpcr_len']

    bonds['source'] = bonds['bgn'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds['target'] = bonds['end'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
    bonds = bonds.groupby(['source', 'target'])['contact'].agg(lambda x: {bondtype for array in x for bondtype in array}).reset_index()
    sources = bonds['source'].values
    targets = bonds['target'].values

    h_edge_index = np.vstack([sources,targets])
    key = g['gpcr']+'_'+g['peptide']
    gpcr_hits_interaction_edges[key] = h_edge_index

In [None]:
hdfs = sorted(glob.glob('/content/drive/MyDrive/peptide/ReP-Pair/AF2/multistate_embeddings/worm_orphans/activebias_pair_representations/average_of_5_models/interaction_region_avg/*.h5'))

In [None]:
emb_map_interaction = dict()
emb_map_interaction_gpcrindex = dict()
pairmissed = []
for hdf in hdfs:
    with h5py.File(hdf, "r") as f:
        keys = list(f.keys())
        gpcr = keys[0].split('_')[0]
        for k in keys:
            try:
                array = np.nan_to_num(f[k][()],0)
                peptide = k.split('_')[1]
                mapkey = gpcr+'_'+peptide
                indices_to_keep = set()
                maximum = lens[gpcr]
                indices_to_keep.update(set(gpcr_hits_interaction_edges[mapkey][0]))
                indices_to_keep.update(set(gpcr_hits_interaction_edges[mapkey][1]))
                indices_to_keep = sorted([i for i in indices_to_keep if i < maximum])
                emb_map_interaction[mapkey] = array[:,indices_to_keep,:]
                emb_map_interaction_gpcrindex[mapkey] = np.array(indices_to_keep)
            except:
                pairmissed.append(k)
                continue


In [None]:
hdfs2 = glob.glob('/content/drive/MyDrive/peptide/ReP-Pair/AF2/multistate_embeddings/worm_orphans/activebias_pair_representations/average_of_5_models/2D_t-average/*.h5')

In [None]:
all_peptide_arrays = []
peptide_keys = []
all_gpcr_arrays = []
gpcr_keys = []
for hdf in hdfs2:
    with h5py.File(hdf, "r") as f:
        arrays = []
        keys = list(f.keys())
        for k in keys:
            arrays.append(f[k][()])
        if "_pep_T" in hdf:
            all_peptide_arrays.append(arrays)
            peptide_keys.append(keys)
        if "_gpcr_T" in hdf:
            all_gpcr_arrays.append(arrays)
            gpcr_keys.append(keys)

In [None]:
emb_map_gpcr = dict()
for keys, arrays in zip(gpcr_keys, all_gpcr_arrays):
    try:
        gpcr = keys[0].split('_')[0]
        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_gpcr[gpcr+'_'+peptide] = array
    except:
        print('GPCR embedding error')
        continue

In [None]:
emb_map_peptide = dict()
for keys, arrays in zip(peptide_keys, all_peptide_arrays):
    try:
        gpcr = keys[0].split('_')[0]
        for i,array in enumerate(arrays):
            peptide = keys[i].split('_')[1]
            emb_map_peptide[gpcr+'_'+peptide] = array
    except:
        print('peptide embedding error')
        continue

In [None]:
embst = pd.DataFrame({'gpcr_keys': [kk for k in gpcr_keys for kk in k ], 'gpcr_embedding': [aa.mean(axis=0) for a in all_gpcr_arrays for aa in a ], 'peptide_keys': [kk for k in peptide_keys for kk in k], 'peptide_embedding': [aa.mean(axis=0) for a in all_peptide_arrays for aa in a]})

In [None]:
embst['peptide'] = embst['peptide_keys'].apply(lambda x: x.split('_')[1])
embst['gpcr'] = embst['gpcr_keys'].apply(lambda x: x.split('_')[0])

In [None]:
for_classification = pd.merge(left=gpcr_hits, right=embst, how='inner', left_on=['gpcr', 'peptide'], right_on=['gpcr', 'peptide'])

In [None]:
gpcrweight = 1/gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

In [None]:
subgraphing = True
subgraph_hops = 1

with_edge_weights = True

missed = []
all_graphs = []
for index, g in gpcr_hits.iterrows():
    pdb = pdb_frames[g['path']].copy()
    if g['path'] not in interactions:
        missed.append((mapkey, g['y']))
        continue

    mapkey = g['gpcr']+'_'+g['peptide']
    gpcr_len = g['gpcr_len']

    h_edge_index = gpcr_hits_interaction_edges[g['gpcr']+'_'+g['peptide']]

    xg = emb_map_gpcr[g['gpcr']+'_'+g['peptide']]
    xp = emb_map_peptide[g['gpcr']+'_'+g['peptide']]

    x = np.concatenate([xg, xp])

    x = torch.from_numpy(x).type(torch.float32)

    pep_edge_index = np.vstack([np.array(range(g['gpcr_len'], len(x)-1)), np.array(range(g['gpcr_len']+1, len(x)))])

    edge_index = torch.cat([torch.from_numpy(h_edge_index), torch.from_numpy(pep_edge_index)], dim=1)

    if with_edge_weights:
        if mapkey not in emb_map_interaction:
            missed.append((mapkey, g['y']))
            continue
        edgefeatures = emb_map_interaction[mapkey]
        edgeindices = emb_map_interaction_gpcrindex[mapkey]

        sources = npi.remap(h_edge_index[0], edgeindices, np.arange(len(edgeindices)))
        targets = npi.remap(h_edge_index[1], edgeindices, np.arange(len(edgeindices)))
        sourcewherever = np.where(sources >= gpcr_len)[0]
        targetwherever = np.where(targets < gpcr_len)[0]
        newsources = np.array(sources)
        newtargets = np.array(targets)
        newsources[sourcewherever] = targets[sourcewherever]
        newtargets[targetwherever] = sources[targetwherever]
        newtargets -= gpcr_len
        edge_attrs = edgefeatures[newtargets, newsources, :]

        pep_edge_attrs = np.ones(shape=(len(pep_edge_index[0]),128))*edge_attrs.mean(axis=0)

        edge_attrs = torch.from_numpy(edge_attrs).type(torch.float32)
        edge_attrs = torch.cat([edge_attrs, torch.from_numpy(pep_edge_attrs)], dim=0)

    if subgraphing:
        to_keep = torch.tensor([i for i in range(gpcr_len, len(x))])
        nodes, edges, _, _ = torch_geometric.utils.k_hop_subgraph(to_keep, subgraph_hops, edge_index, relabel_nodes=True, num_nodes=len(x))
        if with_edge_weights:
            edges, new_edge_attrs = torch_geometric.utils.subgraph(nodes, edge_index, edge_attrs, relabel_nodes=True)
            graph = Data(x=x[nodes], edge_index=edges, edge_attr=new_edge_attrs, y=torch.tensor(g['y']))
        else:
            graph = Data(x=x[nodes], edge_index=edges, y=torch.tensor(g['y']))
    else:
        graph = Data(x=x, edge_index=edge_index, y=torch.tensor(g['y']))

    graph.peptide = g['peptide']
    graph.gpcr = g['gpcr']
    graph.gpcr_family = g['gpcr_family']
    gpcrw = gpcrweight.loc[g['gpcr']].iloc[0]
    all_graphs.append({'graph': graph, 'peptide':g['peptide'], 'gpcr':g['gpcr'], 'gpcr_family': g['gpcr_family'], 'y': g['y'], 'gpcrweight': gpcrw})






In [None]:
def train(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

def train_weighted(model, criterion, optimizer, train_loader):
    model.train()
    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = criterion(logits, data.y)
        loss = (loss*data.gpcrweight).mean()
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_loader.dataset)

@torch.no_grad()
def test_roc(model, criterion, loader):
     model.eval()
     aucs = 0
     total = len(loader.dataset)
     correct = 0
     for data in loader:
         out = model(data.x, data.edge_index, data.edge_attr, data.batch)
         aucs += roc_auc_score(data.y.detach().cpu(), torch.softmax(out.detach(),dim=1).cpu()[:, 1])*(len(out)/total)
     return aucs

@torch.no_grad()
def test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        all_logits.append(logits.cpu().detach()[:,1])
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def nope_test_without_crash(model, criterion, loader):
    model.eval()
    all_logits = []
    atrues = []
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        all_logits.append(torch.sigmoid(logits.squeeze()).cpu().detach())
        atrues.append(data.y.cpu())
    return roc_auc_score(np.concatenate(atrues), np.concatenate(all_logits))

@torch.no_grad()
def test(model, criterion, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(test_loader.dataset)

In [None]:
def move_to_cuda(g):
    g.x = g.x.cuda()
    g.edge_index = g.edge_index.cuda()
    g.edge_attr = g.edge_attr.cuda().type(torch.float32)
    g.y = g.y.cuda()
    return g

In [None]:
class DeorphaNN(torch.nn.Module):
    def __init__(self, hidden_channels, input_channels=128, gatheads=10, gatdropout=0.5, finaldropout=0.5):
        super(DeorphaNN, self).__init__()
        self.finaldropout = finaldropout
        torch.manual_seed(111)
        self.norm = BatchNorm(input_channels)
        self.conv1 = GATv2Conv(input_channels, hidden_channels, dropout=gatdropout, heads=gatheads, concat=False, edge_dim=128)
        self.pooling = global_mean_pool
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, edge_attr, batch, hidden=False):
        x = self.norm(x)
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()

        if hidden:
            return x
        x = self.pooling(x, batch)
        x = F.dropout(x, p=self.finaldropout, training=self.training)
        x = self.lin(x)
        return x

In [None]:
gpcr_hits.groupby(['gpcr']).agg({'y': 'sum'}).sort_values(by='y')

Unnamed: 0_level_0,y
gpcr,Unnamed: 1_level_1
DMSR-11-1,1
NPR-33-1,1
NPR-34-1,1
SEB-2-1,1
H23L24-4-1,2


In [None]:
for g in all_graphs:
    g['graph'].gpcrweight = torch.tensor(g['gpcrweight']).cuda()

In [None]:
model_state_dicts = glob.glob('/content/drive/MyDrive/gpcrpeptidedesign/pretrainedmodels/*')

In [None]:
weights = torch.load(model_state_dicts[0], weights_only=True)

In [None]:
models = []
for state in model_state_dicts:
    weight = torch.load(state, weights_only=True)
    units = weight['lin.weight'].shape[1]
    print(units)
    model = DeorphaNN(units)
    model.load_state_dict(torch.load(state, weights_only=True))
    models.append(model.cuda().eval())

53
73
72
66
50
84
51
71
66
73
64
64
64
64
64
64
64
64
64
64
64
64
53
73
72
66
50
71
73
52
66
78


In [None]:
to_move = []
for graph in all_graphs:
    to_move.append(move_to_cuda(graph['graph']))
test_loader = DataLoader(to_move, batch_size=256, shuffle=False)

In [None]:
for model in models[1:]:
    print('Ensemble:')
    test_logits = []
    test_labels = []

    candidatesgnn = []
    hitmapgnn = []
    average_precisions = dict()

    all_logits = []
    atrues = []
    with torch.no_grad():
        for data in test_loader:
            logits = model(data.x, data.edge_index, data.edge_attr, data.batch)
            all_logits.append(logits.cpu().detach()[:,1])
            atrues.append(data.y.cpu())

    test_logits.append(np.concatenate(all_logits))
    test_labels.append(np.concatenate(atrues))

    val_gpcr = [g.gpcr for g in to_move]
    val_peptide = [g.peptide for g in to_move]
    result = pd.DataFrame(zip(test_labels[-1], test_logits[-1], val_gpcr, val_peptide))
    for gpcr, r in result.groupby(2):
        hitmapgnn.append((gpcr, r.sort_values(by=1).iloc[-17:][0].sum()))
        candidatesgnn.append((gpcr,r.sort_values(by=1)))
        if r[0].sum() > 0:
            average_precisions[gpcr] = sklearn.metrics.average_precision_score(r[0], r[1])
            print(gpcr, average_precisions[gpcr])

GPCR:
DMSR-11-1 0.005208333333333333
H23L24-4-1 0.013081617086193745
NPR-33-1 0.5
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.004629629629629629
H23L24-4-1 0.1736111111111111
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.009523809523809525
H23L24-4-1 0.29166666666666663
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.005235602094240838
H23L24-4-1 0.29166666666666663
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.011235955056179775
H23L24-4-1 0.04949944382647386
NPR-33-1 0.5
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.0053475935828877
H23L24-4-1 0.1736111111111111
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.005494505494505495
H23L24-4-1 0.0900735294117647
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.008264462809917356
H23L24-4-1 0.25
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.005917159763313609
H23L24-4-1 0.24285714285714285
NPR-33-1 1.0
NPR-34-1 1.0
SEB-2-1 1.0
GPCR:
DMSR-11-1 0.030303030303030304
H23L24-4-1 0.5
NPR-33-1 0.5
