In [1]:
import pickle
import re
import numpy as np
import sys
import os
from glob import glob
import torch
import torch_geometric
import random
import yaml
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import remove_isolated_nodes
from torch import nn
from torch_geometric.nn import GCN2Conv
from torch_geometric.nn import SAGPooling
from torch_geometric.nn import MLP
from torch_geometric.nn import AttentiveFP
from torch_geometric.nn.aggr import AttentionalAggregation
from copy import deepcopy 
from torch_geometric.nn import GATConv, MessagePassing, global_add_pool
from torch.nn import TripletMarginLoss
import importlib.util
from torch_geometric.nn import radius_graph
import itertools
from sklearn.metrics import roc_auc_score
import numpy as np



import torch
import torch.nn.functional as F



In [9]:
root_path     = '/xdisk/twheeler/jgaiser/deepvs3/deepvs/'
params_path   = root_path + 'params.yaml'
config_path   = root_path + 'config.yaml'
function_path = root_path + 'code/utils/data_processing_utils.py'

def load_class_from_file(file_path):
    class_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(class_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return getattr(module, class_name)


def load_function_from_file(file_path):
    function_name = file_path.split("/")[-1].split(".")[0]
    spec = importlib.util.spec_from_file_location(
        os.path.basename(file_path), file_path
    )
    module = importlib.util.module_from_spec(spec)
    sys.modules[spec.name] = module
    spec.loader.exec_module(module)
    return getattr(module, function_name) 


with open(params_path, "r") as param_file:
    params = yaml.safe_load(param_file)
    
with open(config_path, "r") as config_file:
    config = yaml.safe_load(config_file)

mol_graph_ft = params['data_dir'] + config['mol_graph_file_template'] 
poxel_graph_ft = params['data_dir'] + config['full_pocket_graph_file_template']

In [97]:
from torch_geometric.utils import add_self_loops, to_undirected

torch.tensor((1,2,3.0))

tensor([1., 2., 3.])

In [3]:
poxel_data = []
pdb_ids = []
mol_data = {}

vox_interaction_count = torch.zeros(9)

for g_file in glob(poxel_graph_ft.replace('%s', '*')):
    g = pickle.load(open(g_file, 'rb'))
    
    pdb_id = g_file.split('/')[-1].split('_')[0]
    g.pdb_id = pdb_id
    vox_interaction_count += torch.sum(g.y, dim=0)
        
    pdb_ids.append(pdb_id) 
    poxel_data.append(g)
    
print(vox_interaction_count)

tensor([  2548.,  78945.,  60011., 126140.,   1062.,   2616.,  13124.,  11344.,
          4288.])


In [4]:
for g_file in glob(mol_graph_ft.replace('%s', '*')):
    pdb_id = g_file.split('/')[-1].split('_')[0]
    g = pickle.load(open(g_file, 'rb'))
    g.pdb_id = pdb_id
    mol_data[pdb_id] = g

In [84]:
pdbbind_dict = {}
candidate_proteins = []
validation_ids = []

with open('/xdisk/twheeler/jgaiser/data/pdbbind/index_readme/general/index/INDEX_general_PL_name.2020', 'r') as index_in:
    for line_i, line in enumerate(index_in):
        if line[0] == '#':
            continue
        
        line_arr = line.rstrip().split('  ')
         
        if len(line_arr) != 4:
             continue
                
        protein_name = line_arr[-1]
        
        if protein_name not in pdbbind_dict:
            pdbbind_dict[protein_name] = [line_arr[0]]
        else:
            pdbbind_dict[protein_name].append(line_arr[0])
            
for k,v in pdbbind_dict.items():
    if len(v) >= 10 and len(v) <= 18:
        candidate_proteins.append(k)
    
random.shuffle(candidate_proteins)

for k in candidate_proteins:
    if sum([len(x) for x in validation_ids]) > 250:
        break
    
    validation_ids.append(pdbbind_dict[k])

print(len(validation_ids))

pox_train_set = []
pox_val_set = []

mol_train_set = []
mol_val_set = []

for pox_sample in poxel_data:
    sample_id = pox_sample.pdb_id
    
    if True in [sample_id in x for x in validation_ids]:
        pox_val_set.append(pox_sample)
        mol_val_set.append(mol_data[pox_sample.pdb_id])
        continue
    
    pox_train_set.append(pox_sample)
    mol_train_set.append(mol_data[pox_sample.pdb_id])
    
def get_random_batch_pair(set_a, set_b, device='cpu', batch_size=32, min_prob=0.1):
    batch_indices = torch.randperm(len(set_a))

    poxel_graph_list = []
    
    for i in batch_indices[:batch_size]:
        g = deepcopy(set_a[i])
        
#         random_prob = random.randrange(int(min_prob*10), 11) / 10
#         random_node_indices = torch.randperm(len(g.x))[:int(len(g.x)*random_prob)]
#         g.x = g.x[random_node_indices]
#         g.pos = g.pos[random_node_indices]
        
        poxel_graph_list.append(g)
        
    
    batch_a = Batch.from_data_list(poxel_graph_list)
    batch_b = Batch.from_data_list([set_b[x] for x in batch_indices[:batch_size]])
    return batch_a.to(device), batch_b.to(device)

pox_val_data = []
mol_val_data = []

for target in validation_ids:
    target_poxels = []
    target_mols = []
    
    for pdb_id in target:
        for poxel_graph in pox_val_set:
            if poxel_graph.pdb_id == pdb_id:
                g = deepcopy(poxel_graph)
        #         poxel_graph_list.append(g)
#                 random_prob = random.randrange(int(min_prob*10), 11) / 10
#                 random_prob = 0.4
#                 random_node_indices = torch.randperm(len(g.x))[:int(len(g.x)*random_prob)]
#                 g.x = g.x[random_node_indices]
#                 g.pos = g.pos[random_node_indices]
                target_poxels.append(g)
                
        for mol_graph in mol_val_set:
            if mol_graph.pdb_id == pdb_id:
                target_mols.append(mol_graph)
                
    pox_val_batch = Batch.from_data_list(target_poxels)
    pox_val_data.append(Batch.from_data_list(target_poxels))
    mol_val_data.append(Batch.from_data_list(target_mols))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for i in range(len(pox_val_data)):
    pox_val_data[i] = pox_val_data[i].to(device)
    mol_val_data[i] = mol_val_data[i].to(device)
    
del poxel_data
del mol_data

20


NameError: name 'poxel_data' is not defined

In [6]:
ac_weights = config['active_classifier_weights'] % root_path
ac_weights

'/xdisk/twheeler/jgaiser/deepvs3/deepvs//models/weights/ac_classifier_7-16.m'

In [7]:
sigmoid = nn.Sigmoid()

def validate(model, pox_val_data, mol_val_data):
    model.eval()
    
    total_auc_scores = []
    total_pos_scores = []
    total_neg_scores = []
    total_mean_scores = []
    all_scores = []
    
    for i in range(len(pox_val_data)):
        print('Evaluating validation target %s' % i)
        val_target_batch = pox_val_data[i]
        val_pos_batch = mol_val_data[i]
        
        target_neg_scores = np.array([])

        for j in range(len(mol_val_data)):
            out, y = model(val_target_batch, 
                           val_pos_batch, 
                           decoy_batch=mol_val_data[j])

            out = sigmoid(out)

            if i == j:
                target_pos_scores = out[torch.where(y==1)[0]].detach().cpu().numpy()
            else:
                negative_scores = out[torch.where(y==0)[0]].detach().cpu().numpy()
                
                if len(target_neg_scores) == 0:
                    target_neg_scores = negative_scores
                    
                target_neg_scores = np.concatenate((target_neg_scores, 
                                                    negative_scores))
                
        # Combine the scores
        scores = np.concatenate([target_pos_scores, target_neg_scores])
        all_scores.append(scores)
        total_mean_scores.append(np.mean(scores))

        # Create labels: 1 for positive class, 0 for negative class
        labels = np.concatenate([np.ones_like(target_pos_scores), 
                                 np.zeros_like(target_neg_scores)])

        # Compute AUC
        auc = roc_auc_score(labels, scores)
        pos_mean = np.mean(target_pos_scores)
        neg_mean = np.mean(target_neg_scores)
        
        total_auc_scores.append(auc)
        total_pos_scores.append(pos_mean)
        total_neg_scores.append(neg_mean)
        
        print(pos_mean, '--', neg_mean)
        print(auc)
        print('----------------------------------')
    
    return np.array(total_auc_scores), np.array(total_pos_scores), np.array(total_neg_scores), np.array(total_mean_scores), all_scores

# a,p,n,m,s = validate(ac_model, pox_val_data, mol_val_data)
# mean_a = np.mean(a)       

# print(np.mean(a))
# print(np.mean(p))
# print(np.mean(n))

In [11]:
mol_class_freqs = torch.tensor(config['MOL_LABEL_COUNT'])
vox_class_freqs = torch.tensor(config['POCKET_LABEL_COUNT'])

mol_class_weights = 1./mol_class_freqs
mol_class_weights = mol_class_weights * mol_class_freqs.sum() / len(mol_class_freqs)

vox_class_weights = 1./vox_class_freqs
vox_class_weights = vox_class_weights * vox_class_freqs.sum() / len(vox_class_freqs)


In [31]:
# # torch.manual_seed(1234)
# with open(config_path, "r") as config_file:
#     config = yaml.safe_load(config_file)

# VoxEmbedder = load_class_from_file(config['vox_embedder_model'] % root_path)
# PoxelAggregator = load_class_from_file(config['poxel_aggregator_model'] % root_path)
# MolEmbedder = load_class_from_file(config['mol_embedder_model'] % root_path)
# MolAggregator = load_class_from_file(config['mol_aggregator_model'] % root_path)

# config['active_classifier_hyperparams']['in_dim'] = config['mol_aggregator_hyperparams']['out_dim'] + config['poxel_aggregator_hyperparams']['out_dim'] 
    
# ActiveClassifier = load_class_from_file(config['active_classifier_model'] % root_path)

# ac_model = ActiveClassifier(
#     voxel_embedder=(VoxEmbedder, config['vox_embedder_hyperparams']),
#     poxel_model=(PoxelAggregator, config['poxel_aggregator_hyperparams']),
#     mol_embed_model=(MolEmbedder, config['mol_embedder_hyperparams']),  
#     mol_agg_model = (MolAggregator, config['mol_aggregator_hyperparams']),
#     **config['active_classifier_hyperparams']).to(device)

# # Define the learning rates for each module
# learning_rates = {
#     'vox_embedder': 1e-3,
#     'pox_agg': 1e-5,
#     'mol_embedder': 1e-3,
#     'mol_agg': 1e-4,
#     'ac_model': 1e-4,
# }

# # Create separate optimizers for each module with their respective learning rates
# optimizers = {
#     'vox_embedder': torch.optim.Adam(ac_model.vox_embedder.parameters(), lr=learning_rates['vox_embedder']),
#     'pox_agg': torch.optim.Adam(ac_model.pox_agg.parameters(), lr=learning_rates['pox_agg']),
#     'mol_embedder': torch.optim.Adam(ac_model.mol_embedder.parameters(), lr=learning_rates['mol_embedder']),
#     'mol_agg': torch.optim.Adam(ac_model.mol_agg.parameters(), lr=learning_rates['mol_agg']),
#     'ac_model': torch.optim.Adam(ac_model.parameters(), lr=learning_rates['ac_model'])
# }



# BATCH_SIZE = 16  
# # optimizer = torch.optim.AdamW(ac_model.parameters(), lr=1e-5)
# optimizer = torch.optim.Adam(ac_model.parameters(), lr=1e-4)

# vox_prediction_loss = nn.BCEWithLogitsLoss(pos_weight=vox_class_weights).to(device) 
# mol_prediction_loss = nn.BCEWithLogitsLoss(pos_weight=mol_class_weights).to(device) 
# ac_loss = nn.BCEWithLogitsLoss().to(device) 
    
# ac_model.train()
# # ac_model.eval()
# best_a = 0.8

for epoch in range(1, 100):
    for batch_index in range(int(len(pox_train_set) / BATCH_SIZE)):  
        torch.cuda.empty_cache()
        
        pox_batch, mol_batch = get_random_batch_pair(pox_train_set, mol_train_set, device, batch_size=BATCH_SIZE, min_prob=0.1)
        decoy_pocket_batch, _ = get_random_batch_pair(pox_train_set, mol_train_set, device, batch_size=BATCH_SIZE, min_prob=0.1)
        _, decoy_mol_batch = get_random_batch_pair(pox_train_set, mol_train_set, device, batch_size=BATCH_SIZE, min_prob=0.1)
        
        vox_interaction_indices = torch.nonzero(pox_batch.y)[:,0]
        decoy_vox_interaction_indices = torch.nonzero(decoy_pocket_batch.y)[:,0]
        
        mol_interaction_indices = torch.nonzero(mol_batch.y)[:,0]
        decoy_mol_interaction_indices = torch.nonzero(decoy_mol_batch.y)[:,0]
        
        vox_interaction_indices = (torch.hstack((vox_interaction_indices, 
                                                len(pox_batch.y) + decoy_vox_interaction_indices)))
        
        mol_interaction_indices = (torch.hstack((mol_interaction_indices, 
                                                len(mol_batch.y) + decoy_mol_interaction_indices)))
        
        vox_pred_y = torch.vstack((pox_batch.y, decoy_pocket_batch.y))[vox_interaction_indices]
        mol_pred_y = torch.vstack((mol_batch.y, decoy_mol_batch.y))[mol_interaction_indices]
        
        out, y, vox_preds, mol_preds = ac_model(pox_batch, mol_batch, decoy_pocket_batch, decoy_mol_batch)
    
        v_pred_loss = vox_prediction_loss(vox_preds[vox_interaction_indices], vox_pred_y)
        m_pred_loss = mol_prediction_loss(mol_preds[mol_interaction_indices], mol_pred_y)
        
        pos_indices = torch.where(y==1)[0]
        neg_indices = torch.where(y==0)[0]
            
        l1 = ac_loss(out, y.unsqueeze(1).to(device))
        l = l1 + 0.5*v_pred_loss + 0.1*m_pred_loss
        
        for k,o in optimizers.items():
            o.zero_grad()
            
        l.backward()
        
        for k,o in optimizers.items():
            o.step()
        
        if batch_index % 100 == 0:
            print('--------------------------------------------')
            print(l1.item(), v_pred_loss.item(), m_pred_loss.item())
            
            random_interaction_indices = torch.randperm(len(vox_interaction_indices))[:2]
            for i in random_interaction_indices:
                print(sigmoid(vox_preds[vox_interaction_indices][i]))
                print(vox_pred_y[i])
                print('-')
                
            print('----')
            
            random_interaction_indices = torch.randperm(len(mol_interaction_indices))[:2]
            for i in random_interaction_indices:
                print(sigmoid(mol_preds[mol_interaction_indices][i]))
                print(mol_pred_y[i])
                print('-')    
                
            print('----')
            print(" ".join(["%.3f" % x for x in sigmoid(out[pos_indices].flatten())]))
            print(torch.mean(sigmoid(out[pos_indices].flatten())))
            print('----')
            print(" ".join(["%.3f" % x for x in sigmoid(out[neg_indices].flatten())]))
            print(torch.mean(sigmoid(out[neg_indices].flatten())))
            
    print("EPOCH %s COMPLETE" % (epoch))
    
#     if (epoch+1) % 5 == 0:
#         ac_model.eval()

    if epoch < 10:
        continue
        
    if epoch % 3 == 0:
        a,p,n,_,_ = validate(ac_model, pox_val_data, mol_val_data)
        mean_a = np.mean(a)       
        
        print(np.mean(a))
        print(np.mean(p))
        print(np.mean(n))
        
        if mean_a > best_a:
            torch.save(ac_model.state_dict(), ac_weights)
            print('WEIGHTS UPDATED')
            best_a = mean_a
            
        ac_model.train()

--------------------------------------------
0.7470303773880005 0.4084528386592865 0.11995812505483627
tensor([0.1113, 0.1417, 0.1284, 0.1583, 0.1516, 0.1108, 0.1054, 0.1140, 0.1110],
       device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0., 0., 0., 1., 0., 0., 0., 0., 0.], device='cuda:0')
-
tensor([0.1113, 0.1417, 0.1284, 0.1583, 0.1516, 0.1108, 0.1054, 0.1140, 0.1110],
       device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([0., 0., 0., 1., 0., 0., 0., 0., 0.], device='cuda:0')
-
----
tensor([0.9940, 0.0465, 0.0050, 0.2103, 0.0493, 0.2289, 0.1631, 0.1043, 0.0228],
       device='cuda:0', grad_fn=<SigmoidBackward0>)
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')
-
tensor([3.2666e-04, 6.4137e-01, 9.3707e-01, 1.9382e-03, 1.5700e-02, 2.5839e-02,
        1.4594e-02, 7.5217e-04, 7.4295e-03], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
tensor([0., 0., 1., 0., 0., 0., 0., 0., 0.], device='cuda:0')
-
----
0.542 0.815 0.135 0.818 0.402 0.768 0.749 0.820 0.115 0

RuntimeError: CUDA out of memory. Tried to allocate 244.00 MiB (GPU 0; 15.89 GiB total capacity; 14.67 GiB already allocated; 168.12 MiB free; 14.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF