In [None]:
#@title Upload the GPCR .npy file
from google.colab import files
import numpy as np
import torch
import pandas as pd
uploaded_gpcr = files.upload()
gpcr_file = list(uploaded_gpcr.keys())[0]
gpcr_emb = np.load(gpcr_file)

In [None]:
#@title Upload the peptide .npy file
from google.colab import files
import numpy as np
import torch
import pandas as pd
uploaded_pep = files.upload()
pep_file = list(uploaded_pep.keys())[0]
pep_emb = np.load(pep_file)

In [None]:
#@title Upload the interaction .npy file
from google.colab import files
import numpy as np
import torch
import pandas as pd
uploaded_edges = files.upload()
edge_file = list(uploaded_edges.keys())[0]
edge_features = np.load(edge_file)

In [None]:
#@title Upload the Arpeggio contact file (.parquet)
from google.colab import files
import numpy as np
import torch
import pandas as pd
uploaded_parquet = files.upload()
parquet_file = list(uploaded_parquet.keys())[0]
bonds = pd.read_parquet(parquet_file)
print(f"Loaded {parquet_file} successfully.")


In [None]:
#@title Install dependencies
# !pip uninstall torch -y > /dev/null
# !pip install torch==2.4.0 > /dev/null
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html > /dev/null
!pip install torch-geometric > /dev/null
!pip install optuna > /dev/null
!pip install numpy-indexed > /dev/null
import glob
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy

import sklearn

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.nn import aggr
from torch_geometric.nn.norm import GraphNorm, LayerNorm, BatchNorm
import torch_geometric

import seaborn
import optuna
import h5py
import numpy_indexed as npi
import random

from collections import defaultdict
import os
import subprocess as sp

In [None]:
#@title Import pretrained models
folder = "/content/pretrained_models"
os.makedirs(folder, exist_ok=True)

files = [
    "pretrained_1.pth",
    "pretrained_2.pth",
    "pretrained_3.pth",
    "pretrained_4.pth",
    "pretrained_5.pth",
    "pretrained_6.pth",
    "pretrained_7.pth",
    "pretrained_8.pth",
    "pretrained_9.pth",
    "pretrained_10.pth"
]

base_url = "https://huggingface.co/datasets/lariferg/DeorphaNN/resolve/main/pretrained_models/"

for fname in files:
    path = os.path.join(folder, fname)
    if not os.path.exists(path):
        sp.call(["wget", "-q", "-c", base_url + fname, "-P", folder])



class DeorphaNN(torch.nn.Module):
    def __init__(self, hidden_channels, input_channels=128, gatheads=10, gatdropout=0.5, finaldropout=0.5):
        super(DeorphaNN, self).__init__()
        self.finaldropout = finaldropout
        torch.manual_seed(111)
        self.norm = BatchNorm(input_channels)
        self.conv1 = GATv2Conv(input_channels, hidden_channels, dropout=gatdropout, heads=gatheads, concat=False, edge_dim=128)
        self.pooling = global_mean_pool
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, edge_attr, batch, hidden=False):
        x = self.norm(x)
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()

        if hidden:
            return x
        x = self.pooling(x, batch)
        x = F.dropout(x, p=self.finaldropout, training=self.training)
        x = self.lin(x)
        return x


model_state_dicts = glob.glob('/content/pretrained_models/*')
weights = torch.load(model_state_dicts[0], weights_only=True)
models = []
for state in model_state_dicts:
    weight = torch.load(state, weights_only=True)
    units = weight['lin.weight'].shape[1]
    # print(units)
    model = DeorphaNN(units)
    model.load_state_dict(torch.load(state, weights_only=True))
    models.append(model.cuda().eval())

In [None]:
#@title Run DeorphaNN
x = torch.from_numpy(np.concatenate([gpcr_emb, pep_emb])).float()


xg = gpcr_emb
xp = pep_emb
x = np.concatenate([xg, xp])   # full node features
x = torch.from_numpy(x).float()



gpcr_len = xg.shape[0]  # Needed for indexing
bonds['source'] = bonds['bgn'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1
bonds['target'] = bonds['end'].apply(lambda x: x['auth_seq_id'] if x['auth_asym_id'] == "A" else x['auth_seq_id'] + gpcr_len) - 1

bonds = bonds.groupby(['source', 'target'])['contact'].agg(lambda x: {bondtype for array in x for bondtype in array}).reset_index()

sources = bonds['source'].values
targets = bonds['target'].values
h_edge_index = np.vstack([sources, targets])

edgeindices = np.arange(gpcr_len)
sources = npi.remap(h_edge_index[0], edgeindices, np.arange(len(edgeindices)))
targets = npi.remap(h_edge_index[1], edgeindices, np.arange(len(edgeindices)))

sourcewherever = np.where(sources >= gpcr_len)[0]
targetwherever = np.where(targets < gpcr_len)[0]

newsources = np.array(sources)
newtargets = np.array(targets)
newsources[sourcewherever] = targets[sourcewherever]
newtargets[targetwherever] = sources[targetwherever]

newtargets -= gpcr_len
edge_attrs = edge_features[newtargets, newsources, :]

pep_edge_index = np.vstack([
    np.array(range(gpcr_len, gpcr_len + xp.shape[0] - 1)),
    np.array(range(gpcr_len + 1, gpcr_len + xp.shape[0]))
])

pep_edge_attrs = np.ones(shape=(pep_edge_index.shape[1], 128)) * edge_attrs.mean(axis=0)

edge_index = np.concatenate([h_edge_index, pep_edge_index], axis=1)
edge_index = torch.from_numpy(edge_index).long()

edge_attr = torch.from_numpy(np.concatenate([edge_attrs, pep_edge_attrs], axis=0)).float()

graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def move_to_cuda(g):
    g.x = g.x.cuda()
    g.edge_index = g.edge_index.cuda()
    g.edge_attr = g.edge_attr.cuda().type(torch.float32)
    return g

graph = move_to_cuda(graph)

graph.batch = torch.zeros(graph.x.size(0), dtype=torch.long).to(graph.x.device)

with torch.no_grad():
    logits = []
    for i, model in enumerate(models):
        out = model(graph.x, graph.edge_index, graph.edge_attr, graph.batch)
        logits.append(torch.softmax(out, dim=1)[:, 1].item())
        prob = torch.softmax(out, dim=1)[:, 1].item()
        #print(f"Model {i+1} predicted probability: {prob:.4f}")

    mean_prob = sum(logits) / len(logits)
print(f"Predicted interaction probability: {mean_prob:.4f}")