In [None]:
!pip install -q  datasets peft accelerate bitsandbytes
!pip install evaluate rouge_score nltk
import transformers, sys
print(transformers.__version__)
print(transformers.__file__)

In [None]:
# from __future__ import annotations
import time, nltk, numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments,
    TrainerCallback,
)
import evaluate

TRAIN_FILE   = "train_data_clean.csv"
TEST_FILE    = "test_data_clean.csv"
VAL_FRAC     = 0.1
MAX_SRC_LEN  = 1024
MAX_TGT_LEN  = 256
BATCH_SIZE   = 6
LR           = 2e-5
EPOCHS       = 3
LOG_STEPS    = 10


nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
all_data = load_dataset("csv", data_files={"train": TRAIN_FILE, "test": TEST_FILE})
train_ds, val_ds = all_data["train"].train_test_split(test_size=VAL_FRAC, seed=42).values()

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", use_fast=True)

def preprocess(batch):
    tok_inp = tokenizer(batch["text"], max_length=MAX_SRC_LEN, truncation=True)
    with tokenizer.as_target_tokenizer():
        lbls = tokenizer(batch["summary"], max_length=MAX_TGT_LEN, truncation=True)
    tok_inp["labels"] = lbls["input_ids"]
    return tok_inp

cols = train_ds.column_names
train_tok = train_ds.map(preprocess, batched=True, remove_columns=cols, desc="Tokenise train")
val_tok = val_ds.map(preprocess,   batched=True, remove_columns=cols, desc="Tokenise val")


model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
collator = DataCollatorForSeq2Seq(tokenizer, model=model, label_pad_token_id=-100)
rouge = evaluate.load("rouge")

def _sent_split(txts):
    return ["".join(nltk.sent_tokenize(t.strip())) for t in txts]

def _decode(seqs):
    return [tokenizer.decode([int(x) for x in seq if int(x) >= 0], skip_special_tokens=True) for seq in seqs]

def compute_metrics(pred):
    y_pred, y_true = pred
    if isinstance(y_pred, tuple):
        y_pred = y_pred[0]
    preds = _sent_split(_decode(y_pred))
    refs  = _sent_split(_decode(np.where(y_true != -100, y_true, tokenizer.pad_token_id)))
    return {k: round(v * 100, 4) for k, v in rouge.compute(predictions=preds, references=refs, use_stemmer=True).items()}

train_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    logging_dir="./logs",
    logging_strategy="steps", logging_steps=LOG_STEPS,
    save_strategy="epoch",
    eval_strategy="no",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    predict_with_generate=True,
    generation_max_length=MAX_TGT_LEN,
    save_total_limit=3,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

# if __name__ == "__main__":
    # trainer.train()

    # final_path = "./results/final_model"
    # trainer.save_model(final_path)
    # print(f"Model saved to {final_path}")

    # print("Validation:")
    # val_metrics = trainer.evaluate()
    # print(val_metrics)

In [None]:
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
print("Running single validation evaluation…")
val_metrics = trainer.evaluate()
print("Validation ROUGE:", val_metrics)

In [None]:
trainer.save_state()

import json, os
json_path = "./trainer_state_post.json"
trainer.state.save_to_json(json_path)
print("state written to", os.path.abspath(json_path))

In [None]:
import json
import matplotlib.pyplot as plt

with open("/content/trainer_state_post.json") as f:
    history = json.load(f)["log_history"]

train_steps, train_losses, eval_steps, eval_losses = [], [], [], []
for x in history:
    if "loss" in x: train_steps.append(x["step"]); train_losses.append(x["loss"])
    if "eval_loss" in x: eval_steps.append(x["step"]); eval_losses.append(x["eval_loss"])

plt.plot(train_steps, train_losses, label="Train Loss", marker="o")
plt.plot(eval_steps, eval_losses, label="Eval Loss", marker="x")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.legend(); plt.grid(); plt.tight_layout()
plt.savefig("/content/loss_plot.png")

In [None]:
trainer.save_model("./my_final_model")

In [None]:
!zip -r /content.zip /content/

from google.colab import files
files.download("/content.zip")

In [None]:
# !rm -rf /content/sample_data/

In [None]:
import torch
print("\nGenerating one example from validation set:")
sample = val_ds[0]["text"]
target = val_ds[0]["summary"]

inputs = tokenizer(sample, return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN).to(model.device)
model.eval()
with torch.no_grad():
    outputs = model.generate(**inputs, max_length=MAX_TGT_LEN)
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("SOURCE:\n", sample[:500], "...\n")
print("REFERENCE SUMMARY:\n", target, "\n")
print("MODEL PREDICTION:\n", prediction)

In [None]:
test_tok = all_data["test"].map(preprocess, batched=True, remove_columns=cols, desc="Tokenise test")
test_metrics = trainer.evaluate(eval_dataset=test_tok)

In [None]:
print(test_metrics)