In [None]:
import torch
import torch.nn as nn
from torchvision import models
import os
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import warnings
import numpy as np
warnings.filterwarnings('ignore')

In [None]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    r"""
    3x3 convolution with padding
    - in_planes: in_channels
    - out_channels: out_channels
    - bias=False: BatchNorm에 bias가 포함되어 있으므로, conv2d는 bias=False로 설정.
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        r"""
         - inplanes: input channel size
         - planes: output channel size
         - groups, base_width: ResNext나 Wide ResNet의 경우 사용
        """
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
            
        # Basic Block의 구조
        self.conv1 = conv3x3(inplanes, planes, stride)  # conv1에서 downsample
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        # short connection
        if self.downsample is not None:
            identity = self.downsample(x)
            
        # identity mapping시 identity mapping후 ReLU를 적용합니다.
        # 그 이유는, ReLU를 통과하면 양의 값만 남기 때문에 Residual의 의미가 제대로 유지되지 않기 때문입니다.
        out += identity
        out = self.relu(out)

        return out

In [None]:
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=30, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        # default values
        self.inplanes = 1 # input feature map
        self.dilation = 1
        # stride를 dilation으로 대체할지 선택
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        
        r"""
        - 처음 입력에 적용되는 self.conv1과 self.bn1, self.relu는 모든 ResNet에서 동일 
        - 3: 입력으로 RGB 이미지를 사용하기 때문에 convolution layer에 들어오는 input의 channel 수는 3
        """
        # self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=4, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        r"""
        - 아래부터 block 형태와 갯수가 ResNet층마다 변화
        - self.layer1 ~ 4: 필터의 개수는 각 block들을 거치면서 증가(64->128->256->512)
        - self.avgpool: 모든 block을 거친 후에는 Adaptive AvgPool2d를 적용하여 (n, 512, 1, 1)의 텐서로
        - self.fc: 이후 fc layer를 연결
        """
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, # 여기서부터 downsampling적용
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.layer5 = self._make_layer(block, 1024, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.layer6 = self._make_layer(block, 2048, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.drop = nn.Dropout2d()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


    def _lazy_linear(out_features):
        super(nn.LazyLinear).__init__()
        lazy = nn.LazyLinear(out_features=out_features)
        return(lazy)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        r"""
        convolution layer 생성 함수
        - block: block종류 지정
        - planes: feature map size (input shape)
        - blocks: layers[0]와 같이, 해당 블록이 몇개 생성돼야하는지, 블록의 갯수 (layer 반복해서 쌓는 개수)
        - stride와 dilate은 고정
        """
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        
        # the number of filters is doubled: self.inplanes와 planes 사이즈를 맞춰주기 위한 projection shortcut
        # the feature map size is halved: stride=2로 downsampling
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        # 블록 내 시작 layer, downsampling 필요
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion # inplanes 업데이트
        # 동일 블록 반복
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):

        aa_len = x.shape[2]
        lazy1 = nn.LazyLinear(30)
        lazy2 = nn.LazyLinear(aa_len)


        # print('OG');print(x.shape)        
        x = self.conv1(x)
        # print('After Conv2D'); print(x.shape)        
        x = self.bn1(x)
        # print('After BN');print(x.shape)        
        x = self.relu(x)
        # print('After ReLU');print(x.shape)        
        x = self.layer1(x)
        # print('After Layer1');print(x.shape)        
        x = self.layer2(x)
        # print('After Layer2');print(x.shape)        
        x = self.layer3(x)
        # print('After Layer3');print(x.shape)        
        x = self.layer4(x)
        # print('After Layer4');print(x.shape)        
        x = F.avg_pool2d(x,2)
        # print('After Avgpool');print(x.shape)        
        x = torch.squeeze(x,-1)
        # print('After Squeeze');print(x.shape)        
        x = lazy1(x)
        # print('After Lazy1');print(x.shape)        
        x = x.permute(0,2,1)
        x = lazy2(x)
        # print('After Lazy2');print(x.shape)        
        x = self.drop(x)
        # print('After Drop');print(x.shape)        
        x = x.permute(0,2,1)
        # print('Final');print(x.shape)        
        # x = self.fc(x)
        return x

    def forward(self, x):
        return self._forward_impl(x)
        

In [293]:
class PSSMDataset(torch.utils.data.Dataset):
    def __init__(self,pssm_in,label_in):
        self.pssm_lst = sorted(os.listdir(pssm_in))
        self.pssm_files = [pssm_in+f for f in self.pssm_lst]
        self.label_files = [label_in+f.split('.')[0]+'.label.txt' for f in self.pssm_lst]
    def __len__(self):
        return len(self.label_files)
    def __getitem__(self, idx):
        # self.pssm_tensor = torch.tensor(pd.read_table(self.pssm_files[idx],index_col=0).astype(float).values) # Processed PSSM
        self.pssm_tensor = torch.tensor(pd.read_table(self.pssm_files[idx],index_col=0).iloc[:,0:20].astype(float).values) # Raw PSSM
        self.label_tensor = torch.tensor(list(map(int,open(self.label_files[idx],'r').readlines()[0].split(','))))
        return self.pssm_tensor, self.label_tensor

In [294]:
def my_collate(batch):
    data = [item[0] for item in batch]
    len_data = [len(item[0]) for item in batch]

    target = [item[1] for item in batch]
    dim_data = [d.size()[0] for d in data]
    max_seqlen = max(dim_data)
    data_transformed = [torch.nn.functional.normalize(F.pad(d.T,(0,int(max_seqlen - d.size()[0])),'constant',0.0)).T 
                        if d.size()[0]!=max_seqlen else d for d in data]
    target_transformed = [F.pad(d,(0,int(max_seqlen - d.size()[0])),'constant',0.0) 
                          if d.size()[0]!=max_seqlen else d for d in target]
    target_onehot = [torch.nn.functional.one_hot(t,num_classes = 30) for t in target_transformed]
    
    data_transformed = torch.nn.utils.rnn.pad_sequence(data_transformed,batch_first=True)
    target_result = torch.nn.utils.rnn.pad_sequence(target_onehot,batch_first=True)
    return [data_transformed, target_result,len_data]

In [295]:
batch_size = 16
pssminpath = '/Users/suhancho/data/Uniprot_metalbinding_challenge/processed_pssm/'
labelinpath = '/Users/suhancho/data/Uniprot_metalbinding_challenge/processed_label/'
train_ds = PSSMDataset(pssminpath,labelinpath)
train_dl = DataLoader(train_ds,batch_size = batch_size,shuffle = True,collate_fn=my_collate)
model = ResNet(BasicBlock,[2,2,2,2])

from torch import optim
device = 'cpu'

lr = 1e-05
num_epochs = 5
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss().to(device)
# loss_function = nn.MSELoss()
params = {
    'num_epochs':num_epochs,
    'optimizer':optimizer,
    'loss_function':loss_function,
    'device':device
}

In [296]:
# https://pseudo-lab.github.io/pytorch-guide/docs/ch04-1.html
from torch import softmax
soft = nn.Softmax(dim=1)

def train(model, params):
    loss_function=params["loss_function"]
    device=params["device"]

    for epoch in range(0, num_epochs):
        for X_batch, y_batch,lendata in train_dl:
            y_batch = torch.tensor(y_batch,dtype=torch.float)
            X_batch = X_batch.unsqueeze(1).float()

            optimizer.zero_grad() 
            outputs = model(X_batch)
            outputs = torch.tensor(outputs,dtype=torch.float)

            train_separate_loss = [loss_function(soft(o[0:i]),yb[0:i]) for o,yb,i in zip(outputs,y_batch,lendata)]
            train_separate_loss = np.median(train_separate_loss)

            train_separate_loss = torch.tensor(train_separate_loss)

            train_separate_loss.requires_grad_(True)
            train_separate_loss.backward()
            optimizer.step()
            print('Epoch: %d/%d, Train loss: %.6f' %(epoch+1, num_epochs, train_separate_loss.item()))


            # train_loss = loss_function(outputs, y_batch)
            # train_loss.requires_grad_(True)
            # train_loss.backward()
            # optimizer.step()
            # print(train_loss)

    #   # test accuracy 계산
    #     total = 0
    #     correct = 0
    #     accuracy = []
    #     for i, data in enumerate(test_dataloader, 0):
    #     inputs, labels = data
    #     inputs = inputs.to(device)
    #     labels = labels.to(device)

    #     # 결과값 연산
        # outputs = model(inputs)

        # _, predicted = torch.max(outputs.data, 1)
        # total += labels.size(0)
        # correct += (predicted == labels).sum().item()
        # test_loss = loss_function(outputs, labels).item()
        # accuracy.append(100 * correct/total)

        # 학습 결과 출력
        # print('Epoch: %d/%d, Train loss: %.6f, Test loss: %.6f, Accuracy: %.2f' %(epoch+1, num_epochs, train_loss.item(), test_loss, 100*correct/total))
        # print('Epoch: %d/%d, Train loss: %.6f' %(epoch+1, num_epochs, train_loss.item()))

In [298]:
train(model, params)

Epoch: 1/5, Train loss: 3.401473
Epoch: 1/5, Train loss: 3.401046
Epoch: 1/5, Train loss: 3.401270
Epoch: 1/5, Train loss: 3.401528
Epoch: 1/5, Train loss: 3.401429
Epoch: 1/5, Train loss: 3.401556
Epoch: 1/5, Train loss: 3.401413
Epoch: 1/5, Train loss: 3.401435
Epoch: 1/5, Train loss: 3.401595
Epoch: 1/5, Train loss: 3.401098
Epoch: 1/5, Train loss: 3.401277
Epoch: 1/5, Train loss: 3.401502


KeyboardInterrupt: 