In [None]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from pathlib import Path
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer
from sklearn.metrics import f1_score
from collections import defaultdict
from google.colab import drive
from textattack.augmentation import WordNetAugmenter
from typing import List, Dict

# Configuration
MODEL_CONFIGS = {
    "deberta-v3": {
        "name": "microsoft/deberta-v3-base",
        "max_length": 256,
        "batch_size": 4
    }
}

class SpanDataset(Dataset):
    def __init__(self, encodings, labels, offset_mappings):
        self.encodings = encodings
        self.labels = labels
        self.offset_mappings = offset_mappings

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx],
            'offset_mapping': self.offset_mappings[idx]
        }

class TokenFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, class_weights=None):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights

    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, reduction='none', weight=self.class_weights
        )
        pt = torch.exp(-bce_loss)
        return (self.alpha * (1 - pt)**self.gamma * bce_loss).mean()

# Initialize text augmenter
augmenter = WordNetAugmenter(pct_words_to_swap=0.1, transformations_per_example=2)

# Modified data loading with augmentation
def load_data(file_path: str, augment: bool = False) -> list:
    with open(data_dir + file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if not augment:
        return data

    augmented_data = []
    for item in data:
        # Original item
        augmented_data.append(item)

        # Augmented versions
        try:
            augmented_text = augmenter.augment(item["text"])
            for text in augmented_text:
                new_item = item.copy()
                new_item["text"] = text
                augmented_data.append(new_item)
        except:
            pass

    return augmented_data

# Mount Google Drive
drive.mount('/content/drive')
data_dir = "/content/drive/My Drive/SEMEVAL/data/"


def convert_spans_to_token_labels(text: str, spans: List[dict], offset_mapping, max_length: int, label_map: dict):
    labels = np.zeros((max_length, len(label_map)), dtype=int)

    for span in spans:
        technique = span["technique"]
        if technique not in label_map:
            continue

        class_idx = label_map[technique]
        span_start = span["start"]
        span_end = span["end"]

        # Find all overlapping tokens
        overlapping_tokens = []
        for token_idx, (start, end) in enumerate(offset_mapping):
            if token_idx >= max_length:
                break
            if max(start, span_start) < min(end, span_end):
                overlapping_tokens.append(token_idx)

        # Add context window (2 tokens before/after)
        context_tokens = set()
        for tok in overlapping_tokens:
            context_tokens.update(range(max(0, tok-2), min(len(offset_mapping)-1, tok+3)))

        for tok in context_tokens:
            if tok < max_length:
                labels[tok, class_idx] = 1

    return labels

def optimize_thresholds(probs, labels, label_map):
    best_thresholds = {}
    for class_idx in range(len(label_map)):
        y_true = labels[:, :, class_idx].flatten()
        y_probs = probs[:, :, class_idx].flatten()

        thresholds = np.linspace(0.1, 0.9, 20)
        best_f1 = 0
        best_thresh = 0.4

        for thresh in thresholds:
            y_pred = (y_probs > thresh).astype(int)
            f1 = f1_score(y_true, y_pred, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh

        best_thresholds[class_idx] = best_thresh

    return best_thresholds

def prepare_datasets(config: dict, tokenizer, data):
    processed = {'input_ids': [], 'attention_mask': [], 'labels': [], 'offset_mappings': []}
    max_length = config["max_length"]

    for item in data:
        # Tokenize with padding and truncation first
        encoding = tokenizer(
            item["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_offsets_mapping=True,
            return_tensors="pt"
        )

        # Convert spans using the same tokenization's offset mapping
        labels = convert_spans_to_token_labels(
            item["text"],
            item.get("labels", []),
            encoding["offset_mapping"][0].tolist(),
            max_length,
            label_map
        )

        processed['input_ids'].append(encoding["input_ids"][0])
        processed['attention_mask'].append(encoding["attention_mask"][0])
        processed['labels'].append(torch.FloatTensor(labels))
        processed['offset_mappings'].append(encoding["offset_mapping"][0])

    return {
        'input_ids': torch.stack(processed['input_ids']),
        'attention_mask': torch.stack(processed['attention_mask']),
        'labels': torch.stack(processed['labels']),
        'offset_mappings': processed['offset_mappings']
    }


def compute_metrics(preds, labels, offset_mappings, label_map, thresholds=None):
    if thresholds is None:
        thresholds = optimize_thresholds(preds, labels, label_map)

    # Apply thresholds to get binary predictions
    binarized_preds = []
    for batch_probs in preds:
        batch_preds = np.zeros_like(batch_probs)
        for class_idx, thresh in thresholds.items():
            batch_preds[:, :, class_idx] = (batch_probs[:, :, class_idx] > thresh).astype(int)
        binarized_preds.append(batch_preds)

    # Flatten and mask tokens for token-level micro F1
    flat_labels = []
    flat_preds = []
    for batch_labels, batch_preds, batch_offsets in zip(labels, binarized_preds, offset_mappings):
        for sample_labels, sample_preds, sample_offsets in zip(batch_labels, batch_preds, batch_offsets):
            mask = sample_offsets[:, 0].astype(bool)  # Ignore padding tokens
            flat_labels.append(sample_labels[mask])
            flat_preds.append(sample_preds[mask])
    flat_labels = np.concatenate(flat_labels)
    flat_preds = np.concatenate(flat_preds)

    # Token-level micro F1
    micro_f1 = f1_score(flat_labels, flat_preds, average='micro', zero_division=0)

    # Span-level F1 using helper function
    span_f1 = calculate_span_f1(binarized_preds, labels, offset_mappings, label_map)

    return {
        "token_micro_f1": micro_f1,
        "span_f1": span_f1,
        "thresholds": thresholds
    }


def train_model(model_config: dict, train_data, dev_data, label_map):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_config["name"], add_prefix_space=True)

    # Prepare datasets with augmented training data
    train_processed = prepare_datasets(model_config, tokenizer,
                                      load_data("training_set_task2.txt", augment=True))
    dev_processed = prepare_datasets(model_config, tokenizer, dev_data)

    model = AutoModelForTokenClassification.from_pretrained(
        model_config["name"],
        num_labels=len(label_map),
        problem_type="multi_label_classification"
    ).to(device)

    train_dataset = SpanDataset(
        {'input_ids': train_processed['input_ids'], 'attention_mask': train_processed['attention_mask']},
        train_processed['labels'],
        train_processed['offset_mappings']
    )
    dev_dataset = SpanDataset(
        {'input_ids': dev_processed['input_ids'], 'attention_mask': dev_processed['attention_mask']},
        dev_processed['labels'],
        dev_processed['offset_mappings']
    )

    train_loader = DataLoader(train_dataset, batch_size=model_config["batch_size"], shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=model_config["batch_size"], shuffle=False)

    # Class weighting
    class_counts = torch.zeros(len(label_map))
    for batch in train_loader:
        class_counts += batch['labels'].sum(dim=[0,1])
    class_weights = (1 / (class_counts + 1e-6)).to(device)

    optimizer = AdamW(model.parameters(), lr=2e-5)
    criterion = TokenFocalLoss(alpha=0.25, gamma=2, class_weights=class_weights)

    best_f1 = 0
    best_thresholds = None
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0

        for batch in train_loader:
            optimizer.zero_grad()

            # Only pass input tensors to model
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            labels = batch['labels'].to(device)

            outputs = model(**inputs)

            # Get active tokens (shape: [batch_size, seq_len])
            active_mask = batch['attention_mask'].to(device)

            # Reshape and mask
            logits = outputs.logits.view(-1, len(label_map))
            labels = labels.view(-1, len(label_map))
            active_mask = active_mask.view(-1)

            # Apply mask
            active_logits = logits[active_mask.bool()]
            active_labels = labels[active_mask.bool()]

            loss = criterion(active_logits, active_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()



        # Evaluation
        model.eval()
        all_probs, all_labels, all_offsets = [], [], []
        with torch.no_grad():
            for batch in dev_loader:
                inputs = {
                    'input_ids': batch['input_ids'].to(device),
                    'attention_mask': batch['attention_mask'].to(device)
                }

                outputs = model(**inputs)
                probs = torch.sigmoid(outputs.logits).cpu().numpy()

                all_probs.append(probs)
                all_labels.append(batch['labels'].numpy())
                all_offsets.append(batch['offset_mapping'].numpy())

        # Optimize thresholds on first epoch and every 5 epochs
        if epoch == 0 or (epoch + 1) % 5 == 0:
            best_thresholds = optimize_thresholds(
                np.concatenate(all_probs),
                np.concatenate(all_labels),
                label_map
            )

        # Apply best thresholds
        all_preds = []
        for probs in all_probs:
            batch_preds = np.zeros_like(probs)
            for class_idx, thresh in best_thresholds.items():
                batch_preds[:, :, class_idx] = (probs[:, :, class_idx] > thresh).astype(int)
            all_preds.append(batch_preds)

       # Initialize metrics with default values
        metrics = compute_metrics(
        np.concatenate(all_probs),
        np.concatenate(all_labels),
        np.concatenate(all_offsets, axis=0),
        label_map,
        best_thresholds
    )

    micro_f1 = metrics["token_micro_f1"]
    span_f1 = metrics["span_f1"]
    best_thresholds = metrics["thresholds"]

    if span_f1 > best_f1:
        best_f1 = span_f1
        model_name = model_config['name'].replace('/', '_')
        save_path = f"best_{model_name}.pth"
        Path(save_path).parent.mkdir(exist_ok=True)
        torch.save(model.state_dict(), save_path)
        print(f"Saved best model to {save_path}")

    return best_f1

def calculate_span_f1(preds, labels, offset_mappings, label_map):
    id_to_label = {v: k for k, v in label_map.items()}
    true_spans = defaultdict(set)
    pred_spans = defaultdict(set)

    # Iterate through each batch
    for batch_idx in range(len(preds)):
        batch_preds = preds[batch_idx]
        batch_labels = labels[batch_idx]
        batch_offsets = offset_mappings[batch_idx]

        # Iterate through each sample in the batch
        for sample_idx in range(len(batch_preds)):
            pred = batch_preds[sample_idx]
            label = batch_labels[sample_idx]
            offsets = batch_offsets[sample_idx]

            for class_idx in label_map.values():
                # True spans
                current_span = []
                for pos in range(len(label)):
                    # Access scalar value with [pos, class_idx]
                    if label[pos, class_idx] == 1:
                        current_span.append(pos)
                    elif current_span:
                        try:
                            start = offsets[current_span[0]][0].item()
                            end = offsets[current_span[-1]][1].item()
                            true_spans[class_idx].add((batch_idx, sample_idx, start, end))
                        except:
                            pass
                        current_span = []
                if current_span:  # Handle last span
                    try:
                        start = offsets[current_span[0]][0].item()
                        end = offsets[current_span[-1]][1].item()
                        true_spans[class_idx].add((batch_idx, sample_idx, start, end))
                    except:
                        pass

                # Predicted spans
                current_span = []
                for pos in range(len(pred)):
                    if pred[pos, class_idx] == 1:
                        current_span.append(pos)
                    elif current_span:
                        try:
                            start = offsets[current_span[0]][0].item()
                            end = offsets[current_span[-1]][1].item()
                            pred_spans[class_idx].add((batch_idx, sample_idx, start, end))
                        except:
                            pass
                        current_span = []
                if current_span:  # Handle last span
                    try:
                        start = offsets[current_span[0]][0].item()
                        end = offsets[current_span[-1]][1].item()
                        pred_spans[class_idx].add((batch_idx, sample_idx, start, end))
                    except:
                        pass

    # Calculate F1 per class
    f1_scores = []
    for class_idx in label_map.values():
        tp = len(true_spans[class_idx] & pred_spans[class_idx])
        fp = len(pred_spans[class_idx] - true_spans[class_idx])
        fn = len(true_spans[class_idx] - pred_spans[class_idx])

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        f1_scores.append(f1)

    return np.mean(f1_scores)
# Execution
NUM_EPOCHS = 10
MODEL_TO_TRAIN = "deberta-v3"

train_data = load_data("training_set_task2.txt")
dev_data = load_data("dev_set_task2.txt")

all_techniques = sorted({
    label["technique"] for item in train_data + dev_data
    for label in item.get("labels", [])
})
label_map = {tech: idx for idx, tech in enumerate(all_techniques)}

print(f"Training {MODEL_TO_TRAIN}...")
best_f1 = train_model(MODEL_CONFIGS[MODEL_TO_TRAIN], train_data, dev_data, label_map)
print(f"\nBest Span F1: {best_f1:.4f}")