In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import gc
import os
import random

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
import math

def accuracy(y_pred, y_true):
    """
    Accuracy Metric

    Args:
        y_pred (np.array): model prediction of shape (batch_size,1).
        y_true (np.array): ground truth of shape (batch_size,1).

    Returns:
        float: Accuracy value computed from given inputs.
    """

    correct = y_pred == y_true
    N = len(y_true)
    acc = correct.sum() / N
    return acc

def get_cosine_decay_with_warmup(total_steps=1000, warmup_steps=100, max_lr=1e-3, min_lr=1e-7):

    def get_lr(step):

        if step < warmup_steps:
            # Linear warmup
            return max_lr * step / warmup_steps
        else:
            # Cosine decay
            cosine_decay = 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
            return min_lr + (max_lr - min_lr) * cosine_decay

    return get_lr

class LRScheduler:
    def __init__(self, optimizer, lr_fn):
        self.current_step = 0
        self.optimizer = optimizer
        self.lr_fn = lr_fn

    def step(self):
        lr = self.lr_fn(self.current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.current_step += 1
        return lr

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

class AverageMeter:
    """Computes and stores the average and current value."""
    def __init__(self):
        self.reset()

    def reset(self):
        """Resets all the statistics."""
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """Updates the meter with a new value.

        Args:
            val (float): The new value to update.
            n (int): The number of occurrences of this value (default is 1).
        """
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
from transformers import GPT2Tokenizer

# Load GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Sample text
text = "The quick brown fox jumps over the lazy dog."

# Tokenize the text
tokens = tokenizer.encode(text)

# Decode tokens back to text
decoded_text = tokenizer.decode(tokens)

print("Original Text:", text)
print("Token IDs:", tokens)
print("Decoded Text:", decoded_text)

In [None]:
special_tokens_dict = {'additional_special_tokens': ['<PAD>']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
vocab = tokenizer.get_vocab()
# model.resize_token_embeddings(len(tokenizer))

In [None]:
vocab['<PAD>']

In [None]:
df = pd.read_csv('/kaggle/input/wmt-sampled-50000-english-to-french-dataset/wmt_sample_50000.csv')
df['en_encoded'] = tokenizer(df.en.tolist())['input_ids']
df['fr_encoded'] = tokenizer(df.fr.tolist())['input_ids']
df['en_len'] = df['en_encoded'].apply(len)
df['fr_len'] = df['fr_encoded'].apply(len)
df

In [None]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_df['split'] = 'train'
test_df['split'] = 'test'

df = pd.concat([train_df, test_df]).reset_index(drop=True)

In [None]:
# Dataset 클래스 정의
class WMTDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        x = data['en_encoded']
        y = data['fr_encoded']

        x = torch.tensor(x)
        y = torch.tensor(y)

        return x, y

In [None]:
def collate_fn(batch, max_len=128, teacher_forcing=True, padding_side='right', pad_idx=vocab['<PAD>'], eos_idx=vocab['<|endoftext|>']):
    #batch = [(x1,y1),(x2,y2)...]
    x, y = zip(*batch) #(x1, x2, ...), (y1,y2,...)
    x = [s[:max_len] for s in x]
    y = [s[:max_len] for s in y]
    x_lens = [len(s) for s in x]
    y_lens = [len(s) for s in y]
    max_lenx = max(x_lens)
    max_leny = max(y_lens)

    #add eos token
    y = [torch.cat([s, torch.tensor([eos_idx])]) for s in y]

    x = torch.stack([F.pad(s, (0,max_lenx-len(s)) if padding_side == 'right' else (max_lenx-len(s),0), value=pad_idx) for s in x])
    y_tar = torch.stack([F.pad(s, (0,max_leny+1-len(s)) if padding_side == 'right' else (max_leny+1-len(s),0), value=pad_idx) for s in y])

    if teacher_forcing:
        #shifted-right input for teacher forcing
        y_inp = torch.stack([F.pad(torch.roll(s,+1), (0,max_leny+1-len(s)) if padding_side == 'right' else (max_leny+1-len(s),0), value=pad_idx) for s in y])
        return (x, y_inp), y_tar
    else:
        return x, y_tar

In [None]:
train_dataset = WMTDataset(df[df.split=='train'], tokenizer)
test_dataset = WMTDataset(df[df.split=='test'], tokenizer)
for x in train_dataset:
    print(x)
    break

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True, num_workers=4, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)

## Seq2Seq + Attention Mechanism

In [None]:
class MaskedRNN(nn.Module):
    def __init__(self, rnn_module, **kwargs):
        super().__init__(**kwargs)
        self.rnn_module = rnn_module
        assert rnn_module.batch_first

    def forward(self, inputs, h=None, mask=None):
        if mask is not None:
            orig_len = inputs.size(1)
            lens = mask.to(torch.int32).sum(1).to('cpu')
            x = nn.utils.rnn.pack_padded_sequence(inputs, lengths=lens, batch_first=True, enforce_sorted=False)
            x, hidden = self.rnn_module(x, h)
            x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        else:
            x, hidden = self.rnn_module(inputs, h)
        return x, hidden

class Seq2SeqAttentionModel(nn.Module):
    def __init__(self, dim=256, vocab_size=len(vocab), pad_idx=vocab['<PAD>']):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim)
        self.encoder_rnn = MaskedRNN(nn.GRU(dim, dim, num_layers=2, batch_first=True, bidirectional=True))
        self.decoder_rnn = MaskedRNN(nn.GRU(dim, dim, num_layers=2, batch_first=True, bidirectional=False))
        self.source_proj = nn.Linear(2*dim, dim)
        self.target_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(2*dim, dim)
        self.head = nn.Linear(dim, vocab_size)
        self.pad_idx = pad_idx

    def forward(self, inp):
        x, y = inp
        mask_x = x != self.pad_idx
        mask_y = y != self.pad_idx
        x = self.embedding(x)
        y = self.embedding(y)
        x, _ = self.encoder_rnn(x, mask=mask_x)
        y, _ = self.decoder_rnn(y, mask=mask_y)
        x = self.source_proj(x) #B,L_x,D
        y = self.target_proj(y) #B,L_y,D

        #dot-product attention
        att = y @ x.transpose(-2,-1) #B,L_y,L_x
        score = F.softmax(att, dim=-1)
        c = score @ x #B,L_y,D

        y = torch.cat([c,y], dim=-1)
        y = self.out_proj(y)
        y = self.head(y)
        return y

In [None]:
for x in train_loader:
    print(x[1].shape)
    # print(x[0][0])
    break

In [None]:
x[0][1][0]

In [None]:
x[1][0]

In [None]:
for x in train_loader:
    break
model = Seq2SeqAttentionModel()
model(x[0]).shape

In [None]:
!pip -q install torchinfo
from torchinfo import summary
summary(model)

In [None]:
class MaskedCCELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y_true, mask):
        y_pred = y_pred[mask]
        y_true = y_true[mask]
        loss = F.cross_entropy(y_pred, y_true)
        return loss.mean()

def masked_accuracy(y_pred, y_true, mask):
    y_pred = y_pred[mask]
    y_true = y_true[mask]
    correct = y_pred == y_true
    N = mask.to(correct.dtype).sum()
    acc = correct.sum() / N
    return acc

In [None]:
seed_everything(seed=42)
torch.cuda.empty_cache()
gc.collect()

model = Seq2SeqAttentionModel()
model = model.to(device)

epochs = 5
clip_grad = 1.0
lr = 1e-3

loss_fn = MaskedCCELoss()#CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW([
    {'params': [param for param in model.parameters() if param.ndim>=2], 'weight_decay': 0.01},
    {'params': [param for param in model.parameters() if param.ndim<2], 'weight_decay': 0.0}
], lr=lr)
accum_loss = AverageMeter()
accum_acc = AverageMeter()
total_steps = len(train_loader) * epochs
lr_fn = get_cosine_decay_with_warmup(total_steps=total_steps, warmup_steps=total_steps//10, max_lr=lr, min_lr=1e-7)
scheduler = LRScheduler(optimizer, lr_fn)
best_val_loss = float('inf') #initialize the best valiation loss as infinity

for epoch in range(1, epochs+1):
    model.train()  # training mode
    accum_loss.reset()
    accum_acc.reset()
    pbar = tqdm(train_loader, desc=f'TRAIN epoch {epoch}', total=len(train_loader))
    for x, y in pbar:
        x, y = (x[0].to(device), x[1].to(device)), y.to(device)
        mask = y != vocab['<PAD>']
        y_pred = model(x)
        loss = loss_fn(y_pred, y, mask)
        loss.backward()
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        lr = scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        N = mask.int().sum()
        accum_loss.update(loss.detach(), N)
        accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)

        pbar.set_postfix({'loss': f'{accum_loss.avg:.4f}',
                          'acc': f'{accum_acc.avg:.4f}',
                          'lr': f'{lr:.6f}',
                          'grad_norm': f'{norm:.4f}'})

    model.eval()  # evaluation mode
    accum_loss.reset()
    accum_acc.reset()
    for x, y in test_loader:
        with torch.no_grad():
            x, y = (x[0].to(device), x[1].to(device)), y.to(device)
            mask = y != vocab['<PAD>']
            y_pred = model(x)

            loss = loss_fn(y_pred, y, mask)
            N = mask.int().sum()
            accum_loss.update(loss.detach(), N)
            accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)
    print(f'Epoch{epoch}: val_loss {accum_loss.avg:.4f} val_acc {accum_acc.avg:.4f}')

In [None]:
#seq2seq val acc = 0.3929, bleu 0.004

In [None]:
def greedy_decode_att(model, x, max_len=128, pad_idx=vocab['<PAD>'], eos_idx=vocab['<|endoftext|>']):
    model.eval()
    batch_size = x.size(0)
    device = x.device

    #get the hiddens states of the input sequences
    mask_x = x != pad_idx
    x = model.embedding(x)
    x, _ = model.encoder_rnn(x, None, mask=mask_x)
    x = model.source_proj(x)

    # Initialize y with <eos> tokens (start of decoding)
    pred = torch.full((batch_size, 1), eos_idx, dtype=torch.long, device=device)
    curr_token = pred

    h = None
    # Iteratively decode until max_len or eos token is generated
    for i in range(max_len):
        # Forward pass through embedding and RNN
        y = model.embedding(curr_token)
        _, h = model.decoder_rnn(y, h)
        y = model.target_proj(h[-1]).unsqueeze(1) #B,1,D

        #dot-product attention
        att = y @ x.transpose(-2,-1) #B,1,L_x
        c = F.softmax(att, dim=-1) @ x #B,1,D

        out = model.out_proj(torch.cat([c,y],dim=-1))
        logits = model.head(out)[:,0]

        # Get logits and select the most probable token (greedy approach)
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        # next_token = pad, if seqeunce have generated an <eos> token
        if i > 0:
            next_token = torch.where((curr_token == eos_idx), eos_idx, next_token)

        # Append the predicted token to the sequence
        pred = torch.cat([pred, next_token], dim=1)

        # Stop if all sequences in the batch have generated an <eos> token
        if (next_token == eos_idx).all():
            break

        curr_token = next_token

    # Remove the initial <eos> token from the output
    return pred[:, 1:]

In [None]:
pred = greedy_decode_att(model, torch.tensor([tokenizer.encode('I love programming')]).cuda())
pred

In [None]:
tokenizer.decode(pred[0])
#answer: J'adore la programmation

In [None]:
import nltk.translate.bleu_score as bleu

In [None]:
#BLEU score, 예측문장과 정답 문장이 겹치는 정도를 계산
bleu.sentence_bleu([tokenizer.decode(pred[0], skip_special_tokens=True).split()], "J'adore la programmation".split())

In [None]:
eval_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)

In [None]:
labels = []
preds = []
max_len = 128
eos_idx = tokenizer.get_vocab()['<|endoftext|>']
pad_idx = tokenizer.get_vocab()['<PAD>']
for x, y in tqdm(eval_loader):
    with torch.no_grad():
        x, y = (x[0].to(device), x[1].to(device)), y.to(device)
        pred = greedy_decode_att(model, x[0], max_len=max_len, pad_idx=pad_idx, eos_idx=eos_idx)
        labels.extend(y)
        preds.extend(pred)

preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
preds = [x.split() for x in preds]
labels = [[x.split()] for x in labels]
print('EXAMPLE PRED:', ' '.join(preds[0]))
print()
print('EXAMPLE LABEL:', ' '.join(labels[0][0]))
print()
print(bleu.corpus_bleu(labels, preds)) #0~1, bigger is better


## Transformer

In [None]:
class MaskedSoftmax(nn.Module):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        # self.softmax = nn.Softmax(self.dim)

    def forward(self, inputs, mask=None):
        if mask is not None:
            inputs = inputs.masked_fill(~mask, torch.finfo(inputs.dtype).min)
        return F.softmax(inputs, dim=self.dim)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim=256, num_heads=4, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = self.dim // self.num_heads
        self.q_proj = nn.Linear(dim, dim, bias=True)
        self.k_proj = nn.Linear(dim, dim, bias=True)
        self.v_proj = nn.Linear(dim, dim, bias=True)
        self.o_proj = nn.Linear(dim, dim, bias=True)

    def forward(self, q, k, v, mask=None, causal_mask=False):
        bs, q_len = q.shape[0], q.shape[1]
        k_len = k.shape[1]
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        q = q.view(bs, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(bs, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(bs, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        #padding mask(applied to key=column of attention matrix)
        if mask is not None:
            mask = mask[:, None, None, :]
        #causal mask for decoder self attention
        if causal_mask:
            causal_mask = ~torch.triu(torch.ones(q_len, k_len, device=q.device), diagonal=1)[None, None, :, :].bool()
            mask = mask & causal_mask if mask is not None else causal_mask

        scale = self.head_dim ** -0.5
        attn = q @ k.transpose(-2,-1) * scale
        attn = MaskedSoftmax(dim=-1)(attn, mask=mask)

        x = attn @ v
        x = x.permute(0, 2, 1, 3).reshape(bs, q_len, self.dim)
        x = self.o_proj(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, dim=256, dropout=0.1, **kwargs):
        super().__init__()
        self.fc1 = nn.Linear(dim, 4*dim)
        self.fc2 = nn.Linear(4*dim, dim)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU() #nn.GELU, nn.SiLU

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
#cuasal_mask example
~torch.triu(torch.ones(5, 5), diagonal=1).bool()

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim=256, num_heads=4, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads

        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim=dim,num_heads=num_heads)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(dim)
        self.ffn = FeedForward(dim, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        att_in = x
        # x = self.norm1(x) #pre-norm
        x = self.attn(q=x,k=x,v=x,mask=mask)
        x = self.dropout1(x)
        x = x + att_in
        x = self.norm1(x) #post-norm

        ffn_in = x
        # x = self.norm2(x) #pre-norm
        x = self.ffn(x)
        x = self.dropout2(x)
        x = x + ffn_in
        x = self.norm2(x) #post-norm
        return x

class TransformerDecoderBlock(nn.Module):
    def __init__(self, dim=256, num_heads=4, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_heads = num_heads

        self.norm1 = nn.LayerNorm(dim)
        self.self_attn = MultiHeadAttention(dim=dim,num_heads=num_heads)
        self.dropout1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(dim)
        self.cross_attn = MultiHeadAttention(dim=dim,num_heads=num_heads)
        self.dropout2 = nn.Dropout(dropout)

        self.norm3 = nn.LayerNorm(dim)
        self.ffn = FeedForward(dim, dropout=dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, y, mask_x=None, mask_y=None):
        self_att_in = y
        # y = self.norm1(y) #pre-norm
        y = self.self_attn(q=y,k=y,v=y,mask=mask_y,causal_mask=True)
        y = self.dropout1(y)
        y = y + self_att_in
        y = self.norm1(y) #post-norm

        cross_att_in = y
        # y = self.norm1(y) #pre-norm
        y = self.cross_attn(q=y,k=x,v=x,mask=mask_x)
        y = self.dropout2(y)
        y = y + cross_att_in
        y = self.norm2(y) #post-norm

        ffn_in = y
        # y = self.norm2(y) #pre-norm
        y = self.ffn(y)
        y = self.dropout3(y)
        y = y + ffn_in
        y = self.norm3(y) #post-norm
        return y

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=256):
        """
        d_model: Dimension of the model (hidden size)
        max_len: Maximum length of the input sequences
        """
        super().__init__()

        # Create a matrix to hold positional encodings of shape (max_len, d_model)
        pos_encoding = torch.zeros((max_len, dim))

        # Compute the positional encodings using the formula
        position = torch.arange(max_len)[:, None]  # Shape (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / dim))  # Shape (d_model/2)

        # Apply sine to even indices in the embedding dimension
        pos_encoding[:, 0::2] = torch.sin(position * div_term)

        # Apply cosine to odd indices in the embedding dimension
        pos_encoding[:, 1::2] = torch.cos(position * div_term)

        # add a batch dimension
        pos_encoding = pos_encoding.unsqueeze(0)  # Shape (1, max_len, d_model)

        # Register the positional encoding as a buffer so it's not trained
        # 모델 파라미터는 아니지만 GPU 메모리에 미리 저장되는, requires_grad=False인 tensor
        self.register_buffer('pos_encoding', pos_encoding)

    def forward(self, x):
        """
        x: Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
        The input tensor with positional encodings added, shape (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)

        # Add positional encoding to the input embeddings
        return self.pos_encoding[:, :seq_len, :]

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        # Calculate position embeddings on the fly (extrapolation is possible)
        position = torch.arange(x.size(1))[:, None].to(x.device) #(seq_len, 1)
        div_term = torch.exp(torch.arange(0, self.dim, 2, device=x.device).float() * -(math.log(10000.0) / self.dim))[None, :]  # (1, dim//2)

        pe = torch.stack([torch.sin(position * div_term), torch.cos(position * div_term)], dim = -1) #(seq_len, dim/2, 2)
        pe = pe.flatten(1) #(seq_len, dim)
        return pe

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, dim=256, num_heads=4, num_layers=2, dropout=0.1, **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerEncoderBlock(dim=dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.blocks:
            x = layer(x, mask=mask)
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, dim=256, num_heads=4, num_layers=2, dropout=0.1, **kwargs):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerDecoderBlock(dim=dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)])

    def forward(self, x, y, mask_x=None, mask_y=None):
        for layer in self.blocks:
            y = layer(x, y, mask_x=mask_x, mask_y=mask_y)
        return y

class Transformer(nn.Module):
    def __init__(self, dim=256, num_heads=4, num_encoder_layers=2, num_decoder_layers=2, dropout=0.1,
                 vocab_size=len(vocab), pad_idx=vocab['<PAD>'], **kwargs):
        super().__init__()
        self.pad_idx = pad_idx
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=dim)
        self.pos_encoding = SinusoidalPositionalEncoding(dim)
        self.emb_dropout = nn.Dropout(dropout)
        self.encoder = TransformerEncoder(dim=dim, num_heads=num_heads, num_layers=num_encoder_layers, dropout=dropout)
        self.decoder = TransformerDecoder(dim=dim, num_heads=num_heads, num_layers=num_decoder_layers, dropout=dropout)
        self.head = nn.Linear(dim, vocab_size)

    def forward(self, inp):
        x, y = inp
        mask_x = x != self.pad_idx
        x = self.embedding(x)
        x = x + self.pos_encoding(x)
        x = self.emb_dropout(x)
        x = self.encoder(x, mask=mask_x)

        mask_y = y != self.pad_idx
        y = self.embedding(y)
        y = y + self.pos_encoding(y)
        y = self.emb_dropout(y)
        y = self.decoder(x, y, mask_x=mask_x, mask_y=mask_y)
        y = self.head(y)
        return y

In [None]:
for x in train_loader:
    break
model = Transformer()
model(x[0]).shape

In [None]:
summary(model)

In [None]:
seed_everything(seed=42)
torch.cuda.empty_cache()
gc.collect()

model = Transformer()
model = model.to(device)

epochs = 5
clip_grad = 1.0
lr = 1e-3

loss_fn = MaskedCCELoss()#CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW([
    {'params': [param for param in model.parameters() if param.ndim>=2], 'weight_decay': 0.01},
    {'params': [param for param in model.parameters() if param.ndim<2], 'weight_decay': 0.0}
], lr=lr)
accum_loss = AverageMeter()
accum_acc = AverageMeter()
total_steps = len(train_loader) * epochs
lr_fn = get_cosine_decay_with_warmup(total_steps=total_steps, warmup_steps=total_steps//10, max_lr=lr, min_lr=1e-7)
scheduler = LRScheduler(optimizer, lr_fn)
best_val_loss = float('inf') #initialize the best valiation loss as infinity

for epoch in range(1, epochs+1):
    model.train()  # training mode
    accum_loss.reset()
    accum_acc.reset()
    pbar = tqdm(train_loader, desc=f'TRAIN epoch {epoch}', total=len(train_loader))
    for x, y in pbar:
        x, y = (x[0].to(device), x[1].to(device)), y.to(device)
        mask = y != vocab['<PAD>']
        y_pred = model(x)
        loss = loss_fn(y_pred, y, mask)
        loss.backward()
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        lr = scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        N = mask.int().sum()
        accum_loss.update(loss.detach(), N)
        accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)

        pbar.set_postfix({'loss': f'{accum_loss.avg:.4f}',
                          'acc': f'{accum_acc.avg:.4f}',
                          'lr': f'{lr:.6f}',
                          'grad_norm': f'{norm:.4f}'})

    model.eval()  # evaluation mode
    accum_loss.reset()
    accum_acc.reset()
    for x, y in test_loader:
        with torch.no_grad():
            x, y = (x[0].to(device), x[1].to(device)), y.to(device)
            mask = y != vocab['<PAD>']
            y_pred = model(x)

            loss = loss_fn(y_pred, y, mask)
            N = mask.int().sum()
            accum_loss.update(loss.detach(), N)
            accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)
    print(f'Epoch{epoch}: val_loss {accum_loss.avg:.4f} val_acc {accum_acc.avg:.4f}')

In [None]:
def greedy_decode(model, x, max_len=128, pad_idx=vocab['<PAD>'], eos_idx=vocab['<|endoftext|>']):
    model.eval()
    batch_size = x.size(0)
    device = x.device

    #get the hiddens states of the input sequences
    mask_x = x != pad_idx
    x = model.embedding(x)
    x = model.pos_encoding(x) + x
    x = model.encoder(x, mask=mask_x)

    # Initialize y with <eos> tokens (start of decoding)
    y_pred = torch.full((batch_size, 1), eos_idx, dtype=torch.long, device=device)
    curr_token = y_pred

    # Iteratively decode until max_len or eos token is generated
    for i in range(max_len):
        # Forward pass through embedding and RNN
        y = model.embedding(y_pred)
        y = model.pos_encoding(y) + y
        y = model.decoder(x, y, mask_x=mask_x)
        logits = model.head(y)[:,-1]

        # Get logits and select the most probable token (greedy approach)
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        # next_token = pad, if seqeunce have generated an <eos> token
        if i > 0:
            next_token = torch.where((curr_token == eos_idx), eos_idx, next_token)

        # Append the predicted token to the sequence
        y_pred = torch.cat([y_pred, next_token], dim=1)

        # Stop if all sequences in the batch have generated an <eos> token
        if (next_token == eos_idx).all():
            break

        curr_token = next_token

    # Remove the initial <eos> token from the output
    return y_pred[:, 1:]

In [None]:
pred = greedy_decode(model, torch.tensor([tokenizer.encode('I love programming')]).cuda())
pred

In [None]:
tokenizer.decode(pred[0])

In [None]:
bleu.sentence_bleu([tokenizer.decode(pred[0], skip_special_tokens=True).split()], "J'adore la programmation".split())

In [None]:
eval_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, drop_last=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)

In [None]:
labels = []
preds = []
max_len = 128
eos_idx = tokenizer.get_vocab()['<|endoftext|>']
pad_idx = tokenizer.get_vocab()['<PAD>']
for x, y in tqdm(eval_loader):
    with torch.no_grad():
        x, y = (x[0].to(device), x[1].to(device)), y.to(device)
        pred = greedy_decode(model, x[0], max_len=max_len, pad_idx=pad_idx, eos_idx=eos_idx)
        labels.extend(y)
        preds.extend(pred)

preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
preds = [x.split() for x in preds]
labels = [[x.split()] for x in labels]
print('EXAMPLE PRED:', ' '.join(preds[0]))
print()
print('EXAMPLE LABEL:', ' '.join(labels[0][0]))
print()
print(bleu.corpus_bleu(labels, preds)) #0~1, bigger is better


## GPT2

In [None]:
from transformers import AutoModelForCausalLM

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")

In [None]:
gpt2

In [None]:
#GPT2에는 padding token이 없으므로, embedding에 padding token을 위한 weight vector를 추가해야함.
print(gpt2.transformer.wte, len(vocab))

In [None]:
gpt2.resize_token_embeddings(len(tokenizer))

In [None]:
gpt2(torch.tensor([1,2,3,4,5])).logits.shape

In [None]:
summary(gpt2)

In [None]:
#TODO gpt2의 pretrained embedding만을 가져와 seq2seqattention, transformer학습
#hint model.embedding.weight = gpt2.transformer.wte.weight

## Zero-shot Translation with GPT2

In [None]:
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
prompt = "Translate the following text from English into French. English: First Nations Governance was launched, to consult with First Nations peoples on the issues of governance under the Indian Act. French: "
input_ids = tokenizer.encode(prompt, return_tensors="pt")

gen_tokens = gpt2.generate(
    input_ids,
    do_sample=True, #sampling-based
    max_length=100,
    use_cache=True, #KV-Caching
    pad_token_id=tokenizer.eos_token_id
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)
#잘 안되는 이유: 가장 작은 gpt2-small 모델임 + gpt2는 pretraininig data의 90%이상이 영어.

## Fine Tuning GPT2

In [None]:
# Dataset 클래스 정의
class WMTSFTDataset(Dataset):
    def __init__(self, df, tokenizer, train=True):
        self.df = df
        self.tokenizer = tokenizer
        self.train = train

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        x = data['en']
        y = data['fr']
        # ' ###'을 delimeter token으로 사용, target과 source sentence를 구분하는 용도
        x = f'Translate the following text from English into French.\nEnglish: {x} ###\nFrench: '
        if self.train:
            x += y
            x = self.tokenizer.encode(x, return_tensors='pt')[0]
            return x
        else:
            x, y = self.tokenizer.encode(x, return_tensors='pt')[0], self.tokenizer.encode(y, return_tensors='pt')[0]
            return x, y

In [None]:
tokenizer.tokenize(' ###\nFrench:')

In [None]:
tokenizer.encode(' ###\nFrench:')

In [None]:
def collate_fn_x(batch, max_len=256, pad_idx=vocab['<PAD>'], eos_idx=vocab['<|endoftext|>'], padding_side='right'):
    #batch = [(x1,y1),(x2,y2)...]
    y = batch #(x1, x2, ...), (y1,y2,...)
    y = [s[:max_len] for s in y]
    y_lens = [len(s) for s in y]
    max_leny = max(y_lens)

    #add eos token
    y = [torch.cat([s, torch.tensor([eos_idx])]) for s in y]
    #left-padding
    tar = torch.stack([F.pad(s, (0,max_leny+1-len(s))if padding_side == 'right' else (max_leny+1-len(s),0), value=pad_idx) for s in y])
    #shifted-right input for teacher forcing
    inp = torch.stack([F.pad(torch.roll(s,1), (0,max_leny+1-len(s))if padding_side == 'right' else (max_leny+1-len(s),0), value=pad_idx) for s in y])

    return inp, tar

In [None]:
train_dataset = WMTSFTDataset(df[df.split=='train'], tokenizer)
test_dataset = WMTSFTDataset(df[df.split=='test'], tokenizer)
for x in train_dataset:
    print(x)
    break

In [None]:
#padding_side == 'left'
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True, num_workers=4, pin_memory=True, collate_fn=lambda x:collate_fn_x(x,padding_side='left'))
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, drop_last=True, num_workers=4, pin_memory=True, collate_fn=lambda x:collate_fn_x(x,padding_side='left'))
for x in train_loader:
    print(x[0][0], x[1][0])
    break

In [None]:
seed_everything(seed=42)
torch.cuda.empty_cache()
gc.collect()

model = AutoModelForCausalLM.from_pretrained("gpt2")
model.pad_token_id = vocab['<PAD>']
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)

epochs = 1
clip_grad = 1.0
lr = 2e-4

loss_fn = MaskedCCELoss()#CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.AdamW([
    {'params': [param for param in model.parameters() if param.ndim>=2], 'weight_decay': 0.01},
    {'params': [param for param in model.parameters() if param.ndim<2], 'weight_decay': 0.0}
], lr=lr)
accum_loss = AverageMeter()
accum_acc = AverageMeter()
total_steps = len(train_loader) * epochs
lr_fn = get_cosine_decay_with_warmup(total_steps=total_steps, warmup_steps=total_steps//10, max_lr=lr, min_lr=1e-7)
scheduler = LRScheduler(optimizer, lr_fn)
best_val_loss = float('inf') #initialize the best valiation loss as infinity

for epoch in range(1, epochs+1):
    model.train()  # training mode
    accum_loss.reset()
    accum_acc.reset()
    pbar = tqdm(train_loader, desc=f'TRAIN epoch {epoch}', total=len(train_loader))
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        mask1 = y != vocab['<PAD>']
        # delim token 뒤의 모든 y_true값을 True로, 앞의 모든 값을 False으로
        mask2 = torch.cumsum(y == vocab['Ġ###'], dim=-1).bool()
        mask = mask1 & mask2
        y_pred = model(x, attention_mask=(x!=vocab['<PAD>']).float()).logits
        loss = loss_fn(y_pred, y, mask)
        loss.backward()
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        lr = scheduler.step()
        optimizer.step()
        optimizer.zero_grad()

        N = mask.int().sum()
        accum_loss.update(loss.detach(), N)
        accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)

        pbar.set_postfix({'loss': f'{accum_loss.avg:.4f}',
                          'acc': f'{accum_acc.avg:.4f}',
                          'lr': f'{lr:.6f}',
                          'grad_norm': f'{norm:.4f}'})

    model.eval()  # evaluation mode
    accum_loss.reset()
    accum_acc.reset()
    for x, y in test_loader:
        with torch.no_grad():
            x, y = x.to(device), y.to(device)
            mask1 = y != vocab['<PAD>']
            # delim token 뒤의 모든 y_true값을 True로, 앞의 모든 값을 False으로
            mask2 = torch.cumsum(y == vocab['Ġ###'], dim=-1).bool()
            mask = mask1 & mask2
            y_pred = model(x, attention_mask=(x!=vocab['<PAD>']).float()).logits

            loss = loss_fn(y_pred, y, mask)
            N = mask.int().sum()
            accum_loss.update(loss.detach(), N)
            accum_acc.update(masked_accuracy(y_pred.detach().argmax(-1), y, mask), N)
    print(f'Epoch{epoch}: val_loss {accum_loss.avg:.4f} val_acc {accum_acc.avg:.4f}')

In [None]:
tokenizer.decode(model.generate(inputs=tokenizer.encode("Translate the following text from English into French. English: I love programming. French: ", return_tensors='pt').cuda(), max_new_tokens=128)[0])

In [None]:
test_dataset_eval = WMTSFTDataset(df[df.split=='test'], tokenizer, train=False)
test_loader_eval = DataLoader(test_dataset_eval, batch_size=64, shuffle=False, drop_last=False, num_workers=4, pin_memory=True, collate_fn=lambda x:collate_fn(x,teacher_forcing=False,padding_side='left'))

In [None]:
for x in test_dataset_eval:
    print(x)
    break

In [None]:
labels = []
preds = []
max_len = 256
eos_idx = tokenizer.get_vocab()['<|endoftext|>']
pad_idx = tokenizer.get_vocab()['<PAD>']
for x, y in tqdm(test_loader_eval):
    with torch.no_grad():
        x, y = x.to(device), y.to(device)
        pred = model.generate(inputs=x, max_new_tokens=max_len, use_cache=True, pad_token_id=tokenizer.pad_token_id, attention_mask=(x!=pad_idx).float())
        labels.extend(y)
        preds.extend(pred)

preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
preds = [x[x.find('###\nFrench: ') + len('###\nFrench: '):].split() for x in preds]
labels = [[x.split()] for x in labels]
print('EXAMPLE PRED:', ' '.join(preds[0]))
print()
print('EXAMPLE LABEL:', ' '.join(labels[0][0]))
print()
print(bleu.corpus_bleu(labels, preds)) #0~1, bigger is better