# Installing the libraries

In [None]:
!pip install transformers
#!apt-get install git-lfs

In [None]:
!python --version

In [None]:
import torch
import torch.nn as nn # for the layers
from transformers import  EncoderDecoderModel, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import pandas as pd
from typing import List, Dict, Union
import random
from tqdm import tqdm
import os
import re

# Get the model and its tokeniser

In [None]:
roberta_tokenizer = AutoTokenizer.from_pretrained("pstroe/roberta-base-latin-cased")

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("pstroe/roberta-base-latin-cased", "pstroe/roberta-base-latin-cased")

In [None]:
model

In [None]:
# # freezing a part of layers

# for param in model.parameters():
#     param.requires_grad = False

# # unfreezing cross attention, pooler and the head

# for i in range(12):
#     for param in model.decoder.roberta.encoder.layer[i].crossattention.parameters():
#         param.requires_grad = True

# for param in model.encoder.pooler.parameters():
#     param.requires_grad = True

# for param in model.decoder.lm_head.parameters():
#     param.requires_grad = True

In [None]:
# unfreezing it all
for param in model.parameters():
    param.requires_grad = True

# Get all data from pickle

In [None]:
import pickle
import pandas as pd

with open('/pbs/home/s/syatsyk/syatsyk/tokenized_data.pkl', 'rb') as f:
    data_new = pickle.load(f)

In [None]:
data_new.head()

# Preparing for training

In [None]:
class PairsDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        assert idx < len(self.x['input_ids'])
        item = {key: val[idx] for key, val in self.x.items()}
        item['decoder_attention_mask'] = self.y['attention_mask'][idx]
        item['labels'] = self.y['input_ids'][idx]
        return item

    @property
    def n(self):
        return len(self.x['input_ids'])

    def __len__(self):
        return self.n

In [None]:
from typing import List, Dict, Union

class DataCollatorWithPadding:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            features,
            padding=True,
        )
        ybatch = self.tokenizer.pad(
            {'input_ids': batch['labels'], 'attention_mask': batch['decoder_attention_mask']},
            padding=True,
        )
        batch['labels'] = ybatch['input_ids']
        batch['decoder_attention_mask'] = ybatch['attention_mask']

        return {k: torch.tensor(v) for k, v in batch.items()}

In [None]:
data_new.head(14700)

In [None]:
# 20 000 lines for validation
df_val = data_new.sample(20000)

In [None]:
data_new['is_train'] = True

In [None]:
# mark the validation set as non-trainable data
data_new.loc[df_val.index, 'is_train'] = False

In [None]:
df_val

In [None]:
df_train = data_new[data_new.is_train]

# Training

In [None]:
batch_size = 3

#x1, x2, y1, y2 = train_test_split(data_new['input'].tolist(), data_new['target'].tolist(), test_size=0.2, random_state=42)
train_dataset = PairsDataset(roberta_tokenizer(df_train['input'].tolist()), roberta_tokenizer(df_train['target'].tolist()))
test_dataset = PairsDataset(roberta_tokenizer(df_val['input'].tolist()), roberta_tokenizer(df_val['target'].tolist()))

data_collator = DataCollatorWithPadding(tokenizer=roberta_tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False, shuffle=True, collate_fn=data_collator)

# it takes 3 min for batch = 6, it takes 17 for batch = 24

In [None]:
loss_function = nn.CrossEntropyLoss(reduction='none')

optimizer = AdamW(model.parameters(), lr=5e-2)
lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=4, T_mult=2, eta_min=1e-6, last_epoch=-1)

num_epochs = 10

In [None]:
lr_sched.get_last_lr()

In [None]:
def error_fixing(text, model, n=None, max_length='auto', temperature=1.0, beams=3):
    texts = [text] if isinstance(text, str) else text
    inputs = roberta_tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(model.device)
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10
    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=True,
        temperature=temperature,
        repetition_penalty=3.0,
        max_length=max_length,
        bad_words_ids=None,  # unk
        num_beams=beams,
    )
    texts = [roberta_tokenizer.decode(r, skip_special_tokens=True) for r in result]
    if not n and isinstance(texts, str):
        return texts
    return texts[0]

In [None]:
texts = df_val.sample(3)['input'].tolist()

In [None]:
import gc

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

cleanup()

In [None]:
# for parallel training

from torch.nn import DataParallel

# checkpoint path
checkpoint_path = "model_checkpoints"
os.makedirs(checkpoint_path, exist_ok=True)

checkpoint_path = "model_checkpoints/model_epoch_3.pt" # start from the last saved epoch

# initialise the model on all available GPUs
model = DataParallel(model)
model.to('cuda')

# check whether there is a checkpoint file
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint '{checkpoint_path}'")
    checkpoint = torch.load(checkpoint_path)
    model.module.load_state_dict(checkpoint['model_state_dict']) # for multiple GPUs
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # start at the last epoch
    train_loss = checkpoint['train_loss']  # get the loss history for train_set
    val_loss = checkpoint['val_loss']  # get the loss history for train_set
else:
    print("No checkpoint found at '{checkpoint_path}'. Starting from scratch.")
    start_epoch = 0
    train_loss = []
    val_loss = []

accumulation_steps = 250

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0.0
    total_train_step = 0
    progress_bar = tqdm(train_dataloader)

    optimizer.zero_grad()

    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to('cuda')
        input_attention_mask = batch['attention_mask'].to('cuda')
        target_ids = batch['labels'].to('cuda')
        target_attention_mask = batch['decoder_attention_mask'].to('cuda')

        try:
            output = model(input_ids=input_ids,
                           attention_mask=input_attention_mask,
                           decoder_input_ids=target_ids,
                           decoder_attention_mask=target_attention_mask)

            max_target_length = min(target_ids.size(1), output.logits.size(1))
            loss = loss_function(output.logits[:, :max_target_length].contiguous().view(-1, roberta_tokenizer.vocab_size), target_ids.view(-1))
            loss = loss * target_attention_mask.view(-1)
            loss = loss.mean() / accumulation_steps
            loss.backward()

            if (step + 1) % accumulation_steps == 0 or (step + 1) == len(progress_bar):
                optimizer.step()
                optimizer.zero_grad()

            total_train_loss += loss.item() * accumulation_steps
            total_train_step += 1

            if total_train_step == 0:
                continue
            progress_bar.set_description(f"Epoch {epoch} - Avg Train Loss: {total_train_loss / total_train_step:.4f}")

        except RuntimeError as e:
            print(f"RuntimeError: {e}")
            print(f"Logits shape: {output.logits.shape}")
            print(f"Target shape: {target_ids.shape}")
            print(f"Max target length: {max_target_length}")
            continue

    train_loss.append(total_train_loss / total_train_step)

    # save the model and the optimiser's state
    checkpoint_path = "model_checkpoints"
    os.makedirs(checkpoint_path, exist_ok=True)
    model_save_path = os.path.join(checkpoint_path, f"model_epoch_{epoch}.pt")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.module.state_dict(),  # for multiple GPUs, DataParallel
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, model_save_path)
    print(f"Model saved to {model_save_path}")

    model.eval()
    total_val_loss = 0
    total_val_step = 0
    val_progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch + 1} Validation")
    with torch.no_grad():
        for batch in val_progress_bar:
            input_ids = batch['input_ids'].to('cuda')
            input_attention_mask = batch['attention_mask'].to('cuda')
            target_ids = batch['labels'].to('cuda')
            target_attention_mask = batch['decoder_attention_mask'].to('cuda')

            output = model(input_ids=input_ids,
                           attention_mask=input_attention_mask,
                           decoder_input_ids=target_ids,
                           labels=target_ids,
                           decoder_attention_mask=target_attention_mask)

            valid_loss = output.loss
            total_val_loss += valid_loss.item()
            total_val_step += 1
            val_progress_bar.set_description(f'Current Val loss (epoch: {epoch}): {total_val_loss / total_val_step:.4f}')

    val_loss.append(total_val_loss / total_val_step)

    lr_sched.step()

# Results visualisation

In [None]:
import matplotlib
import matplotlib.pyplot as plt

In [None]:
plt.plot(train_loss[1:], label='Training Loss')
plt.plot(val_loss[1:], label='Validation Loss')

#plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) # чтобы были целые числа по осям

plt.legend()
plt.show()

# Check results for a model from checkpoints

In [None]:
checkpoint_path = "/pbs/home/s/syatsyk/syatsyk/biblissima_spellchecker/model_checkpoints/model_best.pt"

# get the model and the tokenizer
#model = EncoderDecoderModel.from_encoder_decoder_pretrained("pstroe/roberta-base-latin-cased", "pstroe/roberta-base-latin-cased")
roberta_tokenizer = AutoTokenizer.from_pretrained("pstroe/roberta-base-latin-cased")

# get the checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cuda')
model.load_state_dict(checkpoint['model_state_dict']) #, strict=False

# wrap it in DataParallel
model = torch.nn.DataParallel(model).to('cuda')

In [None]:
def error_fixing(text, model, roberta_tokenizer, n=None, max_length='auto', temperature=1.0, beams=3, device='cuda'):
    texts = [text] if isinstance(text, str) else text
    inputs = roberta_tokenizer(texts, return_tensors='pt', padding=True)['input_ids'].to(device)
    if max_length == 'auto':
        max_length = int(inputs.shape[1] * 1.2) + 10

    model_to_generate = model.module if isinstance(model, torch.nn.DataParallel) else model

    result = model_to_generate.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=True,
        temperature=temperature,
        repetition_penalty=3.0,
        max_length=max_length,
        bad_words_ids=None,
        num_beams=beams,
    )

    texts = [roberta_tokenizer.decode(r, skip_special_tokens=True) for r in result]
    return texts[0] if n is None else texts

input_texts = data_new['input'].head(2).tolist()
target_texts = data_new['target'].head(2).tolist()

problematic_batch = []  # Для сбора проблемных входных данных

for inp, trg in zip(input_texts, target_texts):
    try:
        res = error_fixing(inp, model, roberta_tokenizer, temperature=2.0, beams=10, device='cuda')
        print(f"\n| true: {trg}\n|  inp: {inp}\n| pred: {res}")
    except Exception as e:
        print(f"Error processing input: {inp} with error: {e}")
        problematic_batch.append(inp)
        continue