In [117]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import spacy
import datasets
import time
from datasets import Dataset, DatasetDict, load_dataset
import torchtext
import tqdm
import evaluate
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using", device, "device")
torch.cuda.empty_cache()
import os
import sys
from datasets import load_dataset
from transformers import AutoTokenizer

Using cuda device


In [118]:
seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## SPM Tokeniser for Python ver < 3.10

In [119]:



def initialise_tokenizer(pretrained):
    return AutoTokenizer.from_pretrained(pretrained)

def tokenize_wrap(sentence, tokenizer, max_length=128, return_token_type_ids=False):
    try:
        tokenized = tokenizer.encode_plus(
            sentence,
            max_length=max_length,  # None for padding to max length in training set (1024), but max 512 for BERT
            truncation=True,
            padding="max_length",
            return_token_type_ids=return_token_type_ids,  # set True for BERT
            return_attention_mask=False,
            return_tensors=None, # "pt" for PyTorch, "tf" for TensorFlow, or None
        )  
        return tokenized

    except Exception as e:
        print(f'Error encountered during tokenisation of "{sentence}"')
        print(e)

        return None


# Uncomment to test on one sentence in command line
"""
raw_datasets = load_dataset("iwslt2017", "iwslt2017-en-zh") 
example = raw_datasets["train"][123456]["translation"]["zh"] # change language here

tokenizer = initialise_tokenizer(pretrained="njcay/bert_dataset_zh_tokenizer") # initialise pretrained tokeniser
tokens = tokenizer.tokenize(example)
tokenized = tokenize_wrap(example, tokenizer, return_token_type_ids=False) # tokenise

print(f"Example: \n{example}")
print(f"Tokens: \n{tokens}")
print(f"Tokenized: \n{tokenized}")
"""

zh_tokenizer = initialise_tokenizer(pretrained="njcay/bert_dataset_zh_tokenizer") 
en_tokenizer = initialise_tokenizer(pretrained="njcay/bert_dataset_en_tokenizer") 

tokens = zh_tokenizer.tokenize('这是')
tokenized = tokenize_wrap('这是', zh_tokenizer, return_token_type_ids=False, max_length=3) # tokenise
print(f"Tokens: \n{tokens}")
print(f"Tokenized: \n{tokenized}")

tokens = zh_tokenizer.tokenize('this is')
tokenized = tokenize_wrap('this is', en_tokenizer, return_token_type_ids=False, max_length=3) # tokenise

print(f"Tokens: \n{tokens}")
print(f"Tokenized: \n{tokenized}")

Tokens: 
['这', '是']
Tokenized: 
{'input_ids': [2, 4346, 3]}
Tokens: 
['this', 'is']
Tokenized: 
{'input_ids': [2, 183, 3]}


In [120]:
en_tokenizer.all_special_ids, en_tokenizer.all_special_tokens

([1, 3, 0, 2, 4], ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])

In [121]:
unk_g, eos_g, pad_g, bos_g, _ = en_tokenizer.all_special_ids

## Get Data

In [122]:
dataset = load_dataset("iwslt2017", "iwslt2017-en-zh")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

# first 10k rows
train_data = train_data.select(range(10000))
train_data

Dataset({
    features: ['translation'],
    num_rows: 10000
})

## Tokenise

In [123]:
unk_g, eos_g, pad_g, bos_g, _ = en_tokenizer.all_special_ids
max_length_g = 285

def tokenize_example(example):
    en_tokens = tokenize_wrap(example['translation']["en"], en_tokenizer, return_token_type_ids=False, max_length=max_length_g+2)
    zh_tokens = tokenize_wrap(example['translation']['zh'], zh_tokenizer, return_token_type_ids=False, max_length=max_length_g+2)
    return {"en_ids": en_tokens['input_ids'], "zh_ids": zh_tokens['input_ids']}

train_data = train_data.map(tokenize_example)
valid_data = valid_data.map(tokenize_example)
test_data = test_data.map(tokenize_example)

train_data

Map: 100%|██████████| 10000/10000 [00:03<00:00, 3330.47 examples/s]
Map: 100%|██████████| 879/879 [00:00<00:00, 3210.42 examples/s]
Map: 100%|██████████| 8549/8549 [00:02<00:00, 3440.37 examples/s]


Dataset({
    features: ['translation', 'en_ids', 'zh_ids'],
    num_rows: 10000
})

In [124]:
train_data[0]['zh_ids']

[2,
 4731,
 1462,
 4125,
 4125,
 5014,
 468,
 4483,
 2041,
 172,
 2986,
 3106,
 4731,
 1462,
 3721,
 1477,
 3564,
 2134,
 3245,
 257,
 2344,
 3224,
 1032,
 4346,
 214,
 737,
 194,
 2986,
 2152,
 324,
 5014,
 1734,
 3026,
 2083,
 4731,
 1462,
 1697,
 2649,
 172,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0

In [125]:
# Convert to tensors
data_type = "torch"
format_columns = ["en_ids", "zh_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

valid_data = valid_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

## Dataloaders

In [126]:
len(train_data[0]['en_ids'])

287

In [127]:
def get_collate_fn():
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_zh_ids = [example["zh_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, batch_first=True, padding_value=0)
        batch_zh_ids = nn.utils.rnn.pad_sequence(batch_zh_ids, batch_first=True, padding_value=0)
        batch = {
            "en_ids": batch_en_ids,
            "zh_ids": batch_zh_ids,
        }
        return batch

    return collate_fn

def get_data_loader(dataset, batch_size, shuffle=False):
    collate_fn = get_collate_fn()
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

batch_size = 16

train_data_loader = get_data_loader(train_data, batch_size, shuffle=True)
valid_data_loader = get_data_loader(valid_data, batch_size)
test_data_loader = get_data_loader(test_data, batch_size)

In [128]:
# for i in train_data_loader:
#     print(i['en_ids'])
#     print(len(i['en_ids'][1]))
#     break

## Architecture

In [129]:
class Encoder(nn.Module):
    def __init__(
        self, input_dim, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, dropout
    ):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, encoder_hidden_dim, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim) # encoder_hidden  * 2 if bidirectional!
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src = [src length, batch size]
        embedded = self.dropout(self.embedding(src))
        # embedded = [src length, batch size, embedding dim]
        outputs, hidden = self.rnn(embedded)
        # outputs = [src length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # outputs are always from the last layer
        # hidden [-2, :, : ] is the last of the forwards RNN
        # hidden [-1, :, : ] is the last of the backwards RNN
        # initial decoder hidden is final hidden state of the forwards and backwards
        # encoder RNNs fed through a linear layer
        hidden = torch.tanh(
            self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        )
        # hidden = torch.tanh(self.fc(hidden[-1, :, :])) # IF NOT BIDIRECTIONAL
        # outputs = [src length, batch size, encoder hidden dim * 2]
        # hidden = [batch size, decoder hidden dim]
        return outputs, hidden

In [130]:
class Attention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super().__init__()
        self.attn_fc = nn.Linear(
            (encoder_hidden_dim * 2) + decoder_hidden_dim, decoder_hidden_dim # IF BIDIRECTIONAL MUST encoder_hidden_dim * 2 
        )
        self.v_fc = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden = [batch size, decoder hidden dim]
        # encoder_outputs = [src length, batch size, encoder hidden dim * 2]
        batch_size = encoder_outputs.shape[1]
        src_length = encoder_outputs.shape[0]
        # repeat decoder hidden state src_length times
        hidden = hidden.unsqueeze(1).repeat(1, src_length, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # hidden = [batch size, src length, decoder hidden dim]
        # encoder_outputs = [batch size, src length, encoder hidden dim * 2]
        energy = torch.tanh(self.attn_fc(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy = [batch size, src length, decoder hidden dim]
        attention = self.v_fc(energy).squeeze(2)
        # attention = [batch size, src length]
        return torch.softmax(attention, dim=1)

In [131]:
class Decoder(nn.Module):
    def __init__(
        self,
        output_dim,
        embedding_dim,
        encoder_hidden_dim,
        decoder_hidden_dim,
        dropout,
        attention,
    ):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.GRU((encoder_hidden_dim * 2) + embedding_dim, decoder_hidden_dim) # IF BIDIRECTIONAL MUST encoder_hidden_dim * 2 
        self.fc_out = nn.Linear(
            (encoder_hidden_dim * 2) + decoder_hidden_dim + embedding_dim, output_dim # IF BIDIRECTIONAL MUST encoder_hidden_dim * 2 
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, encoder_outputs):
        # input = [batch size]
        # hidden = [batch size, decoder hidden dim]
        # encoder_outputs = [src length, batch size, encoder hidden dim * 2]
        input = input.unsqueeze(0)
        # input = [1, batch size]
        embedded = self.dropout(self.embedding(input))
        # embedded = [1, batch size, embedding dim]
        a = self.attention(hidden, encoder_outputs)
        # a = [batch size, src length]
        a = a.unsqueeze(1)
        # a = [batch size, 1, src length]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # encoder_outputs = [batch size, src length, encoder hidden dim * 2]
        weighted = torch.bmm(a, encoder_outputs)
        # weighted = [batch size, 1, encoder hidden dim * 2]
        weighted = weighted.permute(1, 0, 2)
        # weighted = [1, batch size, encoder hidden dim * 2]
        rnn_input = torch.cat((embedded, weighted), dim=2)
        # rnn_input = [1, batch size, (encoder hidden dim * 2) + embedding dim]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        # output = [seq length, batch size, decoder hid dim * n directions]
        # hidden = [n layers * n directions, batch size, decoder hid dim]
        # seq len, n layers and n directions will always be 1 in this decoder, therefore:
        # output = [1, batch size, decoder hidden dim]
        # hidden = [1, batch size, decoder hidden dim]
        # this also means that output == hidden
        assert (output == hidden).all()
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))
        # prediction = [batch size, output dim]
        return prediction, hidden.squeeze(0), a.squeeze(1)

In [132]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio):
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        batch_size = src.shape[1]
        trg_length = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_length, batch_size, trg_vocab_size).to(self.device)
        # encoder_outputs is all hidden states of the input sequence, back and forwards
        # hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)
        # outputs = [src length, batch size, encoder hidden dim * 2]
        # hidden = [batch size, decoder hidden dim]
        # first input to the decoder is the <sos> tokens
        input = trg[0, :]
        for t in range(1, trg_length):
            # insert input token embedding, previous hidden state and all encoder hidden states
            # receive output tensor (predictions) and new hidden state
            output, hidden, _ = self.decoder(input, hidden, encoder_outputs)
            # output = [batch size, output dim]
            # hidden = [n layers, batch size, decoder hidden dim]
            # place predictions in a tensor holding predictions for each token
            outputs[t] = output
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            # get the highest predicted token from our predictions
            top1 = output.argmax(1)
            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            input = trg[t] if teacher_force else top1
            # input = [batch size]
        return outputs

In [133]:
input_dim = len(en_tokenizer) # 16384
output_dim = len(zh_tokenizer)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
encoder_hidden_dim = 512
decoder_hidden_dim = 512
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

attention = Attention(encoder_hidden_dim, decoder_hidden_dim)

encoder = Encoder(
    input_dim,
    encoder_embedding_dim,
    encoder_hidden_dim,
    decoder_hidden_dim,
    encoder_dropout,
)

decoder = Decoder(
    output_dim,
    decoder_embedding_dim,
    encoder_hidden_dim,
    decoder_hidden_dim,
    decoder_dropout,
    attention,
)

model = Seq2Seq(encoder, decoder, device).to(device)

In [134]:
def init_weights(m):
    for name, param in m.named_parameters():
        if "weight" in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(16384, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn_fc): Linear(in_features=1536, out_features=512, bias=True)
      (v_fc): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(16384, 256)
    (rnn): GRU(1280, 512)
    (fc_out): Linear(in_features=1792, out_features=16384, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [135]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 44,198,400 trainable parameters


In [136]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [137]:
def train_fn(
    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device
):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        src = batch["en_ids"].to(device)
        trg = batch["zh_ids"].to(device)
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio)
        # output = [trg length, batch size, trg vocab size]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        # output = [(trg length - 1) * batch size, trg vocab size]
        trg = trg[1:].view(-1)
        # trg = [(trg length - 1) * batch size]
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [138]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            src = batch["en_ids"].to(device)
            trg = batch["zh_ids"].to(device)
            # src = [src length, batch size]
            # trg = [trg length, batch size]
            output = model(src, trg, 0)  # turn off teacher forcing
            # output = [trg length, batch size, trg vocab size]
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            # output = [(trg length - 1) * batch size, trg vocab size]
            trg = trg[1:].view(-1)
            # trg = [(trg length - 1) * batch size]
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

## Training

In [139]:
n_epochs = 10
clip = 1.0
teacher_forcing_ratio = 0.5
CHECKPOINT_DIR = "./model_checkpoints"
checkpoint_path = f"{CHECKPOINT_DIR}/model_checkpoint.pt"
if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

best_valid_loss = float("inf")

for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio,
        device,
    )
    valid_loss = evaluate_fn(
        model,
        valid_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': best_valid_loss}, f"{CHECKPOINT_DIR}/model_checkpoint.pt")
        print(f"Checkpoint saved at epoch {epoch+1}")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")

 10%|█         | 1/10 [06:31<58:40, 391.20s/it]

Checkpoint saved at epoch 1
	Train Loss:   6.055 | Train PPL: 426.107
	Valid Loss:   6.197 | Valid PPL: 491.137


 20%|██        | 2/10 [13:01<52:06, 390.86s/it]

Checkpoint saved at epoch 2
	Train Loss:   5.873 | Train PPL: 355.438
	Valid Loss:   6.158 | Valid PPL: 472.497


 30%|███       | 3/10 [19:32<45:34, 390.62s/it]

Checkpoint saved at epoch 3
	Train Loss:   5.844 | Train PPL: 345.287
	Valid Loss:   6.141 | Valid PPL: 464.554


 40%|████      | 4/10 [26:02<39:01, 390.33s/it]

	Train Loss:   5.834 | Train PPL: 341.668
	Valid Loss:   6.148 | Valid PPL: 467.768


 50%|█████     | 5/10 [32:31<32:30, 390.05s/it]

	Train Loss:   5.826 | Train PPL: 339.093
	Valid Loss:   6.156 | Valid PPL: 471.520


 60%|██████    | 6/10 [39:00<25:58, 389.68s/it]

	Train Loss:   5.822 | Train PPL: 337.529
	Valid Loss:   6.150 | Valid PPL: 468.897


 70%|███████   | 7/10 [45:29<19:28, 389.39s/it]

	Train Loss:   5.817 | Train PPL: 335.851
	Valid Loss:   6.156 | Valid PPL: 471.561


 80%|████████  | 8/10 [51:48<12:52, 386.07s/it]

	Train Loss:   5.810 | Train PPL: 333.517
	Valid Loss:   6.156 | Valid PPL: 471.474


 90%|█████████ | 9/10 [58:04<06:22, 382.91s/it]

	Train Loss:   5.813 | Train PPL: 334.502
	Valid Loss:   6.166 | Valid PPL: 476.379


100%|██████████| 10/10 [1:04:18<00:00, 385.82s/it]

	Train Loss:   5.805 | Train PPL: 332.075
	Valid Loss:   6.161 | Valid PPL: 473.687





## Train from best state

In [140]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time % 60)
    return elapsed_mins, elapsed_secs

In [141]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time % 60)
    return elapsed_mins, elapsed_secs

patience = 2  
no_improvement_epochs = 0

if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    best_valid_loss = checkpoint['loss'] 
    print(f"Current best val loss : {best_valid_loss}")
else:
    print("Training started, no checkpoint found.")

for epoch in tqdm.tqdm(range(n_epochs)):
    start_time = time.time()
    train_loss = train_fn(model, train_data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device)
    valid_loss = evaluate_fn(model, valid_data_loader, criterion, device)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_valid_loss}, f"{CHECKPOINT_DIR}/model_checkpoint.pt")
        print(f"Checkpoint saved at epoch {epoch+1}")
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1
        
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Valid Loss: {valid_loss:.3f} | Valid PPL: {np.exp(valid_loss):7.3f}')
    # Auto stop
    if no_improvement_epochs >= patience:
        print(f"No improvement in val loss for {patience} consecutive epochs. Stopping training.")
        break
            
print("Training completed.")

Loading checkpoint...


Current best val loss : 6.141078073328192


 10%|█         | 1/10 [06:15<56:22, 375.79s/it]

Epoch: 01 | Time: 6m 15s
	Train Loss: 5.836 | Train PPL: 342.539
	 Valid Loss: 6.146 | Valid PPL: 467.079


 10%|█         | 1/10 [12:30<1:52:36, 750.73s/it]

Epoch: 02 | Time: 6m 14s
	Train Loss: 5.830 | Train PPL: 340.226
	 Valid Loss: 6.143 | Valid PPL: 465.525
No improvement in val loss for 2 consecutive epochs. Stopping training.
Training completed.





## Eval

In [142]:
checkpoint = torch.load(r"model_checkpoints\model_checkpoint.pt")
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])  
else:
    model.load_state_dict(checkpoint) 

test_loss = evaluate_fn(model, test_data_loader, criterion, device)
print(f"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |")

| Test Loss: 5.926 | Test PPL: 374.572 |


In [143]:
# ([1, 3, 0, 2, 4], ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'])

test = tokenize_wrap("棉花糖必须放在最上面",zh_tokenizer, return_token_type_ids=False)
# zh_tokenizer.decode(test['input_ids'])

def remove_special_ids(input):
    special_ids = [1,3,0,2,4]    
    return [num for num in input if num not in special_ids]

zh_tokenizer.decode(remove_special_ids(test['input_ids']))

'棉 花 糖 必 须 放 在 最 上 面'

In [144]:
en_tokenizer.tokenize("hi how are you")

['hi', 'how', 'are', 'you']

In [145]:
def translate_sentence(
    sentence,
    model,
    # lower,
    device,
    max_output_length=max_length_g,
):
    model.eval()
    with torch.no_grad():
        # en_tokens = en_tokenizer.tokenize(sentence)
        ids = tokenize_wrap(sentence, en_tokenizer, return_token_type_ids=False, max_length=max_length_g)['input_ids'] # tokenise
        # ids = en_vocab.lookup_indices(en_tokens)
        tensor = torch.LongTensor(ids).unsqueeze(-1).to(device)
        encoder_outputs, hidden = model.encoder(tensor)
        inputs = [2]
        attentions = torch.zeros(max_output_length, 1, len(ids))
        for i in range(max_output_length):
            inputs_tensor = torch.LongTensor([inputs[-1]]).to(device)
            output, hidden, attention = model.decoder(
                inputs_tensor, hidden, encoder_outputs
            )
            attentions[i] = attention
            predicted_token = output.argmax(-1).item()
            inputs.append(predicted_token)
            if predicted_token == 3:
                break
        
        zh_tokens = zh_tokenizer.decode(remove_special_ids(inputs))
        translated = ' '.join(zh_tokens)
    return translated, remove_special_ids(ids), attentions[: len(inputs) - 1]

In [146]:
def plot_attention(sentence, translation, attention):
    fig, ax = plt.subplots(figsize=(10, 10))
    attention = attention.squeeze(1).numpy()
    cax = ax.matshow(attention, cmap="bone")
    ax.set_xticks(ticks=np.arange(len(sentence)), labels=sentence, rotation=90, size=15)
    translation = translation[1:]
    ax.set_yticks(ticks=np.arange(len(translation)), labels=translation, size=15)
    plt.show()
    plt.close()

In [147]:
sentence = test_data[2]['translation']["en"]
expected_translation = test_data[2]['translation']["zh"]

sentence, expected_translation

('The marshmallow has to be on top.', '棉花糖必须放在最上面')

In [148]:
translation, sentence_tokens, attention = translate_sentence(
    sentence,
    model,
    device,
)
translation

'的   的'

In [149]:
sentence_tokens

[135, 8667, 5568, 13948, 387, 144, 167, 180, 1125, 18]

In [150]:
# plot_attention(sentence_tokens, translation, attention)

In [151]:
test_data[0]['translation']

{'en': 'Several years ago here at TED, Peter Skillman  introduced a design challenge  called the marshmallow challenge.',
 'zh': '几年前，在TED大会上， Peter Skillman 介绍了一个设计挑战 叫做“棉花糖挑战”'}

In [158]:
translation, sentence_tokens, attention = translate_sentence(sentence,model,device)
en = [] 
zh = [] 
pred = []

import pandas as pd

for i in range(0,8549):
    en.append(test_data[i]['translation']['en'])
    zh.append(test_data[i]['translation']['zh'])
    t,s,a = translate_sentence(test_data[i]['translation']['en'],model,device)

    pred.append(t)

df = pd.DataFrame()
df['en'] = en
df['zh'] = zh 
df['pred'] = pred

df.to_csv("060424_2000_model_checkpoint.csv")

## Ignore

In [153]:
# checkpoint_path = f"{CHECKPOINT_DIR}/060424_0000_model_checkpoint.pt"

# if os.path.exists(checkpoint_path):
#     print("Loading checkpoint...")
#     checkpoint = torch.load(checkpoint_path)
#     print(f"1600_model val loss : {checkpoint['loss']}")

