In [1]:
from typing import Dict, List, Tuple, Union, Optional, Callable

import sys
import wandb
import numpy as np
import math
import pandas as pd
import os
from copy import deepcopy
from argparse import Namespace
from pathlib import Path
import logging
import torch
from scipy import stats
from torch.optim import Adam
from torch.utils.data import DataLoader as DL_torch
from torch_geometric.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from functools import partial
# import pdb

from abag_affinity.dataset import AffinityDataset
from abag_affinity.dataset.advanced_data_utils import complexes_from_dms_datasets, get_bucket_dataloader, load_datasets
from abag_affinity.model import AffinityGNN, TwinWrapper
from abag_affinity.train.wandb_config import configure
from abag_affinity.utils.config import get_data_paths, read_config
from abag_affinity.utils.visualize import plot_correlation
from abag_affinity.utils.argparse_utils import read_args_from_file, parse_args
from abag_affinity.train.utils import load_model
from abag_affinity.model import regression_heads
%load_ext autoreload
%autoreload 2



In [2]:
args_file = "base_args.txt"

sys.argv = sys.argv[:1]
args = parse_args(args_file=args_file)

config = read_config(args.config_file)

args.batch_size = 1
args.cuda = False
use_cuda = False
device = 'cpu'

train_data, val_datas = load_datasets(config, args.target_dataset, args.validation_set, args)

# wandb_benchmark_log = run_and_log_benchmarks(model, args, wandb_inst)

In [3]:
# model = load_model(train_data.num_features, train_datasets[0].num_edge_features, dataset_names, args, device)
# def load_model(num_node_features: int, num_edge_features: int, dataset_names: List[str], args: Namespace,
#                device: torch.device = torch.device("cpu")) -> AffinityGNN:

model = AffinityGNN.load_from_checkpoint('/home/mihail/Documents/workspace/ag_binding_affinity/results/models/2023-12-20_19-00-38_fix_labels_abag_test/model.pt',
                                         map_location='cpu')

model.to('cpu')
train_dataloader, val_dataloaders = get_bucket_dataloader(args, [train_data], val_datas)

In [4]:
preds_per_residue = {}
for data in train_dataloader:
    data_copy = deepcopy(data)
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    output = model(data_copy['input'])
    #
    print(output)
    res_types = np.where(data['input']['graph']['node'].x[:, :20] == 1)[1]
    # with np.printoptions(threshold=np.inf):
    #     print('data graph', data['input']['graph']['node'].x.numpy())

    out2 = model.graph_conv(data['input']['graph'])

    x = out2["node"].x

    batch = regression_heads.get_node_batches(out2).to(x.device)

    if model.regression_head.aggregation_method in ["interface_sum", "interface_mean", "interface_size"]:
        # get interface edges
        interface_node_indices = out2["node", "interface", "node"].edge_index.view(-1).unique()
        batch = batch[interface_node_indices]
        x = x[interface_node_indices]
        res = res_types[interface_node_indices]
    # compute node-wise affinity contribution from graph embedding
    for fc_layer in model.regression_head.fc_layers[:-1]:
        x = fc_layer(x)
        x = model.regression_head.activation(x)
    x = model.regression_head.fc_layers[-1](x)
    print(x.shape)
    for i in range(res.shape[0]):
        if res[i] in preds_per_residue.keys():
            preds_per_residue[res[i].item()].append(x[i].item())
        else:
            preds_per_residue[res[i].item()] = [x[i].item()]

print(preds_per_residue)

{'-log(Kd)': tensor([[0.2581]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.2581]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([25, 1])
{'-log(Kd)': tensor([[0.4305]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4305]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([35, 1])
{'-log(Kd)': tensor([[0.5579]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.5579]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([45, 1])
{'-log(Kd)': tensor([[0.5485]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.5485]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([45, 1])
{'-log(Kd)': tensor([[0.4269]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4269]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([38, 1])
{'-log(Kd)': tensor([[0.4866]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4866]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([64, 1])
{'-log(Kd)': tensor([[0.4451]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4451]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([36, 1])
{'-log(Kd)': tensor([[0.396

torch.Size([70, 1])
{'-log(Kd)': tensor([[0.4265]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4265]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([52, 1])
{'-log(Kd)': tensor([[0.3005]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3005]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([30, 1])
{'-log(Kd)': tensor([[0.4124]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4124]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([31, 1])
{'-log(Kd)': tensor([[0.3578]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3578]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([46, 1])
{'-log(Kd)': tensor([[0.2928]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.2928]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([47, 1])
{'-log(Kd)': tensor([[0.3837]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3837]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([37, 1])
{'-log(Kd)': tensor([[0.1837]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.1837]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([24, 1])
{'-log(

torch.Size([62, 1])
{'-log(Kd)': tensor([[0.3481]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3481]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([35, 1])
{'-log(Kd)': tensor([[0.2884]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.2884]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([41, 1])
{'-log(Kd)': tensor([[0.4938]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4938]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([45, 1])
{'-log(Kd)': tensor([[0.1707]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.1707]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([33, 1])
{'-log(Kd)': tensor([[0.4258]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4258]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([46, 1])
{'-log(Kd)': tensor([[0.1366]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.1366]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([10, 1])
{'-log(Kd)': tensor([[0.3221]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3221]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([31, 1])
{'-log(

torch.Size([50, 1])
{'-log(Kd)': tensor([[0.6031]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.6031]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([51, 1])
{'-log(Kd)': tensor([[0.3961]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3961]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([52, 1])
{'-log(Kd)': tensor([[0.3429]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.3429]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([46, 1])
{'-log(Kd)': tensor([[0.4243]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4243]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([35, 1])
{'-log(Kd)': tensor([[0.4835]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4835]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([42, 1])
{'-log(Kd)': tensor([[0.2153]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.2153]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([24, 1])
{'-log(Kd)': tensor([[0.4217]], grad_fn=<ScatterAddBackward0>), 'E': tensor([[0.4217]], grad_fn=<UnsqueezeBackward0>)}
torch.Size([38, 1])
{'-log(

In [5]:
means_per_residue = [np.mean(preds_per_residue[i]) for i in preds_per_residue.keys()]
stds_per_residue = [np.std(preds_per_residue[i]) for i in preds_per_residue.keys()]
print('average_residue_scores', means_per_residue, '\n')
print('standard dev per residue', stds_per_residue, '\n')
print('standard dev across residues', np.std(means_per_residue))

average_residue_scores [0.010711523841639033, 0.003934089707710394, 0.011648642151160146, 0.012113928873564156, 0.011569175101983045, 0.00898664374913659, 0.011942671090935383, 0.01165766550982294, 0.0005000357357960827, 0.011726194668723188, 0.012600864674965245, 0.006464829072171346, 0.012431729584932327, 0.012211878732411576, 0.00770779219886805, 0.008358863152955708, 0.01030030148842978, 0.010337830510702027, 0.0124819566861042, 0.010311691487421755] 

standard dev per residue [0.00947383488956815, 0.015577367353685817, 0.007164931421099245, 0.004495882606226215, 0.007870973435335701, 0.010032136095196162, 0.006465767111056731, 0.007942211970434057, 0.016641159587417926, 0.0025803447812764225, 0.0043868171659347316, 0.013104958287461003, 0.005167455236347555, 0.005128948883858969, 0.011722457147185424, 0.010763864148001848, 0.009380145136344319, 0.009251854293922715, 0.00506111723675758, 0.009319889997197423] 

standard dev across residues 0.0030984251847454595


In [6]:
preds_per_residue = {}
for data in val_dataloaders[0]:
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    # output = model(data['input'])
    #
    # print(output)
    res_types = np.where(data['input']['graph']['node'].x[:, :20] == 1)[1]
    # with np.printoptions(threshold=np.inf):
    #     print('data graph', data['input']['graph']['node'].x.numpy())

    out2 = model.graph_conv(data['input']['graph'])

    x = out2["node"].x

    batch = regression_heads.get_node_batches(out2).to(x.device)

    if model.regression_head.aggregation_method in ["interface_sum", "interface_mean", "interface_size"]:
        # get interface edges
        interface_node_indices = out2["node", "interface", "node"].edge_index.view(-1).unique()
        batch = batch[interface_node_indices]
        x = x[interface_node_indices]
        res = res_types[interface_node_indices]
    # compute node-wise affinity contribution from graph embedding
    for fc_layer in model.regression_head.fc_layers[:-1]:
        x = fc_layer(x)
        x = model.regression_head.activation(x)
    x = model.regression_head.fc_layers[-1](x)
    print(x.shape)
    for i in range(res.shape[0]):
        if res[i] in preds_per_residue.keys():
            preds_per_residue[res[i].item()].append(x[i].item())
        else:
            preds_per_residue[res[i].item()] = [x[i].item()]

print(preds_per_residue)

torch.Size([51, 1])
torch.Size([48, 1])
torch.Size([26, 1])
torch.Size([36, 1])
torch.Size([35, 1])
torch.Size([41, 1])
torch.Size([43, 1])
torch.Size([29, 1])
torch.Size([30, 1])
torch.Size([43, 1])
torch.Size([29, 1])
torch.Size([35, 1])
torch.Size([24, 1])
torch.Size([47, 1])
torch.Size([52, 1])
torch.Size([29, 1])
torch.Size([27, 1])
torch.Size([48, 1])
torch.Size([53, 1])
torch.Size([52, 1])
torch.Size([67, 1])
torch.Size([40, 1])
torch.Size([47, 1])
torch.Size([54, 1])
torch.Size([54, 1])
torch.Size([50, 1])
torch.Size([45, 1])
torch.Size([26, 1])
torch.Size([38, 1])
torch.Size([35, 1])
torch.Size([45, 1])
torch.Size([38, 1])
torch.Size([52, 1])
torch.Size([55, 1])
torch.Size([32, 1])
torch.Size([42, 1])
torch.Size([28, 1])
torch.Size([49, 1])
torch.Size([31, 1])
torch.Size([32, 1])
torch.Size([47, 1])
torch.Size([53, 1])
torch.Size([41, 1])
torch.Size([31, 1])
torch.Size([32, 1])
torch.Size([39, 1])
torch.Size([44, 1])
torch.Size([33, 1])
torch.Size([45, 1])
torch.Size([55, 1])


In [7]:
means_per_residue = [np.mean(preds_per_residue[i]) for i in preds_per_residue.keys()]
stds_per_residue = [np.std(preds_per_residue[i]) for i in preds_per_residue.keys()]
print('average_residue_scores', means_per_residue, '\n')
print('standard dev per residue', stds_per_residue, '\n')
print('standard dev across residues', np.std(means_per_residue))

average_residue_scores [0.010162440771087605, 0.011797318669190085, 0.012215439545906197, 0.009930931344753554, 0.012735733158991371, 0.011337465995194897, 0.01178576446348621, 0.010341293662786483, 0.009073433776696524, 0.008368170076562452, 0.010813374621946304, 0.012727018761432778, 0.011550032609217876, 0.01255566442714018, 0.007711605075746775, -0.001368281257859731, 0.011744956175486246, 0.0041930282355419225, 0.012175623262137697, 0.009403808943686946] 

standard dev per residue [0.009091622064256583, 0.0012877779003953615, 0.003940162058644751, 0.010347332186076145, 0.003944342931108092, 0.007946693107695012, 0.006828222378863008, 0.007711415989955541, 0.010240237132764902, 0.011591385570799105, 0.008875201023675719, 0.004604837562964154, 0.007553212776470551, 0.004656241180987467, 0.011372757286480278, 0.017210664416088175, 0.007163409617147534, 0.01592018012372943, 0.005137931023401835, 0.011623923670915928] 

standard dev across residues 0.003298508570948213


In [9]:
dataset = AffinityDataset(args.config, args.relaxed_pdbs, "AntibodyBenchmark", "L2",
                          node_type=args.node_type,
                          max_nodes=args.max_num_nodes,
                          interface_distance_cutoff=args.interface_distance_cutoff,
                          interface_hull_size=args.interface_hull_size,
                          max_edge_distance=args.max_edge_distance,
                          pretrained_model=args.pretrained_model,
                          scale_values=args.scale_values,
                          scale_min=args.scale_min,
                          scale_max=args.scale_max,
                          relative_data=False,
                          save_graphs=args.save_graphs,
                          force_recomputation=args.force_recomputation,
                          preprocess_data=args.preprocess_graph,
                          preprocessed_to_scratch=args.preprocessed_to_scratch,
                          num_threads=args.num_workers,
                          load_embeddings=None if not args.embeddings_type else (args.embeddings_type, args.embeddings_path)
                          )

dataloader = DL_torch(dataset, num_workers=args.num_workers, batch_size=1,
                      collate_fn=AffinityDataset.collate)

In [10]:
preds_per_residue = {}
for data in dataloader:
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    # output = model(data['input'])
    #
    # print(output)
    res_types = np.where(data['input']['graph']['node'].x[:, :20] == 1)[1]
    # with np.printoptions(threshold=np.inf):
    #     print('data graph', data['input']['graph']['node'].x.numpy())

    out2 = model.graph_conv(data['input']['graph'])

    x = out2["node"].x

    batch = regression_heads.get_node_batches(out2).to(x.device)

    if model.regression_head.aggregation_method in ["interface_sum", "interface_mean", "interface_size"]:
        # get interface edges
        interface_node_indices = out2["node", "interface", "node"].edge_index.view(-1).unique()
        batch = batch[interface_node_indices]
        x = x[interface_node_indices]
        res = res_types[interface_node_indices]
    # compute node-wise affinity contribution from graph embedding
    for fc_layer in model.regression_head.fc_layers[:-1]:
        x = fc_layer(x)
        x = model.regression_head.activation(x)
    x = model.regression_head.fc_layers[-1](x)
    print(x.shape)
    for i in range(res.shape[0]):
        if res[i] in preds_per_residue.keys():
            preds_per_residue[res[i].item()].append(x[i].item())
        else:
            preds_per_residue[res[i].item()] = [x[i].item()]

print(preds_per_residue)

torch.Size([49, 1])
torch.Size([46, 1])
torch.Size([58, 1])
torch.Size([42, 1])
torch.Size([34, 1])
torch.Size([46, 1])
torch.Size([45, 1])
torch.Size([24, 1])
torch.Size([66, 1])
torch.Size([47, 1])
torch.Size([36, 1])
torch.Size([71, 1])
torch.Size([51, 1])
torch.Size([41, 1])
torch.Size([32, 1])
torch.Size([58, 1])
torch.Size([39, 1])
torch.Size([27, 1])
torch.Size([51, 1])
torch.Size([45, 1])
torch.Size([46, 1])
torch.Size([47, 1])
torch.Size([75, 1])
torch.Size([45, 1])
torch.Size([49, 1])
torch.Size([52, 1])
torch.Size([48, 1])
torch.Size([52, 1])
torch.Size([40, 1])
{15: [-0.006137251853942871, 0.007200092077255249, 0.013697996735572815, 0.013819903135299683, 0.013775259256362915, -0.000992894172668457, -0.01751810312271118, 0.0133228600025177, 0.013379216194152832, 0.013655215501785278, 0.013762235641479492, 0.013773813843727112, -0.0008594393730163574, 0.012547403573989868, -0.015477687120437622, 0.01374274492263794, 0.013918191194534302, -0.00789344310760498, 0.01389622688293

In [11]:
means_per_residue = [np.mean(preds_per_residue[i]) for i in preds_per_residue.keys()]
stds_per_residue = [np.std(preds_per_residue[i]) for i in preds_per_residue.keys()]
print('average_residue_scores', means_per_residue, '\n')
print('standard dev per residue', stds_per_residue, '\n')
print('standard dev across residues', np.std(means_per_residue))

average_residue_scores [0.007108517648542628, 0.012492223113190894, 0.0058677345514297485, -0.0034413284693772975, 0.011869000938703429, 0.011703838977743597, 0.003763270080089569, 0.0030266740654088273, 0.01169963280359904, 0.009914862031632282, 0.012297654853147618, 0.012142408325011471, 0.01169093462025247, 0.013218113627189245, 0.013237503943619904, 0.008379712700843811, 0.008121915681417598, 0.013153266480990819, 0.012504073845989564, -0.0027164727449417113] 

standard dev per residue [0.011659005929381908, 0.0024442587534300193, 0.013240491628527102, 0.018275410154498004, 0.007578405392824837, 0.007967901263765, 0.015405430493539556, 0.016313058812029612, 0.007134159909880966, 0.008009626404017341, 0.0037873652203230133, 0.007079280710352627, 0.0013590548574433875, 0.00048059237191799354, 0.0003669234114973012, 0.01141306300036984, 0.01256134924099337, 0.000561904319548317, 0.005228300247484347, 0.018726993750871522] 

standard dev across residues 0.0049902619100495845


In [12]:
train_y = []
train_pred = []
val_y = []
bench_y = []
val_pred = []
bench_pred = []

preds_per_residue = {}
for data in train_dataloader:
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    # output = model(data['input'])
    #
    # print(output)
    train_y.append(data['input']['graph']['-log(Kd)'].item())
    out = model(data['input'])
    train_pred.append(out['-log(Kd)'].item())
train_pred = np.array(train_pred)
train_y = np.array(train_y)
# print(train_pred)
# print(train_y)

for data in val_dataloaders[0]:
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    # output = model(data['input'])
    #
    # print(output)
    val_y.append(data['input']['graph']['-log(Kd)'].item())
    out = model(data['input'])
    val_pred.append(out['-log(Kd)'].item())
val_pred = np.array(val_pred)
val_y = np.array(val_y)

for data in dataloader:
    # print('data', data)
    # print('data input size', data['input']['graph']['node'].x.shape)
    # output = model(data['input'])
    #
    # print(output)
    bench_y.append(data['input']['graph']['-log(Kd)'].item())
    out = model(data['input'])
    bench_pred.append(out['-log(Kd)'].item())

bench_pred = np.array(bench_pred)
bench_y = np.array(bench_y)

train_mean = np.mean(train_y)

train_train_rmse = np.sqrt(np.mean((train_y - train_mean) ** 2))
train_val_rmse = np.sqrt(np.mean((val_y - train_mean) ** 2))
train_bench_rmse = np.sqrt(np.mean((bench_y - train_mean) ** 2))

train_rmse = np.sqrt(np.mean((train_y - train_pred) ** 2))
val_rmse = np.sqrt(np.mean((val_y - val_pred) ** 2))
bench_rmse = np.sqrt(np.mean((bench_y - bench_pred) ** 2))

print(train_rmse, val_rmse, bench_rmse)
print(train_train_rmse, train_val_rmse, train_bench_rmse)

0.1155047127776308 0.16149641672163967 0.1660761818237867
0.16213827280404855 0.16570292877263126 0.1620403425991364


In [13]:
np.corrcoef([train_rmse, val_rmse, bench_rmse], [train_train_rmse, train_val_rmse, train_bench_rmse])

array([[1.        , 0.40608924],
       [0.40608924, 1.        ]])