In [2]:
import math
import torch
import time
import os
import numpy as np
from collections import Counter, defaultdict
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax

In [3]:
alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}
data = ['ATCGACTACG','CTGACTGAT']
src = Tensor(list(map(lambda x: alphabet_dict[x], data[0]))).reshape(1,-1).long()
tgt = Tensor(list(map(lambda x: alphabet_dict[x], data[1]))).reshape(1,-1).long()

In [4]:
ntokens = 10
d_model = 512
dropout = 0.1
max_len = 1000
nhead = 8
encoder_layer_nums = 6
decoder_layer_nums = 6
dim_ff = 512

In [5]:
class PositionalEmbedding(nn.Module): #done
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor):
        x = x + self.pe[0, :x.size(1), :].requires_grad_(False)
        output = self.dropout(x)
        return output

In [6]:
emb = nn.Embedding(ntokens, d_model)
pos_encode = PositionalEmbedding(d_model, dropout, max_len)

In [None]:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
memory = torch.rand(10, 32, 128)
tgt = torch.rand(9, 32, 128)
out = decoder_layer(tgt, memory)

In [8]:
# type 1
transformer1 = nn.Transformer(d_model, nhead, encoder_layer_nums, decoder_layer_nums, dim_ff, dropout)

In [25]:
# type 2
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout,
                    batch_first = True)
encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout,
                    batch_first = True)
decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)


In [22]:
# test data
src = torch.LongTensor([
    [0,8,3,5,5,9,6,1,2,2,2],
    [0,6,6,8,9,1,2,2,2,2,2]
])

tgt = torch.LongTensor([
    [0,8,3,5,5,9,6,1,2],
    [0,6,6,8,9,1,2,2,2],
])

print('src\'s shape is {}\ntgt\'s shape is {}'\
      .format(src.shape, tgt.shape))

src's shape is torch.Size([2, 11])
tgt's shape is torch.Size([2, 9])


In [23]:
def get_pad(tokens):
    key_padding_mask = torch.zeros(tokens.size())
    key_padding_mask[tokens == 2] = -torch.inf
    return key_padding_mask

src_key_padding_mask = get_pad(src)
tgt_key_padding_mask = get_pad(tgt)

tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))

transformer = nn.Transformer(d_model, batch_first = True)

res = transformer(emb(src),emb(tgt),
          tgt_mask = tgt_mask,
          src_key_padding_mask = src_key_padding_mask,
          tgt_key_padding_mask = tgt_key_padding_mask
          )

In [7]:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_ff, dropout,
                    batch_first = True)
encoder = TransformerEncoder(encoder_layer, encoder_layer_nums)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout,
                    batch_first = True)
decoder = TransformerDecoder(decoder_layer, decoder_layer_nums)

In [10]:
encode_out = encoder(emb(src))
output = decoder(emb(tgt), encode_out, tgt_mask = tgt_mask, tgt_key_padding_mask = tgt_key_padding_mask)

In [16]:
output = decoder(pos_encode(emb(tgt)), pos_encode(emb(src)))

In [41]:
src = torch.randint(10, (50, 100))
tgt = torch.randint(10, (50, 50))
tgt_key_mask = torch.zeros(tgt.size())
for i in range(tgt.size(0)):
  tgt_key_mask[i][3+i:] = -torch.inf

tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))

In [89]:
def expand_dim(data, IO):
  '''
  To expand input and output data into specific dimensions
  '''

  row, col = data.size()
  assert row < 50 and col < 500, 'Index out of range.'

  if IO:# expand input to 50 x 500
    col_expand = 7 * torch.ones(row, 500 - col)
    row_expand = 7 * torch.ones(50 - row, 500)

  elif not IO:# expand output to 50 x 50
    col_expand = 7 * torch.ones(row, 50 - col)
    row_expand = 7 * torch.ones(50 - row, 50)

  res = torch.cat( (torch.cat((data, col_expand), axis = 1), row_expand), axis = 0)

  return res

def get_padding_mask(data):
  mask = torch.zeros(data.size())
  mask[data == 7] = -torch.inf
  return mask

tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))



In [91]:
d = torch.randint(10,(20,30))
res = expand_dim(d, 0)
print(res.shape)

torch.Size([50, 50])


In [73]:
m1 = torch.arange(6).reshape(2,3)
m2 = torch.arange(6,12).reshape(2,3)

In [86]:
torch.cat( (m1,torch.cat((m1,m2), axis = 1)),axis = 1)

tensor([[ 0,  1,  2,  0,  1,  2,  6,  7,  8],
        [ 3,  4,  5,  3,  4,  5,  9, 10, 11]])