In [None]:
# ============ 0. Imports and Setup ============
import json, os
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import torch
import torch.nn as nn
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from tabulate import tabulate
from sklearn.metrics import f1_score
from collections import Counter
from google.colab import drive
import torchcrf
from inspect import signature
from transformers import TrainingArguments

# ============ 1. Mount Drive and Load Data ============
drive.mount('/content/drive')

DATA_DIR = Path("/content/drive/My Drive/SEMEVAL/data")
train_file = DATA_DIR / "training_set_task2.txt"
dev_file   = DATA_DIR / "dev_set_task2.txt"

def load_jsonl(path):
    return json.load(open(path, "r", encoding="utf-8"))

train_data = load_jsonl(train_file)
dev_data   = load_jsonl(dev_file)

# ============ 2. Build Label Map ============
all_techniques = sorted({
    lbl["technique"].strip()
    for item in train_data + dev_data
    for lbl in item.get("labels", [])
})
label2id = {tech: i for i, tech in enumerate(all_techniques)}
id2label = {i: tech for tech, i in label2id.items()}

# BIO tagging labels
bio_labels = ["O"] + [prefix + "-" + tech for tech in all_techniques for prefix in ["B", "I"]]
bio_label2id = {lbl: i for i, lbl in enumerate(bio_labels)}
bio_id2label = {i: lbl for lbl, i in bio_label2id.items()}

# ============ 3. Tokenizer ============
MODEL_NAME = "roberta-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# ============ 4. Stage 1 - Span Detection with BIO-CRF ============
class SpanDetectorModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.backbone = AutoModelForTokenClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            problem_type="single_label_classification"
        )
        self.crf = torchcrf.CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        logits = self.backbone(input_ids=input_ids, attention_mask=attention_mask).logits
        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
            return {"loss": loss, "logits": logits}
        else:
            pred = self.crf.decode(logits, mask=attention_mask.bool())
            return {"logits": pred}

span_detector = SpanDetectorModel(MODEL_NAME, len(bio_labels))

def tokenize_and_bio(example):
    tokens = tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=256,
        return_offsets_mapping=True
    )
    offsets = tokens.pop("offset_mapping")
    labels = ["O"] * len(offsets)

    for span in example.get("labels", []):
        start_char, end_char = span["start"], span["end"]
        tech = span["technique"].strip()
        first = True
        for i, (s, e) in enumerate(offsets):
            if s >= end_char or e <= start_char:
                continue
            labels[i] = ("B-" if first else "I-") + tech
            first = False

    tokens["labels"] = [bio_label2id[lbl] for lbl in labels]
    return tokens

raw_datasets = DatasetDict({
    "train": Dataset.from_list(train_data),
    "dev":   Dataset.from_list(dev_data)
})
bio_tokenized = raw_datasets.map(tokenize_and_bio, batched=False)

# ============ 5. Stage 2 - Span Classification ============
span_classifier = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id)
)

def extract_spans_from_bio(tokens, preds):
    spans = []
    current = None
    for i, tag_idx in enumerate(preds):
        tag = bio_id2label[tag_idx]
        if tag == "O":
            if current:
                spans.append(current)
                current = None
        else:
            prefix, tech = tag.split("-", 1)
            if prefix == "B" or (current and current["technique"] != tech):
                if current:
                    spans.append(current)
                current = {"technique": tech, "start": tokens["offset_mapping"][i][0], "end": tokens["offset_mapping"][i][1]}
            else:  # I- continuation
                current["end"] = tokens["offset_mapping"][i][1]
    if current:
        spans.append(current)
    return spans

# ============ 6. Class Weights for Imbalanced Classification ============
label_counts = Counter()
total_tokens = 0
for ex in bio_tokenized["train"]:
    for lbl in ex["labels"]:
        if lbl != 0:
            label_counts[lbl] += 1
        total_tokens += 1

class_weights = []
for i in range(len(bio_labels)):
    count = label_counts.get(i, 1)
    weight = total_tokens / (len(bio_labels) * count)
    class_weights.append(weight)

class_weights_tensor = torch.tensor(class_weights)

# ============ 7. Official Partial-Overlap Scorer ============
def compute_PRF(gold_spans, pred_spans, label):
    S = [s for s in pred_spans if s["technique"] == label]
    T = [t for t in gold_spans if t["technique"] == label]
    if not S or not T:
        return 0.0, 0.0, 0.0
    def C(s,t):
        if s["technique"] != t["technique"]: return 0
        inter = max(0, min(s["end"], t["end"]) - max(s["start"], t["start"]))
        return inter / (s["end"] - s["start"])
    P = sum(max(C(s, t) for t in T) for s in S) / len(S)
    R = sum(max(C(s, t) for s in S) for t in T) / len(T)
    F1 = 2*P*R/(P+R) if (P+R)>0 else 0.0
    return P, R, F1


def evaluate_spans(examples, predictions, thresholds):
    # examples: list of dicts with "text","labels"
    # predictions: raw logits [batch, seq_len, num_labels]
    preds = torch.sigmoid(torch.tensor(predictions)).numpy()
    all_gold = []
    all_pred = []
    for ex,logits in zip(examples, preds):
        # reconstruct spans
        # get offsets
        enc = tokenizer(ex["text"], truncation=True, padding="max_length",
                        max_length=256, return_offsets_mapping=True)
        offsets = enc["offset_mapping"]
        gold = ex["labels"]
        pred = []
        for lbl_idx, thresh in thresholds.items():
            for i, (s, e) in enumerate(offsets):
                if s < e and logits[i, lbl_idx] > thresh:
                    # Use BIO labels
                    bio_label = bio_id2label[lbl_idx]
                    if bio_label == "O":
                        continue  # Ignore 'O' tags
                    prefix, technique = bio_label.split("-", 1)
                    pred.append({"technique": technique, "start": s, "end": e})
    # compute per-label and overall micro
    rows = []
    micro_tp=micro_fp=micro_fn=0
    for tech in all_techniques:
        P_list,R_list,F1_list=[],[],[]
        for gold,pred in zip(all_gold, all_pred):
            p,r,f = compute_PRF(gold,pred,tech)
            P_list.append(p); R_list.append(r); F1_list.append(f)
            # for micro: count as if each example a unit
            micro_tp += p
            micro_fp += (1-p)
            micro_fn += (1-r)
        rows.append([tech,
                     np.mean(P_list),
                     np.mean(R_list),
                     np.mean(F1_list)])
    microP = micro_tp/(micro_tp+micro_fp) if micro_tp+micro_fp>0 else 0
    microR = micro_tp/(micro_tp+micro_fn) if micro_tp+micro_fn>0 else 0
    microF = 2*microP*microR/(microP+microR) if microP+microR>0 else 0
    return rows, (microP,microR,microF)
# ============ 8. Trainer ============

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    # Threshold optimization on dev batch
    thresholds = {}
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    labels = np.array(labels)

    # One-hot encode labels to match [batch_size, seq_len, num_labels]
    if labels.ndim == 2:  # (batch_size, seq_len)
        one_hot_labels = np.zeros_like(probs)
        for i in range(probs.shape[2]):
            one_hot_labels[:, :, i] = (labels == i).astype(int)
    else:
        one_hot_labels = labels

    for i in range(probs.shape[2]):
        best_f, best_t = 0, 0.5
        for t in np.arange(0.1, 0.9, 0.05):
            pr = (probs[:, :, i] > t).astype(int)
            f = f1_score(one_hot_labels[:, :, i].flatten(), pr.flatten(), zero_division=0)
            if f > best_f:
                best_f, best_t = f, t
        thresholds[i] = best_t

    # Evaluate using thresholds
    rows, micro = evaluate_spans(
        raw_datasets["dev"].select(range(len(probs))),
        logits,
        thresholds
    )

    print("\n=== Per-label results ===")
    print(tabulate(rows, headers=["Technique", "P", "R", "F1"], floatfmt=".3f"))
    print(f"\n=== Micro Overall ===  P={micro[0]:.3f}  R={micro[1]:.3f}  F1={micro[2]:.3f}\n")

    return {"micro_f1": micro[2]}


class BIOCRFTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        return (outputs["loss"], outputs) if return_outputs else outputs["loss"]

bio_training_args = TrainingArguments(
    output_dir="./bio_crf_results",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    eval_strategy="epoch",
    num_train_epochs=10,
    logging_steps=50,
    save_strategy="no",
    report_to="none"
)

bio_trainer = BIOCRFTrainer(
    model=span_detector,
    args=bio_training_args,
    train_dataset=bio_tokenized["train"],
    eval_dataset=bio_tokenized["dev"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)



# ============ 9. Train BIO-CRF ============
bio_trainer.train()