In [1]:
#default_exp model.base

In [2]:
#export
import os
import numpy as np
import pandas as pd
import torch

from zipfile import ZipFile
from typing import IO, Type, Tuple

# ModelImplBase
The base model for operations of all models, it does not contains the model (torch.nn.Module), but just provides some common APIs, including `load()` to load models, `save()` to save modles, ...

In [3]:
#export
class ModelImplBase(object):
    def __init__(self, *args, **kargs):
        if 'GPU' in kargs:
            self.use_GPU(kargs['GPU'])
        else:
            self.use_GPU(True)

    def use_GPU(self, GPU=True):
        if GPU and not torch.cuda.is_available():
            GPU=False
        self.device = torch.device('cuda' if GPU else 'cpu')

    def init_train(self, lr=0.001):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_func = torch.nn.L1Loss()

    def build(self,
        model_class: Type[torch.nn.Module],
        lr = 0.001,
        *args, **kargs
    ):
        self.model = model_class(*args, **kargs)
        self.model.to(self.device)
        self.init_train(lr)

    def get_parameter_num(self):
        return np.sum([p.numel() for p in self.model.parameters()])

    def save(self, save_as):
        dir = os.path.dirname(save_as)
        if not dir: dir = './'
        if not os.path.exists(dir): os.makedirs(dir)
        torch.save(self.model.state_dict(), save_as)
        with open(save_as+'.txt','w') as f: f.write(str(self.model))

    def _load_model_file(self, stream):
        self.model.load_state_dict(torch.load(
            stream, map_location=self.device)
        )

    def load(
        self,
        model_file: Tuple[str, IO],
        model_name_in_zip: str = None,
        *args, **kargs
    ):
        if isinstance(model_file, str):
            # We may release all models (msms, rt, ccs, ...) in a single zip file
            if model_file.lower().endswith('.zip'):
                with ZipFile(model_file, 'rb') as model_zip:
                    with model_zip.open(model_name_in_zip) as pt_file:
                        self._load_model_file(pt_file)
            else:
                with open(model_file,'rb') as pt_file:
                    self._load_model_file(pt_file)
        else:
            self._load_model_file(model_file)

    def _train_one_batch(
        self, 
        targets:torch.Tensor, 
        *features
    ):
        self.optimizer.zero_grad()
        predicts = self.model(*[fea.to(self.device) for fea in features])
        cost = self.loss_func(predicts, targets.to(self.device))
        cost.backward()
        self.optimizer.step()
        return cost

    def _predict_one_batch(self,
        *features
    ):
        return self.model(*[fea.to(self.device) for fea in features])

    def train(self, batch_size=1024, epoch=20, *args, **kargs):
        raise NotImplementedError('train() function is not finished yet')

    def predict(self, batch_size=1024, *args, **kargs):
        raise NotImplementedError('predict() function is not finished yet')

## Here we provide some basic torch sub-models.

In [4]:
#export
def zero_param(*shape):
    return torch.nn.Parameter(torch.zeros(shape), requires_grad=False)

def xavier_param(*shape):
    x = torch.nn.Parameter(torch.empty(shape), requires_grad=False)
    torch.nn.init.xavier_uniform_(x)
    return x

init_state = xavier_param

def aa_embedding(embedding_size):
    return torch.nn.Embedding(27, embedding_size, padding_idx=0)

In [5]:
from alphadeep.model.featurize import parse_aa_indices
sequence = 'ACDEFGIK'

embedding_hidden = 4
embedding = aa_embedding(embedding_hidden)
x = embedding(torch.LongTensor(parse_aa_indices([sequence])))
x

tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000],
         [-0.5638,  0.8596, -0.0549, -0.0189],
         [ 0.9118,  1.0247,  3.3288,  0.7046],
         [-0.3026,  0.4432,  0.0145,  0.8551],
         [-1.5997,  2.1119,  0.9526, -0.5363],
         [ 0.1958,  0.9696, -0.3995, -0.9213],
         [-1.0171,  0.9094,  1.2428, -1.2176],
         [-0.4703, -1.4687,  0.7420, -0.6106],
         [-1.3442,  0.8077,  1.5787, -0.2008],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=<EmbeddingBackward>)

`SeqCNN` or TextCNN extracts sequence features using `nn.Conv1D` with different kernel sizes (3,5,7), and then concatenates the outputs of these Conv1Ds.

In [6]:
#export
class SeqCNN(torch.nn.Module):
    def __init__(self, embedding_hidden):
        super().__init__()

        self.cnn_short = torch.nn.Conv1d(
            embedding_hidden, embedding_hidden,
            kernel_size=3, padding=1
        )
        self.cnn_medium = torch.nn.Conv1d(
            embedding_hidden, embedding_hidden,
            kernel_size=5, padding=2
        )
        self.cnn_long = torch.nn.Conv1d(
            embedding_hidden, embedding_hidden,
            kernel_size=7, padding=3
        )

    def forward(self, x):
        x = x.transpose(1, 2)
        x1 = self.cnn_short(x)
        x2 = self.cnn_medium(x)
        x3 = self.cnn_long(x)
        return torch.cat((x, x1, x2, x3), dim=1).transpose(1,2)

`SeqInput` takes embedded sequences as the input, processes inputs using `SeqCNN`, and outputs RNN results

In [7]:
#export
class SeqLSTM(torch.nn.Module):
    def __init__(self, in_features, hidden, 
                 rnn_layer=2, bidirectional=True
        ):
        super().__init__()

        self.rnn_h0 = init_state(
            rnn_layer+rnn_layer*bidirectional,
            1, hidden
        )
        self.rnn_c0 = init_state(
            rnn_layer+rnn_layer*bidirectional,
            1, hidden
        )
        self.rnn = torch.nn.LSTM(
            input_size = in_features,
            hidden_size = hidden,
            num_layers = rnn_layer,
            batch_first = True,
            bidirectional = bidirectional,
        )

    def forward(self, x:torch.Tensor):
        h0 = self.rnn_h0.repeat(1, x.size(0), 1)
        c0 = self.rnn_c0.repeat(1, x.size(0), 1)
        x, _ = self.rnn(x, (h0,c0))
        return x
    

class SeqGRU(torch.nn.Module):
    def __init__(self, in_features, hidden, 
                 rnn_layer=2, bidirectional=True
        ):
        super().__init__()

        self.rnn_h0 = init_state(
            rnn_layer+rnn_layer*bidirectional, 
            1, hidden
        )
        self.rnn = torch.nn.GRU(
            input_size = in_features,
            hidden_size = hidden,
            num_layers = rnn_layer,
            batch_first = True,
            bidirectional = bidirectional,
        )

    def forward(self, x:torch.Tensor):
        h0 = self.rnn_h0.repeat(1, x.size(0), 1)
        x, _ = self.rnn(x, h0)
        return x

In [24]:
#export
class SeqAttentionSum(torch.nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.attn = torch.nn.Sequential(
            torch.nn.Linear(in_features, 1, bias=False),
            torch.nn.Softmax(dim=1),
        )
    
    def forward(self, x):
        attn = self.attn(x)
        return torch.sum(torch.mul(x, attn), dim=1)
        

In [26]:
import torch
x = [[1,2,3,4,5,6],[1,2,3,1,2,3]]
x = torch.LongTensor(x)
x = torch.nn.functional.one_hot(x, 7).float()
attn = SeqAttentionSum(7)
attn(x)

tensor([[0.0000, 0.1061, 0.1317, 0.1841, 0.1706, 0.1997, 0.2078],
        [0.0000, 0.2515, 0.3121, 0.4363, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SumBackward1>)

In [None]:
#export
class LinearDecoder(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        self.nn = torch.nn.Sequential(
            torch.nn.Linear(in_features, 128),
            torch.nn.PReLU(),
            torch.nn.Linear(128, 32),
            torch.nn.PReLU(),
            torch.nn.Linear(32, out_features)
        )

    def forward(self, x):
        return self.nn(x)

### Test these basic models

In [None]:
class SeqEncoder(torch.nn.Module):
    def __init__(self, embedding_hidden, dropout=0.2):
        super().__init__()
        
        self.dropout = torch.nn.Dropout(dropout)
        self.input_cnn = SeqCNN(embedding_hidden)
        self.input_nn = SeqLSTM(embedding_hidden*4, embedding_hidden*4, rnn_layer=2)
        self.attn_sum = SeqAttentionSum(embedding_hidden*8) #4 for MultiScaleCNN output, and 2 for BiRNN output

    def forward(self, x):
        x = self.input_cnn(x)
        x = self.input_nn(x)
        x = self.dropout(x)
        x = self.attn_sum(x)
        return x

In [None]:
encoder = SeqEncoder(embedding_hidden)
code = encoder(x)
code

tensor([[ 0.3658,  1.1804,  2.5408, -1.0424,  1.7857, -0.4062,  0.4979,  0.8366,
          0.4606,  1.4251, -0.3487, -0.4141, -1.8821, -1.1386,  1.0718,  0.4254,
         -0.4744, -1.6219, -0.7583, -1.6310,  1.3904,  0.4034,  0.6103, -1.9193,
          0.5267, -1.7338, -0.0106, -0.2590, -0.6797, -0.2902, -0.2994,  1.3211]],
       grad_fn=<SumBackward1>)

In [None]:
#hide
encoder

SeqEncoder(
  (dropout): Dropout(p=0.2, inplace=False)
  (input_cnn): SeqCNN(
    (cnn_short): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(1,))
    (cnn_medium): Conv1d(4, 4, kernel_size=(5,), stride=(1,), padding=(2,))
    (cnn_long): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,))
  )
  (input_nn): SeqLSTM(
    (rnn): LSTM(16, 16, num_layers=2, batch_first=True, bidirectional=True)
  )
  (attn_sum): SeqAttentionSum(
    (pos_weight): Sequential(
      (0): Linear(in_features=32, out_features=1, bias=False)
      (1): Softmax(dim=2)
    )
  )
)

In [None]:
class SeqDecoder(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()

        hidden = 128
        self.rnn_h0 = init_state(1, 1, hidden)
        self.rnn = torch.nn.GRU(
            input_size = in_features,
            hidden_size = hidden,
            num_layers = 1,
            batch_first = True,
            bidirectional = False,
        )

        self.output_nn = torch.nn.Linear(
            hidden, out_features, bias=False
        )
    
    def forward(self, x):
        h0 = self.rnn_h0.repeat(1, x.size(0), 1)
        x, h = self.rnn(x, h0)
        x = self.output_nn(x)
        return x

In [None]:
decoder = SeqDecoder(embedding_hidden*8, 2)

decode = decoder(code.unsqueeze(1).repeat(1, len(sequence), 1))
decode

tensor([[[-0.1205, -0.0147],
         [-0.1780, -0.0156],
         [-0.2151, -0.0192],
         [-0.2380, -0.0241],
         [-0.2519, -0.0286],
         [-0.2603, -0.0321],
         [-0.2655, -0.0346],
         [-0.2688, -0.0362]]], grad_fn=<UnsafeViewBackward>)