In [None]:
import pickle
import re
import numpy as np
import sys
import os
from glob import glob
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch.nn import TripletMarginLoss
import importlib
import yaml
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

In [2]:
holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
                     "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
                     "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
                     "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
                     "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
                     "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
                     "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
                     "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
                     "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

In [3]:
params_path = '/xdisk/twheeler/jgaiser/deepvs3/deepvs_params.yaml'
config_path = '/xdisk/twheeler/jgaiser/deepvs3/deepvs/config.yaml'
root_path = '/xdisk/twheeler/jgaiser/deepvs3/deepvs/'
function_path = root_path + 'code/utils/data_processing_utils.py'

def load_class_from_file(file_path):
    class_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(class_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return getattr(module, class_name)


def load_function_from_file(file_path):
    function_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(
        os.path.basename(file_path), file_path
    )
    module = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = module
    spec.loader.exec_module(module)
    return getattr(module, function_name) 


with open(params_path, "r") as param_file:
    params = yaml.safe_load(param_file)
    
with open(config_path, "r") as config_file:
    config = yaml.safe_load(config_file)

In [4]:
ATOM_LABELS = config['POCKET_ATOM_LABELS']
MOL_ATOM_LABELS = config['MOL_ATOM_LABELS']
EDGE_LABELS = config['POCKET_EDGE_LABELS']
INTERACTION_LABELS = config['INTERACTION_LABELS']

mol_graph_ft = params['data_dir'] + config['mol_graph_file_template']
training_sample_ft = params['data_dir'] + config['training_sample_file_template']

In [5]:
pocket_class_freqs = torch.tensor(config['POCKET_LABEL_COUNT'])
mol_class_freqs = torch.tensor(config['MOL_LABEL_COUNT'])

pocket_class_weights = 1./pocket_class_freqs
pocket_class_weights = pocket_class_weights * pocket_class_freqs.sum() / len(pocket_class_freqs)

mol_class_weights = 1./mol_class_freqs
mol_class_weights = mol_class_weights * mol_class_freqs.sum() / len(mol_class_freqs)

In [9]:
mol_dict = {}
mol_graph_files = glob(mol_graph_ft.replace('%s','*'))

for graph_file in mol_graph_files:
    pdb_id = graph_file.split('/')[-1].split('_')[0]
    g = pickle.load(open(graph_file, 'rb'))
    g.pdb_id = pdb_id
    mol_dict[pdb_id] = g

In [10]:
def get_vox_batch(vox_dict, id_list):
    batch_a = []
    batch_b = []
    
    for pdb_id in id_list:
        sample_a, sample_b = random.choices(vox_dict[pdb_id], k=2)
        batch_a.append(sample_a)
        batch_b.append(sample_b)
        
    return (next(iter(DataLoader(batch_a, 
                                shuffle=False, 
                                batch_size=len(id_list)))), 
            
            next(iter(DataLoader(batch_b,                                   
                                 shuffle=False,
                                 batch_size=len(id_list)))))


def get_mol_batch(mol_collection, pdb_ids):
    return next(iter(DataLoader([mol_collection[x] for x in pdb_ids], 
                                shuffle=False, 
                                batch_size = len(pdb_ids))))    


def get_batch_indices(vox_batch, mol_batch):
    vox_interaction_indices = torch.where(torch.sum(vox_batch.y, dim=1) != 0)[0]
    vox_contact_indices = torch.where(vox_batch.contact_map != -1)[0]
    
    mol_interaction_indices = torch.where(torch.sum(mol_batch.y, dim=1) != 0)[0]
    mol_contact_indices = mol_batch.ptr[vox_contact_indices] + vox_batch.contact_map[vox_contact_indices]
    
    return (vox_interaction_indices, 
            vox_contact_indices, 
            mol_interaction_indices, 
            mol_contact_indices)


def get_shuffled_mol_indices(mol_batch):
    shuffled_indices = []
    
    for ptr_i in range(1, mol_batch.ptr.size(0)):
        mol_start = mol_batch.ptr[ptr_i-1]
        mol_stop = mol_batch.ptr[ptr_i]
        
        mol_indices = torch.arange(mol_start, mol_stop)
        heavy_indices = mol_indices[mol_batch.heavy[mol_indices]==1]
        heavy_indices = heavy_indices[torch.randperm(heavy_indices.size(0))]
        light_indices = mol_indices[mol_batch.heavy[mol_indices]==0]
        
        shuffled_indices.extend([heavy_indices, light_indices])
    
    return torch.hstack(shuffled_indices)

In [34]:
BATCH_SIZE = 64 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

criterion1 = nn.BCEWithLogitsLoss(pos_weight=pocket_class_weights).to(device)
criterion2 = nn.BCEWithLogitsLoss(pos_weight=mol_class_weights).to(device)

criterion3 = TripletMarginLoss(margin=0.5, p=2).to(device)
criterion4 = TripletMarginLoss(margin=0.5, p=2).to(device)

VoxEncoder = load_class_from_file(params['vox_encoder_model'])
vox_data_transform = load_function_from_file(params['vox_encoder_data_transform'])
params['vox_encoder_hyperparams']['data_transform'] = vox_data_transform
voxel_model = VoxEncoder(**params['vox_encoder_hyperparams']).to(device)

MolEncoder = load_class_from_file(params['mol_encoder_model'])
params['mol_encoder_hyperparams']['data_transform'] = None
mol_model = MolEncoder(**params['mol_encoder_hyperparams']).to(device)

v_optimizer = torch.optim.Adam(voxel_model.parameters(), lr=1e-3)
m_optimizer = torch.optim.Adam(mol_model.parameters(), lr=1e-3)

In [30]:
VOX_COUNT = 0
for k,v in training_sample_dict.items():
     VOX_COUNT += len(v)
VOX_COUNT

582505

In [None]:
for epoch in range(200):
    print("EPOCH %s" % epoch)
    loss_history = None
    
    for batch_index in range(int(VOX_COUNT/BATCH_SIZE)):
        random_id_list = random.choices(pdb_ids, k=BATCH_SIZE)
        
        vox_batch_a, vox_batch_b = get_vox_batch(training_sample_dict, random_id_list)
        mol_batch = get_mol_batch(mol_dict, random_id_list)

        vox_a_interxn_i, vox_a_contact_i, mol_interxn_i, mol_a_contact_i = get_batch_indices(vox_batch_a, mol_batch) 
        vox_b_interxn_i, _, _, _ = get_batch_indices(vox_batch_b, mol_batch) 

        vox_embed_a, vox_pred_a = voxel_model(vox_batch_a.to(device))
        vox_embed_b, vox_pred_b = voxel_model(vox_batch_b.to(device))
        mol_embed, mol_pred, _ = mol_model(mol_batch.to(device))

        atom_preds = torch.vstack((vox_pred_a[vox_a_interxn_i], vox_pred_b[vox_b_interxn_i]))
        atom_labels = torch.vstack((vox_batch_a.y[vox_a_interxn_i], vox_batch_b.y[vox_b_interxn_i])).float()

        l1 = criterion1(atom_preds, atom_labels)
        l2 = criterion2(mol_pred[mol_interxn_i], mol_batch.y[mol_interxn_i])
        l3 = criterion3(vox_embed_a[vox_a_contact_i], 
                   mol_embed[mol_a_contact_i],
                   mol_embed[get_shuffled_mol_indices(mol_batch)[mol_a_contact_i]])

        l4 = criterion4(mol_embed[mol_a_contact_i], vox_embed_a[vox_a_contact_i], vox_embed_b[vox_a_contact_i])
        
#         loss = l1+l2+5*l3+5*l4
        loss = l3+l4
        
        loss.backward()
        
        v_optimizer.step()
        m_optimizer.step()
        
        l_tensor = torch.hstack([l1,l2,l3,l4]).unsqueeze(0)
        
        if loss_history is None:
            loss_history = l_tensor
        else:
            loss_history = torch.vstack((loss_history, l_tensor))
            
        if batch_index % 1000 == 0:
            print(torch.mean(loss_history, dim=0))
            loss_history = l_tensor

EPOCH 0
tensor([0.7429, 0.7346, 0.4932, 0.5037], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7234, 0.6970, 0.5008, 0.4993], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7213, 0.6984, 0.5000, 0.5019], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7251, 0.7101, 0.5002, 0.4994], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7269, 0.7115, 0.4994, 0.4987], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7294, 0.7149, 0.5001, 0.4977], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7281, 0.7140, 0.4998, 0.4995], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7284, 0.7145, 0.4999, 0.4986], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7277, 0.7143, 0.4997, 0.4965], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7274, 0.7145, 0.4995, 0.4976], device='cuda:0',
       grad_fn=<MeanBackward1>)
EPOCH 1
tensor([0.7103, 0.6994, 0.4705, 0.4488], device='cuda:0',
       grad_fn=<MeanBackwa

tensor([0.7311, 0.7140, 0.4993, 0.5008], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7331, 0.7142, 0.4997, 0.4997], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7314, 0.7139, 0.4993, 0.4971], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([0.7301, 0.7146, 0.4999, 0.5018], device='cuda:0',
       grad_fn=<MeanBackward1>)


In [123]:
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GCN2Conv

def handle_data(batch):
    beta = batch.beta/100 
    x = torch.hstack((batch.x, beta.unsqueeze(1)))
    edge_attr = batch.edge_attr.unsqueeze(1) / 12
    return x, batch.edge_index, edge_attr

class GCN(torch.nn.Module):
    def __init__(self, feature_dim, hidden_dim, out_dim, data_transform):
        super().__init__()
        self.data_transform = data_transform
        self.linear1 = torch.nn.Linear(feature_dim, hidden_dim)
        
        self.conv1 = GCN2Conv(hidden_dim, 0.25)
        self.conv2 = GCN2Conv(hidden_dim, 0.25)
        self.conv3 = GCN2Conv(hidden_dim, 0.25)

        self.dropout = torch.nn.Dropout(p=0.25)
        self.batchnorm = torch_geometric.nn.norm.BatchNorm(hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index, edge_weights = self.data_transform(data)

        x = self.linear1(x)

        h = self.conv1(x, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv2(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv3(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))
        
        o = self.linear2(h)
        return h,o

voxel_model = GCN(39, 512, 9, data_transform=handle_data)

for sample_file in sample_col:
    sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=32, shuffle=True)
    
    for batch in sample_collection:
        out = voxel_model(batch)
        print(out[0].shape)
        break
    break

NameError: name 'sample_col' is not defined

In [None]:
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

class ME(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.atom_classifier = nn.Linear(kwargs['hidden_channels'], 512)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.atom_classifier(x), self.lin2(out)
    
mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5)

for batch in sample_collection:
    mol_batch = get_mol_batch(mol_dict, batch.pdb_id)
    print(torch.where(mol_batch.heavy == 1)[0][torch.randperm(torch.sum(mol_batch.heavy))][:10])
    break
    print(mol_batch.ptr)
    out = mol_model(mol_batch)
    print(out[0].shape)
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device='cpu'

mol_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/mol_embedder/mol_embedder_holdout_3-20.m"
vox_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/vox_embedder/vox_embedder_holdout_3-20.m"

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

EPOCHS = 1000
triplet_loss1 = TripletMarginLoss(margin=1.0, p=2).to(device)
triplet_loss2 = TripletMarginLoss(margin=1.0, p=2).to(device)
v_optimizer = torch.optim.Adam(vox_model.parameters(), lr=1e-6)
m_optimizer = torch.optim.Adam(mol_model.parameters(), lr=1e-6)

min_loss = 999

for epoch in range(EPOCHS):
    print("EPOCH %s" % epoch)
    epoch_loss = []
    
    for sample_file in sample_col:
        sample_file = pickle.load(open(sample_file, 'rb'))
        filtered_sample_graphs = []
        
        for sample_graph in sample_file:
            if sample_graph.pdb_id not in holdout_complexes:
                filtered_sample_graphs.append(sample_graph)
            
        sample_collection = DataLoader(filtered_sample_graphs, batch_size=64, shuffle=True)
        sample_col_loss1 = []
        sample_col_loss2 = []
        sample_col_loss = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]

            true_contacts = torch.where(vox_batch.contact_map != -1)[0]
            contact_indices = vox_batch.contact_map + mol_batch.ptr[:-1]

            contact_indices = contact_indices[true_contacts]
            voxel_indices = voxel_indices[true_contacts]
            shuffled_voxel_indices = voxel_indices[torch.randperm(voxel_indices.size(0))]
            
#             random_mol_indices = torch.randint(0, mol_batch.x.size(0), (true_contacts.size(0),))
            random_mol_indices = torch.where(mol_batch.heavy == 1)[0]
            random_mol_indices = random_mol_indices[torch.randperm(torch.sum(mol_batch.heavy))]
            random_mol_indices = random_mol_indices[:true_contacts.size(0)]

            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            loss1 = triplet_loss1(vox_pos, mol_pos, mol_neg)
            loss2 = triplet_loss2(mol_pos, vox_pos, vox_neg)
            loss = loss1+loss2
            
            sample_col_loss1.append(loss1.item())
            sample_col_loss2.append(loss2.item())
            sample_col_loss.append(loss.item())
            loss.backward()
            v_optimizer.step()
            m_optimizer.step()

        avg_col_loss1 = sum(sample_col_loss1) / len(sample_col_loss1)
        avg_col_loss2 = sum(sample_col_loss2) / len(sample_col_loss2)
        avg_col_loss = sum(sample_col_loss) / len(sample_col_loss)
        
        if avg_col_loss==avg_col_loss:
            epoch_loss.append(avg_col_loss)
            
        print(avg_col_loss1, avg_col_loss2, avg_col_loss)

    avg_epoch_loss = sum(epoch_loss)/len(epoch_loss)
    print("Epoch %s loss: %s" % (epoch, avg_epoch_loss))
    
    if avg_epoch_loss < min_loss:
        min_loss = avg_epoch_loss
        torch.save(mol_model.state_dict(), mol_weights)
        torch.save(vox_model.state_dict(), vox_weights)
        print("Weights Updated")


In [None]:
torch.save(mol_model.state_dict(), mol_weights)
torch.save(vox_model.state_dict(), vox_weights)

In [None]:
num = float('nan')
num==num

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

mol_model.eval()
vox_model.eval()


In [None]:
with torch.no_grad():
    for sample_file in sample_col:
        sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=64, shuffle=True)
        sample_col_loss = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]

            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            print(vox_out[0])
            
            break
        break