In [None]:
import os
import json
from typing import List

import torch
from torch.utils.data import DataLoader

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
)
import sacrebleu
from tqdm.auto import tqdm

# ----------------- Config -----------------
DATA_DIR = "wikilarge"
MODEL_NAME = "t5-base"
OUTPUT_DIR = "t5_base_wikilarge_baseline"

MAX_SOURCE_LENGTH = 128
MAX_TARGET_LENGTH = 128

NUM_EPOCHS = 1
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 32
LR = 5e-5

MAX_TRAIN_SAMPLES = None


# ---------- Simple SARI implementation (no easse) ----------

def _get_ngrams(tokens: List[str], n: int):
    return [" ".join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]

def _sari_sentence(src: str, cand: str, refs: List[str], max_n: int = 4) -> float:
    src_toks = src.split()
    cand_toks = cand.split()
    refs_toks = [r.split() for r in refs]

    F_add_total = F_keep_total = F_del_total = 0.0
    n_count = 0

    for n in range(1, max_n+1):
        n_count += 1
        src_ngrams = set(_get_ngrams(src_toks, n))
        cand_ngrams = set(_get_ngrams(cand_toks, n))
        refs_ngrams_list = [set(_get_ngrams(rt, n)) for rt in refs_toks]

        # Added
        cand_add = cand_ngrams - src_ngrams
        refs_add = set().union(*[(r - src_ngrams) for r in refs_ngrams_list]) if refs_ngrams_list else set()
        overlap_add = cand_add & refs_add
        P_add = len(overlap_add) / len(cand_add) if cand_add else 0.0
        R_add = len(overlap_add) / len(refs_add) if refs_add else 0.0
        F_add = 2 * P_add * R_add / (P_add + R_add) if (P_add + R_add) > 0 else 0.0

        # Kept
        cand_keep = cand_ngrams & src_ngrams
        refs_keep = set().union(*[(r & src_ngrams) for r in refs_ngrams_list]) if refs_ngrams_list else set()
        overlap_keep = cand_keep & refs_keep
        P_keep = len(overlap_keep) / len(cand_keep) if cand_keep else 0.0
        R_keep = len(overlap_keep) / len(refs_keep) if refs_keep else 0.0
        F_keep = 2 * P_keep * R_keep / (P_keep + R_keep) if (P_keep + R_keep) > 0 else 0.0

        # Deleted
        src_del = src_ngrams - cand_ngrams
        refs_del = set().union(*[(src_ngrams - r) for r in refs_ngrams_list]) if refs_ngrams_list else set()
        overlap_del = src_del & refs_del
        P_del = len(overlap_del) / len(src_del) if src_del else 0.0
        R_del = len(overlap_del) / len(refs_del) if refs_del else 0.0
        F_del = 2 * P_del * R_del / (P_del + R_del) if (P_del + R_del) > 0 else 0.0

        F_add_total += F_add
        F_keep_total += F_keep
        F_del_total += F_del

    F_add_avg = F_add_total / n_count
    F_keep_avg = F_keep_total / n_count
    F_del_avg = F_del_total / n_count

    return (F_add_avg + F_keep_avg + F_del_avg) / 3.0

def sari_corpus(sources: List[str], candidates: List[str], references: List[List[str]]) -> float:
    scores = [
        _sari_sentence(s, c, rs)
        for s, c, rs in zip(sources, candidates, references)
    ]
    return sum(scores) / len(scores)


# ---------- Load WikiLarge ----------

def read_parallel(src_path, dst_path, filter_short=True, filter_ratio=True):
    with open(src_path, "r", encoding="utf-8") as f_src, \
         open(dst_path, "r", encoding="utf-8") as f_dst:
        src_lines = [line.strip() for line in f_src.readlines()]
        dst_lines = [line.strip() for line in f_dst.readlines()]

    assert len(src_lines) == len(dst_lines)
    src_out, dst_out = [], []

    for s, d in zip(src_lines, dst_lines):
        if not s or not d:
            continue
        if filter_short and (len(s.split()) < 5 or len(d.split()) < 3):
            continue
        if filter_ratio:
            # keep only cases where target is shorter than source
            if len(d.split()) >= len(s.split()):
                continue
            ratio = len(d.split()) / len(s.split())
            if ratio >= 1:
                continue

        src_out.append(s)
        dst_out.append(d)

    # NOTE: columns are named "source" and "target"
    return {"source": src_out, "target": dst_out}

def load_wikilarge(data_dir):
    train = read_parallel(
        os.path.join(data_dir, "wiki.full.aner.train.src"),
        os.path.join(data_dir, "wiki.full.aner.train.dst"),
        filter_short=True,
        filter_ratio=True
    )
    valid = read_parallel(
        os.path.join(data_dir, "wiki.full.aner.valid.src"),
        os.path.join(data_dir, "wiki.full.aner.valid.dst"),
        filter_short=True,
        filter_ratio=True
    )
    test = read_parallel(
        os.path.join(data_dir, "wiki.full.aner.test.src"),
        os.path.join(data_dir, "wiki.full.aner.test.dst"),
        filter_short=False,
        filter_ratio=False
    )

    ds_train = Dataset.from_dict(train)
    ds_valid = Dataset.from_dict(valid)
    ds_test = Dataset.from_dict(test)

    return DatasetDict(train=ds_train, validation=ds_valid, test=ds_test)

raw_datasets = load_wikilarge(DATA_DIR)


# ---------- Tokenizer + model ----------

print(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

def preprocess(examples):
    # use correct column names: "source" and "target"
    model_inputs = tokenizer(
        examples["source"],
        max_length=MAX_SOURCE_LENGTH,
        truncation=True,
    )
    labels = tokenizer(
        text_target=examples["target"],
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(
    preprocess,
    batched=True,
    remove_columns=["source", "target"],
)

print(tokenized_datasets)


# ---------- Manual training loop (subset) ----------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model.to(device)

train_dataset_tok = tokenized_datasets["train"]
if MAX_TRAIN_SAMPLES is not None and MAX_TRAIN_SAMPLES < len(train_dataset_tok):
    train_dataset_tok = train_dataset_tok.select(range(MAX_TRAIN_SAMPLES))

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

train_loader = DataLoader(
    train_dataset_tok,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=data_collator,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

print("Start training on subset:", len(train_dataset_tok), "examples")
model.train()
for epoch in range(NUM_EPOCHS):
    epoch_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
    for batch in epoch_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_bar.set_postfix({"loss": loss.item()})


# ---------- Evaluation: BLEU + SARI (with tqdm) ----------

print("Evaluating on test set (BLEU + SARI)...")
test_dataset_tok = tokenized_datasets["test"]
test_raw = raw_datasets["test"]

test_loader = DataLoader(
    test_dataset_tok,
    batch_size=EVAL_BATCH_SIZE,
    shuffle=False,
    collate_fn=data_collator,
)

pred_texts = []
model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating", leave=False):
        batch = {
            k: v.to(device)
            for k, v in batch.items()
            if k in ["input_ids", "attention_mask"]
        }
        outputs = model.generate(
            **batch,
            max_length=MAX_TARGET_LENGTH,
            num_beams=4,
        )
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        pred_texts.extend(decoded)

# Align lengths using correct raw keys
test_src = test_raw["source"][: len(pred_texts)]
test_ref = test_raw["target"][: len(pred_texts)]

# BLEU (0–100)
bleu = sacrebleu.corpus_bleu(pred_texts, [test_ref]).score

# SARI (our implementation is 0–1; convert to 0–100 for reporting)
sari_raw = sari_corpus(
    test_src,
    pred_texts,
    [[r] for r in test_ref],
)
sari = sari_raw * 100.0

print(f"Test BLEU: {bleu:.2f}")
print(f"Test SARI: {sari:.2f}")


# ---------- Save outputs to OUTPUT_DIR ----------

os.makedirs(OUTPUT_DIR, exist_ok=True)

# 1) Save model + tokenizer
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# 2) Save metrics
metrics = {"bleu": float(bleu), "sari": float(sari)}
with open(os.path.join(OUTPUT_DIR, "metrics.json"), "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=2)

# 3) Save predictions for error analysis
pred_path = os.path.join(OUTPUT_DIR, "predictions.tsv")
with open(pred_path, "w", encoding="utf-8") as f:
    f.write("src\tref\tpred\n")
    for s, r, p in zip(test_src, test_ref, pred_texts):
        s_clean = s.replace("\t", " ")
        r_clean = r.replace("\t", " ")
        p_clean = p.replace("\t", " ")
        f.write(f"{s_clean}\t{r_clean}\t{p_clean}\n")

print("Saved model and outputs to:", OUTPUT_DIR)
