In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_emb, n_heads = 8):
        super().__init__()
        assert d_emb%n_heads == 0
        self.d_emb = d_emb
        self.n_heads = n_heads
        self.WQ = nn.Linear(d_emb, d_emb, bias = False)
        self.WK = nn.Linear(d_emb, d_emb, bias = False)
        self.WV = nn.Linear(d_emb, d_emb, bias = False)
        self.WO = nn.Linear(d_emb, d_emb, bias = True)
    def forward(self, x1, x2 = None, x3 = None, mask = None):
        b,s,d = x1.shape
        if x2 is None:
            x2 = x1
        if x3 is None:
            x3 = x1
        s_enc = x2.shape[1]
        q = self.WQ(x1)
        k = self.WK(x2)
        v = self.WV(x3)
        Q = q.view(b, s, self.n_heads, self.d_emb//self.n_heads).permute(0,2,1,3)
        K = k.view(b, s_enc, self.n_heads, self.d_emb//self.n_heads).permute(0,2,1,3)
        V = v.view(b, s_enc, self.n_heads, self.d_emb//self.n_heads).permute(0,2,1,3)
        attn_scores = torch.matmul(Q, K.transpose(-2,-1))/((self.d_emb//self.n_heads)**0.5)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == False, -1e9)
        attn_weights = torch.softmax(attn_scores, dim = -1)
        output = torch.matmul(attn_weights, V).transpose(1,2).contiguous().view(b, s, d)
        output = self.WO(output)
        return output

In [3]:
# if decode is True:
#     dec_mask = torch.ones(s,s)
#     dec_mask = torch.triu(dec_mask, diagonal = 1).unsqueeze(0).unsqueeze(1)
#     dec_mask = ((dec_mask==0)*1).to(attn_scores.device)
#     attn_scores = attn_scores.masked_fill(dec_mask == 0, float("-inf"))

In [4]:
mha = MultiHeadAttention(512, 8)
x = torch.randn(2,30,512)
x = F.pad(x, (0,0,0,20))
mask = torch.ones(2,1,1,30)
mask = F.pad(mask, (0,20))
mha(x, mask = mask)

tensor([[[ 0.0086, -0.0743,  0.0080,  ..., -0.0151,  0.0093, -0.0310],
         [ 0.0384, -0.0666,  0.0039,  ...,  0.0290,  0.0526, -0.0432],
         [ 0.0269, -0.0577,  0.0205,  ..., -0.0297, -0.0066, -0.0101],
         ...,
         [ 0.0114, -0.0682,  0.0135,  ..., -0.0151,  0.0052, -0.0342],
         [ 0.0114, -0.0682,  0.0135,  ..., -0.0151,  0.0052, -0.0342],
         [ 0.0114, -0.0682,  0.0135,  ..., -0.0151,  0.0052, -0.0342]],

        [[-0.0423, -0.1515,  0.0534,  ...,  0.0307, -0.0512, -0.1718],
         [-0.0365, -0.1598,  0.0988,  ...,  0.0236, -0.0788, -0.1285],
         [-0.0176, -0.1429,  0.0959,  ...,  0.0194, -0.0593, -0.1364],
         ...,
         [-0.0145, -0.1519,  0.0583,  ...,  0.0187, -0.0773, -0.1305],
         [-0.0145, -0.1519,  0.0583,  ...,  0.0187, -0.0773, -0.1305],
         [-0.0145, -0.1519,  0.0583,  ...,  0.0187, -0.0773, -0.1305]]],
       grad_fn=<ViewBackward0>)

In [5]:
class FeedForwardLayer(nn.Module):
    def __init__(self, d_emb, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.upsample_factor = upsample_factor
        self.linear1 = nn.Linear(d_emb, upsample_factor*d_emb)
        self.linear2 = nn.Linear(upsample_factor*d_emb, d_emb)
        self.relu = nn.ReLU()
    def forward(self, X):
        out1 = self.relu(self.linear1(X))
        return self.linear2(out1)

In [6]:
ffn = FeedForwardLayer(512)
ffn(torch.randn(2,5,512)).shape

torch.Size([2, 5, 512])

In [7]:
class EncoderAttentionBlock(nn.Module):
    def __init__(self, d_emb, n_heads = 8, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.n_heads = n_heads
        self.upsample_factor = upsample_factor
        self.ln1 = nn.LayerNorm(d_emb)
        self.ln2 = nn.LayerNorm(d_emb)
        self.multihead = MultiHeadAttention(self.d_emb, self.n_heads)
        self.feedforward = FeedForwardLayer(self.d_emb, self.upsample_factor)
        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)
    def forward(self, X, mask = None):
        out1 = self.dropout1(self.multihead(X, mask = mask))
        out1 = out1 + X
        ln_out1 = self.ln1(out1)
        out2 = self.dropout2(self.feedforward(ln_out1))
        out2 = out2 + ln_out1
        ln_out2 = self.ln2(out2)
        return ln_out2

In [8]:
enc_attn = EncoderAttentionBlock(512)
enc_attn(torch.randn(2,5,512)).shape

torch.Size([2, 5, 512])

In [9]:
class Encoder(nn.Module):
    def __init__(self, d_emb, n_layers = 6, n_heads = 8, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.upsample_factor = upsample_factor
        self.layers = nn.ModuleList([EncoderAttentionBlock(d_emb, n_heads=n_heads, upsample_factor=upsample_factor) 
        for _ in range(n_layers)])
    def forward(self, X, mask = None):
        for layer in self.layers:
            X = layer(X, mask = mask)
        return X

In [10]:
enc = Encoder(512)
enc(torch.randn(2,5,512)).shape

torch.Size([2, 5, 512])

In [11]:
class DecoderAttentionBlock(nn.Module):
    def __init__(self, d_emb, n_heads = 8, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.n_heads = n_heads
        self.upsample_factor = upsample_factor
        self.masked_self_attn = MultiHeadAttention(self.d_emb, self.n_heads)
        self.cross_attn = MultiHeadAttention(self.d_emb, self.n_heads)
        self.feedforward = FeedForwardLayer(self.d_emb, self.upsample_factor)
        self.ln1 = nn.LayerNorm(self.d_emb)
        self.ln2 = nn.LayerNorm(self.d_emb)
        self.ln3 = nn.LayerNorm(self.d_emb)
        self.dropout1 = nn.Dropout(p = 0.1)
        self.dropout2 = nn.Dropout(p = 0.1)
        self.dropout3 = nn.Dropout(p = 0.1)
    def forward(self, X, encoder_output, enc_mask = None, dec_mask = None):
        out1 = self.masked_self_attn(X, mask = dec_mask)
        out1 = self.dropout1(out1) + X
        ln_out1 = self.ln1(out1)
        out2 = self.cross_attn(ln_out1, x2 = encoder_output, x3 = encoder_output, mask = enc_mask)
        out2 = self.dropout2(out2) + ln_out1
        ln_out2 = self.ln2(out2)
        out3 = self.feedforward(ln_out2)
        out3 = self.dropout3(out3) + ln_out2
        ln_out3 = self.ln3(out3)
        return ln_out3

In [12]:
class Decoder(nn.Module):
    def __init__(self, d_emb, n_layers = 6, n_heads = 8, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.upsample_factor = upsample_factor
        self.layers = nn.ModuleList([DecoderAttentionBlock(d_emb, n_heads = n_heads, upsample_factor = upsample_factor)
                                    for _ in range(n_layers)])
    def forward(self, X, encoder_output, enc_mask = None, dec_mask = None):
        for layer in self.layers:
            X = layer(X, encoder_output, enc_mask = enc_mask, dec_mask = dec_mask)
        return X

In [13]:
dec = Decoder(512)
dec(torch.randn(2,5,512), torch.randn(2,5,512)).shape

torch.Size([2, 5, 512])

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_emb, max_len):
        super().__init__()
        self.d_emb = d_emb
        self.max_len = max_len
        pe = torch.zeros(self.max_len, self.d_emb)
        for pos in range(self.max_len):
            for i in range(self.d_emb):
                if i%2 == 0:
                    pe[pos][i] = torch.sin(torch.tensor(pos/((10000)**((2*i)/self.d_emb))))
                else:
                    pe[pos][i] = torch.cos(torch.tensor(pos/((10000)**((2*i)/self.d_emb))))
        self.register_buffer("pe", pe)
    def forward(self, X):
        X = X + self.pe[:X.size(1)]
        return X

In [15]:
class Transformer(nn.Module):
    def __init__(self, d_emb, vocab_size, max_len, n_layers = 6, n_heads = 8, upsample_factor = 4):
        super().__init__()
        self.d_emb = d_emb
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.max_len = max_len
        self.upsample_factor = upsample_factor
        self.enc = Encoder(self.d_emb, n_layers = self.n_layers, n_heads = self.n_heads, upsample_factor = self.upsample_factor)
        self.dec = Decoder(self.d_emb, n_layers = self.n_layers, n_heads = self.n_heads, upsample_factor = self.upsample_factor)
        self.embed_layer = nn.Embedding(vocab_size, d_emb)
        self.pe = PositionalEncoding(d_emb, max_len)
        self.out_linear = nn.Linear(d_emb, vocab_size)
        self.embed_layer.weight = self.out_linear.weight
        self.dropout1 = nn.Dropout(p = 0.1)
        self.dropout2 = nn.Dropout(p = 0.1)
    def forward(self, X, y, temperature = 1.0, enc_mask = None, dec_mask = None):
        enc_embeds = self.embed_layer(X)*(self.d_emb**0.5)
        dec_embeds = self.embed_layer(y)*(self.d_emb**0.5)
        enc_embeds = self.dropout1(self.pe(enc_embeds))
        dec_embeds = self.dropout2(self.pe(dec_embeds))
        enc_output = self.enc(enc_embeds, mask = enc_mask)
        dec_output = self.dec(dec_embeds, enc_output, enc_mask = enc_mask, dec_mask = dec_mask)
        dec_output = self.out_linear(dec_output)
        return dec_output/temperature

In [16]:
transformer = Transformer(512, 512, 5)
X = torch.randint(0, 100, (2, 5), dtype=torch.int64)
y = torch.randint(0, 100, (2, 5), dtype=torch.int64)
output = transformer(X, y)

In [17]:
output.shape

torch.Size([2, 5, 512])

In [18]:
# total_dataset = load_dataset("wmt14", "fr-en", split="train")
total_dataset = load_dataset('bentrevett/multi30k', split = "train")
print(total_dataset[0])

{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}


In [19]:
# dataset = total_dataset.shuffle(seed = 42).select(range(100000))
dataset = total_dataset

In [20]:
dataset[0]

{'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

In [79]:
tokenizer = AutoTokenizer.from_pretrained("gpt2",
                                         errors = 'replace',
                                         unk_token = '<UNK>',
                                         bos_token = '<SOS>',
                                         eos_token = '<EOS>',
                                         pad_token = '<PAD>',
                                         )

def tokenize_function(batch):
    eng_sentences = ["<SOS>"+data+"<EOS>" for data in batch['en']]
    german_sentences = ["<SOS>"+data+"<EOS>" for data in batch['de']]
    dec_tokenized = tokenizer(eng_sentences, padding="max_length", truncation=True, max_length=50)
    enc_tokenized = tokenizer(german_sentences, padding="max_length", truncation=True, max_length=50)
    return {
        "enc_input_ids": enc_tokenized["input_ids"],
        "enc_attention_mask": enc_tokenized["attention_mask"],
        "dec_input_ids": dec_tokenized["input_ids"],
        "dec_attention_mask" : dec_tokenized["attention_mask"],
    }

tokenized_dataset = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

In [80]:
class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_dataset):
        self.enc_inputs = [tokenized_datapoint["enc_input_ids"] for tokenized_datapoint in tokenized_dataset]
        self.enc_mask = [tokenized_datapoint["enc_attention_mask"] for tokenized_datapoint in tokenized_dataset]
        self.dec_inputs = [tokenized_datapoint["dec_input_ids"] for tokenized_datapoint in tokenized_dataset]
        self.dec_mask = [tokenized_datapoint["dec_attention_mask"] for tokenized_datapoint in tokenized_dataset]
        
    def __len__(self):
        return len(self.enc_inputs)

    def __getitem__(self, idx):
        return {
            "enc_inputs": torch.tensor(self.enc_inputs[idx]),
            "dec_inputs": torch.tensor(self.dec_inputs[idx]),
            "enc_mask" : torch.tensor(self.enc_mask[idx]),
            "dec_mask" : torch.tensor(self.dec_mask[idx]),
        }

translation_dataset = TranslationDataset(tokenized_dataset)

In [81]:
train_size = int(0.9*len(translation_dataset))
val_size = len(translation_dataset) - train_size

In [82]:
train_size

26100

In [83]:
train_indices = np.random.choice(np.arange(len(translation_dataset)), train_size)
val_indices = np.array([index for index in np.arange(len(translation_dataset)) if index not in train_indices])

In [84]:
train_subset = [translation_dataset[i] for i in train_indices]
val_subset = [translation_dataset[i] for i in val_indices]

In [85]:
batch_size = 16
train_data_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

In [86]:
# !pip install torch-adopt

In [87]:
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import LambdaLR

# def setup_optimizer_and_scheduler(model, total_steps):
#     base_lr = 1e-4
#     warmup_steps = total_steps // 10 
    
#     optimizer = Adam(model.parameters(), 
#                      lr=base_lr)

#     def lr_lambda(current_step):
#         if current_step < warmup_steps:
#             return float(current_step) / float(max(1, warmup_steps))
#         return max(0.0, float(warmup_steps)**-0.5 * min(
#             float(current_step)**-0.5,
#             float(current_step) * float(warmup_steps)**-1.5
#         ))
    
#     scheduler = LambdaLR(optimizer, lr_lambda)
    
#     return optimizer, scheduler

In [88]:
import math

class CustomScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = (self.d_model ** -0.5) * min(
            self.step_num ** -0.5, self.step_num * (self.warmup_steps ** -1.5)
        )
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

In [89]:
# from adopt import ADOPT
INPUT_DIM = len(tokenizer)
OUTPUT_DIM = len(tokenizer)
EMB_DIM = 512
N_LAYERS = 6
N_HEADS = 8
FF_DIM = 2048
DROPOUT = 0.1
MAX_LEN = 100
warmup_steps = 4000

model = Transformer(EMB_DIM, OUTPUT_DIM, MAX_LEN, n_layers=N_LAYERS, n_heads=N_HEADS, upsample_factor=FF_DIM//EMB_DIM).to(device)
# total_steps = 5 * len(train_data_loader)
# optimizer, lr_scheduler = setup_optimizer_and_scheduler(model, total_steps)
# optimizer = optim.Adam(model.parameters(), lr = 5e-4)
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9, lr = 1e-4)
scheduler = CustomScheduler(optimizer, EMB_DIM, warmup_steps)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [90]:
def generate_mask(src, tgt):
    src_mask = (src != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(3)
    seq_length = tgt.size(1)
    causal_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
    tgt_mask = tgt_mask & causal_mask
    return src_mask, tgt_mask

In [91]:
src_data = torch.randint(1, 5000, (32, 50)).to(device)
tgt_data = torch.randint(1, 5000, (32, 50)).to(device)
generate_mask(src_data, tgt_data)[1][0]

tensor([[[ True, False, False,  ..., False, False, False],
         [ True,  True, False,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [ True,  True,  True,  ...,  True, False, False],
         [ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ...,  True,  True,  True]]], device='cuda:0')

In [92]:
# def initial_cost(model, data_loader):
#     model.eval()
#     epoch_loss = 0
#     with torch.inference_mode():
#         for batch in tqdm(data_loader):
#             src = batch['enc_inputs'].to(device)
#             trg = batch['dec_inputs'].to(device)
#             trg_in = trg.clone()
#             # trg_in[trg_in == tokenizer.eos_token_id] = tokenizer.pad_token_id
#             trg_in1 = torch.tensor([token for token in trg_in[0] if token!=tokenizer.pad_token_id]).unsqueeze(0).to(device)
#             trg_out = torch.tensor([token for token in trg[0] if token!=tokenizer.pad_token_id]).unsqueeze(0)
#             trg_out = trg_out[:,1:].clone().to(device)
#             enc_mask, dec_mask = generate_mask(src, trg_in1)
#             output = model(src, trg_in1[:,:-1])
#             # print(output.shape)
#             output_dim = output.shape[-1]
#             output = output.contiguous().view(-1, output_dim)
#             trg_out = trg_out.contiguous().view(-1)
#             loss = criterion(output, trg_out)
#             epoch_loss += loss.item() / len(data_loader)
#             del batch
#     return epoch_loss

In [93]:
# initial_cost(model, data_loader)

In [94]:
for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [103]:
def inference(test_sentence):
    model.eval()
    src = tokenizer.encode(test_sentence, padding = 'max_length', max_length = 50, truncation = True, return_tensors = 'pt').to(device)
    enc_mask = (src!=tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
    out = [tokenizer.bos_token_id]
    with torch.inference_mode():
        src = model.embed_layer(src)
        enc_output = model.enc(src, mask = enc_mask)
        for _ in range(model.max_len):
            out_embed = (model.embed_layer(torch.tensor(out).to(device)).to(device)).unsqueeze(0)*(model.d_emb**0.5)
            tgt = torch.tensor(out).unsqueeze(0)
            seq_length = tgt.size(1)
            tgt_mask = (tgt != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(3).to(device)
            causal_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
            tgt_mask = (tgt_mask & causal_mask).to(device)
            dec_output = model.dec(out_embed, enc_output, enc_mask = enc_mask, dec_mask = tgt_mask)
            dec_output = F.softmax(model.out_linear(dec_output), dim = -1)
            next_token = dec_output[:,-1,:].argmax(dim = -1).item()
            out.append(next_token)

            if next_token == tokenizer.eos_token_id:
                break
        out_sentence = tokenizer.decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return out_sentence

In [134]:
optimizer = optim.SGD(model.parameters(), lr = 1e-3, momentum = 0.9)

In [135]:
def train_model(model, data_loader, optimizer, scheduler, criterion, clip=1):
    model.train()
    epoch_loss = 0

    for batch in tqdm(train_data_loader):
        src = batch['enc_inputs'].to(device)
        trg = batch['dec_inputs'].to(device)

        trg_in = trg.clone()    
        trg_in = trg_in[:,:-1]
        trg_out = trg[:,1:].clone()
        enc_mask, dec_mask = generate_mask(src, trg_in)
        
        optimizer.zero_grad()
        output = model(src, trg_in, enc_mask = enc_mask, dec_mask = dec_mask)

        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg_out = trg_out.contiguous().view(-1)

        loss = criterion(output, trg_out)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = clip)

        optimizer.step()
        # scheduler.step()
        epoch_loss += loss.item()
        # batch = batch.detach()
        del batch, trg_in, trg_out, enc_mask, dec_mask, output
        torch.cuda.empty_cache()
    
    #lr_scheduler.step()
    return epoch_loss / len(train_data_loader)

In [136]:
# checkpoint = torch.load('/kaggle/input/transformer-2epoch/transformer.pth', weights_only = True)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [137]:
N_EPOCHS = 5
CLIP = 1

for epoch in tqdm(range(N_EPOCHS)):
    train_loss = train_model(model, train_data_loader, optimizer, scheduler, criterion, CLIP)
    print(f"Epoch {epoch + 1}")
    print(f"Train Loss: {train_loss:.3f}")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1632 [00:00<?, ?it/s]

Epoch 1
Train Loss: 4.306


  0%|          | 0/1632 [00:00<?, ?it/s]

Epoch 2
Train Loss: 4.287


  0%|          | 0/1632 [00:00<?, ?it/s]

Epoch 3
Train Loss: 4.279


  0%|          | 0/1632 [00:00<?, ?it/s]

Epoch 4
Train Loss: 4.272


  0%|          | 0/1632 [00:00<?, ?it/s]

Epoch 5
Train Loss: 4.264


In [72]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    # 'scheduler_state_dict': lr_scheduler.state_dict()
}

torch.save(checkpoint, 'transformer2.pth')

In [115]:
data = next(iter(train_data_loader))

In [122]:
test_sentence = tokenizer.decode(data['enc_inputs'][1])

In [123]:
tokenizer.decode(data['dec_inputs'][1])

'<SOS>People gathered around at an outdoor event.<EOS><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>'

In [124]:
print(test_sentence)

<SOS>Mehrere Personen bei einer Freiluftveranstaltung.<EOS><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>


In [128]:
out = inference(test_sentence)

In [41]:
def validation():
    model.eval()
    out_dict = {}
    out_dict['input'] = []
    out_dict['target'] = []
    out_dict['output'] = []
    with torch.inference_mode():
        for batch in tqdm(val_data_loader):
            src = batch['enc_inputs']
            trg = batch['dec_inputs']
            trg = tokenizer.decode(trg[0])
            test_sentence = tokenizer.decode(src[0])
            out_dict['input'].append(test_sentence)
            out_dict['target'].append(trg)
            output = inference(test_sentence)
            out_dict['output'].append(output)
            del batch, src, trg, test_sentence, output
    return out_dict

In [None]:
# checkpoint = torch.load('/kaggle/input/transformer-2epoch/transformer.pth', weights_only = True)
# model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
datasets.config.clear_cache() 