# Transformer
Some parts of the code were taken from [nus-cs4248x github repository](https://github.com/chrisvdweth/nus-cs4248x/blob/master/3-neural-nlp/Section%204.2%20-%20Transformer%20Machine%20Translation.ipynb)

Baseline Transformer with Bert encoding. Decoding strategies tried are greedy decoding and beam search. Loads model if model exists and saves model at every epoch.

In [None]:
import torch
import math
import os
import pandas as pd
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
from torch.nn.utils.rnn import pad_sequence
from timeit import default_timer as timer
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertModel

In [None]:
use_cuda = torch.cuda.is_available()
DEVICE = torch.device("cuda:0" if use_cuda else "cpu")

print("Available device: {}".format(DEVICE))

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# embedding = BertModel.from_pretrained("bert-base-cased").to(DEVICE)
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        # self.embedding = BertModel.from_pretrained("bert-base-cased")
        self.emb_size = emb_size

    def forward(self, tokens: Tensor, padding_mask=None):
        # [seq_len, batch_size, embedding size]
        # encoding = tokens.long().T
        # if padding_mask != None:
        #     mask = ~padding_mask
        #     return embedding(encoding, attention_mask=mask.long()).last_hidden_state.transpose(0, 1) * math.sqrt(self.emb_size)
        # return embedding(encoding).last_hidden_state.transpose(0, 1) * math.sqrt(self.emb_size)
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src, src_padding_mask))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg, tgt_padding_mask))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask.bool()

PAD_IDX = 0

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
def train_epoch(model, optimizer, criterion, train_iter):
    model.train()
    losses = 0
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE) # removed collate_fn
    for src, tgt in tqdm(train_dataloader, total=len(list(train_dataloader))):
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        # Remove last entry an all target sequences (typically PAD, can be EOS)
        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        optimizer.zero_grad()
        
        # Remove <SOS> from all targets
        tgt_out = tgt[1:, :]
        
        loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model, val_iter):
    model.eval()
    losses = 0

    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE) # removed collate_fn

    for src, tgt in tqdm(val_dataloader, total=len(list(val_dataloader))):
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        # logits torch.Size([seq len, batch size, 28996])

        tgt_out = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [None]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = 28996 # len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = 28996 # len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 768 # bert
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

# Create model
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

# Define optimizer
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

# Initialize weights
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# Move model to device (ideally GPU, otherwise CPU)
model_path = "baseline_bert_encoding_len50.pth"
transformer = transformer.to(DEVICE)
START_EPOCH = 1
if os.path.exists(model_path):
    print('----------Resume training----------')
    loaded_obj = torch.load(model_path)
    transformer.load_state_dict(loaded_obj['model_state_dict'])
    optimizer.load_state_dict(loaded_obj['optimizer_state_dict'])
    START_EPOCH = loaded_obj['epoch'] + 1
    print(loaded_obj['epoch'], loaded_obj['loss'])

# Define loss function
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)


In [None]:
NUM_EPOCHS = 20
MAX_INPUT_LENGTH = 50
train = pd.read_csv('train_dataset.csv', header=None, names=["input_tokens", "output_tokens"])
train['input_tokens'] = train['input_tokens'].apply(lambda x: list(map(int, x.split())))
train['output_tokens'] = train['output_tokens'].apply(lambda x: list(map(int, x.split())))
train = train[train['input_tokens'].apply(len) <= MAX_INPUT_LENGTH]
print(train.shape)
train_original = train['input_tokens'].apply(lambda data: torch.tensor(data))
train_corrected = train['output_tokens'].apply(lambda data: torch.tensor(data))
train_original = pad_sequence(train_original, batch_first=True)
train_corrected = pad_sequence(train_corrected, batch_first=True)
train_dataset = TensorDataset(train_original, train_corrected)

val = pd.read_csv('val_dataset.csv', header=None, names=["input_tokens", "output_tokens"])
val['input_tokens'] = val['input_tokens'].apply(lambda x: list(map(int, x.split())))
val['output_tokens'] = val['output_tokens'].apply(lambda x: list(map(int, x.split())))
print(val.shape)
val_original = val['input_tokens'].apply(lambda data: torch.tensor(data))
val_corrected = val['output_tokens'].apply(lambda data: torch.tensor(data))
val_original = pad_sequence(val_original, batch_first=True)
val_corrected = pad_sequence(val_corrected, batch_first=True)
val_dataset = TensorDataset(val_original, val_corrected)
for epoch in range(START_EPOCH, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer, criterion, train_dataset)
    end_time = timer()
    val_loss = evaluate(transformer, val_dataset) # TODO: should change to val iter
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time (total) = {(end_time - start_time):.3f}s"))
    torch.save({
      'epoch': epoch,
      'model_state_dict': transformer.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': train_loss,
    }, model_path)

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
SOS_IDX = 101 # bert
EOS_IDX = 102 # bert

# Method to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

def beam_search(model, src, src_mask, beam_width, max_length, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    # Initialize beams with start-of-sequence token
    beams = [(ys, 0)]  # (sequence, score)
    
    for _ in range(max_length-1):
        memory = memory.to(DEVICE)
        candidates = []
        for sequence, score in beams:
            tgt_mask = (generate_square_subsequent_mask(sequence.size(0)).type(torch.bool)).to(DEVICE)
            # Generate next token probabilities using the model
            out = model.decode(sequence, memory, tgt_mask)
            out = out.transpose(0, 1)
            next_token_probs = model.generator(out[:, -1])
            # Get top-k candidate tokens and their probabilities
            topk_tokens = torch.topk(next_token_probs, beam_width)
            for token, token_prob in zip(topk_tokens.indices[0], topk_tokens.values[0]):
                # Extend sequence with candidate token
                candidate_sequence = torch.cat([sequence, torch.ones(1, 1).type_as(src.data).fill_(token.item())], dim=0)
                # Calculate score (e.g., log probability)
                candidate_score = score + token_prob.item()
                candidates.append((candidate_sequence, candidate_score))
        
        # Prune beams to keep top-k sequences
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        
        # Check for end-of-sequence token
        eos_indices = [i for i, seq in enumerate(beams) if seq[0][-1] == EOS_IDX]
        if eos_indices:
            # Return sequence with highest score
            return beams[eos_indices[0]][0]
    
    # Return sequence with highest score
    return beams[0][0]


In [None]:
# Actual method to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = torch.Tensor(tokenizer.encode(src_sentence)).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    # tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens+5, start_symbol=SOS_IDX).flatten()
    tgt_tokens = beam_search(model, src, src_mask, beam_width=10, max_length=num_tokens+5, start_symbol=SOS_IDX).flatten()
    return tokenizer.decode(tgt_tokens, skip_special_tokens=True)
    # return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<SOS>", "").replace("<EOS>", "")

In [None]:
print(translate(transformer, "I am working as a teacher in Spanish school with children aged between 8 and 14 ."))
print(translate(transformer, "I completely agree with you , maybe each word you say ."))
print(translate(transformer, "Conclusión"))
print(tokenizer.decode(tokenizer.encode("Conclusión")))

In [None]:
# spacy used to tokenize output, same as BEA shared task
!pip install spacy
!python -m spacy download en_core_web_sm
import spacy

# Load the en_core_web_sm-1.2.0 model
nlp = spacy.load('en_core_web_sm')

In [None]:
def test(model: torch.nn.Module, test_name):
    outputs = []
    with open(test_name, 'r') as f:
        lines = f.readlines()
        for line in lines:
            output = translate(model, line)
            tokens = [token.text for token in nlp(output)]
            outputs.append(f"{' '.join(tokens)}\n")

    print("done testing")
    with open("eval_baseline_bert_encoding_beam10.txt", "w") as f:
        f.write("".join(outputs))
        print("Done writing to file")


In [None]:
test(transformer, "eval_orig.txt")