In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

In [None]:
import pandas as pd
import pytorch_lightning as pl
from torch.optim import AdamW
from transformers import T5ForConditionalGeneration, T5Tokenizer


class DistillNERModel(pl.LightningModule):
    def __init__(
        self, teacher: T5ForConditionalGeneration, student: T5ForConditionalGeneration
    ):
        super().__init__()
        self.teacher_model = teacher.requires_grad_(False)
        self.student_model = student

    def forward(self, input_ids, attention_mask, labels=None):
        # teacher_output = self.teacher_model(
        #     input_ids=input_ids,
        #     attention_mask=attention_mask,
        #     labels=labels,
        #     output_hidden_states=True,
        # )
        student_output = self.student_model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

        return student_output.loss, student_output.logits

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=0.001)

In [None]:
file_path = "../data/data-hard.csv"
root_path = "../data/"


df = pd.read_csv(file_path)
df["prefix"] = "clsorg"
df = df.rename({"message": "input_text", "label": "target_text"}, axis=1)
df.sample(20)

In [None]:
from copy import deepcopy


def get_student_t5_model(model):
    new_model = deepcopy(model)
    encoder = new_model.encoder
    decoder = new_model.decoder
    print(len(encoder.block))
    for i in [3, 1]:
        encoder.block.pop(i)

    print(len(decoder.block))
    for i in [3, 1]:
        decoder.block.pop(i)

    return new_model


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
m_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(m_name)
teacher_model = T5ForConditionalGeneration.from_pretrained("./pretrained/")
student_model = get_student_t5_model(teacher_model)

In [None]:
# [pin]
count_parameters(teacher_model) / count_parameters(student_model)

1.3204871780546399

In [None]:
from sklearn.model_selection import train_test_split

from src.t5.dataset import NERDataModel

BATCH_SIZE = 64
EPOCHS = 10
train_df, test_df = train_test_split(df, test_size=0.25, random_state=42)
data_module = NERDataModel(
    train_df, test_df, tokenizer, batch_size=BATCH_SIZE, source_max_token_length=396
)
data_module.setup()

In [None]:
loader = data_module.train_dataloader()
encoded_batch = next(iter(loader))
encoded_batch.keys()

In [None]:
teacher_out = teacher_model(
    input_ids=encoded_batch["input_ids"],
    attention_mask=encoded_batch["attention_mask"],
    labels=encoded_batch["labels"],
)

In [None]:
# [pin]
count_parameters(student_model)

45821440

In [None]:
# with torch.no_grad():
#     teacher_out = teacher_model(
#         input_ids=encoded_batch["input_ids"],
#         attention_mask=encoded_batch["attention_mask"],
#         labels=encoded_batch["labels"],
#         output_hidden_states=True,
#     )
#     student_out = student_model(
#         input_ids=encoded_batch["input_ids"].cuda(),
#         attention_mask=encoded_batch["attention_mask"].cuda(),
#         labels=encoded_batch["labels"].cuda(),
#         output_hidden_states=True,
#     )

In [None]:
# kl_loss = torch.nn.KLDivLoss()
# def distillation_loss(teacher_output, student_output):
#     loss = 0
#     for i, tdhs in enumerate(teacher_output)
#     decoder_loss = teacher_output['decoder_hidden_states'][]

In [None]:
# 0,2,4,5,6
# teacher_out["encoder_hidden_states"]

In [None]:
student_model.cuda()
teacher_model.cuda();

In [None]:
import torch

from src.t5.utils import evaluate_metric, generate_answer_batched

with torch.inference_mode(), torch.cuda.amp.autocast():
    predictions = generate_answer_batched(
        trained_model=student_model,
        tokenizer=tokenizer,
        data=test_df[:1200],
        batch_size=64,
        max_length=396,
    )

In [None]:
import torch

from src.t5.utils import generate_answer_batched

with torch.inference_mode(), torch.cuda.amp.autocast():
    predictions = generate_answer_batched(
        trained_model=teacher_model,
        tokenizer=tokenizer,
        data=test_df[:1200],
        batch_size=64,
        max_length=396,
    )

In [None]:
ldf = test_df.copy()[:1200]
ldf["predictions"] = predictions
ldf[["tcomp", "tsent"]] = (
    ldf["target_text"].str.split(";", expand=True)[0].str.split("-", expand=True)
)
ldf[["pcomp", "psent"]] = (
    ldf["predictions"].str.split(";", expand=True)[0].str.split("-", expand=True)
)

ldf.drop(
    index=ldf[ldf["pcomp"].str.findall(r"[^\d]").str.len() > 0].index, inplace=True
)
ldf["psent"] = ldf["psent"].fillna(1).replace("0", "1")

In [None]:
# [pin]

evaluate_metric(
    company_predictions=ldf["pcomp"].tolist(),
    company_labels=ldf["tcomp"].tolist(),
    sentiment_predictions=ldf["psent"].tolist(),
    sentiment_labels=ldf["tsent"].tolist(),
)

{'total': 32.873398668342645,
 'f1': 0.25963644709495964,
 'accuracy': 0.3978315262718932}