In [15]:
!pip install datasets evaluate rouge_score pytorch-lightning nlp



In [16]:
import torch
from datasets import load_dataset
import pandas as pd
import zipfile
import os
import re
from pathlib import Path
import argparse
from argparse import ArgumentParser
import glob
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation
import math
import numpy as np
from torch.optim import AdamW
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from nlp import load_metric
import string
from pathlib import Path
from transformers import (
    Adafactor,
    T5ForConditionalGeneration,
    T5Tokenizer,
    T5Config,
    get_linear_schedule_with_warmup,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from torch.utils.data import RandomSampler
import textwrap
from tqdm.auto import tqdm
from nlp import load_dataset

## Prepare data for pretraining
Given a zip archive with russian text corpus presented in .txt files preprocess text and convert it to the format convinient for training

In [6]:
def unzip_file(zip_path, extract_to=None, password=None):
    """Safe unzipping function"""
    if not os.path.exists(zip_path):
        raise FileNotFoundError(f"Zip file {zip_path} not found")

    extract_to = extract_to or os.path.dirname(zip_path)
    os.makedirs(extract_to, exist_ok=True)

    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            if password:
                zip_ref.setpassword(password)
            zip_ref.extractall(extract_to)
            print(f"Successfully extracted to {extract_to}")
            return True
    except zipfile.BadZipFile:
        print("Error: File is not a zip file or is corrupted")
    except RuntimeError as e:
        if 'encrypted' in str(e):
            print("Error: Password required or incorrect password")
        else:
            print("Error:", e)
    return False

unzip_file('02.zip', extract_to='.')

Successfully extracted to .


True

In [7]:
def process_russian_texts(input_dir, output_file, max_tokens=128):
    texts = []

    text_files = list(Path(input_dir).glob("*.txt"))

    for file in text_files:
        with open(file, 'r', encoding='utf-8') as f:
            text = f.read()

            text = re.sub(r'\s+', ' ', text)
            text = text.strip()

            sentences = re.split(r'(?<=[.!?])\s+', text)

            current_chunk = []
            current_length = 0

            for sent in sentences:
                sent_length = len(sent.split())
                if current_length + sent_length > max_tokens:
                    texts.append(' '.join(current_chunk))
                    current_chunk = []
                    current_length = 0
                current_chunk.append(sent)
                current_length += sent_length

            if current_chunk:
                texts.append(' '.join(current_chunk))

    df = pd.DataFrame({'context': texts})
    df.to_csv(output_file, index=False, encoding='utf-8')

process_russian_texts(
    input_dir='02',
    output_file='train_context.csv',
    max_tokens=128
)

In [8]:
df = pd.read_csv('train_context.csv')
df.dropna(inplace=True)

In [9]:
df

Unnamed: 0,context
0,Волчонок Декабрь 1994 года. Детская память сох...
1,"Приносила что-то поесть, и тогда он засыпал на..."
2,"Ночью, когда он уже засыпал, пришел этот страш..."
3,"Но она не отвечала, и он стал ее звать громко,..."
4,Потом она его отвела к себе и позвала дедушку ...
...,...
35807,"Поэтому лучшее, что ты можешь сделать-сосредот..."
35808,"Через несколько дней, когда я, как обычно, сид..."
35809,"Но всё равно мне было очень приятно, что и я п..."
35810,"Судьба, очевидно, оценила моё позитивное к ней..."


In [10]:
df.to_csv('train_context.csv')

In [11]:
print("Missing values in 'formal':", df["context"].isnull().sum())
print("Empty strings in 'formal':", (df["context"].str.strip() == "").sum())


Missing values in 'formal': 0
Empty strings in 'formal': 0


#Prepare for pretraining

Implement a T5 model fine-tuning pipeline using PyTorch Lightning, including custom data loading (with text segmentation and span corruption for pretraining), training/validation logic with Adafactor optimization, and evaluation metrics (exact/subset match). It prepares the model for sequence-to-sequence tasks by masking text spans as input and reconstructing the original as target during pretraining.

----
the code is based on the following implementation:
https://github.com/joeljang/Pretraining_T5_custom_dataset/tree/master

In [12]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return int(normalize_answer(prediction) == normalize_answer(ground_truth))

def approx_match_score(prediction, ground_truth):
    answer = normalize_answer(prediction)
    gt = normalize_answer(ground_truth)
    match = 0
    gt_words = gt.split(" ")
    for word in gt_words:
        if word in answer:
            match = 1
            return match
    return match

def calculate_scores(predictions, ground_truths):
    em_score = 0
    subset_match_score = 0

    for i in range(len(predictions)):
        ground_truth = ground_truths[i]
        prediction = predictions[i]
        em_score +=  exact_match_score(prediction, ground_truth)
        subset_match_score += approx_match_score(prediction, ground_truth)

    em_score /= len(predictions)
    subset_match_score /= len(predictions)
    return em_score*100, subset_match_score*100

class T5FineTuner(pl.LightningModule):
    def __init__(self, hparams):
        super(T5FineTuner, self).__init__()
        self.save_hyperparameters(hparams)
        self.model = T5ForConditionalGeneration.from_pretrained(self.hparams.model_name_or_path)
        self.tokenizer = T5Tokenizer.from_pretrained(self.hparams.tokenizer_name_or_path)

        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            self.freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

        self.step_count = 0
        self.output_dir = Path(self.hparams.output_dir)

        n_observations_per_split = {
            "train": self.hparams.n_train,
            "validation": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
        self.validation_step_outputs = []
        self.training_step_outputs = []
    def freeze_params(self, model):
        for par in model.parameters():
            par.requires_grad = False


    def freeze_embeds(self):
        """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
        try:
            self.freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        except AttributeError:
            self.freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                self.freeze_params(d.embed_tokens)

    def lmap(self, f, x):
        """list(map(f, x))"""
        return list(map(f, x))


    def is_logger(self):
      return self.trainer.global_rank == 0


    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=lm_labels,
    )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            lm_labels=lm_labels,
            decoder_attention_mask=batch['target_mask']
        )

        loss = outputs[0]

        return loss


    def ids_to_clean_text(self, generated_ids):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return self.lmap(str.strip, gen_text)


    def _generative_step(self, batch) :

        t0 = time.time()

        generated_ids = self.model.generate(
            batch["source_ids"],
            attention_mask=batch["source_mask"],
            use_cache=True,
            decoder_attention_mask=batch['target_mask'],
            max_length=10,
            num_beams=2,
            early_stopping=True
        )
        preds = self.ids_to_clean_text(generated_ids)
        targets = self.ids_to_clean_text(batch["target_ids"])

        gen_time = (time.time() - t0) / batch["source_ids"].shape[0]

        loss = self._step(batch)
        base_metrics = {'val_loss': loss}
        summ_len = np.mean(self.lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=targets)
        em_score, subset_match_score = calculate_scores(preds, targets)

        return {
            "val_loss": loss,
            "em_score": torch.tensor(em_score, dtype=torch.float32),
            "subset_match_score": torch.tensor(subset_match_score, dtype=torch.float32),
            "preds": preds,
            "targets": targets
        }

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)
        self.training_step_outputs.append(loss)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        avg_loss = torch.mean(torch.stack(self.training_step_outputs))
        self.log("avg_train_loss", avg_loss, prog_bar=True)
        self.training_step_outputs.clear()


    def validation_step(self, batch, batch_idx):
        result = self._generative_step(batch)
        self.validation_step_outputs.append(result)
        return result


    def on_validation_epoch_end(self):
        outputs = self.validation_step_outputs
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        em_scores = torch.stack([x["em_score"] for x in outputs]).mean()
        subset_scores = torch.stack([x["subset_match_score"] for x in outputs]).mean()

        self.log("val_loss", avg_loss, prog_bar=True)
        self.log("em_score", em_scores, prog_bar=True)
        self.log("subset_match_score", subset_scores, prog_bar=True)

        self.validation_step_outputs.clear()
        print('-'*10,'\nval_loss: ', avg_loss, '\nem_score: ', em_scores, '\nsubset_match_score', subset_scores)
        return {
            "val_loss": avg_loss,
            "em_score": em_scores,
            "subset_match_score": subset_scores,
        }

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"

        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        optimizer = Adafactor(optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False,
                             relative_step=False)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )

        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

    def get_tqdm_dict(self):
        tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

        return tqdm_dict


    def train_dataloader(self):
        n_samples = self.n_obs['train']
        train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", num_samples=n_samples, args=self.hparams)
        sampler=RandomSampler(train_dataset)

        return DataLoader(
                          train_dataset,
                          sampler=sampler,
                          batch_size=self.hparams.train_batch_size,
                          num_workers=4,
                          persistent_workers=True
                      )

    def val_dataloader(self):
        n_samples = self.n_obs['validation']

        validation_dataset = get_dataset(tokenizer=self.tokenizer, type_path="validation", num_samples=n_samples, args=self.hparams)
        sampler=RandomSampler(validation_dataset)
        return DataLoader(validation_dataset, batch_size=self.hparams.eval_batch_size, sampler =sampler, num_workers=0)


    def test_dataloader(self):
        n_samples = self.n_obs['test']
        test_dataset = get_dataset(tokenizer=self.tokenizer, type_path="test", num_samples=n_samples, args=self.hparams)

        return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=0)


    def on_save_checkpoint(self, checkpoint):
        save_path = self.output_dir.joinpath("best_tfmr")
        self.model.config.save_step = self.step_count
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

logger = logging.getLogger(__name__)

class LoggingCallback(pl.Callback):
    def on_validation_end(self, trainer, pl_module):
        logger.info("***** Validation results *****")
        if pl_module.is_logger():
            metrics = trainer.callback_metrics
            for key in sorted(metrics):
                if key not in ["log", "progress_bar"]:
                    logger.info("{} = {}\n".format(key, str(metrics[key])))

    def on_test_end(self, trainer, pl_module):
        logger.info("***** Test results *****")
        if pl_module.is_logger():
            metrics = trainer.callback_metrics
            output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
            with open(output_test_results_file, "w") as writer:
                for key in sorted(metrics):
                    if key not in ["log", "progress_bar"]:
                        logger.info("{} = {}\n".format(key, str(metrics[key])))
                        writer.write("{} = {}\n".format(key, str(metrics[key])))

prefix_path='' #Path to custom training data. Name the training corpus train_context.csv
class Pretrain(Dataset):
    def __init__(self, tokenizer, type_path, num_samples, input_length, output_length, print_text=False):
      self.dataset = self.split_into_segment(pd.read_csv(prefix_path+"train_context.csv"),input_length)
      self.input_length = input_length
      self.tokenizer = tokenizer
      self.output_length = output_length
      self.print_text = print_text

    def split_into_segment(self, ds, input_length, min_length=5):
        new_rows = []

        for index, row in ds.iterrows():
            text = row['context']
            words = text.split()

            # Skip original text if it's already too short
            if len(words) < min_length:
                continue

            # If text is within limit, keep as is
            if len(words) <= input_length:
                new_rows.append({'context': text})
                continue

            # Split into sentences first using Russian-specific pattern
            sentences = re.split(r'(?<=[.!?])\s+', text)
            current_chunk = []
            current_length = 0

            for sentence in sentences:
                sentence_words = sentence.split()
                sentence_len = len(sentence_words)

                # Skip empty sentences
                if sentence_len == 0:
                    continue

                # If sentence itself is too long, split into word chunks
                if sentence_len > input_length:
                    if current_chunk:
                        chunk_text = ' '.join(current_chunk)
                        if len(current_chunk) >= min_length:
                            new_rows.append({'context': chunk_text})
                        current_chunk = []
                        current_length = 0

                    # Split the long sentence into word chunks
                    for i in range(0, sentence_len, input_length):
                        chunk = sentence_words[i:i+input_length]
                        if len(chunk) >= min_length:
                            new_rows.append({'context': ' '.join(chunk)})
                    continue

                # Normal sentence processing
                if current_length + sentence_len > input_length:
                    chunk_text = ' '.join(current_chunk)
                    if len(current_chunk) >= min_length:
                        new_rows.append({'context': chunk_text})
                    current_chunk = [sentence]
                    current_length = sentence_len
                else:
                    current_chunk.append(sentence)
                    current_length += sentence_len

            # Add remaining chunk if it meets length requirements
            if current_chunk and len(current_chunk) >= min_length:
                chunk_text = ' '.join(current_chunk)
                new_rows.append({'context': chunk_text})

        # Create new DataFrame with filtered segments
        new_ds = pd.DataFrame(new_rows)
        return pd.concat([ds, new_ds], ignore_index=True)

    def __len__(self):
        return len(self.dataset)

    def clean_text(self, text):
        text = text.replace('Example of text:', '')
        text = text.replace('Example of Summary:', '')
        text = text.replace('\n','')
        text = text.replace('``', '')
        text = text.replace('"', '')

        return text

    def span_corruption_mask(self, text, noise_span_length=3, noise_density=0.15):
        words = text.split()
        max_index = len(words)

        # Handle edge cases
        if max_index < noise_span_length * 2:
            return [0] * max_index

        mask = [0] * max_index
        max_spans = (max_index // noise_span_length) - 1
        target_num_spans = min(
            math.ceil((max_index * noise_density) / noise_span_length),
            max_spans
        )

        exclude = set()
        attempts_per_span = 0
        max_attempts_per_span = max(50, max_index // 2)  # Dynamic attempt limit
        placed_spans = 0

        for _ in range(target_num_spans):
            success = False
            attempts = 0

            while attempts < max_attempts_per_span:
                rand_num = np.random.randint(0, max_index - noise_span_length + 1)

                # Check if position is available and has enough space
                span_positions = set(range(rand_num, rand_num + noise_span_length))
                if span_positions.isdisjoint(exclude):
                    # Check remaining capacity
                    available_positions = max_index - len(exclude)
                    if available_positions < noise_span_length:
                        break

                    # Mark excluded area (current span + buffer zone)
                    buffer = 1  # Number of words to exclude around the span
                    exclude.update(range(
                        max(0, rand_num - buffer),
                        min(max_index, rand_num + noise_span_length + buffer)
                    ))

                    # Apply mask
                    for i in span_positions:
                        mask[i] = 1
                    placed_spans += 1
                    success = True
                    break

                attempts += 1

            if not success:
                break  # Exit if we can't place this span

        # Fallback: If no spans placed, use first valid position
        if placed_spans == 0:
            for i in range(0, max_index - noise_span_length + 1):
                if all(mask[j] == 0 for j in range(i, i + noise_span_length)):
                    for j in range(i, i + noise_span_length):
                        mask[j] = 1
                    break

        return mask

    def noise_span_to_unique_sentinel(self, text, mask,sentinels):
        tokens = text.split()
        text_ = []
        one_count=0
        sentinel_cnt=0
        for i in range(len(tokens)):
            if mask[i] == 1:
                one_count+=1
                if one_count==1:
                    text_.append(sentinels[sentinel_cnt])
                    sentinel_cnt+=1
                else:
                    if one_count==3:
                        one_count=0
            else:
                text_.append(tokens[i])
        text_ = ' '.join(text_)
        return text_

    def nonnoise_span_to_unique_sentinel(self, text, mask,sentinels):
        tokens = text.split()
        text_ = []
        zero_first=True
        sentinel_cnt=0
        for i in range(len(tokens)):
            if mask[i] == 0:
                if zero_first:
                    text_.append(sentinels[sentinel_cnt])
                    zero_first=False
                    sentinel_cnt+=1
            else:
                zero_first=True
                text_.append(tokens[i])
        text_ = ' '.join(text_)
        return text_

    def convert_to_features(self, example_batch):

        if self.print_text:
            print("Input Text: ", self.clean_text(example_batch['context']))
        text = self.clean_text(example_batch['context'])
        mask = self.span_corruption_mask(text)
        sentinels=[]
        for i in range(100):
            sentinels.append(f'<extra_id_{i}>')
        input_ = self.noise_span_to_unique_sentinel(text,mask,sentinels)
        target_ = self.nonnoise_span_to_unique_sentinel(text,mask,sentinels)
        source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length,
                                                     padding='max_length', truncation=True, return_tensors="pt")

        targets = self.tokenizer.batch_encode_plus([target_], max_length=self.output_length,
                                                     padding='max_length', truncation=True, return_tensors="pt")

        return source, targets

    def __getitem__(self, index):
        source, targets = self.convert_to_features(self.dataset.iloc[index])

        source_ids = source["input_ids"].squeeze()
        target_ids = targets["input_ids"].squeeze()

        src_mask    = source["attention_mask"].squeeze()
        target_mask = targets["attention_mask"].squeeze()

        return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}

def get_dataset(tokenizer, type_path, num_samples, args):
    return Pretrain(tokenizer=tokenizer, type_path=type_path, num_samples=num_samples,  input_length=args.max_input_length,
                        output_length=args.max_output_length)


In [13]:
from tqdm.auto import tqdm

parser = ArgumentParser()
parser.add_argument('--input_length', default=128)
parser.add_argument('--output_length', default=128)
parser.add_argument('--num_train_epochs', default=1)
parser.add_argument('--output_dir', default='t5_pretraining')
parser.add_argument('--train_batch_size', default=8)
parser.add_argument('--learning_rate', default=1e-3)
parser.add_argument('--model', default='t5-base')
hparam, unknown = parser.parse_known_args()

args_dict = dict(
    output_dir="",
    model_name_or_path=hparam.model,
    tokenizer_name_or_path=hparam.model,
    max_input_length=int(hparam.input_length),
    max_output_length=int(hparam.output_length),
    freeze_encoder=False,
    freeze_embeds=False,
    learning_rate=1e-5,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=4,
    eval_batch_size=4,
    num_train_epochs=2,
    gradient_accumulation_steps=1,
    n_gpu=1,
    resume_from_checkpoint=None,
    val_check_interval = 1.0,
    n_val=0,
    val_percent_check= 0,
    n_train=-1,
    n_test=-1,
    early_stop_callback=False,
    fp_16=False,
    opt_level='O1',
    max_grad_norm=1.0,
    seed=101,
)

args_dict.update({'output_dir': hparam.output_dir, 'num_train_epochs':int(hparam.num_train_epochs),
                'train_batch_size': int(hparam.train_batch_size), 'eval_batch_size': int(hparam.train_batch_size), 'learning_rate': float(hparam.learning_rate)})
args = argparse.Namespace(**args_dict)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=args.output_dir,
    filename="iter-{step}",
    every_n_train_steps=100,
    save_last=True,
    monitor=None,
    )

train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    max_epochs=args.num_train_epochs,
    precision= 16 if args.fp_16 else 32,
    gradient_clip_val=args.max_grad_norm,
    enable_checkpointing=checkpoint_callback,
    val_check_interval=args.val_check_interval,
    callbacks=[checkpoint_callback, LoggingCallback()],
    enable_progress_bar=True,
    log_every_n_steps=30
)

set_seed(42)
trainer = pl.Trainer(**train_params)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


### Training

At this step we load the model from checkpoint (we had do have several colab sessions) and train it for better general text understanding

In [18]:
model = T5FineTuner.load_from_checkpoint(
    'iter-step=5000.ckpt',
    hparams=args,
)

In [19]:
trainer.fit(model)#, ckpt_path=checkpoint_path)

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/t5_pretraining exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params | Mode
------------------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M  | eval
------------------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)
0         Modules in train mode
541       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


---------- 
val_loss:  tensor(1.7764, device='cuda:0') 
em_score:  tensor(0.) 
subset_match_score tensor(75.)


Training: |          | 0/? [00:00<?, ?it/s]

In [17]:
from pathlib import Path
save_path = Path(args.output_dir) / "final_model"
model.model.save_pretrained(save_path)
model.tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")

Model saved to t5_pretraining/final_model


### Test the model on a small example

In [44]:
def generate_sample(model, input_text, max_length=10):
    model.eval()
    inputs = model.tokenizer.encode(
        input_text,
        return_tensors="pt",
        max_length=model.hparams.max_input_length,
        truncation=True
    ).to(model.device)

    outputs = model.model.generate(
        inputs,
        max_length=max_length,
        num_beams=2,
        early_stopping=True
    )

    return model.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Usage example
test_samples = [
    "Съешь еще жтих мягких французских <extra_id_0>"]

for sample in test_samples:
    print(f"Input: {sample}")
    print(f"Output: {generate_sample(model, sample)}\n")

Input: Съешь еще жтих мягких французских <extra_id_0>
Output: еловек

