In [1]:
import numpy as np
import sys
from os import listdir
from os.path import isfile, join
from tqdm import tqdm
import matplotlib.pyplot as plt

from site_utils import *
from model_utils import *
from data_utils import *
from model import *

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torchnet.dataset import ListDataset
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.nn.parallel.data_parallel import DataParallel

## Model Testing

In this notebook, we load our current best performing model and compute evaluation metrics for the training and testing set used. Additionally, we provide code for specific protein testing for later use in PYMOL.

Future work will incorporate PYMOL more directly into scripting. 

In [2]:
device, num_workers, pin_memory, dtype = check_gpu()

Using GPU.


In [3]:
model = ConvNet4()
model.load_state_dict(torch.load('results/cnn4_full_xSmooth1.pt'))

if torch.cuda.device_count() > 1: 
    print("Using ", torch.cuda.device_count(), " GPUs.")
    print('------------------------------------')
    model = DataParallel(model)

model = model.to(device=device)

Using  2  GPUs.
------------------------------------


## Load Data and Set Parameters Needed for Evaluation

In [4]:
batch_size = 250
split = 0.8

data = np.load('./datasets/dataset_400_maps.npy')
labels = np.load('./datasets/dataset_400_smoothlabels.npy')
labels[labels > 0] = 1

# Pull Out Testing Data
test_data = data[:500,:,:]
test_labels = labels[:500,:]
data = data[500:,:,:]
labels = labels[500:,:]

## Load Dataset into PyTorch Loaders

In [16]:
x_train = torch.from_numpy(data).to(dtype=torch.float)
y_train = torch.from_numpy(labels)

x_test = torch.from_numpy(test_data).to(dtype=torch.float)
y_test = torch.from_numpy(test_labels)

train_dataset = TensorDataset(x_train,y_train)
test_dataset = TensorDataset(x_test,y_test)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [17]:
training_results = check_metrics(train_loader, model, 'train')

train ==> Precision: 0.8291 | Recall: 0.7204 | F1: 0.77 | Accuracy: 0.963


In [18]:
testing_results = check_metrics(test_loader, model, 'test')

test ==> Precision: 0.5451 | Recall: 0.3772 | F1: 0.45 | Accuracy: 0.915


## Testing for Individual Proteins

These results are able to be fed into the PYMOL script. 

Input the protein ID of choice that is included in the dataset. This will then output a numpy array (pdb_id_output.npy) that will then interface with the PYMOL script. 

In [13]:
pdb_id = '5e7d'

In [14]:
maps_dir = '../data/parsed/maps/'
labels_dir = '../data/parsed/labels/'

protein_map = np.load(maps_dir + pdb_id + '.npy')
protein_label = np.load(labels_dir + pdb_id + '.npy')

output = generate_output(protein_map, protein_label, model)
np.save(pdb_id + '_output.npy', output)
print('Saved file!')

Saved file!


## Find PDBs from Testing Set to Use for Qualitative Validation

To select PDBs specifically from the testing set

In [5]:
names = np.load('./datasets/dataset_400.npy')

In [6]:
names[:500]

array(['2vk5.npy', '5d8g.npy', '1gro.npy', '2owq.npy', '5kl0.npy',
       '5vve.npy', '2eak.npy', '3rs8.npy', '5c6w.npy', '4nye.npy',
       '3nbx.npy', '3u4l.npy', '4ymb.npy', '4ej5.npy', '4od3.npy',
       '2yqc.npy', '5g3l.npy', '3ex3.npy', '4xcc.npy', '1kyn.npy',
       '5ehg.npy', '4m8j.npy', '4d5k.npy', '4pw9.npy', '4kzj.npy',
       '3mdm.npy', '5htp.npy', '2fg0.npy', '1nt0.npy', '2qnu.npy',
       '2ewm.npy', '3roa.npy', '6c4r.npy', '5ob5.npy', '3vny.npy',
       '1lvb.npy', '2h52.npy', '4jkc.npy', '1dfp.npy', '4htf.npy',
       '5tmw.npy', '3oli.npy', '2r0p.npy', '4jk9.npy', '4gr9.npy',
       '5n5p.npy', '5fv3.npy', '3thq.npy', '3p5d.npy', '4o82.npy',
       '1ro8.npy', '1dlj.npy', '1bcr.npy', '4jge.npy', '4d25.npy',
       '3mez.npy', '3uon.npy', '5x79.npy', '3lkk.npy', '1zct.npy',
       '3bca.npy', '5qbu.npy', '3p6n.npy', '5g3f.npy', '5jkv.npy',
       '1so4.npy', '5fy4.npy', '3iv0.npy', '3dwo.npy', '5dbv.npy',
       '3x3f.npy', '5vaw.npy', '5az7.npy', '5fz3.npy', '1q0y.n