In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer

In [None]:
import random

import numpy as np
import torch


def seed_everything(seed=10):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything()

In [None]:
# [pin]
file_path = "../data/data-hard2.csv"
root_path = "../data/"


df = pd.read_csv(file_path)
df["prefix"] = "clsorg"
df = df.rename({"message": "input_text", "label": "target_text"}, axis=1)
df.sample(20)

Unnamed: 0,input_text,target_text,prefix
6626,🛢🇪🇺#газ #европа #запасы По данным Gas Infrast...,48-3,clsorg
4531,🇷🇺 $LKOH #buyback ЛУКОЙЛ ПРОСИТ РАЗРЕШЕНИЯ У ...,111-4,clsorg
4064,☎️ Ростелеком: брать или не брать — вот в чем ...,142-4,clsorg
1326,#CHMF Cеверсталь изучает возможность строитель...,152-4,clsorg
4884,🇷🇺#NVTK #расписки «НОВАТЭК» сообщает о заверш...,115-3,clsorg
1128,"""📢Banking news 🔻ЦБ оштрафовал Тинькофф Банк и...",7-4,clsorg
3721,​​Эталон - как спрятать слабость за сделкой? ...,56-3;218-2,clsorg
6649,🛢🇷🇺#NVTK #GAZP #спг #газ Перспективный проект ...,115-3,clsorg
3854,​​🟢 ИТОГИ ДНЯ. Российский рынок акций вновь ра...,99-3;227-4;89-3;90-3,clsorg
2081,Банк из северной столицы идет на юг Вышел о...,33-4,clsorg


In [None]:
# m_name = "t5-small"
m_name = "cointegrated/rut5-small"
tokenizer = T5Tokenizer.from_pretrained(m_name)

In [None]:
from src.t5.dataset import NERDataModel

BATCH_SIZE = 128
EPOCHS = 10
num_workers = 12
train_df, test_df = train_test_split(df, test_size=0.25, random_state=42)
data_module = NERDataModel(
    train_df, test_df, tokenizer, batch_size=BATCH_SIZE, num_workers=num_workers
)
data_module.setup()

In [None]:
from transformers import T5ForConditionalGeneration

In [None]:
m_name = "t5-small"
trained_model = T5ForConditionalGeneration.from_pretrained(m_name, return_dict=True)
state_dict = torch.load("./checkpoints/ner-v8.ckpt")["state_dict"]
state_dict = {k.partition("model.")[2]: v for k, v in state_dict.items()}
trained_model.load_state_dict(state_dict)
trained_model.save_pretrained("./pretrained")

In [None]:
m_name = "../pretrained-rut5-2"
trained_model = T5ForConditionalGeneration.from_pretrained(
    m_name, return_dict=True, torch_dtype=torch.float16
)
trained_model.save_pretrained("../pretrained-rut5-2-fp16")

In [None]:
m_name = "../pretrained-rut5-2-fp16"
trained_model = T5ForConditionalGeneration.from_pretrained(
    m_name, return_dict=True, torch_dtype=torch.float16
)
trained_model.cuda();

In [None]:
import torch

from final_solution.utils import generate_answer_batched, postprocess_predictions
from src.t5.utils import evaluate_metric

In [None]:
# [pin]
with torch.inference_mode(), torch.cuda.amp.autocast():
    predictions = generate_answer_batched(
        trained_model=trained_model,
        tokenizer=tokenizer,
        data=test_df[:],
        batch_size=128,
        num_beams=1,
        max_source_length=396,
        max_target_length=40,
        verbose=False,
    )

100%|██████████| 14/14 [00:12<00:00,  1.11it/s]


In [None]:
ldf = test_df.copy()[:]
ldf[["tcomp", "tsent"]] = (
    ldf["target_text"].str.split(";", expand=True)[0].str.split("-", expand=True)
)

In [None]:
orgsent = list(
    map(lambda x: x[0] if len(x) else [0, 1], postprocess_predictions(predictions))
)
org = list(map(lambda x: x[0], orgsent))
sent = list(map(lambda x: x[1], orgsent))

In [None]:
len(org), len(ldf["tcomp"].tolist())

In [None]:
# [pin]

evaluate_metric(
    company_predictions=org,
    company_labels=ldf["tcomp"].tolist(),
    sentiment_predictions=sent,
    sentiment_labels=ldf["tsent"].tolist(),
)

{'total': 66.29664090017164,
 'f1': 0.7122964543670691,
 'accuracy': 0.6136363636363636}