In [None]:
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
tokenizer_en = AutoTokenizer.from_pretrained('bert-base-cased')
tokenizer_zh = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')

# hyperparameters
# model
vocab_size_en = len(tokenizer_en.vocab)
vocab_size_zh = len(tokenizer_zh.vocab)
print(vocab_size_en,vocab_size_zh)
max_length = 512        # max length of the input sequence
n_emb = 512             # embedding size
n_head = 8             # number of heads in multi-head attention
head_size = 64          # number of 'features' output by a single-head self-attention
n_blocks = 3            # number of blocks in a encoder or decoder
n_hidden = 1024
assert head_size*n_head == n_emb, ''

# training
num_epochs = 20
batch_size = 128
learning_rate = 8e-5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class embedding(nn.Module):
  def __init__(self,vocab_size,n_emb,max_len):
    super().__init__()
    self.n_emb = n_emb

    self.word_embedding = nn.Embedding(vocab_size,n_emb)

    pe = torch.zeros(max_len, n_emb)
    position = torch.unsqueeze(torch.arange(0, max_len, dtype=torch.float),dim=1)
    div_term = torch.exp(torch.arange(0, n_emb, 2).float() * (-math.log(10000.0) / n_emb))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)  # Add batch dimension
    self.register_buffer('pe', pe)

  def forward(self,x):
    word_emb = self.word_embedding(x) * math.sqrt(self.n_emb)         # [B,T,n_emb]
    pos_emb = self.pe[:,:word_emb.size(1),:]                           # [T,n_emb]
    return pos_emb + word_emb                                         # [B,T,n_emb]

In [None]:
# # test embedding
# x = torch.randint(low=0,high=20,size=(2,5))
# emb = embedding(vocab_size_en,n_emb,max_length)
# out = emb(x)
# out.shape

torch.Size([2, 5, 8])
torch.Size([1, 32, 8])


torch.Size([2, 5, 8])

In [None]:
class TorchTransformer(nn.Module):
  def __init__(self,n_emb,head_size,n_head,n_blocks,vocab_size_enc,vocab_size_dec,n_hidden,max_len):
    super().__init__()
    self.embedding_enc = embedding(vocab_size_enc,n_emb,max_len)
    self.embedding_dec = embedding(vocab_size_dec,n_emb,max_len)
    self.transformer = nn.Transformer(d_model=n_emb,nhead=n_head,num_encoder_layers=n_blocks,num_decoder_layers=n_blocks,dim_feedforward=n_hidden,batch_first=True)
    self.linear = nn.Linear(n_emb,vocab_size_dec)
    self.max_len = max_len

  def forward(self,seq_enc,seq_dec,mask_enc=None,mask_dec=None,mask_enc_padding=None,mask_dec_padding=None,memory_key_padding_mask=None):
    emb_enc = self.embedding_enc(seq_enc)
    emb_dec = self.embedding_dec(seq_dec)
    out = self.transformer(src=emb_enc,tgt=emb_dec,
                           src_mask=mask_enc,tgt_mask=mask_dec,
                           src_key_padding_mask=mask_enc_padding,tgt_key_padding_mask=mask_dec_padding,
                           memory_key_padding_mask=memory_key_padding_mask)
    out = self.linear(out)
    return out

  def encode(self,seq_enc,attention_mask_input,attention_mask_input_padding):
    emb_enc = self.embedding_enc(seq_enc)
    return self.transformer.encoder(emb_enc,attention_mask_input,attention_mask_input_padding)

  def decode(self,seq_dec,memory,mask_dec):
    emb_dec = self.embedding_dec(seq_dec)
    return self.transformer.decoder(emb_dec,memory,mask_dec)

  def generate(self,input_ids,attention_mask_input,attention_mask_input_padding,max_length=max_length,device=device):
    input_ids = input_ids.to(device)
    memory = self.encode(input_ids,attention_mask_input,attention_mask_input_padding).to(device)

    output_token = [101]
    while len(output_token)<=max_length:
      inputs = output_token
      seq_dec = torch.tensor(inputs).unsqueeze(0).to(device)
      mask_dec = (torch.tril(torch.ones([seq_dec.shape[1],seq_dec.shape[1]]))==0).to(device)
      logits = self.linear(self.decode(seq_dec,memory,mask_dec)[:,-1,:])
      probs = F.softmax(logits,dim=-1)
      temp = torch.argmax(probs,dim=-1).squeeze().item()
      output_token.append(temp)
      if temp == 102:
        break
    return output_token

In [None]:
# # test TorchTransformer
# import spacy
# nlp = spacy.load('en_core_web_sm')

# def generate(input_seq,test=False):
#   tokens = nlp(input_seq)
#   tokens = [stoi_en['<SOS>']] + [stoi_en[token.text.lower()] for token in tokens] + [stoi_en['<EOS>']]
#   seq_enc = torch.tensor(tokens).to(device)
#   seq_dec = torch.tensor([stoi_zh['<SOS>']]).to(device)
#   output = model.generate(seq_enc,seq_dec,test)
#   output = [itos_zh[o.item()] for o in output][1:]
#   return ''.join(output)

# model = TorchTransformer(n_emb,head_size,n_head,n_blocks,vocab_size_en,vocab_size_zh,n_hidden,max_length)
# input_seq = "harry potter"
# output_seq = generate(input_seq,test=True)