In [1]:
import math
import torch
import time
import os
import numpy as np
from collections import Counter, defaultdict
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer,\
    TransformerDecoder, TransformerDecoderLayer
from torch.nn.functional import softmax

In [2]:
alphabet_dict = {'A' : 1, 'T' : 2, 'C' : 3, 'G' : 4}
data = ['ATCGACTACG','CTGACTGAT']
src = Tensor(list(map(lambda x: alphabet_dict[x], data[0]))).reshape(1,-1).long()
tgt = Tensor(list(map(lambda x: alphabet_dict[x], data[1]))).reshape(1,-1).long()

In [10]:
ntokens = 10
d_model = 128
dropout = 0.1
max_len = 1000
nhead = 8
encoder_layer_nums = 6
decoder_layer_nums = 6
dim_ff = 100

In [11]:
class PositionalEmbedding(nn.Module): #done
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000) / d_model) )
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: Tensor):
        x = x + self.pe[0, :x.size(1), :].requires_grad_(False)
        output = self.dropout(x)
        return output

In [16]:
emb = nn.Embedding(ntokens, d_model)
pos_encode = PositionalEmbedding(d_model, dropout, max_len)
encoder = TransformerEncoder(TransformerEncoderLayer(d_model, nhead, dim_ff, dropout), encoder_layer_nums)
decoder = TransformerDecoder(TransformerDecoderLayer(d_model, nhead, dim_ff, dropout), decoder_layer_nums)


In [28]:
encoder(pos_encode(emb(src))).shape

torch.Size([1, 10, 128])

In [36]:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_ff, dropout)
memory = torch.rand(10, 32, 128)
tgt = torch.rand(9, 32, 128)
out = decoder_layer(tgt, memory)

In [20]:
decoder(pos_encode(emb(tgt)),encoder(pos_encode(emb(src))))

RuntimeError: shape '[1, 72, 16]' is invalid for input of size 1280

In [22]:
len(data[1])

9

In [24]:
transformer = nn.Transformer(d_model, batch_first = True)

In [25]:
res = transformer(pos_encode(emb(src)), pos_encode(emb(tgt)))

In [26]:
res.shape

torch.Size([1, 9, 128])

In [43]:
src = torch.LongTensor([
    [0,8,3,5,5,9,6,1,2,2,2],
    [0,6,6,8,9,1,2,2,2,2,2]
])

tgt = torch.LongTensor([
    [0,8,3,5,5,9,6,1,2],
    [0,6,6,8,9,1,2,2,2]
])

print('src\'s shape is {}\ntgt\'s shape is {}'\
      .format(src.shape, tgt.shape))

src's shape is torch.Size([2, 11])
tgt's shape is torch.Size([2, 9])


In [44]:
def get_pad(tokens):
    key_padding_mask = torch.zeros(tokens.size())
    key_padding_mask[tokens == 2] = -torch.inf
    return key_padding_mask

src_key_padding_mask = get_pad(src)
tgt_key_padding_mask = get_pad(tgt)

tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(-1))

transformer = nn.Transformer(d_model, batch_first = True)

res = transformer(emb(src),emb(tgt),
                  tgt_mask = tgt_mask,
                  src_key_padding_mask = src_key_padding_mask,
                  tgt_key_padding_mask = tgt_key_padding_mask
                 )

In [42]:
print(emb(src).shape)

torch.Size([2, 11, 128])
