In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from torch.autograd import Variable

import torch.utils.data

import pandas as pd
import numpy as np

import pdb
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm

from torch.nn.utils import clip_grad_norm
import os

use_cuda = True

GPU_ids = [2]
GPU_id = GPU_ids[0]

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(ID) for ID in GPU_ids])
GPU_ids = list(range(0, len(GPU_ids)))

torch.cuda.device_count()

1

In [2]:
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer

import torch
import torch.utils.data 
from torch.autograd import Variable

import pandas as pd
import numpy as np

from seq2loc.utils import *

import math


class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, sequence_path = './data/uniprot.csv', max_seq_len = 25, mlb = None):
        self.max_seq_len = max_seq_len
        
        column_name = 'GO id'
        
        df = pd.read_csv(sequence_path)

        df = df.dropna(subset=[column_name])
        df = df.reset_index(drop=True)

        if mlb is None:
            mlb = MultiLabelBinarizer()
            mlb.fit(df[column_name].str.split(';').tolist())
        
        some_hot_targets = mlb.transform(df[column_name].str.split(';').tolist())
        df['GOsomehot'] = pd.Series(tuple(some_hot_targets.astype(np.float32)))

        df_somehot = df['GOsomehot']
        
        df_sequences = df['Sequence']

        #Trim out the bonkers long sequences
        seq_lengths = [len(seq) for seq in df_sequences]
        max_len = np.percentile(seq_lengths, 99.5)
        
        keep_inds = seq_lengths <= max_len
        
        df_sequences = df_sequences[keep_inds].reset_index(drop=True)

        self.somehot = df_somehot[keep_inds].reset_index(drop=True)
        self.pd_sequences = df_sequences
        
        self.mlb = mlb
        
    def __getitem__(self, index):
    
        seq = self.pd_sequences[index]
        
        if len(seq) <= self.max_seq_len:
            tensor_indices = lineToIndices(seq)
        else:
            start = np.random.randint(len(seq)-self.max_seq_len)
            
            tensor_indices = lineToIndices(seq[start:(start+self.max_seq_len)])
            
        somehot = self.somehot[index]
            
        return tensor_indices, torch.Tensor(somehot)
        
    def __len__(self):
        return len(self.pd_sequences)
    
class PaddedSequenceDataset(torch.utils.data.Dataset):
    #returns tensors in <batch, len, channels> order
        
    def __init__(self, sequenceDataset, GPU_id = None):
        self.sequenceDataset = sequenceDataset
        self.GPU_id = GPU_id
            
    def __getitem__(self, indices):

        sequence_tensor_indices = list()
        somehots = list()

        #get all the sequences as a list of character indices
        for index in indices:
            tensor_indices, somehot = self.sequenceDataset[index]
            sequence_tensor_indices += [Variable(tensor_indices)]
            somehots += [Variable(somehot)]

        #get the longest sequence
        ind = np.argmax([len(s) for s in sequence_tensor_indices])

        tensor_len = len(sequence_tensor_indices[ind])
        nchars = n_letters()

        #pad all shorter sequences with the stop character
        for i in range(len(sequence_tensor_indices)):

            my_inds = sequence_tensor_indices[i]
            my_len = my_inds.shape[0]

            sequence_tensor_indices[i] = torch.unsqueeze(torch.cat([my_inds, Variable(torch.ones(tensor_len - my_len).long()*(nchars-1))]), 1)

        sequence_tensor_indices = torch.cat(sequence_tensor_indices, 1)

        somehots = torch.stack(somehots)

        sequence_tensors = Variable(indicesToTensor(sequence_tensor_indices))
        sequence_tensors = sequence_tensors.transpose(1,0).transpose(2,1)

        if self.GPU_id is not None:
            sequence_tensors = sequence_tensors.cuda(self.GPU_id)
            somehots = somehots.cuda(self.GPU_id)


        return sequence_tensors, somehots            

    def __len__(self):
        return len(self.sequenceDataset)

In [None]:
import os
from decimal import Decimal

import seq2loc.models
import seq2loc.utils as utils

from sklearn.metrics import precision_recall_fscore_support, average_precision_score

from tensorboardX import SummaryWriter


GPU_id = 0
LR = 0.001
N_EPOCHS = 500
hidden_size = 256
batch_size = 128
num_gru_layers = 2

N_LETTERS = utils.n_letters()

#really big number to trim sequences to
max_seq_len = 4000

ds = PaddedSequenceDataset(SequenceDataset('./data/hpa_data_resized_train.csv', max_seq_len = max_seq_len), GPU_id = GPU_id)
mlb = ds.sequenceDataset.mlb
ds_validate = PaddedSequenceDataset(SequenceDataset('./data/hpa_data_resized_validate.csv', max_seq_len = max_seq_len, mlb = mlb), GPU_id = GPU_id)
ds_test = PaddedSequenceDataset(SequenceDataset('./data/hpa_data_resized_test.csv', max_seq_len = max_seq_len, mlb = mlb), GPU_id = GPU_id)


criterion = torch.nn.BCEWithLogitsLoss()

enc = seq2loc.models.SeqConvResidClassifier(N_LETTERS, 33, kernel_size = 4, layers_deep = 10, ch_intermed = 128, pooling_type='avg').cuda(GPU_id)

opt = optim.Adam(enc.parameters(), lr = LR)

writer = SummaryWriter()

save_dir = './classifier_conv_resid/{}'

iteration = 0

for epoch in range(N_EPOCHS):

    epoch_inds = utils.get_epoch_inds(len(ds), batch_size)
    pbar = tqdm(epoch_inds)

    epoch_losses = list()
    
    for batch in pbar:
        opt.zero_grad()

        x, y = ds[batch]        
        
        y_hat  = enc(x)
        
        loss = criterion(y_hat, y)
     
        loss.backward()
        opt.step()

        losses_np = np.squeeze(loss.detach().cpu().numpy())
        
        epoch_losses += [losses_np]
        pbar.set_description('%.4E' % Decimal(str(losses_np)))
        
        writer.add_scalars(save_dir.format('train'), {'loss': losses_np}, iteration)
        
        iteration += 1
    
    ###########################
    ### Write out test results
    ###########################
    enc.train(False)
    
    epoch_inds = utils.get_epoch_inds(len(ds_validate), batch_size)

    y_list = list()
    y_hat_list = list()
    losses_test = list()
    
#     losses_test = 0

    for batch in epoch_inds:

        x, y = ds_validate[batch]        

        with torch.no_grad():
            y_hat  = nn.Sigmoid()(enc(x))

        loss = criterion(y_hat, y)/y.shape[0]
        losses_test += [np.squeeze(loss.detach().cpu().numpy())]


        y_list += [y]
        y_hat_list += [y_hat]

    writer.add_scalars(save_dir.format('test'), {'loss': np.mean(losses_test)}, iteration)

    y = torch.cat(y_list).cpu().numpy()
    y_hat = torch.cat(y_hat_list).cpu().numpy()

    thresh = 0.5


    true_labs = y
    pred_acts = y_hat
    pred_labs = np.zeros_like(pred_acts)
    pred_labs[pred_acts > thresh] = 1

    df_stats = pd.DataFrame()
    for i,col in enumerate(mlb.classes_):

        # get true labels and predicted activations
        true_labs_col = true_labs[:,i]
        pred_acts_col = pred_acts[:,i]
        pred_labs_col = pred_labs[:,i]

        # compute one against all prec + recall stats
        p,r,f,_ = precision_recall_fscore_support(true_labs_col,pred_labs_col, average='binary')

        writer.add_scalars(save_dir.format('test_stats'), {'precision_{}'.format(col): p,
                                                    'recall_{}'.format(col): r,
                                                    'f1score_{}'.format(col): f,
                                                    'auprc_{}'.format(col): average_precision_score(true_labs_col, pred_acts_col),
#                                                     'support_{}'.format(col)] = int(true_labs_col.sum())
                                                    }, iteration)


    enc.train(True) 

    pbar.set_description('%.4E' % Decimal(str(np.mean(epoch_losses))))
    


HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




  'precision', 'predicted', average, warn_for)
  recall = tps / tps[-1]


HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




  'recall', 'true', average, warn_for)


HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))




HBox(children=(IntProgress(value=0, max=465), HTML(value='')))

In [None]:
x = Variable(torch.ones(1,1,10))
x[:,:,4:6] = 2

print(x)

kernel_size = 2
stride = 2
padding = 0

layer = torch.nn.AvgPool1d(kernel_size, stride=stride, padding=padding, ceil_mode=False, count_include_pad=False)

print(layer(x))
print(layer(x).shape)

In [None]:
x = Variable(torch.ones(1,1,100))
x[:,:,4:6] = 2

print(x)

kernel_size = 4
stride = 2
padding = kernel_size/2 - 1

layer = torch.nn.Conv1d(1, 2, kernel_size= kernel_size, stride=stride, padding=padding)

print(layer(x))
print(layer(x).shape)

In [None]:
N_LETTERS + 128

In [None]:
df_stats

In [None]:
y = torch.cat(y_list).cpu().numpy()
y_hat = torch.cat(y_hat_list).cpu().numpy()

for i in range(y.shape[1]):
    label_inds = np.where(y[:,i]>0)[0]
    
    print(np.mean(y[label_inds,i] == (y_hat[label_inds,i]>0.5)))


In [None]:
(y_hat[label_inds,i]>0.5)

In [None]:
import matplotlib.pyplot as plt

print(x.shape)
print()

plt.plot(losses)
plt.show()

In [None]:
np.equal(y.cpu().data.numpy(), nn.Sigmoid()(y_hat).cpu().data.numpy()>0.5)

In [None]:
np.log(np.exp(3.219125824868201)/28)

In [None]:
enc.train(False)
dec.train(False)

x_tmp, _ = ds[[np.random.randint(len(ds))]]

# x = torch.unsqueeze(x[:,0,:],1)
batch_size_tmp = x_tmp.shape[1]

hidden = enc.initHidden(batch_size_tmp).cuda(GPU_id)
out, hidden = enc(x_tmp, hidden)

#input the stop character to the stream    
out = Variable(stopChar(batch_size_tmp)).cuda(GPU_id)


#     pdb.set_trace()
out_chars = list()

for i in range(x_tmp.shape[0]):

    out, hidden = dec(out, hidden) 
    
    out_chars += [tensorToChar(out)[0,0]]
    
enc.train(True)
dec.train(True)

print(''.join(np.hstack(tensorToChar(x_tmp))))
print(''.join(out_chars))

In [None]:
3E-4*np.log(15)

In [None]:
6.77e-05/np.log(15)