In [None]:
%load_ext tensorboard

In [1]:
import pandas as pd
from py.misс import load_data
from py.qa_models.dataset_model import QADataModule
from transformers import AutoTokenizer, T5Tokenizer

MODEL_NAME = "ai-forever/ruT5-base"
BATCH_SIZE = 24
N_EPOCHS = 3
tokenizer:T5Tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, legacy = False)
MAX_TEXT_TOKEN_LEN = 396
train_df = load_data("datasets\\sberquad\\train_v1.0\\train_v1.0.json", dropDuplicates=False, tokenizer = tokenizer, max_context_token_len=MAX_TEXT_TOKEN_LEN)

context_token_len = train_df["context_token_len"].max()
question_token_len = train_df["question_token_len"].max()
answer_max_token_len = train_df["answer_max_token_len"].max()

val_df = load_data("datasets\\sberquad\\dev_v1.0\\dev_v1.0.json", dropDuplicates=False, tokenizer = tokenizer, max_context_token_len=context_token_len)

#idxmax = train_df.idxmax(axis = 0)

dataLoader = QADataModule(tokenizer, train_df, val_df, val_df, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE,summary_max_token_len=answer_max_token_len,text_max_token_len=context_token_len)

dataLoader.setup()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
print(f"max_context_token_len: {context_token_len}")
print(f"max_question_token_len: {question_token_len}")
print(f"max_answer_max_token_len: {answer_max_token_len}")

max_context_token_len: 393
max_question_token_len: 140
max_answer_max_token_len: 232


In [4]:
len(train_df)

44954

In [None]:
from py.qa_models.model import QAModule
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import (ModelCheckpoint, EarlyStopping)
from transformers import T5ForConditionalGeneration

import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
t5_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
model = QAModule(t5_model, lr=1e-4, weight_decay=1e-5)
model.to(device)

torch.set_float32_matmul_precision('medium')

tb_logger = TensorBoardLogger('logs', name="t5_qa",)

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="best_model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min",
    every_n_epochs=1,
    verbose=True
)

#early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min")

trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    logger=tb_logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    log_every_n_steps=1,
)

trainer.fit(model, dataLoader)
#trainer.validate(model, dataLoader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\SawKing\Documents\T5\.venv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:652: Checkpoint directory C:\Users\SawKing\Documents\T5\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params | Mode
------------------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M  | eval
------------------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

c:\Users\SawKing\Documents\T5\.venv\Lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [4]:
trainer.test(datamodule=dataLoader)

Restoring states from the checkpoint path at C:\Users\SawKing\Documents\T5\checkpoints\best_model-epoch=00-val_loss=2.88.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at C:\Users\SawKing\Documents\T5\checkpoints\best_model-epoch=00-val_loss=2.88.ckpt
c:\Users\SawKing\Documents\T5\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_loss_epoch        2.7646102905273438
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 2.7646102905273438}]

In [None]:
qa_sample = val_df.iloc[0]
qa_sample

question        Вопрос|Где встречаются первые упоминания о стр...
context         Контекст|Первые упоминания о строении человече...
answer_text                                      в Древнем Египте
answer start                                                   60
answer_end                                                     76
Name: 0, dtype: object