In [1]:
import pickle
import re
import numpy as np
import sys
import os
from glob import glob
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch.nn import TripletMarginLoss

import torch
import torch.nn.functional as F

In [2]:
holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
                     "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
                     "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
                     "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
                     "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
                     "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
                     "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
                     "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
                     "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

In [3]:
root = '/xdisk/twheeler/jgaiser/deepvs2/deepvs/'

with open(root + 'config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file) 

ATOM_LABELS = config['constants']['pocket']['HEAVY_ATOM_LABELS']
EDGE_LABELS = config['constants']['pocket']['EDGE_LABELS']
INTERACTION_LABELS = config['constants']['pocket']['INTERACTION_LABELS']

voxel_label_index = ATOM_LABELS.index('VOXEL')

voxel_dir = root + "data/training_data/graph_data/training_samples/1.0_angstroms/partitions/"
mol_dir = root + 'data/training_data/graph_data/molecules/'

sample_col_ft = voxel_dir + "training_samples_%s.pkl"
mol_ft = mol_dir + "%s_mol.pkl"

In [4]:
mol_dict = {}
mol_graph_files = glob(mol_dir + "*.pkl")

for graph_file in mol_graph_files:
    pdb_id = graph_file.split('/')[-1].split('_')[0]
    mol_dict[pdb_id] = pickle.load(open(graph_file, 'rb'))

In [5]:
sample_col = glob(voxel_dir + "*.pkl")

In [6]:
def get_mol_batch(mol_collection, pdb_ids):
    return next(iter(DataLoader([mol_collection[x] for x in pdb_ids], 
                                shuffle=False, 
                                batch_size = len(pdb_ids))))    

In [7]:
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import GCN2Conv

def handle_data(batch):
    beta = batch.beta/100 
    x = torch.hstack((batch.x, beta.unsqueeze(1)))
    edge_attr = batch.edge_attr.unsqueeze(1) / 12
    return x, batch.edge_index, edge_attr

class GCN(torch.nn.Module):
    def __init__(self, feature_dim, hidden_dim, out_dim, data_transform):
        super().__init__()
        self.data_transform = data_transform
        self.linear1 = torch.nn.Linear(feature_dim, hidden_dim)
        
        self.conv1 = GCN2Conv(hidden_dim, 0.25)
        self.conv2 = GCN2Conv(hidden_dim, 0.25)
        self.conv3 = GCN2Conv(hidden_dim, 0.25)

        self.dropout = torch.nn.Dropout(p=0.25)
        self.batchnorm = torch_geometric.nn.norm.BatchNorm(hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index, edge_weights = self.data_transform(data)

        x = self.linear1(x)

        h = self.conv1(x, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv2(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))

        h = self.conv3(h, x, edge_index, edge_weights)
        h = self.batchnorm(h)
        h = self.dropout(F.relu(h))
        
        o = self.linear2(h)
        return h,o

voxel_model = GCN(39, 512, 9, data_transform=handle_data)

for sample_file in sample_col:
    sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=32, shuffle=True)
    
    for batch in sample_collection:
        out = voxel_model(batch)
        print(out[0].shape)
        break
    break

torch.Size([352, 512])


In [8]:
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool

class ME(AttentiveFP):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.atom_classifier = nn.Linear(kwargs['hidden_channels'], 512)
        
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.atom_classifier(x), self.lin2(out)
    
mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5)

for batch in sample_collection:
    mol_batch = get_mol_batch(mol_dict, batch.pdb_id)
    print(mol_batch.ptr)
    out = mol_model(mol_batch)
    print(out[0].shape)
    break

tensor([   0,   64,  152,  393,  481,  592,  667,  863,  934,  967, 1005, 1057,
        1116, 1326, 1363, 1438, 1482, 1593, 1802, 1869, 1939, 1986, 2017, 2077,
        2115, 2188, 2263, 2510, 2571, 2618, 2814, 2940, 3007])
torch.Size([3007, 512])


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mol_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/mol_embedder/mol_embedder_holdout_3-20.m"
vox_weights = "/xdisk/twheeler/jgaiser/deepvs2/contrastive_loss_exp/vox_embedder/vox_embedder_holdout_3-20.m"

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

mol_model.eval()
vox_model.eval()


GCN(
  (linear1): Linear(in_features=39, out_features=512, bias=True)
  (conv1): GCN2Conv(512, alpha=0.25, beta=1.0)
  (conv2): GCN2Conv(512, alpha=0.25, beta=1.0)
  (conv3): GCN2Conv(512, alpha=0.25, beta=1.0)
  (dropout): Dropout(p=0.25, inplace=False)
  (batchnorm): BatchNorm(512)
  (linear2): Linear(in_features=512, out_features=9, bias=True)
)

In [48]:
import umap.umap_ as umap
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)

In [53]:
with torch.no_grad():
    for sample_file in sample_col:
        sample_file = pickle.load(open(sample_file, 'rb'))
        filtered_sample_graphs = []

        for sample_graph in sample_file:
            if sample_graph.pdb_id not in holdout_complexes:
                filtered_sample_graphs.append(sample_graph)

        sample_collection = DataLoader(filtered_sample_graphs, batch_size=64, shuffle=True)
        
        data_points = None 
        data_classes = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]
            
            interaction_indices = torch.where(torch.sum(vox_batch.y, dim=1))[0]
            true_contacts = torch.where(vox_batch.contact_map != -1)[0]
            
            print(interaction_indices)
            print(true_contacts)
            
            true_contacts = np.intersect1d(interaction_indices.cpu().detach().numpy(), 
                                           true_contacts.cpu().detach().numpy())
            
            contact_indices = vox_batch.contact_map + mol_batch.ptr[:-1]
            contact_indices = contact_indices[true_contacts]
            
            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            vox_pos = vox_out[voxel_indices]
            mol_pos = mol_out[contact_indices]
            
            for sample_val in vox_batch.y[interaction_indices]:
                sample_class = np.random.choice(torch.where(sample_val==1)[0].cpu().detach().numpy())
                data_classes.append(sample_class)  
            
            if data_points == None:
                data_points = vox_out[interaction_indices]
                continue
            
            data_points = torch.vstack((data_points, vox_out[interaction_indices]))
            
#             mol_out,_ = mol_model(mol_batch)
        
        break
    
# data = np.array(data_points.cpu().detach())
# labels = np.array(data_classes)

# # reducer = umap.UMAP(n_neighbors=5, n_components=2, metric='euclidean')
# # embedding = reducer.fit_transform(data)
# embedding = tsne.fit_transform(data)
        
# plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='viridis')
# plt.colorbar()
# plt.show()


tensor([ 1,  4,  5,  7,  8,  9, 10, 12, 13, 15, 17, 22, 28, 31, 32, 37, 38, 47,
        48, 50, 51, 58, 61, 63], device='cuda:0')
tensor([ 0,  1,  2,  3,  5,  6,  7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 39, 40, 41,
        42, 43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62,
        63], device='cuda:0')
[ 1  5  7 10 13 15 17 22 28 31 32 50 51 61 63]


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mol_model = ME(in_channels=38, 
               hidden_channels=512, 
               out_channels=512, 
               edge_dim=1, 
               num_layers=3, 
               num_timesteps=3, 
               dropout=0.5).to(device)

vox_model = GCN(39, 512, 9, data_transform=handle_data).to(device)

mol_model.load_state_dict(torch.load(mol_weights))
vox_model.load_state_dict(torch.load(vox_weights))

mol_model.eval()
vox_model.eval()


GCN(
  (linear1): Linear(in_features=39, out_features=512, bias=True)
  (conv1): GCN2Conv(512, alpha=0.25, beta=1.0)
  (conv2): GCN2Conv(512, alpha=0.25, beta=1.0)
  (conv3): GCN2Conv(512, alpha=0.25, beta=1.0)
  (dropout): Dropout(p=0.25, inplace=False)
  (batchnorm): BatchNorm(512)
  (linear2): Linear(in_features=512, out_features=9, bias=True)
)

In [15]:
with torch.no_grad():
    for sample_file in sample_col:
        sample_collection = DataLoader(pickle.load(open(sample_file, 'rb')), batch_size=64, shuffle=True)
        sample_col_loss = []

        for batch_i, vox_batch in enumerate(sample_collection):
            vox_batch = vox_batch.to(device)
            mol_batch = get_mol_batch(mol_dict, vox_batch.pdb_id).to(device)
            voxel_indices = torch.where(vox_batch.x[:, voxel_label_index]==1)[0]

            vox_out,_ = vox_model(vox_batch)
            mol_out,_ = mol_model(mol_batch)
            
            print(vox_out[0])
            
            break
        break

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1374, 0.0000,
        0.0000, 5.3852, 0.6956, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1381, 1.4832, 0.0000,
        3.6136, 0.0000, 0.0000, 2.8386, 3.4378, 0.0000, 0.1881, 4.5349, 0.8032,
        0.0000, 0.0000, 0.0000, 0.2161, 0.0000, 3.3003, 0.0000, 0.0000, 0.0000,
        2.8792, 0.0000, 1.7987, 0.0000, 0.0000, 0.0000, 0.0000, 1.4856, 0.0000,
        0.7792, 0.0000, 5.3847, 0.2224, 4.6548, 0.0000, 0.0000, 3.6778, 2.7847,
        0.0000, 0.0000, 0.1906, 0.0000, 0.0000, 0.0000, 2.0430, 1.1994, 0.0000,
        1.4317, 0.0000, 0.0000, 2.5312, 0.0000, 0.6340, 0.0000, 0.0000, 0.0000,
        0.6127, 0.8837, 3.2478, 0.0000, 1.2316, 0.5463, 1.4025, 0.0000, 0.0000,
        0.0000, 1.0905, 0.0000, 0.5936, 0.0000, 8.7095, 0.0000, 0.0000, 0.0000,
        0.0000, 2.2545, 0.7675, 0.0000, 0.0000, 0.0000, 0.9533, 1.4715, 3.0845,
        5.6001, 0.0000, 0.0000, 0.0000, 