In [1]:
import pickle
import re
import numpy as np
import sys
import os
from glob import glob
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch.nn import TripletMarginLoss
import importlib
import yaml
import matplotlib.pyplot as plt
from torch_geometric.nn import radius_graph

import torch
import torch.nn.functional as F

# holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
#                      "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
#                      "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
#                      "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
#                      "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
#                      "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
#                      "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
#                      "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
#                      "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

In [49]:
root_path      = '/xdisk/twheeler/jgaiser/deepvs3/deepvs/'
params_path    = root_path + 'params.yaml'
config_path    = root_path + 'config.yaml'
function_path  = root_path + 'code/utils/data_processing_utils.py'
validation_ids = pickle.load(open('/xdisk/twheeler/jgaiser/deepvs3/training_data/validation_ids.pkl', 'rb'))

def load_class_from_file(file_path):
    class_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(class_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return getattr(module, class_name)


def load_function_from_file(file_path):
    function_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(
        os.path.basename(file_path), file_path
    )
    module = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = module
    spec.loader.exec_module(module)
    return getattr(module, function_name) 


with open(params_path, "r") as param_file:
    params = yaml.safe_load(param_file)
    
with open(config_path, "r") as config_file:
    config = yaml.safe_load(config_file)

ATOM_LABELS = config['POCKET_ATOM_LABELS']
INTERACTION_LABELS = config['INTERACTION_LABELS']

training_sample_ft = params['data_dir'] + config['training_sample_file_template']

pocket_class_freqs = torch.tensor(config['POCKET_LABEL_COUNT'])

pocket_class_weights = 1./pocket_class_freqs
pocket_class_weights = pocket_class_weights * pocket_class_freqs.sum() / len(pocket_class_freqs)

In [41]:
training_sample_files = glob(training_sample_ft.replace('%s', '*'))
vox_training_data = []
vox_val_data = []

for graph_collection_file in training_sample_files:
    graph_col = pickle.load(open(graph_collection_file, 'rb'))
    
    for g_idx in range(len(graph_col)):
        graph_col[g_idx].y = graph_col[g_idx].y.unsqueeze(0)
    
    if graph_col[0].pdb_id in sum(validation_ids, []):
        vox_val_data.extend(graph_col)
    else:
        vox_training_data.extend(graph_col)
    
# random.shuffle(vox_training_data)

In [4]:
def batch_logit_accuracy(logits_batch, labels_batch):
    batch_size = logits_batch.size(0)
    accuracies = torch.zeros(batch_size)

    i=0
    for logits,labels in zip(logits_batch, labels_batch):

        num_ones = torch.sum(labels).item()
        topk_values, topk_indices = torch.topk(logits, int(num_ones))

        label_indices = (labels == 1).nonzero(as_tuple=True)[0]
    
        correct = torch.eq(topk_indices.sort()[0], label_indices.sort()[0]).sum().item()

        accuracies[i] = correct / num_ones
        i+=1

    return torch.mean(accuracies).item()
    

In [54]:
BATCH_SIZE = 32 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = DataLoader(vox_training_data, shuffle=True, batch_size=BATCH_SIZE) 
validation_loader = DataLoader(vox_val_data, shuffle=False, batch_size=BATCH_SIZE)

In [45]:
torch.cuda.empty_cache()

In [10]:
weight_delta = -1*pocket_class_weights + torch.mean(pocket_class_weights)
delta_mod = 0.25

criterion = nn.BCEWithLogitsLoss(pos_weight=(pocket_class_weights+weight_delta*delta_mod)).to(device)
val_criterion = nn.BCEWithLogitsLoss().to(device)

VoxEncoder = load_class_from_file(config['vox_embedder_model'] % root_path)
voxel_model = VoxEncoder(**config['vox_embedder_hyperparams']).to(device)

optimizer = torch.optim.Adam(voxel_model.parameters(), lr=1e-4)
sigmoid = nn.Sigmoid()

In [52]:
vox_embedder_weights = config['vox_embedder_weights'] % root_path[:-1]
# vox_embedder_weights = root_path[:-1]+config['vox_embedder_weights']+'_adamW'
vox_embedder_weights

'/xdisk/twheeler/jgaiser/deepvs3/deepvs/models/weights/vox_embedder_7-17.m'

In [None]:
# weight_delta = -1*pocket_class_weights + torch.mean(pocket_class_weights)
# delta_mod = 0 

# # criterion = nn.BCEWithLogitsLoss(pos_weight=(pocket_class_weights+weight_delta*delta_mod)).to(device)
# criterion = nn.BCEWithLogitsLoss(pos_weight=(pocket_class_weights)).to(device)
# val_criterion = nn.BCEWithLogitsLoss().to(device)

# VoxEncoder = load_class_from_file(config['vox_embedder_model'] % root_path)
# voxel_model = VoxEncoder(**config['vox_embedder_hyperparams']).to(device)

# optimizer = torch.optim.Adam(voxel_model.parameters(), lr=1e-4)
# sigmoid = nn.Sigmoid()

validation_loss_history = []
validation_accuracy_history = []

training_loss_history = []
training_accuracy_history = []

min_val_loss = 999
max_val_accuracy = -999

for epoch in range(100):
    print("EPOCH %s" % epoch)
    loss_history = None
    
#     for batch_index, batch2 in enumerate(train_loader):
    for batch_index, batch in enumerate(train_loader):
        batch = batch.to(device)
        batch.pos += (torch.randn(batch.pos.shape)*0.25).to(device)
       
        _, interaction_preds = voxel_model(batch)
        interaction_preds = interaction_preds[batch.x[:,-1]==1]
        
        optimizer.zero_grad()
        loss = criterion(interaction_preds, batch.y.float())
        loss.backward()
        optimizer.step()
        
        if loss_history is None:
            loss_history = loss
        else:
            loss_history = torch.vstack((loss_history, loss))
            
        if batch_index % 1000 == 0:
            training_loss_history.append(torch.mean(loss_history).item())
            print("Loss: %s" % torch.mean(loss_history).item())
            print("Accuracy: %s" % batch_logit_accuracy(interaction_preds, batch.y))

            for i in torch.randperm(len(batch.y))[:3]:
                print("%.2f "*len(INTERACTION_LABELS) % tuple(sigmoid(interaction_preds[i]).tolist()))
                print("%.2f "*len(INTERACTION_LABELS) % tuple(batch.y[i].tolist()))
                print("")

            loss_history = None 
    
    voxel_model.eval()
    
    with torch.no_grad():
        batch_val_losses = []
        batch_val_accuracy = []
        
        for validation_batch in validation_loader:
            validation_batch = validation_batch.to(device)
            _, validation_preds = voxel_model(validation_batch)
            validation_preds = validation_preds[validation_batch.x[:,-1]==1]

            batch_val_losses.append(val_criterion(validation_preds, validation_batch.y.float()).item())
            batch_val_accuracy.append(batch_logit_accuracy(validation_preds, validation_batch.y))
            
        validation_loss = sum(batch_val_losses) / len(batch_val_losses)
        validation_accuracy = sum(batch_val_accuracy) / len(batch_val_accuracy)
            
        if validation_loss < min_val_loss:
            min_val_loss = validation_loss
            torch.save(voxel_model.state_dict(), vox_embedder_weights)
            print('WEIGHTS UPDATED')

        print("VALIDATION LOSS:", validation_loss)
        print("VALIDATION ACC:", validation_accuracy)
        
    voxel_model.train()

EPOCH 0
Loss: 0.13817700743675232
Accuracy: 0.59375
0.05 0.02 0.97 0.01 0.00 0.00 0.00 0.00 0.00 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.10 0.07 0.03 0.30 0.11 0.02 0.38 0.05 0.07 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.02 0.00 0.00 0.85 0.00 0.01 0.04 0.01 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.2065134346485138
Accuracy: 0.625
0.01 0.52 0.07 0.00 0.00 0.00 0.00 0.27 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 1.00 0.00 

0.05 0.34 0.03 0.09 0.00 0.03 0.03 0.06 0.01 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.23 0.01 0.01 0.69 0.07 0.01 0.02 0.00 0.02 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.2073463797569275
Accuracy: 0.71875
0.01 0.03 0.02 0.14 0.19 0.51 0.59 0.29 0.03 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.23 0.01 0.62 0.08 0.00 0.00 0.00 0.00 0.00 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.00 0.00 0.00 0.93 0.00 0.08 0.03 0.00 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20068731904029846
A

Loss: 0.2025631219148636
Accuracy: 0.59375
0.04 0.11 0.07 0.07 0.39 0.00 0.04 0.01 0.26 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.12 0.21 0.22 0.07 0.00 0.01 0.01 0.02 0.01 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.08 0.58 0.06 0.01 0.00 0.01 0.00 0.16 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.19989173114299774
Accuracy: 0.59375
0.01 0.02 0.02 0.21 0.86 0.02 0.31 0.00 0.60 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.25 0.06 0.02 0.73 0.07 0.01 0.02 0.00 0.01 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.05 0.30 0.36 0.02 0.00 0.13 0.00 0.19 0.08 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20372061431407928
Accuracy: 0.65625
0.01 0.46 0.11 0.00 0.00 0.00 0.00 0.23 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.13 0.01 0.00 0.96 0.00 0.00 0.01 0.00 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.10 0.05 0.02 0.41 0.21 0.04 0.25 0.01 0.10 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20637142658233643
Accura

Loss: 0.19998304545879364
Accuracy: 0.59375
0.02 0.25 0.00 0.30 0.39 0.00 0.05 0.00 0.01 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.11 0.84 0.04 0.01 0.00 0.00 0.00 0.02 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.01 0.95 0.00 0.01 0.00 0.00 0.01 0.05 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.19484567642211914
Accuracy: 0.65625
0.04 0.18 0.73 0.01 0.00 0.01 0.00 0.04 0.03 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.05 0.16 0.45 0.02 0.52 0.00 0.01 0.01 0.24 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.36 0.10 0.12 0.18 0.00 0.00 0.00 0.00 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.19770392775535583
Accuracy: 0.78125
0.01 0.11 0.00 0.32 0.00 0.37 0.03 0.05 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.05 0.47 0.13 0.01 0.00 0.00 0.00 0.01 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.03 0.13 0.51 0.01 0.48 0.01 0.00 0.01 0.14 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20344004034996033
Accur

Loss: 0.2008899450302124
Accuracy: 0.59375
0.06 0.06 0.39 0.11 0.01 0.00 0.01 0.02 0.00 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.01 0.00 0.00 0.81 0.01 0.00 0.27 0.00 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.34 0.50 0.07 0.05 0.00 0.00 0.00 0.00 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20102132856845856
Accuracy: 0.71875
0.03 0.02 0.01 0.41 0.06 0.02 0.27 0.01 0.01 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.00 0.03 0.85 0.00 0.00 0.00 0.00 0.00 0.00 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.11 0.05 0.33 0.08 0.02 0.03 0.01 0.01 0.08 
0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.19812646508216858
Accuracy: 0.65625
0.03 0.02 0.00 0.40 0.02 0.73 0.07 0.07 0.09 
0.00 0.00 0.00 0.00 0.00 1.00 0.00 0.00 0.00 

0.24 0.24 0.61 0.02 0.00 0.01 0.00 0.04 0.01 
0.00 1.00 1.00 0.00 0.00 0.00 0.00 0.00 1.00 

0.04 0.87 0.02 0.02 0.00 0.01 0.02 0.01 0.00 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.20077234506607056
Accura

Loss: 0.19844287633895874
Accuracy: 0.6875
0.03 0.19 0.07 0.09 0.01 0.24 0.04 0.10 0.07 
0.00 1.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.05 0.33 0.24 0.04 0.00 0.00 0.00 0.02 0.04 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.01 0.57 0.33 0.00 0.00 0.00 0.00 0.04 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

WEIGHTS UPDATED
VALIDATION LOSS: 0.19260944316492362
VALIDATION ACC: 0.6433823529411765
EPOCH 11
Loss: 0.11959711462259293
Accuracy: 0.625
0.00 0.00 0.00 0.20 0.39 0.01 0.75 0.00 0.05 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.10 0.05 0.02 0.56 0.01 0.00 0.06 0.00 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.00 0.70 0.30 0.00 0.00 0.00 0.00 0.00 0.02 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

Loss: 0.19579488039016724
Accuracy: 0.71875
0.00 0.00 0.00 0.62 0.03 0.01 0.13 0.00 0.21 
0.00 0.00 0.00 1.00 0.00 0.00 0.00 0.00 0.00 

0.42 0.44 0.23 0.02 0.00 0.00 0.00 0.00 0.00 
0.00 1.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 

0.05 0.23 0.37 0.04 0.01 0.17 

In [80]:
batch.pos += (torch.randn(batch.pos.shape)*0.5).to(device)
batch.pos

tensor([[ 54.3097,   8.6678,  50.1004],
        [ 55.5176,  11.4416,  48.5374],
        [ 56.4517,  11.5105,  46.6047],
        ...,
        [-35.1678,  85.6733,  34.5262],
        [-34.8897,  83.2536,  29.7269],
        [-30.2410,  91.6662,  29.4116]], device='cuda:0')

In [None]:
# 5 
# VALIDATION LOSS: 0.35250553488731384
# VALIDATION ACC: 0.4176136255264282
# EPOCH 24

#10
# VALIDATION LOSS: 0.3194412887096405
# VALIDATION ACC: 0.521842360496521
# EPOCH 61

#10
# VALIDATION LOSS: ----
# VALIDATION ACC: 0.54
# EPOCH 61

# 10 batch size 32
# WEIGHTS UPDATED
# VALIDATION LOSS: 0.23447826504707336
# VALIDATION ACC: 0.5934844017028809
# EPOCH 88





In [None]:
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GCN2Conv

def handle_data(batch):
    beta = batch.beta/100 
    x = torch.hstack((batch.x, beta.unsqueeze(1)))
    edge_attr = batch.edge_attr.unsqueeze(1) / 12
    return x, batch.edge_index, edge_attr

class GCN(torch.nn.Module):
    def __init__(self, feature_dim, hidden_dim, out_dim, data_transform):
        super().__init__()
        self.data_transform = data_transform
        self.linear1 = torch.nn.Linear(feature_dim, hidden_dim)
        
        self.conv1 = GCN2Conv(hidden_dim, 0.25)
        self.conv2 = GCN2Conv(hidden_dim, 0.25)
        self.conv3 = GCN2Conv(hidden_dim, 0.25)

        self.dropout = torch.nn.Dropout(p=0.25)
        self.batchnorm = torch_geometric.nn.norm.BatchNorm(hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index, edge_weights = self.data_transform(data)

        x = self.linear1(x)

        h = self.conv1(x, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv2(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv3(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))
        
        o = self.linear2(h)
        return h,o

voxel_model = GCN(39, 512, 9, data_transform=handle_data)

for sample_file in sample_col:
    sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=32, shuffle=True)
    
    for batch in sample_collection:
        out = voxel_model(batch)
        print(out[0].shape)
        break
    break

In [None]:
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

class ME(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.atom_classifier = nn.Linear(kwargs['hidden_channels'], 512)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.atom_classifier(x), self.lin2(out)
    
mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5)

for batch in sample_collection:
    mol_batch = get_mol_batch(mol_dict, batch.pdb_id)
    print(torch.where(mol_batch.heavy == 1)[0][torch.randperm(torch.sum(mol_batch.heavy))][:10])
    break
    print(mol_batch.ptr)
    out = mol_model(mol_batch)
    print(out[0].shape)
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device='cpu'

mol_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/mol_embedder/mol_embedder_holdout_3-20.m"
vox_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/vox_embedder/vox_embedder_holdout_3-20.m"

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

EPOCHS = 1000
triplet_loss1 = TripletMarginLoss(margin=1.0, p=2).to(device)
triplet_loss2 = TripletMarginLoss(margin=1.0, p=2).to(device)
v_optimizer = torch.optim.Adam(vox_model.parameters(), lr=1e-6)
m_optimizer = torch.optim.Adam(mol_model.parameters(), lr=1e-6)

min_loss = 999

for epoch in range(EPOCHS):
    print("EPOCH %s" % epoch)
    epoch_loss = []
    
    for sample_file in sample_col:
        sample_file = pickle.load(open(sample_file, 'rb'))
        filtered_sample_graphs = []
        
        for sample_graph in sample_file:
            if sample_graph.pdb_id not in holdout_complexes:
                filtered_sample_graphs.append(sample_graph)
            
        sample_collection = DataLoader(filtered_sample_graphs, batch_size=64, shuffle=True)
        sample_col_loss1 = []
        sample_col_loss2 = []
        sample_col_loss = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]

            true_contacts = torch.where(vox_batch.contact_map != -1)[0]
            contact_indices = vox_batch.contact_map + mol_batch.ptr[:-1]

            contact_indices = contact_indices[true_contacts]
            voxel_indices = voxel_indices[true_contacts]
            shuffled_voxel_indices = voxel_indices[torch.randperm(voxel_indices.size(0))]
            
#             random_mol_indices = torch.randint(0, mol_batch.x.size(0), (true_contacts.size(0),))
            random_mol_indices = torch.where(mol_batch.heavy == 1)[0]
            random_mol_indices = random_mol_indices[torch.randperm(torch.sum(mol_batch.heavy))]
            random_mol_indices = random_mol_indices[:true_contacts.size(0)]

            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            loss1 = triplet_loss1(vox_pos, mol_pos, mol_neg)
            loss2 = triplet_loss2(mol_pos, vox_pos, vox_neg)
            loss = loss1+loss2
            
            sample_col_loss1.append(loss1.item())
            sample_col_loss2.append(loss2.item())
            sample_col_loss.append(loss.item())
            loss.backward()
            v_optimizer.step()
            m_optimizer.step()

        avg_col_loss1 = sum(sample_col_loss1) / len(sample_col_loss1)
        avg_col_loss2 = sum(sample_col_loss2) / len(sample_col_loss2)
        avg_col_loss = sum(sample_col_loss) / len(sample_col_loss)
        
        if avg_col_loss==avg_col_loss:
            epoch_loss.append(avg_col_loss)
            
        print(avg_col_loss1, avg_col_loss2, avg_col_loss)

    avg_epoch_loss = sum(epoch_loss)/len(epoch_loss)
    print("Epoch %s loss: %s" % (epoch, avg_epoch_loss))
    
    if avg_epoch_loss < min_loss:
        min_loss = avg_epoch_loss
        torch.save(mol_model.state_dict(), mol_weights)
        torch.save(vox_model.state_dict(), vox_weights)
        print("Weights Updated")


In [None]:
torch.save(mol_model.state_dict(), mol_weights)
torch.save(vox_model.state_dict(), vox_weights)

In [None]:
num = float('nan')
num==num

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

mol_model.eval()
vox_model.eval()


In [None]:
with torch.no_grad():
    for sample_file in sample_col:
        sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=64, shuffle=True)
        sample_col_loss = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]

            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            print(vox_out[0])
            
            break
        break