In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer

In [None]:
import random

import numpy as np
import torch


def seed_everything(seed=10):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything()

In [None]:
# [pin]
file_path = "../data/data-hard.csv"
root_path = "../data/"


df = pd.read_csv(file_path)
df["prefix"] = "clsorg"
df = df.rename({"message": "input_text", "label": "target_text"}, axis=1)
df.sample(20)

Unnamed: 0,input_text,target_text,prefix
5222,🇷🇺#TCSG #отчетность КОНСЕНСУС: TCS Group во I...,225-4,clsorg
4135,⏰ Доброе утро! 16 марта 🌍 Ночное дежурство (з...,228-3;251-3,clsorg
3130,СПБ Биржа начнет торги ценными бумагами восьми...,255-4,clsorg
1629,#SMLT Откуп в акциях Самолёта: гэп на открыти...,56-3,clsorg
3917,​​🟢 ИТОГИ ДНЯ. Российские акции немного подрос...,90-2;152-2,clsorg
4392,⛔️ Россия запрещает экспорт бензина. Какие ком...,25-2,clsorg
5193,🇷🇺#SPBE #отчетность Итоги торгов на СПБ Бирже...,255-3,clsorg
2996,Просто вспомните как просрался Севка на ожидан...,90-5,clsorg
5391,🇷🇺#авиа #россия Россия может открыть прямые а...,32-4,clsorg
4306,⚡️ Сбер (SBER) отыграл всё падение на СВО. #хв...,150-3,clsorg


In [None]:
# m_name = "t5-small"
m_name = "cointegrated/rut5-small"
tokenizer = T5Tokenizer.from_pretrained(m_name)

In [None]:
from src.t5.dataset import NERDataModel

BATCH_SIZE = 128
EPOCHS = 10
num_workers = 12
train_df, test_df = train_test_split(df, test_size=0.25, random_state=42)
data_module = NERDataModel(
    train_df, test_df, tokenizer, batch_size=BATCH_SIZE, num_workers=num_workers
)
data_module.setup()

In [None]:
from transformers import T5ForConditionalGeneration

In [None]:
m_name = "t5-small"
trained_model = T5ForConditionalGeneration.from_pretrained(m_name, return_dict=True)
state_dict = torch.load("./checkpoints/ner-v8.ckpt")["state_dict"]
state_dict = {k.partition("model.")[2]: v for k, v in state_dict.items()}
trained_model.load_state_dict(state_dict)
trained_model.save_pretrained("./pretrained")

In [None]:
m_name = "../pretrained-rut5-2"
trained_model = T5ForConditionalGeneration.from_pretrained(
    m_name, return_dict=True, torch_dtype=torch.float16
)
trained_model.save_pretrained("../pretrained-rut5-2-fp16")

In [None]:
m_name = "../pretrained-rut5-2-fp16"
trained_model = T5ForConditionalGeneration.from_pretrained(
    m_name, return_dict=True, torch_dtype=torch.float16
)
trained_model.cuda();

In [None]:
import torch

from src.t5.utils import evaluate_metric, generate_answer_batched

In [None]:
# [pin]
with torch.inference_mode(), torch.cuda.amp.autocast():
    predictions = generate_answer_batched(
        trained_model=trained_model,
        tokenizer=tokenizer,
        data=test_df[:],
        batch_size=128,
        num_beams=1,
        max_source_length=396,
        max_target_length=40,
        verbose=False,
    )

100%|██████████| 15/15 [00:15<00:00,  1.02s/it]


In [None]:
ldf = test_df.copy()[:]
ldf[["tcomp", "tsent"]] = (
    ldf["target_text"].str.split(";", expand=True)[0].str.split("-", expand=True)
)

In [None]:
from src.t5.utils import postprocess_predictions

orgsent = list(
    map(lambda x: x[0] if len(x) else [0, 1], postprocess_predictions(predictions))
)
org = list(map(lambda x: x[0], orgsent))
sent = list(map(lambda x: x[1], orgsent))

In [None]:
len(org), len(ldf["tcomp"].tolist())

In [None]:
# [pin]

evaluate_metric(
    company_predictions=org,
    company_labels=ldf["tcomp"].tolist(),
    sentiment_predictions=sent,
    sentiment_labels=ldf["tsent"].tolist(),
)

{'total': 74.53947501804372,
 'f1': 0.7940727502217538,
 'accuracy': 0.6967167501391207}