In [1]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn import Transformer
import torch.nn.functional as F
from torch import Tensor
from torch import nn
import warnings
import random
import torch
import math
import yaml
import json
import os
import string
import matplotlib.pyplot as plt
from pathlib import Path
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
import argparse
warnings.filterwarnings("ignore")

In [2]:
# Hyper parameter
embedding_num = 1030713
embedding_dim = 16
num_layers = 8
num_heads = 8
ff_dim = 512
dropout = 0.1
train_mode = True
display_mode = False
islogout = True

PAD_token = 0

BATCH_SIZE = 4
MAX_LENGTH = 25

data_root = Path('my_data')
model_root = Path('model')
image_root = Path('image')
log_root = Path('log')
model_name = "transformer_model-sgd"

In [3]:
# switch index 2 char, also have token index2char
class index2char():
    def __init__(self, root, tokenizer=None):
        if tokenizer is None:
            with open(root / 'tokenizer.yaml', 'r') as f:
                self.tokenizer = yaml.load(f, Loader=yaml.CLoader)
        else:
            self.tokenizer = tokenizer
    
    def __call__(self, indices:list, without_token=True):
        if type(indices) == Tensor:
            indices = indices.tolist()
        result = ''.join([self.tokenizer['index_2_song_id'][i] for i in indices])
        return result.replace('[pad]', '')
    
    def char2index(self):
        if type(self.tokenizer) is dict:
            return self.tokenizer['song_id_2_index']
        
        return self.tokenizer
    
# compute accuracy score
def metrics(pred:list, target:list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError('length of pred and target must be the same')
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100

# compute BLEU-4 score
def compute_bleu(output, reference):
    cc = SmoothingFunction()
    if len(reference) == 3:
        weights = (0.33,0.33,0.33)
    else:
        weights = (0.25,0.25,0.25,0.25)
    return sentence_bleu([reference], output,weights=weights,smoothing_function=cc.method1)

def compute_bleu_list(output:list, reference:list):
    bleu_score = 0
    for out, ref in zip(output, reference):
        bleu_score += compute_bleu(out, ref)
    return bleu_score / len(output)

# Visualize of training progress
def train_visualize(train, valid, datatype="Loss"):
	plt.plot(train, label='Train')
	plt.plot(valid, label='Valid')

	plt.xlabel('Epochs')
	plt.ylabel(datatype)
	plt.title(f'Transformer training {datatype}')
	plt.legend()
	plt.savefig(image_root / f'transformer_{datatype}-{model_name}.jpg')
	plt.clf()
	if display_mode:
		plt.show()

def evaluate_logger(input:list, target:list, predict:list, split='test'):
    f = open(log_root / f'{split}-{model_name}_log.txt', 'w')

    for s, t, p in zip(input, target, predict):
        f.write("="*20+"\n")
        f.write(f"input:  {s}\n")
        f.write(f"target: {t}\n")
        f.write(f"pred:   {p}\n")
    
    f.write(f"Bleu-4 score: {compute_bleu_list(predict, target):4f}, Accuracy: {metrics(predict, target):.4f}")
    f.close()


In [4]:
# Dataset
class SpellCorrectionDataset(Dataset):
    def __init__(self, root, split:str = 'train', tokenizer=None, padding:int=0):
        super(SpellCorrectionDataset, self).__init__()
        #load your data here
        self.data = self.load_data(os.path.join(root, split+".json"))
        self.tokenizer = tokenizer
        self.padding = padding

    def load_data(self, file_name):
        with open(file_name) as f:
            data = json.load(f)
        
        return data
    
    def pad_sequence(self, sequence):
        # Pad the sequence to a specified length
        if len(sequence) < self.padding:
            padded_sequence = sequence + [PAD_token] * (self.padding - len(sequence))
        else:
            padded_sequence = sequence[:self.padding]
        return padded_sequence

    def tokenize(self, song_ids:list):
        # tokenize your song_ids here
        # ex: "data" -> [4, 1, 20, 1]
        ids = [self.tokenizer[song] for song in song_ids]
        return ids

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # get your data by index here
        # ex: return input_ids, target_ids
        # return type: torch.tensor
        input_ids = torch.tensor(self.pad_sequence(self.tokenize(self.data[index]['input'])))
        target_ids = torch.tensor(self.pad_sequence(self.tokenize(self.data[index]['target'])))
        return input_ids, target_ids
    
# Transformer
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            x = x + self.pe[:x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[:x.size(0)]
            return self.dropout(x)


class Encoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Encoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        self.layer = nn.TransformerEncoderLayer(d_model=hid_dim, nhead=n_heads, batch_first=True, dim_feedforward=ff_dim, dropout=dropout)
        self.encoder = nn.TransformerEncoder(self.layer, num_layers=n_layers)

    def forward(self, src, src_mask):
        src = self.tok_embedding(src)
        src = self.pos_embedding(src)
        enc_src = self.encoder(src, src_key_padding_mask=src_mask)
        return enc_src


class Decoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100):
        super(Decoder, self).__init__()
        self.tok_embedding = nn.Embedding(num_emb, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, dropout, max_length, batch_first=True)
        self.layer = nn.TransformerDecoderLayer(d_model=hid_dim, nhead=n_heads, batch_first=True, dim_feedforward=ff_dim, dropout=dropout)
        self.decoder = nn.TransformerDecoder(self.layer, num_layers=n_layers)
        

    def forward(self, tgt, memory, src_pad_mask, tgt_mask, tgt_pad_mask):
        tgt = self.tok_embedding(tgt)
        tgt = self.pos_embedding(tgt)

        tgt = self.decoder(tgt, memory, memory_key_padding_mask=src_pad_mask, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask)
        return tgt


class TransformerAutoEncoder(nn.Module):
    def __init__(self, num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length=100, encoder=None):
        super(TransformerAutoEncoder, self).__init__()
        if encoder is None:
            self.encoder = Encoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        else:
            self.encoder = encoder
        self.decoder = Decoder(num_emb, hid_dim, n_layers, n_heads, ff_dim, dropout, max_length)
        self.fc = nn.Linear(hid_dim, num_emb)

    def forward(self, src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask):
        enc_src = self.encoder(src, src_pad_mask)
        out = self.decoder(tgt, enc_src, src_pad_mask, tgt_mask, tgt_pad_mask)
        out = self.fc(out)
        return out


In [5]:
# Mask Process
def gen_padding_mask(src, pad_idx):
    # detect where the padding value is
    # pad_mask = [token == pad_idx for token in src]
    return src.eq(pad_idx)

def gen_mask(seq):
    # triu mask for decoder
    seq_len = seq.shape[-1]

    # Create an upper triangular matrix with ones above the main diagonal
    mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1).bool()

    return mask

def get_index(pred, dim=2):
    return pred.clone().argmax(dim=dim)

def random_change_idx(data: torch.Tensor, prob: float = 0.2):
    # randomly change the index of the input data
    # Generate a random mask with the same shape as the input data
    mask = (torch.rand_like(data) < prob).bool()

    # Generate random indices for each element in the data
    random_indices = torch.randint_like(data, low=0, high=len(data), dtype=torch.int)

    # Apply the random change with the given probability
    sample = torch.where(mask, data[random_indices], data)

    return sample

def random_masked(data: torch.Tensor, prob: float = 0.2, mask_idx: int = 3):
    # randomly mask the input data
    # Generate a random mask with the same shape as the input data
    mask = (torch.rand_like(data) < prob).bool()

    # Apply the random mask with the given probability
    sample = torch.where(mask, torch.tensor(mask_idx), data)

    return sample


In [6]:
# Data Loading 
from tqdm import tqdm

i2c = index2char(data_root)
tokenizer = i2c.char2index()

trainset = SpellCorrectionDataset(data_root, tokenizer=tokenizer, padding=MAX_LENGTH)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valset = SpellCorrectionDataset(data_root, tokenizer=tokenizer, split='test', padding=MAX_LENGTH)
valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)

# Model Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ce_loss = nn.CrossEntropyLoss(ignore_index=PAD_token)

best_valid_loss = float('inf')

In [8]:
# Training configuration
CLIP = 1
N_EPOCHS = 400
LR = 0.005

# Start training
model = TransformerAutoEncoder(embedding_num, embedding_dim, num_layers, num_heads, ff_dim, dropout, MAX_LENGTH).to(device)
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=LR)#choose your optimizer
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

In [9]:
def train(dataloader, model, device, eps):
    losses = []
    pred_str_list = []
    tgt_str_list = []

    i_bar = tqdm(dataloader, unit='iter', desc=f'epoch{eps}')
    for src, tgt in i_bar:
        src, tgt = src.to(device), tgt.to(device)
        # generate the mask and padding mask
        src_pad_mask = gen_padding_mask(src, PAD_token)#generate the padding mask
        tgt_pad_mask = gen_padding_mask(tgt, PAD_token)#generate the padding mask
        tgt_mask = gen_mask(tgt).to(device)#generate the mask
        optimizer.zero_grad()
        pred = model(src, tgt, src_pad_mask, tgt_mask, tgt_pad_mask)
        pred_idx = get_index(pred)
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()
        losses.append(loss.item())
        i_bar.set_postfix_str(f"loss: {sum(losses)/len(losses):.3f}")

        for i in range (tgt.shape[0]):
            pred_str_list.append(i2c(pred_idx[i].tolist()))
            tgt_str_list.append(i2c(tgt[i].tolist()))

    return sum(losses)/len(losses), compute_bleu_list(pred_str_list, tgt_str_list)

# Validation function
def validation(dataloader, model, device, logout=islogout, split="test"):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        #An all pad token tensor with the same shape as tgt and the first token is <sos>
        tgt_input = torch.full_like(tgt, PAD_token).to(device)
        tgt_input[:,0] = tgt[:,0]
        for i in range(tgt.shape[1]-1):
            src_pad_mask = gen_padding_mask(src, PAD_token)#generate the padding mask
            tgt_pad_mask = gen_padding_mask(tgt_input, PAD_token)#generate the padding mask
            tgt_mask = gen_mask(tgt).to(device)#generate the mask
            pred = model(src, tgt_input, src_pad_mask, tgt_mask, tgt_pad_mask)
            pred_idx = get_index(pred)[:, i]  # get the prediction idx from the model for the last token
            tgt_input[:, i + 1] = pred_idx  # assign the prediction idx to the next token of tgt_input
        for i in range (tgt.shape[0]):
            pred_str_list.append(i2c(tgt_input[i].tolist()))
            tgt_str_list.append(i2c(tgt[i].tolist()))
            input_str_list.append(i2c(src[i].tolist()))

        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        losses.append(loss.item())
    print(f"{split}_acc: {metrics(pred_str_list, tgt_str_list):.2f}", f"{split}_loss: {sum(losses)/len(losses):.2f}", end=' | ')
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]", f"{split}_bleu4: {compute_bleu_list(pred_str_list, tgt_str_list):.4f}")

    if logout:
        evaluate_logger(input_str_list, tgt_str_list, pred_str_list, split)

    return sum(losses)/len(losses), compute_bleu_list(pred_str_list, tgt_str_list)


In [10]:
train_losses = []
valid_losses = []
train_bleus = []
valid_bleus = []

if train_mode:
    for eps in range(N_EPOCHS):
        # train
        model.train()
        
        train_loss, train_bleu = train(trainloader, model, device, eps)
        
        # eval
        model.eval()
        with torch.no_grad():
            valid_loss, valid_bleu = validation(valloader, model, device, False, "test")

        # scheduler.step()

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), model_root / f'{model_name}.pt')

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        train_bleus.append(train_bleu)
        valid_bleus.append(valid_bleu)

        train_visualize(train_losses, valid_losses, "Loss")
        train_visualize(train_bleus, valid_bleus, "Bleu")

epoch0:   0%|          | 0/114452 [00:01<?, ?iter/s]


RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

In [None]:
# Load checkpoint
model.load_state_dict(torch.load(model_root / f'{model_name}.pt'))
# test
model.eval()
with torch.no_grad():
    validation(valloader, model, device, True, "test")

new_test_acc: 50.00 new_test_loss: 2.60 | [pred: appreciate target: appreciate] new_test_bleu4: 0.6450
test_acc: 94.00 test_loss: 0.08 | [pred: contented target: contented] test_bleu4: 0.9628


In [5]:
import json
from pathlib import Path

data_root = Path('my_data')
with open(data_root / 'test.json') as f:
    data = json.load(f)

data = data[:100]

In [11]:
data

[{'input': ['f6f06a71bb8bc38af6c0b7dae9cab00d',
   '7b48a87effd31c9c07b68ed212062854',
   '61c46d6401aab1dde7c7de23dc55c037',
   '7e54c9199aad70e35fe256d23701bad0',
   '6178580fa01b62e9b52787902c0d8ae6',
   'ab694649c65477d0bc574bf391a3f4a0',
   '5b3387fa195672dcfe979d17e4a62c9e',
   '2790c612d8d301e2f35550c75aea8c75',
   'd36c6cf30154e18e6c972704206d6b1e',
   '1cbcc681ecf7acef4948bff2eb8e39d7',
   '95eb6b55a0b6d049aadd729aaabd63de',
   '7035839edc259dcad4b1632d10eded74',
   '420af27f7145b4eebeec566c0fa7a4c1',
   'ea3083c238e4fbb01a8035816ad7101f',
   'e060390d1cfa800bd6032cfa524c57f6',
   '43fdd8f154e5c522eef60f6edfb38896',
   '8d1ee4d9df7226fd8af3c4466a48afdc',
   '2ad3043e1a7e459ddb09c5ba27e475f8',
   '7bb8fadfc8f2bf145f4b29a0325fe79a',
   '824c159701c8553b0e38f0d36ddd6197'],
  'target': ['f6f06a71bb8bc38af6c0b7dae9cab00d',
   '7b48a87effd31c9c07b68ed212062854',
   '61c46d6401aab1dde7c7de23dc55c037',
   '7e54c9199aad70e35fe256d23701bad0',
   '6178580fa01b62e9b52787902c0d8ae6',
   'a