In [None]:
#! pip install datasets transformers rouge-score nltk

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import torch
from datasets import load_dataset
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import Seq2SeqTrainer
from transformers import TrainerCallback
import time
import sys

In [None]:
df = pd.read_csv('df-model-with-hierarchical-labels/df_for_model_with_hierarchical.csv')

#df = df.sample(n=1000000, random_state=42)

In [None]:
# convert label to string
df['hierarchical_label'] = df['hierarchical_label'].astype(str)

In [None]:
# define a data set class for the seq2seq model

class Seq2SeqSpecimenDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_input_length=512, max_target_length=16):
        self.tokenizer = tokenizer
        self.inputs = dataframe['text_combined'].fillna("Unknown").tolist()
        self.targets = dataframe['hierarchical_label'].fillna("Unknown").tolist()
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = "predict cell token: " + self.inputs[idx]
        target_text = self.targets[idx]

        # No fixed padding
        model_input = self.tokenizer(
            input_text,
            max_length=self.max_input_length,
            truncation=True,
            return_tensors='pt'
        )

        target = self.tokenizer(
            text_target=target_text,
            max_length=self.max_target_length,
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': model_input['input_ids'].squeeze(0),
            'attention_mask': model_input['attention_mask'].squeeze(0),
            'labels': target['input_ids'].squeeze(0)
        }




In [None]:
# model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)




In [None]:
# split 

df_train, df_val = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = Seq2SeqSpecimenDataset(df_train, tokenizer)
val_dataset = Seq2SeqSpecimenDataset(df_val, tokenizer)





In [None]:
len(train_dataset)

In [None]:
# initialise data set

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=False,
    logging_dir="./logs",
    num_train_epochs=5,
    report_to=[],
    fp16=True
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,  # important so it pads `labels` with -100 for loss masking
    padding=True  # enables dynamic padding
)

In [None]:


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer  # important for decoding
)

In [None]:
# loffer to obtain good logs output when running
class EpochLoggerCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        self.epoch_start_time = time.time()
        print(f"🔄 Starting epoch {state.epoch:.0f}")
        sys.stdout.flush()

    def on_epoch_end(self, args, state, control, **kwargs):
        duration = (time.time() - self.epoch_start_time) / 60
        print(f"✅ Finished epoch {state.epoch:.0f} in {duration:.2f} minutes")
        sys.stdout.flush()


In [None]:
trainer.add_callback(EpochLoggerCallback())

In [None]:
print("📢 Starting training...")
trainer.train()
print("🎉 Training complete.")

In [None]:
# save the model and tokenizer 
save_path = "/kaggle/working/fine_tuned_model_t5_flan_less_cells_gt5"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved locally to {save_path}")