In [None]:
from huggingface_hub import notebook_login
notebook_login()  # This will prompt you to enter the token

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import json
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments
)
from peft import LoraConfig, get_peft_model
from sklearn.metrics import precision_score, recall_score, f1_score
from collections import defaultdict

# Configuration with 4-bit quantization
MODEL_CONFIGS = {
    "mistral-7b": {
        "name": "mistralai/Mistral-7B-Instruct-v0.3",
        "max_length": 512,  # Use longer context
        "batch_size": 2,    # Reduced for VRAM
        "lora_r": 16,       # LoRA parameters
        "lora_alpha": 32
    }
}

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# LoRA configuration
lora_config = LoraConfig(
    r=MODEL_CONFIGS["mistral-7b"]["lora_r"],
    lora_alpha=MODEL_CONFIGS["mistral-7b"]["lora_alpha"],
    target_modules=["q_proj", "v_proj"],
    task_type="TOKEN_CLS",
    inference_mode=False
)

class MistralSpanDataset(Dataset):
    def __init__(self, encodings, labels):
        self.input_ids = encodings["input_ids"]
        self.attention_mask = encodings["attention_mask"]
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

def load_model_and_tokenizer(config):
    tokenizer = AutoTokenizer.from_pretrained(
        config["name"],
        padding_side="right",
        add_prefix_space=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForTokenClassification.from_pretrained(
        config["name"],
        num_labels=len(label_map),
        quantization_config=bnb_config,
        device_map="auto",
        problem_type="multi_label_classification"
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model, tokenizer

def mistral_collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])

    }

def convert_spans_to_token_labels(text, spans, offset_mapping, max_length, label_map):
    """
    Converts character-level spans to token-level multi-label classification tags.
    Returns a tensor of shape [max_length, num_labels] with binary labels.
    """
    num_labels = len(label_map)
    labels = np.zeros((max_length, num_labels), dtype=np.float32)

    for span in spans:
        technique = span["technique"]
        start_char = span["start"]
        end_char = span["end"]
        label_idx = label_map[technique]

        # Find tokens overlapping with the span
        for token_idx, (token_start, token_end) in enumerate(offset_mapping):
            if token_idx >= max_length:
                break  # Skip truncated tokens

            # Check token overlap with span
            if (token_start < end_char) and (token_end > start_char):
                labels[token_idx, label_idx] = 1.0

    return labels

def optimize_thresholds(probs, labels, label_map):
    """
    Finds the optimal threshold for each class by maximizing F1 score.
    Returns a dictionary mapping class index -> optimal threshold.
    """
    best_thresholds = {}
    for i in range(len(label_map)):
        best_f1 = 0.0
        best_thresh = 0.5
        for thresh in np.arange(0.1, 0.9, 0.05):
            pred = (probs[:, :, i] > thresh).astype(int)
            f1 = f1_score(labels[:, :, i].flatten(), pred.flatten(), zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh
        best_thresholds[i] = best_thresh
    return best_thresholds


def token_preds_to_spans(preds, offset_mapping, label_map, text):
    """
    Converts token-level predictions to character spans.
    Returns a list of {"technique": str, "start": int, "end": int}.
    This improved version maintains separate active spans for each technique.
    """
    spans = []
    # Dictionary to keep current active span for each technique
    current_spans = {}  # technique -> {"start": int, "end": int}
    reversed_label_map = {v: k for k, v in label_map.items()}

    # Iterate over each token with its offset mapping
    for token_idx, (token_start, token_end) in enumerate(offset_mapping):
        # Skip special tokens (e.g. [CLS], [SEP])
        if token_start == 0 and token_end == 0:
            continue

        # Determine which techniques are predicted for this token
        pred_class_indices = np.where(preds[token_idx] == 1)[0]
        predicted_techniques = [reversed_label_map[c] for c in pred_class_indices]

        # Flush active spans for techniques not predicted in the current token.
        techniques_to_flush = []
        for technique in current_spans:
            if technique not in predicted_techniques:
                spans.append({
                    "technique": technique,
                    "start": current_spans[technique]["start"],
                    "end": current_spans[technique]["end"]
                })
                techniques_to_flush.append(technique)
        for technique in techniques_to_flush:
            del current_spans[technique]

        # For each technique predicted in the current token, update or start a span.
        for technique in predicted_techniques:
            if technique in current_spans:
                # If token_start exactly matches the end of the current span, extend it.
                if token_start == current_spans[technique]["end"]:
                    current_spans[technique]["end"] = token_end
                else:
                    # Otherwise, flush the current span and start a new one.
                    spans.append({
                        "technique": technique,
                        "start": current_spans[technique]["start"],
                        "end": current_spans[technique]["end"]
                    })
                    current_spans[technique] = {"start": token_start, "end": token_end}
            else:
                # Start a new span for this technique.
                current_spans[technique] = {"start": token_start, "end": token_end}

    # Flush any remaining active spans
    for technique, span in current_spans.items():
        spans.append({
            "technique": technique,
            "start": span["start"],
            "end": span["end"]
        })

    # Merge overlapping spans per technique (if needed)
    spans = merge_overlapping_spans(spans)
    return spans

def merge_overlapping_spans(spans):
    """Merge overlapping spans of the same technique"""
    if not spans:
        return []

    # Sort by start position
    spans.sort(key=lambda x: x["start"])
    merged = [spans[0]]

    for current in spans[1:]:
        last = merged[-1]
        if (current["technique"] == last["technique"] and
            current["start"] <= last["end"]):
            # Merge overlapping spans
            last["end"] = max(last["end"], current["end"])
        else:
            merged.append(current)
    return merged

def predict_spans(model, tokenizer, text, label_map, thresholds, max_length=512):  # Add thresholds as argument
    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    offset_mapping = encoding["offset_mapping"][0].tolist()

    # Run model inference
    with torch.no_grad():
        outputs = model(**encoding.to(model.device))
        logits = outputs.logits
        probs = torch.sigmoid(logits).cpu().numpy()

    # Apply thresholds passed as argument
    preds = np.zeros_like(probs)
    for class_idx, thresh in thresholds.items():
        preds[0, :, class_idx] = (probs[0, :, class_idx] > thresh).astype(int)

    spans = token_preds_to_spans(preds, offset_mapping, label_map, text)
    return spans

def span_f1(true_spans, pred_spans):
    # Convert spans to sets of tuples (technique, start, end)
    true_set = set((s["technique"], s["start"], s["end"]) for s in true_spans)
    pred_set = set((s["technique"], s["start"], s["end"]) for s in pred_spans)

    tp = len(true_set & pred_set)
    fp = len(pred_set - true_set)
    fn = len(true_set - pred_set)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return {"span_precision": precision, "span_recall": recall, "span_f1": f1}

def train_model(model_config, train_data, dev_data, label_map):
    # Load model with 4-bit quantization
    model, tokenizer = load_model_and_tokenizer(model_config)

    # Prepare datasets
    train_processed = prepare_datasets(model_config, tokenizer, train_data)
    dev_processed = prepare_datasets(model_config, tokenizer, dev_data)

    train_dataset = MistralSpanDataset(
    {
        "input_ids": train_processed['encodings']['input_ids'],
        "attention_mask": train_processed['encodings']['attention_mask']
    },
    labels=train_processed['labels']
)

    # Similarly for dev_dataset:
    dev_dataset = MistralSpanDataset(
        {
            "input_ids": dev_processed['encodings']['input_ids'],
            "attention_mask": dev_processed['encodings']['attention_mask']
        },
        labels=dev_processed['labels']
    )

    train_dataset_length = len(train_dataset)
    gradient_accumulation = 4  # Match your gradient_accumulation_steps
    max_steps = (train_dataset_length // (model_config["batch_size"] * gradient_accumulation)) * NUM_EPOCHS

    # Optimizer with gradient checkpointing
    training_args = TrainingArguments(
        output_dir="./results",
        report_to="none",
        num_train_epochs=NUM_EPOCHS,
        max_steps=max_steps,
        per_device_train_batch_size=model_config["batch_size"],
        gradient_accumulation_steps=4,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        gradient_checkpointing=True,
        logging_steps=10,
        learning_rate=1e-4,
        weight_decay=0.01,
        optim="paged_adamw_8bit",
        bf16=True,
        fp16=False
    )

    class CustomTrainer(Trainer):
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            labels = inputs.pop("labels")
            outputs = model(**inputs)
            logits = outputs.logits

            # Masked loss calculation
            active_loss = inputs["attention_mask"].view(-1) == 1
            active_logits = logits.view(-1, len(label_map))[active_loss]
            active_labels = labels.view(-1, len(label_map))[active_loss]

            loss = nn.BCEWithLogitsLoss()(active_logits, active_labels)
            return (loss, outputs) if return_outputs else loss

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        data_collator=mistral_collate_fn,
        compute_metrics=lambda eval_pred: compute_metrics(eval_pred, label_map)
    )



    trainer.train()
    eval_results = trainer.evaluate()
    thresholds = eval_results["thresholds"]

    return model, thresholds

def compute_metrics(eval_pred, label_map):
    probs, labels = eval_pred
    probs = torch.sigmoid(torch.tensor(probs)).numpy()
    labels = labels.astype(np.float32)

    best_thresholds = optimize_thresholds(probs, labels, label_map)

    # Apply thresholds
    preds = np.zeros_like(probs)
    for class_idx, thresh in best_thresholds.items():
        preds[:, :, class_idx] = (probs[:, :, class_idx] > thresh).astype(int)

    # Flatten predictions and labels for metric calculation
    flat_labels = labels.reshape(-1, len(label_map))
    flat_preds = preds.reshape(-1, len(label_map))

    micro_precision = precision_score(flat_labels, flat_preds, average="micro", zero_division=0)
    micro_recall = recall_score(flat_labels, flat_preds, average="micro", zero_division=0)
    micro_f1 = f1_score(flat_labels, flat_preds, average="micro", zero_division=0)

    return {
        "micro_precision": micro_precision,
        "micro_recall": micro_recall,
        "micro_f1": micro_f1,
        "thresholds": best_thresholds
    }

def load_data(file_path: str) -> list:
    with open(data_dir + file_path, "r", encoding="utf-8") as f:
        return json.load(f)

def prepare_datasets(config: dict, tokenizer, data):
    processed = {
        "encodings": {"input_ids": [], "attention_mask": []},
        "labels": [],
        "offset_mappings": []
    }
    max_length = config["max_length"]

    for item in data:
        encoding = tokenizer(
            item["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_offsets_mapping=True,
            return_tensors="pt"
        )

        labels = convert_spans_to_token_labels(
            item["text"],
            item.get("labels", []),
            encoding["offset_mapping"][0].tolist(),
            max_length,
            label_map
        )

        # Store input_ids and attention_mask under "encodings"
        processed["encodings"]["input_ids"].append(encoding["input_ids"][0])
        processed["encodings"]["attention_mask"].append(encoding["attention_mask"][0])
        processed["labels"].append(torch.FloatTensor(labels))
        processed["offset_mappings"].append(encoding["offset_mapping"][0])

    # Stack encodings
    processed["encodings"]["input_ids"] = torch.stack(processed["encodings"]["input_ids"])
    processed["encodings"]["attention_mask"] = torch.stack(processed["encodings"]["attention_mask"])

    return {
        "encodings": processed["encodings"],
        "labels": torch.stack(processed["labels"]),
        "offset_mappings": processed["offset_mappings"]  # Plural
    }

# Add data loading and label map creation BEFORE training
from google.colab import drive
drive.mount('/content/drive')
data_dir = "/content/drive/My Drive/SEMEVAL/data/"

# Load datasets
train_data = load_data("training_set_task2.txt")
dev_data = load_data("dev_set_task2.txt")

# Create label map
all_techniques = sorted({
    label["technique"].strip().lower() for item in train_data + dev_data
    for label in item.get("labels", [])
})
label_map = {tech: idx for idx, tech in enumerate(all_techniques)}





# Now run the training
NUM_EPOCHS = 10
MODEL_TO_TRAIN = "mistral-7b"

print(f"Training {MODEL_TO_TRAIN}...")
best_model, thresholds = train_model(
    MODEL_CONFIGS[MODEL_TO_TRAIN],
    train_data,
    dev_data,
    label_map
)