# The process of training my model in notebook on Kaggle

In [None]:
import nltk

nltk.download('punkt_tab')

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"


In [None]:
import torch
import torch.nn as nn


class MLMHead(nn.Module):
    def __init__(self, d_model = 256):
        super().__init__()
        self.lin = nn.Linear(d_model, d_model, bias=False)
        self.gelu = nn.GELU()
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.lin(x)
        x = self.gelu(x)
        x = self.norm(x)
        return x

In [None]:
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model = 256, num_heads = 8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        assert d_model % num_heads == 0, "Number of dimensions should be divisible by heads"

        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.projection = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, attention_mask=None):
        batch_size, seq_length, d_model = x.shape
        Q = self.W_q(x) #(batch_size, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (batch_size, num_heads, seq_length, d_k)
        K = K.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

        attention_scores = Q @ K.transpose(2, 3)

        if attention_mask is not None:
            mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch_dim, 1, 1, seq_length)
            mask = mask.to(attention_scores.device) # making mask to prevent model attending to PAD tokens
            attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))

        attention_weights = torch.softmax(attention_scores / math.sqrt(self.d_k),  dim=-1)
        attention_weights = self.dropout(attention_weights)

        final_weights = attention_weights @ V # (batch_size, num_heads, seq_length, d_k)
        final_weights = final_weights.transpose(1,2).contiguous().view(batch_size, seq_length, d_model)

        out_projection = self.projection(final_weights)

        return out_projection


class FeedForward(nn.Module):
    def __init__(self, d_model = 256):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        return self.projection(x)



class TransformerBlock(nn.Module):
    def __init__(self, d_model = 256):
        super().__init__()
        self.attn = MultiHeadAttention()
        self.ffn = FeedForward()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask):
        residual = x
        x = self.norm1(x)
        x = self.attn(x, attn_mask)

        x += residual

        residual = x

        x = self.norm2(x)
        x = self.ffn(x)
        x += residual

        return x

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaTokenizerFast, default_data_collator
import nltk
from datasets import Dataset as HFDataset
from nltk import sent_tokenize
from tqdm import tqdm



class PreprocessDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.texts = dataset["text"]
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.MASK_TOKEN_ID = tokenizer.mask_token_id
        self.PAD_TOKEN_ID = tokenizer.pad_token_id
        self.CLS_TOKEN_ID = tokenizer.cls_token_id
        self.SEP_TOKEN_ID = tokenizer.sep_token_id

        self.all_chunks = []
        for article in tqdm(self.texts, desc="Chunking dataset"):
           if article:
               self.all_chunks.extend(self._chunk(article))

    def __len__(self):
        return len(self.all_chunks)

    def __getitem__(self, idx):
        sequence_chunked = self.all_chunks[idx]
        attention_mask = [1 if token != self.PAD_TOKEN_ID else 0 for token in sequence_chunked]
        masked_input_ids, labels = self.mask_tokens(torch.tensor(sequence_chunked, dtype=torch.long))

        return {
            "input_ids" : masked_input_ids,
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": labels
        }

    def _chunk(self, article):
        sentences = sent_tokenize(article)

        chunks = []

        current_chunk = [self.CLS_TOKEN_ID]

        for sentence in sentences:
            tokenized = self.tokenizer.encode(sentence)

            if len(tokenized) > self.max_length - 2: # first check if tokenized sentences is more than self.maxlength - 2 in length
                truncated = tokenized[:self.max_length-2]
                chunk = [self.CLS_TOKEN_ID] + truncated + [self.SEP_TOKEN_ID]
                chunks.append(chunk)
                continue

            if len(current_chunk) + len(tokenized) + 1 > self.max_length: # second check when stack senteces exceed limit how many PAD when need
                current_chunk.append(self.SEP_TOKEN_ID)
                current_chunk += [self.PAD_TOKEN_ID] * (self.max_length - len(current_chunk))
                chunks.append(current_chunk)

                current_chunk = [self.CLS_TOKEN_ID]

            current_chunk.extend(tokenized)

        if len(current_chunk) > 1: # check after last iteration if there something left in current_chunk
            current_chunk.append(self.SEP_TOKEN_ID)
            current_chunk += [self.PAD_TOKEN_ID] * (self.max_length - len(current_chunk))
            chunks.append(current_chunk)

        if not chunks: # if chunks are empty all sequence is padded
          return [[self.CLS_TOKEN_ID, self.SEP_TOKEN_ID] + [self.PAD_TOKEN_ID] * (self.max_length - 2)]
        else:
          return chunks

    def mask_tokens(self, input_ids): # function for dynamic masking
        labels = input_ids.clone()
        orig_input_ids = input_ids.clone()

        idx = [i for i, index in enumerate(input_ids) if index not in [self.CLS_TOKEN_ID, self.SEP_TOKEN_ID, self.PAD_TOKEN_ID]]
        if len(idx) == 0:
          labels[:] = -100
          return input_ids, labels

        if int((len(idx)*0.15)) == 0: # case, where senteces is too short
          labels[:] = -100
          return input_ids, labels

        idx = torch.tensor(idx, dtype=torch.long)
        idx_to_mask = torch.multinomial(torch.ones(len(idx)), int(len(idx) * 0.15), replacement=False) # select 15% of indexes to mask
        probs = torch.rand(len(idx_to_mask)) # range of numbers to further pick from
        mask_with_word = probs < 0.8
        mask_with_random = (probs < 0.9) & (probs>= 0.8)
        to_keep = probs >= 0.9

        indices_to_mask_word = idx_to_mask[mask_with_word] # 80% for MASK token
        indices_to_mask_random = idx_to_mask[mask_with_random] # 10% for random word from vocab
        indices_to_keep = idx_to_mask[to_keep] # 10% stays the same

        input_ids[indices_to_mask_word] = self.MASK_TOKEN_ID
        input_ids[indices_to_mask_random] = torch.randint(low=0, high=len(self.tokenizer.vocab), size=(len(indices_to_mask_random),), dtype=torch.long)

        labels[:] = -100
        position_to_mask = torch.cat([indices_to_mask_word, indices_to_mask_random, indices_to_keep]) # select positions for MLM
        labels[position_to_mask] = orig_input_ids[position_to_mask]

        return input_ids, labels

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class RoBERTa(nn.Module):
    def __init__(self, vocab_size, padding_idx, max_sequence_length = 128, d_model = 256, layers=6):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
        self.pos_emb = nn.Embedding(max_sequence_length, d_model)
        self.trf_block = nn.Sequential(*[TransformerBlock(d_model=d_model) for _ in range(layers)])
        self.mlmHead = MLMHead(d_model)

    def forward(self, x, attn_mask):
        batch_size, seq_len = x.shape
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(torch.arange(seq_len, device=x.device)).unsqueeze(0)
        x = tok_emb + pos_emb


        for block in self.trf_block:
            x = block(x, attn_mask)

        # x shape now is (32, 128, 256)
        x = self.mlmHead(x)
        x = F.linear(x, self.tok_emb.weight) # weight tying technique to save parameters(reusing existing weight matrix instead of creating new one)

        return x

In [None]:
import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.amp import GradScaler, autocast
from datasets import load_dataset
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup


class RoBERTaModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        self.model = RoBERTa(vocab_size=self.tokenizer.vocab_size, padding_idx=self.tokenizer.pad_token_id)

    def forward(self, x, attn_mask):
        return self.model(x, attn_mask)

    def train_model(self, train_loader, validation_loader, num_epochs, lr=6e-4, optimizer=None, scheduler=None, scaler=None,
                    save_check_period=1):
        device = torch.device("cuda")
        self.model.to(device)

        total_steps = len(train_loader) * num_epochs
        warmup_steps = int(0.1 * total_steps) # 10% of the total number of steps are warmups steps

        if optimizer is None:
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, betas=(0.98, 0.999), eps=1e-6, weight_decay=0.01)

        if scheduler is None:
            scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

        if scaler is None:
            scaler = GradScaler()

        writer = SummaryWriter()

        # early stopping
        patience_counter = 0
        patience_limit = 5
        epsilon = 1e-3
        best_valid_loss = torch.tensor(float('inf'))
        best_model_state = None

        for epoch in range(num_epochs):
            # train part
            self.model.train()
            total_loss_train = 0

            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}")):
                input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)

                with autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                    output = self.model(input_ids, attention_mask)
                    loss = F.cross_entropy(output.view(-1, output.shape[-1]), labels.view(-1), ignore_index=-100)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer) # unscale before clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # gradient clipping
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()
                optimizer.zero_grad()

                total_loss_train += loss.item()

            train_loss = total_loss_train / len(train_loader)

            # validation part

            self.model.eval()

            total_loss_valid = 0
            total_correct = 0
            total_tokens = 0

            with torch.no_grad():
                for batch_idx, batch in enumerate(validation_loader):
                    input_ids, attention_mask, labels = batch["input_ids"].to(device), batch["attention_mask"].to(device), batch["labels"].to(device)


                    with autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                        output = self.model(input_ids, attention_mask)
                        loss = F.cross_entropy(output.view(-1, output.shape[-1]), labels.view(-1), ignore_index=-100)

                    preds = output.argmax(dim=-1)
                    mask = labels != -100
                    correct = ((preds == labels) & mask).float().sum()

                    total_loss_valid += loss.item()
                    total_correct += correct.item()
                    total_tokens += mask.float().sum().item()

            validation_loss = total_loss_valid / len(validation_loader)
            validation_accuracy = total_correct / total_tokens

            if validation_loss < best_valid_loss - epsilon:
                best_valid_loss = validation_loss
                patience_counter = 0
                best_model_state = self.model.state_dict().copy()
            else:
                patience_counter += 1
                if patience_counter >= patience_limit:
                    self.model.load_state_dict(best_model_state)
                    self.save_checkpoint(best_model_state, optimizer, scheduler, scaler, epoch, best_valid_loss, path="cpEarly.pt")
                    break

            writer.add_scalar("Loss/Train", train_loss, epoch)
            writer.add_scalar("Loss/Validation", validation_loss, epoch)
            writer.add_scalar("Accuracy/Validation", validation_accuracy, epoch)

            if epoch % save_check_period == 0:
              test_sentences = [
                "The <mask> barked at the girl",
                "She wore a <mask> dress to the party",
                "The <mask> is shining brightly",
                "The cat <mask> on the mat.",
                "The president gave a <mask> speech.",
                "She took a <mask> before dinner."
            ]
              for sent in test_sentences:
                  prediction = self.inference(self.model, self.tokenizer, sent, self.tokenizer.pad_token_id)
                  print(f"Inference [{sent}] → {prediction}")

              self.save_checkpoint(self.model.state_dict(), optimizer, scheduler, scaler, epoch, best_valid_loss, path="checkpoint.pt")

            print(f'Epoch {epoch + 1}, train loss: {train_loss:.4f}, valid loss: {validation_loss:.4f}')

        writer.close()

        last_model_copy = copy.deepcopy(self.model.state_dict())

        self.model.load_state_dict(best_model_state)
        self.save_checkpoint(self.model.state_dict(), optimizer, scheduler, scaler, epoch, best_valid_loss, path="finishedBest.pt")

        self.save_checkpoint(last_model_copy, optimizer, scheduler, scaler, epoch, best_valid_loss, path="finishedLast.pt")

    def inference(self, model, tokenizer, sentence, pad_token_id, device="cuda"):
        model.eval()

        input_ids = tokenizer.encode(sentence)
        input_ids_tensor = torch.tensor([input_ids]).to(device)

        attention_mask = (input_ids_tensor != pad_token_id).long()

        mask_token_id = tokenizer.mask_token_id

        mask_indices = [i for i, token in enumerate(input_ids) if token == mask_token_id]
        if not mask_indices:
            raise ValueError("<mask> token not found in input. Make sure tokenizer uses correct <mask> ID.")

        with torch.no_grad():
            logits = model(input_ids_tensor, attention_mask)

        predicted_tokens = []
        for idx in mask_indices:
          pred_token_id = logits[0, idx].argmax().item()
          predicted_tokens.append(tokenizer.decode([pred_token_id]))

        return predicted_tokens if len(mask_indices) > 1 else predicted_tokens[0]


    def save_checkpoint(self, model, optimizer, scheduler, scaler, epoch, best_valid_loss, path="checkpoint.pt"):
        torch.save({
            "model_state_dict": model,
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "scaler_state_dict": scaler.state_dict(),
            "epoch": epoch,
            "best_valid_loss": best_valid_loss,
        }, path)
        print(f"Checkpoint saved to {path}")

    def load_checkpoint(self, optimizer, scheduler, scaler, model=None, path="finished.pt"):
        checkpoint = torch.load(path)

        if not model:
          self.model.load_state_dict(checkpoint["model_state_dict"])
        else:
          model.load_state_dict(checkpoint["model_state_dict"])

        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        scaler.load_state_dict(checkpoint["scaler_state_dict"])

        epoch = checkpoint["epoch"]
        best_valid_loss = checkpoint["best_valid_loss"]

        print(f"Checkpoint loaded from {path}")
        return epoch, best_valid_loss

In [None]:
from datasets import load_dataset

ds_train = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split="train")
ds_val = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split="validation")

In [None]:
from transformers import RobertaTokenizerFast

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

In [None]:
trainSet = PreprocessDataset(ds_train, tokenizer)
validSet = PreprocessDataset(ds_val, tokenizer)

In [None]:
trainloader = DataLoader(trainSet, batch_size=32, drop_last=True, shuffle=True, num_workers=2)
validloader = DataLoader(validSet, batch_size=32, drop_last=False, shuffle=False, num_workers=2)

In [None]:
roberta_module = RoBERTaModule()

In [None]:
trainSet = PreprocessDataset(ds_train, tokenizer)
validSet = PreprocessDataset(ds_val, tokenizer)

In [None]:
roberta_module.train_model(trainloader, validloader, num_epochs=10)