In [1]:
import pickle as pkl
import numpy as np
import os
os.chdir('/scratch/cl1205/protease-gcnn-pytorch/model')
from utils import *
import torch
from torch import nn

In [3]:
def TestIndexSave(dataset_str):
    cwd = os.getcwd()
    names = ['x', 'y', 'graph', 'sequences', 'proteases', 'labelorder']
    features, y_arr, adj_ls, sequences, proteases, labelorder = tuple(load_input(dataset_str, names, input_type='train'))

    idx = np.arange(y_arr.shape[0])
    print(y_arr.shape[0])
    np.random.shuffle(idx)
    cutoff_2 = int(0.7 * len(idx)) # 10% of the benchmark set as testing data
    idx_test = idx[cutoff_2:]
    idx_train = idx[:cutoff_2]
    print(len(idx_test))
    np.savetxt('../data/ind.' + dataset_str + '.test.index', idx_test, fmt='%d')
    return idx_test


In [4]:
idx = TestIndexSave('TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond')

5425
1628


In [5]:
idx

array([3947, 4140,  777, ...,  224, 5388, 1575])

# Validation - Test Dataset

In [14]:
def ValTestIndex(dataset_str):
    cwd = os.getcwd()
    names = ['x', 'y', 'graph', 'sequences', 'proteases', 'labelorder']
    features, y_arr, adj_ls, sequences, proteases, labelorder = tuple(load_input(dataset_str, names, input_type='train'))

    idx = np.arange(y_arr.shape[0])
    np.random.shuffle(idx)
    cutoff = int(0.8 * len(idx)) # 10% of the benchmark set as testing data
    cutoff_2 = int(0.9 * len(idx))
    idx_test = idx[cutoff_2:]
    idx_train = idx[:cutoff]
    idx_val = idx[cutoff: cutoff_2]
    print(len(idx_train), len(idx_val), len(idx_test))
    np.savetxt('../data/ind.' + dataset_str + '.trisplit.test.index', idx_test, fmt='%d')
    np.savetxt('../data/ind.' + dataset_str + '.trisplit.val.index', idx_val, fmt='%d')
    return idx_val, idx_test


In [15]:
for data in ['TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_WT_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_A171T_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_D183A_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_Triple_binary_10_ang_aa_energy_7_energyedge_5_hbond']:
    print(data)
    idx_val, idx_test = ValTestIndex(data)
    

TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond
4340 542 543
HCV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond
31399 3925 3925
HCV_WT_binary_10_ang_aa_energy_7_energyedge_5_hbond
5873 734 735
HCV_A171T_binary_10_ang_aa_energy_7_energyedge_5_hbond
10564 1320 1321
HCV_D183A_binary_10_ang_aa_energy_7_energyedge_5_hbond
9491 1186 1187
HCV_Triple_binary_10_ang_aa_energy_7_energyedge_5_hbond
5470 684 684


# Training/Val/Test Data Simple Statistics

In [2]:
def raw_data_statistics(dataset):
    idy = pkl.load(open('/scratch/cl1205/protease-gcnn-pytorch/data/ind.{}.y'.format(dataset), 'rb'))
    test_index = np.loadtxt('/scratch/cl1205/protease-gcnn-pytorch/data/ind.{}.test.index'.format(dataset), dtype=int)
    sequences = pkl.load(open('/scratch/cl1205/protease-gcnn-pytorch/data/ind.{}.sequences'.format(dataset), 'rb'))
    test_index = np.loadtxt('/scratch/cl1205/protease-gcnn-pytorch/data/ind.{}.trisplit.test.index'.format(dataset), dtype=int)
    val_index =  np.loadtxt('/scratch/cl1205/protease-gcnn-pytorch/data/ind.{}.trisplit.val.index'.format(dataset), dtype=int)
    test_index = np.sort(test_index)
    val_index = np.sort(val_index)
    y_val = idy[val_index]
    y_test = idy[test_index]
    
    train_mask = np.array([i not in test_index and i not in val_index for i in range(idy.shape[0])]) 
    y_train = idy[train_mask]
    print(np.array(sequences)[train_mask][0]) # 1 0 means cleaved
    print('Train:| Cleaved {} | Uncleaved {} | Total {} |'.format(np.sum(y_train==[1,0],axis=0)[0], 
                                                                  np.sum(y_train==[0,1],axis=0)[0], 
                                                                  y_train.shape[0]))
    print('Val:| Cleaved {} | Uncleaved {} | Total {} |'.format(np.sum(y_val==[1,0],axis=0)[0], 
                                                                np.sum(y_val==[0,1],axis=0)[0], 
                                                                y_val.shape[0]))
    print('Test:| Cleaved {} | Uncleaved {} | Total {} |'.format(np.sum(y_test==[1,0],axis=0)[0], 
                                                                 np.sum(y_test==[0,1],axis=0)[0], 
                                                                 y_test.shape[0]))
    print('Total: {}'.format(y_train.shape[0] + y_val.shape[0] + y_test.shape[0]))
    return sequences, test_index, val_index

In [32]:
raw_data_statistics('TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond')

N176I_TAHLYFQSGT.pdb
Train:| Cleaved 2111 | Uncleaved 2229 | Total 4340 |
Val:| Cleaved 259 | Uncleaved 283 | Total 542 |
Test:| Cleaved 238 | Uncleaved 305 | Total 543 |
Total: 5425


In [3]:
sequences, test_index, val_index = raw_data_statistics('TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond')

N176I_TAHLYFQSGT.pdb
Train:| Cleaved 2111 | Uncleaved 2229 | Total 4340 |
Val:| Cleaved 259 | Uncleaved 283 | Total 542 |
Test:| Cleaved 238 | Uncleaved 305 | Total 543 |
Total: 5425


In [7]:
test_sequences = np.array(sequences)[test_index]

In [9]:
for seq in test_sequences:
    if seq == 'WT_TENLYFQSGT.pdb':
        print('in test')
val_sequences = np.array(sequences)[val_index]
for seq in val_sequences:
    if seq == 'WT_TENLYFQSGT.pdb':
        print('in_val')

In [27]:
np.sum(y_train==[1,0], axis=0)[0]

2111

In [33]:
for data in ['TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_WT_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_A171T_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_D183A_binary_10_ang_aa_energy_7_energyedge_5_hbond',
            'HCV_Triple_binary_10_ang_aa_energy_7_energyedge_5_hbond']:
    print(data)
    raw_data_statistics(data)

TEV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond
N176I_TAHLYFQSGT.pdb
Train:| Cleaved 2111 | Uncleaved 2229 | Total 4340 |
Val:| Cleaved 259 | Uncleaved 283 | Total 542 |
Test:| Cleaved 238 | Uncleaved 305 | Total 543 |
Total: 5425
HCV_all_binary_10_ang_aa_energy_7_energyedge_5_hbond
AYYYEPC.ASHL
Train:| Cleaved 10404 | Uncleaved 20995 | Total 31399 |
Val:| Cleaved 1319 | Uncleaved 2606 | Total 3925 |
Test:| Cleaved 1338 | Uncleaved 2587 | Total 3925 |
Total: 39249
HCV_WT_binary_10_ang_aa_energy_7_energyedge_5_hbond
AYYYEPC.ASHL
Train:| Cleaved 1566 | Uncleaved 4307 | Total 5873 |
Val:| Cleaved 175 | Uncleaved 559 | Total 734 |
Test:| Cleaved 191 | Uncleaved 544 | Total 735 |
Total: 7342
HCV_A171T_binary_10_ang_aa_energy_7_energyedge_5_hbond
AETMLLC.ASHL
Train:| Cleaved 2905 | Uncleaved 7659 | Total 10564 |
Val:| Cleaved 366 | Uncleaved 954 | Total 1320 |
Test:| Cleaved 373 | Uncleaved 948 | Total 1321 |
Total: 13205
HCV_D183A_binary_10_ang_aa_energy_7_energyedge_5_hbond
ADLMDDC.AS