In [1]:
import pickle
import re
import numpy as np
import sys
import os
from glob import glob
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch.nn import TripletMarginLoss
import importlib
import yaml
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

# holdout_complexes = ["3gdt", "3g1v", "3w07", "3g1d", "1loq", "3wjw", "2zz1", "2zz2", "1km3", "1x1z", 
#                      "6cbg", "5j7q", "6cbf", "4wrb", "6b1k", "5hvs", "5hvt", "3rf5", "3rf4", "1mfi", 
#                      "5efh", "6csq", "5efj", "6csr", "6css", "6csp", "5een", "5ef7", "5eek", "5eei",
#                      "3ozt", "3u81", "4p58", "5k03", "3ozr", "3ozs", "3oe5", "3oe4", "3hvi", "3hvj",
#                      "3g2y", "3g2z", "3g30", "3g31", "3g34", "3g32", "4de2", "3g35", "4de0", "4de1",
#                      "2exm", "4i3z", "1e1v", "5jq5", "1jsv", "1e1x", "4bcp", "4eor", "1b38", "1pxp", "2xnb", "4bco", "4bcm", "1pxn", "4bcn", "1h1s", "4bck", "2fvd", "1pxo", "2xmy",
#                      "4xoe", "5fs5", "1uwf", "4att", "4av4", "4av5", "4avh", "4avj", "4avi", "4auj", "4x50", "4lov", "4x5r", "4buq", "4x5p", "4css", "4xoc", "4cst", "4xo8", "4x5q",
#                      "1gpk", "3zv7", "1gpn", "5bwc", "5nau", "5nap", "1h23", "1h22", "1e66", "4m0e", "4m0f", "2ha3", "2whp", "2ha6", "2ha2", "1n5r", "4arb", "4ara", "5ehq", "1q84",
#                      "2z1w", "3rr4", "1s38", "1q65", "4q4q", "4q4p", "4q4r", "4kwo", "1r5y", "4leq", "4lbu", "1f3e", "4pum", "4q4s", "3gc5", "2qzr", "4q4o", "3gc4", "5jxq", "3ge7"]

In [2]:
root_path = '/xdisk/twheeler/jgaiser/deepvs3/deepvs/'
params_path = root_path + 'params.yaml'
config_path = root_path + 'config.yaml'

def load_class_from_file(file_path):
    class_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(class_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return getattr(module, class_name)


def load_function_from_file(file_path):
    function_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(
        os.path.basename(file_path), file_path
    )
    module = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = module
    spec.loader.exec_module(module)
    return getattr(module, function_name) 


with open(params_path, "r") as param_file:
    params = yaml.safe_load(param_file)
    
with open(config_path, "r") as config_file:
    config = yaml.safe_load(config_file)

In [3]:
ATOM_LABELS = config['POCKET_ATOM_LABELS']
MOL_ATOM_LABELS = config['MOL_ATOM_LABELS']
EDGE_LABELS = config['POCKET_EDGE_LABELS']
INTERACTION_LABELS = config['INTERACTION_LABELS']

mol_graph_ft = params['data_dir'] + config['mol_graph_file_template']

In [4]:
mol_class_freqs = torch.tensor(config['MOL_LABEL_COUNT'])

mol_class_weights = 1./mol_class_freqs
mol_class_weights = mol_class_weights * mol_class_freqs.sum() / len(mol_class_freqs)

In [5]:
training_sample_files = glob(mol_graph_ft.replace('%s', '*'))
mol_training_data = []

for graph_file in training_sample_files:
    mol_training_data.append(pickle.load(open(graph_file, 'rb')))
    
random.shuffle(mol_training_data)

In [28]:
torch.sum(mol_training_data[0].y,dim=1)

tensor([0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [24]:
val_class_N = 40 

val_set = []
train_set = []

val_class_counts = torch.zeros(len(INTERACTION_LABELS))

for g_i, mol_g in enumerate(mol_training_data):
    if torch.sum(mol_g.y) == 0:
        continue
        
    is_val = False
    y_totals = torch.sum(mol_g.y, dim=0)
    
    for i_index in torch.where(y_totals >0)[0]:
        if val_class_counts[i_index] < val_class_N:
            val_set.append(mol_g)
            val_class_counts += y_totals
            is_val = True
            break
            
    if is_val:
        continue
    
    train_set.append(mol_g)

def batch_logit_accuracy(logits_batch, labels_batch):
    batch_size = logits_batch.size(0)
    accuracies = torch.zeros(batch_size)

    i=0
    for logits,labels in zip(logits_batch, labels_batch):

        num_ones = torch.sum(labels).item()
        topk_values, topk_indices = torch.topk(logits, int(num_ones))

        label_indices = (labels == 1).nonzero(as_tuple=True)[0]
    
        correct = torch.eq(topk_indices.sort()[0], label_indices.sort()[0]).sum().item()

        accuracies[i] = correct / num_ones
        i+=1

    return torch.mean(accuracies).item()

BATCH_SIZE = 32 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = DataLoader(train_set, shuffle=True, batch_size=BATCH_SIZE) 
validation_batch = next(iter(DataLoader(val_set, shuffle=False, batch_size=len(val_set)))).to(device) 

In [50]:
root_path[:-1] + config['mol_embedder_weights']

'/xdisk/twheeler/jgaiser/deepvs3/deepvs/models/weights/mol_embedder_6-1.m'

In [None]:
criterion = nn.BCEWithLogitsLoss(pos_weight=mol_class_weights).to(device)
val_criterion = nn.BCEWithLogitsLoss().to(device)

MolEmbedder = load_class_from_file(root_path+config['mol_embedder_model'])
mol_model = MolEmbedder(**config['mol_embedder_hyperparams']).to(device)

optimizer = torch.optim.Adam(mol_model.parameters(), lr=1e-3)
sigmoid = nn.Sigmoid()

validation_loss_history = []
validation_accuracy_history = []

training_loss_history = []
training_accuracy_history = []

for epoch in range(100):
    print("EPOCH %s" % epoch)
    loss_history = None
    
    for batch_index, batch in enumerate(train_loader):
        batch = batch.to(device)
        interacting_atoms = torch.where(torch.sum(batch.y, dim=1) > 0)[0] 
        y = batch.y[interacting_atoms]
       
        _, interaction_preds, mol_embed = mol_model(batch)
        interaction_preds = interaction_preds[interacting_atoms]
        
        loss = criterion(interaction_preds, y.float())
        loss.backward()
        optimizer.step()
        
        if loss_history is None:
            loss_history = loss
        else:
            loss_history = torch.vstack((loss_history, loss))
            
        if batch_index % 1000 == 0:
            training_loss_history.append(torch.mean(loss_history).item())
            print("Loss: %s" % torch.mean(loss_history).item())
            print("Accuracy: %s" % batch_logit_accuracy(interaction_preds, y))

            for i in torch.randperm(len(y))[:3]:
                print("%.2f "*len(INTERACTION_LABELS) % tuple(sigmoid(interaction_preds[i]).tolist()))
                print("%.2f "*len(INTERACTION_LABELS) % tuple(y[i].tolist()))
                print("")

            loss_history = None 
    
    mol_model.eval()
    with torch.no_grad():
        _, validation_preds, _ = mol_model(validation_batch)
        
        val_interacting_atoms = torch.where(torch.sum(validation_batch.y, dim=1) > 0)[0] 
        validation_y = validation_batch.y[val_interacting_atoms]
        validation_preds = validation_preds[val_interacting_atoms]
        
        validation_loss = val_criterion(validation_preds, validation_y.float()).item()
        validation_accuracy = batch_logit_accuracy(validation_preds, validation_y)
        
        if len(validation_loss_history) > 0: 
            if validation_accuracy > max(validation_accuracy_history):
                torch.save(mol_model.state_dict(), root_path[:-1] + config['mol_embedder_weights'])
                print('WEIGHTS UPDATED')
        
        validation_loss_history.append(validation_loss)
        validation_accuracy_history.append(validation_accuracy)
        
        print("VALIDATION LOSS:", validation_loss)
        print("VALIDATION ACC:", validation_accuracy)
    mol_model.train()