## Mount Google Drive



In [None]:
from google.colab import drive
drive.mount('/gdrive/')
!ls /gdrive

Mounted at /gdrive/
MyDrive


In [None]:
# Change working directory
import os
os.chdir("/content/drive/MyDrive/Colab Notebooks")
!pwd

/content/drive/MyDrive/Colab Notebooks


## Download Dataset

In [None]:
# English-Japanese Translation Dataset
!wget https://nlp.stanford.edu/projects/jesc/data/raw.tar.gz
!tar -xf raw.tar.gz

--2021-12-25 06:10:32--  https://nlp.stanford.edu/projects/jesc/data/raw.tar.gz
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 102198198 (97M) [application/x-gzip]
Saving to: ‘raw.tar.gz’


2021-12-25 06:10:38 (17.8 MB/s) - ‘raw.tar.gz’ saved [102198198/102198198]



In [None]:
!head raw/raw

you are back, aren't you, harold?	あなたは戻ったのね ハロルド?
my opponent is shark.	俺の相手は シャークだ。
this is one thing in exchange for another.	引き換えだ ある事とある物の
yeah, i'm fine.	もういいよ ごちそうさま ううん
don't come to the office anymore. don't call me either.	もう会社には来ないでくれ 電話もするな
looks beautiful.	きれいだ。
get him out of here, because i will fucking kill him.	連れて行け 殺しそうだ わかったか?
you killed him!	殺したのか!
okay, then who?	わぁ~! いつも すみません。 いいのよ~。
it seems a former employee...	カンパニーの元社員が


## Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
import numpy as np
import multiprocessing
from math import sin, cos, sqrt

## Data Generation

In [None]:
# Build dictionary to map tokens to indices
import string
data = 'raw/raw' # path to the data file
dataset = []
enc_dic = {}
dec_dic = {}
dec_dic['<sos>'] = len(dec_dic)
dec_dic['<eos>'] = len(dec_dic)
enc_dic['<pad>'] = len(enc_dic)
dec_dic['<pad>'] = len(dec_dic)

max_seq1_len = 31
max_seq2_len = 30

with open(data) as f:
    for i, line in enumerate(f):
        s1, s2 = line[:-1].split('\t')
        enc_seq = []
        dec_seq = []
        for token in s1.split(' '):
            tokens_to_add = []
            if token[-1] in string.punctuation:
                tokens_to_add.append(token[:-1])
                tokens_to_add.append(token[-1])
            else:
                tokens_to_add.append(token)
            
            for t in tokens_to_add:
                if t not in enc_dic:
                    enc_dic[t] = len(enc_dic)
            enc_seq = enc_seq + tokens_to_add
        
        for char in s2:
            
            if char not in dec_dic:
                    dec_dic[char] = len(dec_dic)
            dec_seq.append(char)
        
        if len(enc_seq) <= max_seq1_len and len(dec_seq) <= max_seq2_len:
            dataset.append((enc_seq, dec_seq))

In [None]:
# Define Dataset
class EngJpDataset(torch.utils.data.Dataset):
    def __init__(self, data, transforms=None):
        super(EngJpDataset, self).__init__()
        self.data = data
        self.transforms = transforms


    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        seq1, seq2 = self.data[idx]
        enc_input = [enc_dic[token] for token in seq1]

        # For decoder inputs, add <sos> and <eos> tokens
        dec_input = [dec_dic['<sos>']] + [dec_dic[token] for token in seq2] + [dec_dic['<eos>']]
        
        # Pad both encoder and decoder inputs
        for i in range(len(enc_input), max_seq1_len):
            enc_input.append(enc_dic['<pad>'])

        for i in range(len(dec_input), max_seq2_len + 2):
            dec_input.append(dec_dic['<pad>'])
        
        # Convert to tensors
        seq1 = torch.tensor(enc_input)
        seq2 = torch.tensor(dec_input)
        return seq1, seq2

## Define Modules

In [None]:
# A lookup table mapping integer index to an embedding
class Embedding(nn.Module):
    def __init__(self, num_unique_tokens, embed_dim):
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(num_unique_tokens, embed_dim)

    def forward(self, x):
        # x -> N x Seq 
        x = self.embed(x)
        # x -> N x Seq x embed_dim
        return x

In [None]:
# A method for encoding positional information using sin and cos waves
class PositionalEncoding(nn.Module):
    def __init__(self, max_sequence_len, embed_dim, device):
        super(PositionalEncoding, self).__init__()

        self.position_matrix = torch.zeros((max_sequence_len, embed_dim), device=device, requires_grad=False)
        for pos in range(max_sequence_len):
            for i in range(0, embed_dim, 2):
                self.position_matrix[pos][i] = sin(pos / 10000**(2*i/embed_dim))
                self.position_matrix[pos][i + 1] = cos(pos / 10000**(2*(i + 1)/embed_dim))

        self.position_matrix = self.position_matrix.unsqueeze(0)

    def forward(self, x):
        sequence_len = x.shape[1]
        x = x + self.position_matrix[:,:sequence_len]
        return x

In [None]:
# Normalization module
class Normalization(nn.Module):
    def __init__(self, embed_dim, method='L'):
        super(Normalization, self).__init__()
        self.method = method

        if self.method not in ['L', 'B']:
            self.method = 'L'
        self.gamma = nn.Parameter(torch.ones(embed_dim))
        self.beta = nn.Parameter(torch.zeros(embed_dim))
        self.eps = 1e-7

    def forward(self, x):
        # x -> N x seq x embed_dim

        # Batch Norm
        if self.method == 'B':
            pass
        # Layer Norm 
        elif self.method == 'L':
            mu = torch.mean(x, dim=-1, keepdim=True)
            var = torch.var(x, dim=-1, keepdim=True)
            x = (x - mu) / torch.sqrt(var + self.eps)
            x = self.gamma * x + self.beta

        return x

In [None]:
# Attention mechanism
class Attention(nn.Module):
    def __init__(self, embed_dim, latent_dim, device):
        super(Attention, self).__init__()

        self.W_Q = nn.Linear(embed_dim, latent_dim, device=device)
        self.W_K = nn.Linear(embed_dim, latent_dim, device=device)
        self.W_V = nn.Linear(embed_dim, latent_dim, device=device)

        self.scale = sqrt(latent_dim)

    def forward(self, x, mask=None):
        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        scores = torch.matmul(Q, torch.transpose(K,1,2)) / self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -float('inf'))

        scores = F.softmax(scores, dim=-1)

        return torch.matmul(scores, V)

class MultiHeadedAttention(nn.Module):
    def __init__(self, embed_dim, latent_dim, device, num_heads):
        super(MultiHeadedAttention, self).__init__()

        if num_heads < 1:
            num_heads = 1

        self.attention_heads = nn.ModuleList([Attention(embed_dim, latent_dim, device) for _ in range(num_heads)])
        self.W = nn.Linear(latent_dim * num_heads, embed_dim) # bring back to the original input dimensions

    def forward(self, x, mask=None):
        heads = []
        for i, head in enumerate(self.attention_heads):
            heads.append(head(x, mask))
        z = torch.concat(heads, dim=-1)
        return self.W(z)


> ### Encoder





In [None]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, latent_dim, device, num_heads, dropout=0.2):
        super(Encoder, self).__init__()
        
        self.multi_head_attention = MultiHeadedAttention(embed_dim, latent_dim, device, num_heads)
        self.normalize1 = Normalization(embed_dim)
        self.normalize2 = Normalization(embed_dim)

        self.dp1 = nn.Dropout(dropout)
        self.dp2 = nn.Dropout(dropout)

        self.feed_forward = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        # x -> Batch_size x Seq_len x embed_dim
        
        x = x + self.dp1(self.multi_head_attention(x, mask))
        x = self.normalize1(x)
        x = x + self.dp2(F.relu(self.feed_forward(x)))
        x = self.normalize2(x)
        return x

In [None]:
class EncoderStack(nn.Module):
    def __init__(self, num_unique_tokens, max_sequence_len, embed_dim, latent_dim, device, num_heads, num_encoders, dropout=0.2):
        super(EncoderStack, self).__init__()

        self.embedding = Embedding(num_unique_tokens, embed_dim)
        self.pos_encoding = PositionalEncoding(max_sequence_len, embed_dim, device)
        self.encoders = nn.ModuleList([Encoder(embed_dim, latent_dim, device, num_heads, dropout) for _ in range(num_encoders)])

        self.W_K = nn.Linear(embed_dim, embed_dim)
        self.W_V = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for i, encoder in enumerate(self.encoders):
            x = encoder(x, mask)
        K = self.W_K(x)
        V = self.W_V(x)
        return K, V


> ### Decoder



In [None]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, latent_dim, device, num_heads, dropout=0.2):
        super(Decoder, self).__init__()
        
        self.multi_head_attention = MultiHeadedAttention(embed_dim, latent_dim, device, num_heads)
        self.normalize1 = Normalization(embed_dim)
        self.normalize2 = Normalization(embed_dim)
        self.normalize3 = Normalization(embed_dim)

        self.dp1 = nn.Dropout(dropout)
        self.dp2 = nn.Dropout(dropout)
        self.dp3 = nn.Dropout(dropout)

        self.W_Q = nn.Linear(embed_dim, embed_dim)
        self.feed_forward = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, K, V, enc_mask=None, dec_mask=None):
        # x -> Batch_size x Seq_len x embed_dim
        
        # Self-attention uses decoder mask
        x = x + self.dp1(self.multi_head_attention(x, dec_mask))
        x = self.normalize1(x)

        # Cross-attention
        Q = self.W_Q(x)
        scores = torch.matmul(Q, torch.transpose(K,1,2)) / sqrt(x.shape[-1])

        # Use encoder mask
        if enc_mask is not None:
            scores = scores.masked_fill(enc_mask == 0, -float('inf'))

        scores = F.softmax(scores, dim=-1)

        x = x + self.dp2(torch.matmul(scores, V))
        x = self.normalize2(x)

        x = x + self.dp3(F.relu(self.feed_forward(x)))
        x = self.normalize3(x)
        return x

In [None]:
class DecoderStack(nn.Module):
    def __init__(self, num_unique_tokens, max_sequence_len, embed_dim, latent_dim, device, num_heads, num_decoders, dropout=0.2):
        super(DecoderStack, self).__init__()

        self.embedding = Embedding(num_unique_tokens, embed_dim)
        self.pos_encoding = PositionalEncoding(max_sequence_len, embed_dim, device)
        self.decoders = nn.ModuleList([Decoder(embed_dim, latent_dim, device, num_heads, dropout) for _ in range(num_decoders)])

        self.linear = nn.Linear(embed_dim, num_unique_tokens)

    def forward(self, x, K, V, enc_mask=None, dec_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for i, decoder in enumerate(self.decoders):
            x = decoder(x, K, V, enc_mask, dec_mask)
        return self.linear(x) # softmax activation applied by the cross entropy loss function

> ### Transformer Module



In [None]:
class Transformer(nn.Module):
    def __init__(self, num_unique_input_tokens, num_unique_output_tokens, max_sequence_len, embed_dim, latent_dim, device, num_heads, num_stacks, dropout=0.2):
        super(Transformer, self).__init__()
        self.encoder_stack = EncoderStack(num_unique_input_tokens, max_sequence_len, embed_dim, latent_dim, device, num_heads, num_stacks, dropout)
        self.decoder_stack = DecoderStack(num_unique_output_tokens, max_sequence_len, embed_dim, latent_dim, device, num_heads, num_stacks, dropout)
        

    def forward(self, enc_seq, dec_seq, enc_mask, dec_mask):
        K, V = self.encoder_stack(enc_seq, enc_mask)
        out = self.decoder_stack(dec_seq, K, V, enc_mask, dec_mask)
        return out

# Model Training

In [None]:
EPOCHS = 100
BATCH_SIZE = 25

full_dataset = EngJpDataset(dataset)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device', device)

model = Transformer(len(enc_dic), len(dec_dic), 80, 512, 64, device, 8, 6, dropout=0.1)
optim = torch.optim.Adam(model.parameters(), lr=0.01)

#t = transforms.Compose([transforms.ToPILImage(mode='F'), transforms.Resize(32), transforms.ToTensor()])
#t = transforms.Compose([transforms.ToTensor()])

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])


#num_workers = multiprocessing.cpu_count()
#print('num workers:', num_workers)

kwargs = {'num_workers': 1, #num_workers,
          'pin_memory': True} if use_cuda else {}

train = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
                                            shuffle=True, **kwargs)
test = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                            shuffle=True, **kwargs)
model = model.to(device)
model.train()

for epoch in range(EPOCHS):
  for idx, (enc_input, dec_input) in enumerate(train):
    optim.zero_grad()
    
    # enc_input, dec_input -> N x Seq
    enc_input, dec_input = enc_input.to(device), dec_input.to(device)

    enc_mask = (enc_input != enc_dic['<pad>']).unsqueeze(1)
    dec_mask = (dec_input[:, :-1] != dec_dic['<pad>']).unsqueeze(1) & (torch.triu(torch.ones((max_seq2_len + 1, max_seq2_len + 1)), diagonal=1)==0).to(device)

    output = model(enc_input, dec_input[:, :-1], enc_mask, dec_mask)
    
    # Remove all the <pad> tokens as we don't want to penalize the model for not learning the paddings
    ground_truth = dec_input[:,1:].reshape(-1)
    output = output.view(-1, output.shape[-1]) # output -> (N * Seq) x Vocab_size 
    
    loss = F.cross_entropy(output[ground_truth != dec_dic['<pad>']], ground_truth[ground_truth != dec_dic['<pad>']])
    loss.backward()
    optim.step()

    if idx % 10 == 0:
      print('Epoch:', epoch)
      print('Loss:', loss.item())
      #torch.save(model, '/content/drive/MyDrive/Colab Notebooks/models/' + 'transformer.pt')

print('Finished training model')

# Testing

In [None]:
model.eval()

idx_to_token = {}

for key, value in dec_dic.items():
    idx_to_token[value] = key

enc_input = 'hi nice to meet you'.split(' ')
enc_input = torch.tensor([enc_dic[token] for token in enc_input]).unsqueeze(0).to(device)

dec_input = torch.tensor([dec_dic['<sos>']]).unsqueeze(0).to(device)

i = 0
while True:
    output = model(enc_input, dec_input, None, None)

    idx = torch.argmax(output[0][-1], dim=-1).item()
    token = idx_to_token[idx]
    print(token)

    i += 1
    if i > 30 or token == '<eos>':
        break
    dec_input = torch.cat((dec_input, torch.tensor([[idx]]).to(device)), dim=-1)
