In [None]:
import torch
from torch.optim import AdamW, Adam
from torch.utils.data import DataLoader, Dataset
import os
from torch.nn.functional import one_hot

import torch
import torch.nn as nn

import joblib
import numpy as np
import math

# Data preparation

In [None]:
entokened = joblib.load("/kaggle/input/tokenized-test2/tokenized_test2.pkl")
print(len(entokened))

In [None]:
import random as rd
def load_and_tokenized(file):
    return file

class MyDataset(Dataset):
    def __init__(self):
        self.files = entokened

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        return torch.tensor(self.files[index])[:rd.randint(100, 2000)]

def my_collate_fn(batch):
    sequences = batch

    max_len = max([seq.size(0) for seq in sequences])

    padded_sequences = []
    padding_value = 2617

    for seq in sequences:
        padding_needed = max_len - seq.size(0)
        padded_seq = torch.nn.functional.pad(seq, (0, padding_needed), value=padding_value)
        padded_sequences.append(padded_seq)

    batch_tensor = torch.stack(padded_sequences, dim=0)
    return batch_tensor


In [None]:
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)


In [None]:
next(iter(dataloader)).shape

# Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

In [None]:
class TransformerBlock(nn.Module):
    """
    A single Transformer block using Pre-LN (Layer Normalization before attention/MLP)
    and accepting boolean padding mask.
    """
    def __init__(self, embedding_dimension=128, heads=1, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = nn.MultiheadAttention(embed_dim=embedding_dimension, num_heads=heads, dropout=dropout_rate, batch_first=True)
        self.norm1 = nn.LayerNorm(embedding_dimension)
        self.norm2 = nn.LayerNorm(embedding_dimension)
        self.MLP = nn.Sequential(
            nn.Linear(embedding_dimension, embedding_dimension * 4), 
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(embedding_dimension * 4, embedding_dimension),
            nn.Dropout(dropout_rate)
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, key_padding_mask=None):
        seq_len = x.size(1)

        norm_x = self.norm1(x)

        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
        

        attn_output, _ = self.att(norm_x, norm_x, norm_x,
                                  key_padding_mask=key_padding_mask.to(causal_mask.device),
                                  is_causal=True,
                                  need_weights=False,
                                  attn_mask=causal_mask)
        
        x = x + self.dropout(attn_output) 

        norm_x = self.norm2(x)
        mlp_output = self.MLP(norm_x)
        out = x + mlp_output 
        return out

class T1(nn.Module):
    def __init__(self, vocab_size, embedding_dimension=128, heads=1, num_transformer_blocks=1, padding_token=0, dropout_rate=0.1):
        super(T1, self).__init__()
        self.padding_token = padding_token
        self.embedding = nn.Embedding(vocab_size, embedding_dimension, padding_idx=padding_token) 
        self.pos_embedder = nn.Embedding(3000, embedding_dimension)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embedding_dimension=embedding_dimension, heads=heads, dropout_rate=dropout_rate)
            for _ in range(num_transformer_blocks)
        ])

        self.final_norm = nn.LayerNorm(embedding_dimension) 
        self.unembedder = nn.Linear(embedding_dimension, vocab_size)
        # self.last_act = nn.Sigmoid()


    def forward(self, input_seq):
        if input_seq.device != self.embedding.weight.device:
             input_seq = input_seq.to(self.embedding.weight.device)
             
        padding_mask = (input_seq == self.padding_token) 
        x = self.embedding(input_seq) * math.sqrt(self.embedding.embedding_dim) 
        x += self.pos_embedder(torch.arange(input_seq.shape[1]).to(x.device).unsqueeze(0))

        for block in self.transformer_blocks:
            x = block(x, key_padding_mask=padding_mask) 

        x = self.final_norm(x)

        logits = self.unembedder(x)
        #logits = self.last_act(logits)

        return logits

In [None]:
vocab_size = 3151

model = T1(vocab_size=vocab_size, embedding_dimension=250, heads=5, num_transformer_blocks=16, padding_token=3150)
model = model.to(device)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model).cuda()

optimizer = Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
epochs = 100
total_params = sum(p.numel() for p in model.parameters())
total_params



In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

#del pred
del losses
del model
del optimizer
del batch


In [None]:
for epoch in range(epochs):
    losses = []
    for i, batch in enumerate(dataloader):
        optimizer.zero_grad()
        batch = batch.to(device)
        preds = model(batch[:, :-1])

        batch_size, sequence_length, vocab_size = preds.shape

        #print(preds.shape)
        loss = loss_fn(preds.view(-1, vocab_size), batch[:, 1:].contiguous().view(-1))

        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        if i % 5 == 0:
             print(f"epoch: {epoch}, loss: {loss.item():.4f}")
    print(f"Epoch {epoch+1} Mean Loss: {np.mean(losses):.4f}")

In [None]:
example = dataset[0]


In [None]:
for _ in range(60):
    with torch.no_grad():
        pred = model(example.unsqueeze(0).to(device))[0, -1, :]
        pred = torch.argmax(pred, dim=-1).cpu()
    
    example = torch.cat((example, pred.unsqueeze(0)))
example

In [None]:
from torch.distributions import Categorical

example = example.unsqueeze(0)
for _ in range(40):
    with torch.no_grad():
        logits = model(example.to(device))
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        batch_size, sequence_length, vocab_size = logits.shape
        probabilities_flat = probabilities.view(-1, vocab_size)
        dist = Categorical(probabilities_flat)
        sampled_indices_flat = dist.sample()
        pred = sampled_indices_flat.view(batch_size, sequence_length)[:, -1:]
        pred = pred.cpu()
        
    #print(example.shape, pred.shape)
    example = torch.cat((example, pred), dim=1)
example