In [59]:
from torch.nn import (
    TransformerEncoder,
    TransformerEncoderLayer,
    TransformerDecoder,
    TransformerDecoderLayer,
)
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.nn import Transformer
from torch import Tensor
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
import warnings
import random
import torch
import math
import yaml
import json
import os

warnings.filterwarnings("ignore")

embedding_num = 29
embedding_dim = 256
num_layers = 8
num_heads = 8
ff_dim = 1024
dropout = 0.1

sos = 27
eos = 28
pad = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATAPATH = "/home/zonghan/Documents/Courses/DL/DL_Lab3/data/"

In [None]:
# bleu-4 score function
def bleu4_score(output, reference):
    cc = SmoothingFunction()
    if len(reference) == 3:
        weights = (0.33, 0.33, 0.33)
    else:
        weights = (0.25, 0.25, 0.25, 0.25)
    return sentence_bleu(
        [reference], output, weights=weights, smoothing_function=cc.method1
    )


def metrics(pred: list, target: list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError("length of pred and target must be the same")
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100


class index2char:
    def __init__(self, root, tokenizer=None):
        if tokenizer is None:
            with open(root + "tokenizer.yaml", "r") as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        else:
            self.tokenizer = tokenizer

    def __call__(self, indices: list, without_token=True):
        if type(indices) == Tensor:
            indices = indices.tolist()
        result = "".join([self.tokenizer["index_2_char"][i] for i in indices])
        if without_token:
            result = result.split("[eos]")[0]
            result = (
                result.replace("[sos]", "").replace("[eos]", "").replace("[pad]", "")
            )
        return result

    def char2index(self, text):
        # Convert a string to a list of indices
        indices = [self.tokenizer["char_2_index"].get(char) for char in text]
        return indices

In [None]:
class SpellCorrectionDataset(Dataset):
    def __init__(self, root, split: str = "train", tokenizer=None, padding: int = 0):
        super(SpellCorrectionDataset, self).__init__()

        self.tokenizer = index2char(root, tokenizer=tokenizer)
        self.inputs = []
        self.targets = []

        # Load your JSON data from a file
        with open(root + split + ".json", "r") as f:
            jsonData = json.load(f)

        # Torkenize the JSON data and add the sos, eos
        for dict in jsonData:
            torkenized_target = self.tokenize(dict["target"])
            target = (
                [sos]
                + torkenized_target
                + [eos]
                + [pad] * (padding - len(torkenized_target) - 2)
            )
            inputWords = dict["input"]

            for word in inputWords:
                torkenized_input = self.tokenize(word)
                input = (
                    [sos]
                    + torkenized_input
                    + [eos]
                    + [pad] * (padding - len(torkenized_input) - 2)
                )
                self.inputs.append(input)
                self.targets.append(target)

        # Convert the list to a PyTorch tensor
        self.inputs = torch.tensor(self.inputs)
        self.targets = torch.tensor(self.targets)

    def tokenize(self, text: str):
        # tokenize your text here
        # ex: "data" -> [4, 1, 20, 1]
        return self.tokenizer.char2index(text)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        # Get the data by index here
        input_ids = self.inputs[index]
        target_ids = self.targets[index]

        return input_ids, target_ids


class PositionalEncoding(nn.Module):
    def __init__(
        self,
        d_model: int,
        dropout: float = 0.1,
        max_len: int = 5000,
        batch_first: bool = False,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            x = x + self.pe[: x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[: x.size(0)]
            return self.dropout(x)

In [None]:
class Encoder(nn.Module):
    def __init__(
        self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100
    ):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim, padding_idx=pad)
        self.pos_embedding = PositionalEncoding(
            hid_dim, dropout, max_len=max_length, batch_first=True
        )
        self.layer = nn.TransformerEncoderLayer(
            d_model=hid_dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
        )
        self.encoders = nn.TransformerEncoder(self.layer, n_layers)
        self.d_model = hid_dim

    def forward(self, src, src_padding_mask):
        _ = self.tok_embedding(src) * math.sqrt(self.d_model)
        src = self.pos_embedding(_)
        output = self.encoders(src, src_key_padding_mask=src_padding_mask)
        return output


class Decoder(nn.Module):
    def __init__(
        self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100
    ):
        super(Decoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim, padding_idx=pad)
        self.pos_embedding = PositionalEncoding(
            hid_dim, dropout, max_len=max_length, batch_first=True
        )
        self.layer = nn.TransformerDecoderLayer(
            d_model=hid_dim,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
        )
        self.d_model = hid_dim
        self.decoder = nn.TransformerDecoder(self.layer, n_layers)
        self.fc = nn.Linear(hid_dim, num_emb)

    def forward(self, tgt, memory, tgt_mask, tgt_padding_mask, src_pad_mask):
        _ = self.tok_embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_embedding(_)
        output = self.decoder(
            tgt=tgt,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_pad_mask,
        )
        output = self.fc(output)
        return output


class TransformerAutoEncoder(nn.Module):
    def __init__(
        self,
        num_emb,
        hid_dim,
        n_layers,
        n_heads,
        ff_dim,
        dropout,
        max_length=100,
        encoder=None,
    ):
        super(TransformerAutoEncoder, self).__init__()
        if encoder is None:
            self.encoder = Encoder(
                num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length
            )
        else:
            self.encoder = encoder
        self.decoder = Decoder(
            num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length
        )

    def forward(self, src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask):
        memory = self.encoder(src, src_pad_mask)
        dec_out = self.decoder(tgt, memory, tgt_mask, tgt_pad_mask, src_pad_mask)
        return dec_out

In [None]:
# detect where the padding value is
def gen_padding_mask(src, pad_idx):
    # pad_mask = (src == pad_idx
    return src.eq(pad_idx)


# triu mask for decoder
def gen_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask


def get_index(pred, dim=2):
    return pred.clone().argmax(dim=dim)


def random_change_idx(data: torch.Tensor, prob: float = 0.2):
    # randomly change the index of the input data
    mask = torch.rand(data.size()) < prob
    new_data = torch.randint_like(data, low=0, high=data.max() + 1)
    return torch.where(mask, new_data, data)


def random_masked(data: torch.Tensor, prob: float = 0.2, mask_idx: int = 3):
    # randomly mask the input data
    mask = torch.rand(data.size()) < prob
    masked_data = torch.full_like(data, fill_value=mask_idx)
    return torch.where(mask, masked_data, data)


def validation(dataloader, model, device, logout=False, dataset="test"):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    bleu_scores = []
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        # An all pad token tensor with the same shape as tgt and the first token is <sos>
        tgt_input = torch.full_like(tgt, fill_value=pad)
        tgt_input[:, 0] = sos
        for i in range(tgt.shape[1] - 1):
            src_pad_mask = gen_padding_mask(src, pad_idx=0).to(device)
            tgt_pad_mask = gen_padding_mask(tgt_input, pad_idx=0).to(device)
            tgt_mask = gen_mask(tgt_input).to(device)
            pred = model(
                src=src,
                tgt=tgt_input,
                src_pad_mask=src_pad_mask,
                tgt_mask=tgt_mask,
                tgt_pad_mask=tgt_pad_mask,
            )
            pred_idx = get_index(pred)
            tgt_input[:, i + 1] = pred_idx[:, i]
        for i in range(tgt.shape[0]):
            pred_str_list.append(i2c(tgt_input[i].tolist()))
            tgt_str_list.append(i2c(tgt[i].tolist()))
            input_str_list.append(i2c(src[i].tolist()))
            if logout:
                print("=" * 30)
                print(f"input: {input_str_list[-1]}")
                print(f"pred: {pred_str_list[-1]}")
                print(f"target: {tgt_str_list[-1]}")
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        losses.append(loss.item())

        # Calculate BLEU-4 score
        for pred_str, tgt_str in zip(pred_str_list, tgt_str_list):
            bleu_score = bleu4_score(pred_str, tgt_str)
            bleu_scores.append(bleu_score)

    avg_bleu_score = sum(bleu_scores) / len(bleu_scores)
    avg_loss = sum(losses) / len(losses)

    print(
        f"{dataset}_acc: {metrics(pred_str_list, tgt_str_list):.2f}",
        f"{dataset}_loss: {avg_loss:.2f}",
        f"BLEU-4: {avg_bleu_score:.4f}",
        end=" | ",
    )
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]")
    return avg_bleu_score, avg_loss

In [None]:
i2c = index2char(DATAPATH)

trainset = SpellCorrectionDataset(DATAPATH, padding=22)
testset = SpellCorrectionDataset(DATAPATH, split="new_test", padding=22)
valset = SpellCorrectionDataset(DATAPATH, split="test", padding=22)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
valloader = DataLoader(valset, batch_size=64, shuffle=False)


ce_loss = nn.CrossEntropyLoss(ignore_index=pad)
learning_rate = 0.00005  # Adjust this based on your problem
# learning_rate = 0.0001
model = TransformerAutoEncoder(
    embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout, max_length=22
).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate)

test_bleu_scores = []
validation_bleu_scores = []
train_losses = []
validation_losses = []
best_bleu_score = 0.0
best_epoch = 0

for eps in range(1000):
    # train
    losses = []
    model.train()
    i_bar = tqdm(trainloader, unit="iter", desc=f"epoch{eps}")
    for src, tgt in i_bar:
        src, tgt = src.to(device), tgt.to(device)

        src_pad_mask = gen_padding_mask(src, pad_idx=0).to(device)
        tgt_pad_mask = gen_padding_mask(tgt, pad_idx=0).to(device)
        tgt_mask = gen_mask(tgt).to(device)

        optimizer.zero_grad()
        pred = model(src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask)
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        i_bar.set_postfix_str(f"loss: {sum(losses)/len(losses):.3f}")

    # Store the average training loss for the epoch
    train_losses.append(sum(losses) / len(losses))

    # test
    model.eval()
    with torch.no_grad():
        # print(f"epoch: {eps}")
        avg_bleu_score, _ = validation(testloader, model, device, dataset="test")
        test_bleu_scores.append(avg_bleu_score)
    # eval
    model.eval()
    with torch.no_grad():
        avg_bleu_score, avg_loss = validation(valloader, model, device, dataset="vali")
        validation_bleu_scores.append(avg_bleu_score)
        validation_losses.append(avg_loss)

        # Check if the current BLEU score is better than the best so far
        if avg_bleu_score > best_bleu_score:
            best_bleu_score = avg_bleu_score
            best_epoch = eps

            # Save the model state dictionary
            torch.save(
                model.state_dict(),
                "/home/zonghan/Documents/Courses/DL/DL_Lab3/best_model.pth",
            )

    # Plotting the training loss and BLEU-4 scores
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Training Loss")
    plt.plot(validation_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("CrossEntropy Loss")
    plt.title("Training Loss Curve")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(test_bleu_scores, label="Test BLEU-4 Score")
    plt.plot(validation_bleu_scores, label="Validation BLEU-4 Score")
    plt.xlabel("Epoch")
    plt.ylabel("BLEU-4 Score")
    plt.title("BLEU-4 Score Curve")
    plt.legend()

    plt.tight_layout()
    plt.savefig("/home/zonghan/Documents/Courses/DL/DL_Lab3/result.png")

In [60]:
def Inference(dataloader, model, device, logout=False, dataset="test"):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    bleu_scores = []
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        # An all pad token tensor with the same shape as tgt and the first token is <sos>
        tgt_input = torch.full_like(tgt, fill_value=pad)
        tgt_input[:, 0] = sos
        for i in range(tgt.shape[1] - 1):
            src_pad_mask = gen_padding_mask(src, pad_idx=0).to(device)
            tgt_pad_mask = gen_padding_mask(tgt_input, pad_idx=0).to(device)
            tgt_mask = gen_mask(tgt_input).to(device)
            pred = model(
                src=src,
                tgt=tgt_input,
                src_pad_mask=src_pad_mask,
                tgt_mask=tgt_mask,
                tgt_pad_mask=tgt_pad_mask,
            )
            pred_idx = get_index(pred)
            tgt_input[:, i + 1] = pred_idx[:, i]
        for i in range(tgt.shape[0]):
            pred_str_list.append(i2c(tgt_input[i].tolist()))
            tgt_str_list.append(i2c(tgt[i].tolist()))
            input_str_list.append(i2c(src[i].tolist()))
            if logout:
                print("=" * 30)
                print(f"input: {input_str_list[-1]}")
                print(f"pred: {pred_str_list[-1]}")
                print(f"target: {tgt_str_list[-1]}")
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        losses.append(loss.item())

        # Calculate BLEU-4 score
        for pred_str, tgt_str in zip(pred_str_list, tgt_str_list):
            bleu_score = bleu4_score(pred_str, tgt_str)
            bleu_scores.append(bleu_score)

    avg_bleu_score = sum(bleu_scores) / len(bleu_scores)
    avg_loss = sum(losses) / len(losses)

    print(
        f"{dataset}_acc: {metrics(pred_str_list, tgt_str_list):.2f}",
        f"{dataset}_loss: {avg_loss:.2f}",
        f"BLEU-4: {avg_bleu_score:.4f}",
    )
    for i in range(len(pred_str_list)):
        print(
            "Ground Truth: %-20s || Predicted word: %-20s"
            % (pred_str_list[i], tgt_str_list[i])
        )
    return avg_bleu_score, avg_loss

In [61]:
# Load the saved model weights
model.load_state_dict(
    torch.load("/home/zonghan/Documents/Courses/DL/DL_Lab3/best_model_transformer.pth")
)

<All keys matched successfully>

In [62]:
model.eval()
with torch.no_grad():
    # print(f"epoch: {eps}")
    avg_bleu_score, _ = Inference(testloader, model, device, dataset="test")
    print(f"Test BLEU-4 Score: {avg_bleu_score:.4f}")

test_acc: 52.00 test_loss: 4.37 BLEU-4: 0.6209
Ground Truth: appreciate           || Predicted word: appreciate          
Ground Truth: appreciate           || Predicted word: appreciate          
Ground Truth: appreciate           || Predicted word: appreciate          
Ground Truth: appreciate           || Predicted word: appreciate          
Ground Truth: appreciate           || Predicted word: appreciate          
Ground Truth: love                 || Predicted word: love                
Ground Truth: cloud                || Predicted word: cold                
Ground Truth: earths               || Predicted word: heart               
Ground Truth: television           || Predicted word: television          
Ground Truth: phone                || Predicted word: phone               
Ground Truth: chap                 || Predicted word: phase               
Ground Truth: approm               || Predicted word: poem                
Ground Truth: tomorraw             || Predicted word: