In [1]:
#%% library import
import numpy as np
import pandas as pd
import networkx as nx
import torch as tc
import torch
import pprint
import pickle
import time

from rdkit.Chem import AllChem as chem
from rdkit.Chem import Draw as draw
from torch.autograd import Variable
from sklearn.utils import shuffle
from sklearn.preprocessing import Normalizer
from matplotlib import pyplot as plt
from matplotlib import image as img
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from functools import partial

In [2]:
#%% Load dataset and cuda
dataset = pd.read_csv("datasets/mini-dataset.csv")
datalen = len(dataset)
cuda = tc.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Using device: cuda
GeForce RTX 2080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [3]:
#%% protein-ligand-kiba split
protein = dataset.loc[:, 'uniprotID']
ligand = dataset.loc[:, 'chemblID']
kiba = list(dataset['KIBA'])
# del dataset

In [4]:
#%% protein sequence load
f = open('datasets/dictionaries/seq_codeinfo.txt', 'rb')
seq_voc, seq_len = pickle.load(f)
f.close()

sequence = np.zeros(((2**13)+(2**8), 4128))
for i, s in enumerate(protein):
    sequence[i] = seq_voc[s]
    
sequence_len = np.zeros(((2**13)+(2**8),))
for i, s in enumerate(protein):
    sequence_len[i] = seq_len[s]

In [5]:
#%% ligand image load
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

image = np.zeros(((2**13)+(2**8), 1, 280, 280))
for i, l in enumerate(ligand):
    im = img.imread("datasets/dictionaries/ligand_img/{}.png".format(l))
    image[i][0] = rgb2gray(im)[10:290, 10:290]

In [6]:
#%% dataset zip
revised_dataset = list(zip(sequence, sequence_len, image, kiba))
shuffled_dataset = shuffle(revised_dataset); del revised_dataset
trainset = shuffled_dataset[:2**13]
validset = shuffled_dataset[2**13:(2**13)+(2**8)]
del shuffled_dataset

In [7]:
#%% Make collate func.
def collate(samples):
    # The input `samples` is a list of pairs [(graph, label),(graph, label)].
    sequences, sequence_lens, images, labels = map(list, zip(*samples))
    return tc.LongTensor(sequences).cuda(), tc.LongTensor(sequence_lens), tc.tensor(images, dtype=tc.float).cuda(), tc.tensor(labels).cuda()

In [8]:
#%% network module 선언
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, bias=False) # verify bias false
        self.bn = nn.BatchNorm2d(out_planes,
                                 eps=0.001, # value found in tensorflow
                                 momentum=0.1, # default pytorch value
                                 affine=True)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        out = self.relu(x)
        return out
    

class InceptionResnet_Ablock(nn.Module):
    def __init__(self, scale=1.0):
        super(InceptionResnet_Ablock, self).__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(4, 4, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(4, 4, kernel_size=1, stride=1),
            BasicConv2d(4, 4, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(4, 4, kernel_size=1, stride=1),
            BasicConv2d(4, 6, kernel_size=3, stride=1, padding=1),
            BasicConv2d(6, 8, kernel_size=3, stride=1, padding=1)
        )

        self.conv2d = nn.Conv2d(16, 4, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        out = self.relu(out)
        return out
    

class Reduction_Ablock(nn.Module):
    def __init__(self):
        super(Reduction_Ablock, self).__init__()

        self.branch0 = BasicConv2d(4, 6, kernel_size=3, stride=2)

        self.branch1 = nn.Sequential(
            BasicConv2d(4, 4, kernel_size=1, stride=1),
            BasicConv2d(4, 4, kernel_size=3, stride=1, padding=1),
            BasicConv2d(4, 6, kernel_size=3, stride=2)
        )

        self.branch2 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class InceptionResnet_Bblock(nn.Module):
    def __init__(self, scale=1.0):
        super(InceptionResnet_Bblock, self).__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(16, 16, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(16, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 20, kernel_size=(1,7), stride=1, padding=(0,3)),
            BasicConv2d(20, 24, kernel_size=(7,1), stride=1, padding=(3,0))
        )

        self.conv2d = nn.Conv2d(40, 16, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        out = self.relu(out)
        return out
    

class Reduction_Bblock(nn.Module):
    def __init__(self):
        super(Reduction_Bblock, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(16, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 24, kernel_size=3, stride=2)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(16, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 20, kernel_size=3, stride=2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(16, 16, kernel_size=1, stride=1),
            BasicConv2d(16, 18, kernel_size=(3,1), stride=1, padding=(1,0)),
            BasicConv2d(18, 20, kernel_size=3, stride=2)
        )

        self.branch3 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out

    
class InceptionResnet_Cblock(nn.Module):
    def __init__(self, scale=1.0, noReLU=False):
        super(InceptionResnet_Cblock, self).__init__()

        self.scale = scale
        self.noReLU = noReLU

        self.branch0 = BasicConv2d(80, 80, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(80, 80, kernel_size=1, stride=1),
            BasicConv2d(80, 93, kernel_size=(1,3), stride=1, padding=(0,1)),
            BasicConv2d(93, 106, kernel_size=(3,1), stride=1, padding=(1,0))
        )

        self.conv2d = nn.Conv2d(186, 80, kernel_size=1, stride=1)
        if not self.noReLU:
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        out = out * self.scale + x
        return out


class SqueezeExcitation(nn.Module):
    def __init__(self, channel):
        super(SqueezeExcitation, self).__init__()
        
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(channel, channel // 2),
            nn.ReLU(),
            nn.Linear(channel // 2, channel),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        b, c, _, _ = x.size()
        out = self.squeeze(x).view(b, c)
        out = self.excitation(out).view(b, c, 1, 1)
        out = x * out.expand_as(x)
        return out

In [9]:
#%% learning module 선언
class Regressor(nn.Module):
    def __init__(self):
        super(Regressor, self).__init__()    # method 상속받고 __init__()은 여기서 하겠다.
        
        self.emlayer = nn.Embedding(21, 10)
        self.lslayer = nn.LSTM(10, 64, num_layers=1, bidirectional=True, batch_first=True)
        
        self.imlayers = nn.Sequential(
                        BasicConv2d(1, 4, kernel_size=4, stride=1),
                        InceptionResnet_Ablock(scale=0.17),
                        SqueezeExcitation(channel=4),
                        InceptionResnet_Ablock(scale=0.17),
                        SqueezeExcitation(channel=4),
                        InceptionResnet_Ablock(scale=0.17),
                        SqueezeExcitation(channel=4),
                        Reduction_Ablock(), 
                        InceptionResnet_Bblock(scale=0.10),
                        SqueezeExcitation(channel=16),
                        InceptionResnet_Bblock(scale=0.10),
                        SqueezeExcitation(channel=16),
                        InceptionResnet_Bblock(scale=0.10),
                        SqueezeExcitation(channel=16),
                        Reduction_Bblock(),
                        InceptionResnet_Cblock(scale=0.20),
                        SqueezeExcitation(channel=80),
                        InceptionResnet_Cblock(scale=0.20),
                        SqueezeExcitation(channel=80),
                        InceptionResnet_Cblock(scale=0.20),
                        SqueezeExcitation(channel=80)
                        )

        self.avgpool = nn.AvgPool2d(68, count_include_pad=False)
        self.regress = nn.Linear(208, 1)

    def forward(self, seq, seq_len, image):
        sorted_seq_len, sorted_idx = seq_len.sort(0, descending=True)
        seq = seq[sorted_idx]
        
        ls_i = self.emlayer(seq)
        ls_i = pack_padded_sequence(ls_i, sorted_seq_len.tolist(), batch_first=True)
        ls_h = torch.zeros(2, 16, 64).cuda()     # (num_layers * num_directions, batch, hidden_size)
        ls_c = torch.zeros(2, 16, 64).cuda()
        
        ls_o, (ls_h, ls_c) = self.lslayer(ls_i, (ls_h, ls_c))
        ls_o, _ = pad_packed_sequence(ls_o, batch_first=True)

        # 순서 다시 바로잡아주기        
        _, sortedback_idx = sorted_idx.sort(0)
        ls_o = ls_o[sortedback_idx]
        
        # 각 sample의 last output vector 추출
        for_o = []
        for idx, o in enumerate(ls_o):
            for_o.append(o[seq_len[idx]-1, :64].view(1, 64))
        for_o = torch.cat(for_o, 0)
        back_o = ls_o[:, 0, 64:]
        concat_o = tc.cat((for_o, back_o), axis=1)   # batch, hidden*2
        
        im_h = self.imlayers(image)
        im_h = self.avgpool(im_h)
        dim = 1
        for d in im_h.size()[1:]: #16, 4, 4
            dim = dim * d
        im_h = im_h.view(-1, dim)      # batch * 80
        
        cat = tc.cat((concat_o, im_h), axis=1).cuda()
       
        return self.regress(cat).cuda()

In [10]:
#%% Set hyperparameter
hp_d = {}

# FIXME: 학습 관련 하이퍼파라미터
hp_d['batch_size'] = 16
hp_d['num_epochs'] = 300

hp_d['init_learning_rate'] = 10 ** -3.70183
hp_d['eps'] = 10 ** -8.39981
hp_d['weight_decay'] = 10 ** -3.59967

In [11]:
#%% learning and validation
tr_data_loader = DataLoader(trainset, batch_size=hp_d['batch_size'], shuffle=False, collate_fn=collate)
va_data_loader = DataLoader(validset, batch_size=hp_d['batch_size'], shuffle=False, collate_fn=collate)

model = Regressor().to(torch.device('cuda:0'))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
loss_func = nn.MSELoss(reduction='mean').cuda()
optimizer = optim.Adam(model.parameters(), lr=hp_d['init_learning_rate'], 
    weight_decay=hp_d['weight_decay'], eps=hp_d['eps'])

print('tr_var:', np.var(np.array([s[3] for s in trainset])))
print('va_var:', np.var(np.array([s[3] for s in validset])))
print('total params:', total_params)

tr_epoch_losses = []
va_epoch_losses = []

start = time.time()

for epoch in range(hp_d['num_epochs']):                          #!! epoch-loop
    # training session
    model.train()
    tr_epoch_loss = 0

    for iter, (seq, seq_len, image, label) in enumerate(tr_data_loader):  #!! batch-loop
        prediction = model(seq, seq_len, image).view(-1).cuda()
        loss = loss_func(prediction, label).cuda()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tr_epoch_loss += loss.detach().item()
    
    tr_epoch_loss /= (iter + 1)
    print('Training epoch {}, loss {:.4f}'.format(epoch, tr_epoch_loss))
    tr_epoch_losses.append(tr_epoch_loss)

# ===========================================================================
    # validation session
    model.eval()
    va_epoch_loss = 0

    for iter, (seq, seq_len, image, label) in enumerate(va_data_loader):  # batch-loop
        prediction = model(seq, seq_len, image).view(-1).cuda()
        loss = loss_func(prediction, label).cuda()
        
        va_epoch_loss += loss.detach().item()
        
    va_epoch_loss /= (iter + 1)
    print('Validation epoch {}, loss {:.4f}'.format(epoch, va_epoch_loss))
    va_epoch_losses.append(va_epoch_loss)
    
end = time.time()
print('time elapsed:', end-start)

tr_var: 0.7881695087571461
va_var: 0.9048297662299764
total params: 336749
Training epoch 0, loss 20.5127
Validation epoch 0, loss 0.9358
Training epoch 1, loss 0.7319
Validation epoch 1, loss 1.1567
Training epoch 2, loss 0.6576
Validation epoch 2, loss 18.2908
Training epoch 3, loss 0.6105
Validation epoch 3, loss 2.3776
Training epoch 4, loss 0.5535
Validation epoch 4, loss 20.1493
Training epoch 5, loss 0.4932
Validation epoch 5, loss 1.6746
Training epoch 6, loss 0.4465
Validation epoch 6, loss 14.9868
Training epoch 7, loss 0.4084
Validation epoch 7, loss 15.0104
Training epoch 8, loss 0.3808
Validation epoch 8, loss 25.2584
Training epoch 9, loss 0.3601
Validation epoch 9, loss 24.9943
Training epoch 10, loss 0.3457
Validation epoch 10, loss 20.7846
Training epoch 11, loss 0.3320
Validation epoch 11, loss 10.1866
Training epoch 12, loss 0.3206
Validation epoch 12, loss 3073.2799
Training epoch 13, loss 0.3128
Validation epoch 13, loss 15.0069
Training epoch 14, loss 0.3049
Valid

Validation epoch 125, loss 26.2150
Training epoch 126, loss 0.2064
Validation epoch 126, loss 1.6775
Training epoch 127, loss 0.1995
Validation epoch 127, loss 1.8132
Training epoch 128, loss 0.2084
Validation epoch 128, loss 2.7229
Training epoch 129, loss 0.2119
Validation epoch 129, loss 91.6913
Training epoch 130, loss 0.2109
Validation epoch 130, loss 6678.9907
Training epoch 131, loss 0.2150
Validation epoch 131, loss 977.2655
Training epoch 132, loss 0.2173
Validation epoch 132, loss 63729.4697
Training epoch 133, loss 0.2151
Validation epoch 133, loss 2422.5505
Training epoch 134, loss 0.2105
Validation epoch 134, loss 3.7489
Training epoch 135, loss 0.2053
Validation epoch 135, loss 1.9392
Training epoch 136, loss 0.2013
Validation epoch 136, loss 3.8222
Training epoch 137, loss 0.2071
Validation epoch 137, loss 2.4300
Training epoch 138, loss 0.2061
Validation epoch 138, loss 3.1655
Training epoch 139, loss 0.2062
Validation epoch 139, loss 3.6976
Training epoch 140, loss 0.2

Validation epoch 249, loss 4.2385
Training epoch 250, loss 0.1575
Validation epoch 250, loss 1.3465
Training epoch 251, loss 0.1647
Validation epoch 251, loss 5.3815
Training epoch 252, loss 0.1702
Validation epoch 252, loss 4.5094
Training epoch 253, loss 0.1606
Validation epoch 253, loss 1.1437
Training epoch 254, loss 0.1588
Validation epoch 254, loss 2.1495
Training epoch 255, loss 0.1602
Validation epoch 255, loss 0.9572
Training epoch 256, loss 0.1529
Validation epoch 256, loss 2.1719
Training epoch 257, loss 0.1612
Validation epoch 257, loss 179.7643
Training epoch 258, loss 0.1622
Validation epoch 258, loss 209.8333
Training epoch 259, loss 0.1624
Validation epoch 259, loss 197.5218
Training epoch 260, loss 0.1673
Validation epoch 260, loss 106.1998
Training epoch 261, loss 0.1614
Validation epoch 261, loss 346.1191
Training epoch 262, loss 0.1608
Validation epoch 262, loss 37.7326
Training epoch 263, loss 0.1535
Validation epoch 263, loss 6.4393
Training epoch 264, loss 0.1683

In [13]:
np.save('lstm+chemception_tr_losses_v2', tr_epoch_losses)
np.save('lstm+chemception_va_losses_v2', va_epoch_losses)