In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AdamW
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
import os
from typing import Dict

# Configuration for models
MODEL_CONFIGS = {
    "mistral-7b": {
        "name": "mistralai/Mistral-7B-v0.1",
        "quantize": True,
        "max_length": 1024,
        "batch_size": 2
    }
}

class MistralForMultiLabelClassification(nn.Module):
    def __init__(self, model_name, num_labels, quantize_config=None):
        super().__init__()
        # Load base model
        self.model = AutoModelForCausalLM.from_pretrained(  # FIXED THIS LINE
            model_name,
            quantization_config=quantize_config,
            device_map="auto",
            token=os.environ.get("HF_TOKEN")
        )  # ADDED MISSING PARENTHESIS HERE

        # Get device from base model
        self.device = next(self.model.parameters()).device

        # Add classification head on the same device
        self.classifier = nn.Linear(
            self.model.config.hidden_size,
            num_labels
        ).to(self.device)

        if quantize_config:
            self.model = prepare_model_for_kbit_training(self.model)

        # Add LoRA adapters
        peft_config = LoraConfig(
            r=8,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )
        self.model = get_peft_model(self.model, peft_config)

    def forward(self, input_ids, attention_mask):
        # Move inputs to model's device
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states[-1]
        pooled = hidden_states[:, 0]
        return self.classifier(pooled)

class PropagandaDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings['input_ids'][idx],
            'attention_mask': self.encodings['attention_mask'][idx],
            'labels': self.labels[idx]
        }

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return loss.mean()



from google.colab import drive
drive.mount('/content/drive')

data_dir = "/content/drive/My Drive/SEMEVAL/data/"

# Load data
def load_data(file_path: str) -> list:
    full_path = data_dir + file_path  # <-- Use data_dir to construct the full path
    with open(full_path, "r", encoding="utf-8") as f:
        return json.load(f)

def prepare_datasets(config: Dict, tokenizer, train_texts, train_labels, dev_texts, dev_labels):
    def tokenize(texts):
        return tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=config["max_length"],
            return_tensors="pt"
        )

    train_encodings = tokenize(train_texts)
    dev_encodings = tokenize(dev_texts)

    return (
        PropagandaDataset(train_encodings, train_labels),
        PropagandaDataset(dev_encodings, dev_labels)
    )

def train_model(model_config, train_data, dev_data, all_labels):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_config["name"],
        use_fast=True,
        padding_side="left",
        token=os.environ.get("HF_TOKEN")
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load model
    if model_config["quantize"]:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        model = MistralForMultiLabelClassification(
            model_config["name"],
            num_labels=len(all_labels),
            quantize_config=bnb_config
        )
    else:
        model = MistralForMultiLabelClassification(
            model_config["name"],
            num_labels=len(all_labels)
        ).to(device)

    # Prepare data
    train_dataset, dev_dataset = prepare_datasets(
        model_config, tokenizer, *train_data, *dev_data
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=model_config["batch_size"],
        shuffle=True
    )

    dev_loader = DataLoader(
        dev_dataset,
        batch_size=model_config["batch_size"],
        shuffle=False
    )

    # Initialize training components
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
    criterion = FocalLoss(alpha=1.0, gamma=2.0)

    # Training loop
    best_f1 = 0
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()

            # Move inputs to model's device
            inputs = {
                'input_ids': batch['input_ids'].to(model.device),
                'attention_mask': batch['attention_mask'].to(model.device)
            }
            labels = batch['labels'].to(model.device)

            outputs = model(**inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluation
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in dev_loader:
                inputs = {
                    'input_ids': batch['input_ids'].to(model.device),
                    'attention_mask': batch['attention_mask'].to(model.device)
                }
                labels = batch['labels'].cpu().numpy()

                outputs = model(**inputs)
                preds = torch.sigmoid(outputs).cpu().numpy()
                all_preds.append(preds)
                all_labels.append(labels)

        # Calculate metrics
        preds = np.concatenate(all_preds)
        labels = np.concatenate(all_labels)

        # Threshold optimization
        best_thresh = 0.5
        best_f1 = 0
        thresholds = np.linspace(0.3, 0.7, 20)

        for thresh in thresholds:
            current_f1 = f1_score(labels, (preds > thresh).astype(int),
                                average='micro', zero_division=0)
            if current_f1 > best_f1:
                best_f1 = current_f1
                best_thresh = thresh

        preds_binary = (preds > best_thresh).astype(int)
        f1_micro = f1_score(labels, preds_binary, average='micro')
        f1_macro = f1_score(labels, preds_binary, average='macro')
        precision = precision_score(labels, preds_binary, average='micro')
        recall = recall_score(labels, preds_binary, average='micro')

        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}")
        print(f"  F1-Micro: {f1_micro:.4f} | F1-Macro: {f1_macro:.4f}")
        print(f"  Precision: {precision:.4f} | Recall: {recall:.4f}")

        if f1_micro > best_f1:
            best_f1 = f1_micro
            torch.save(model.state_dict(), f"best_{model_config['name'].replace('/', '_')}.pth")

    return best_f1

# Main execution
from huggingface_hub import notebook_login
notebook_login()

NUM_EPOCHS = 15
MODEL_TO_TRAIN = "mistral-7b"

# Load and prepare data
train_data = load_data("training_set_task1.txt")
dev_data = load_data("dev_set_task1.txt")

train_texts = [item["text"] for item in train_data]
train_labels = [item["labels"] for item in train_data]
dev_texts = [item["text"] for item in dev_data]
dev_labels = [item["labels"] for item in dev_data]

all_labels = sorted({label for labels in train_labels + dev_labels for label in labels})
mlb = MultiLabelBinarizer(classes=all_labels)
train_labels_enc = mlb.fit_transform(train_labels)
dev_labels_enc = mlb.transform(dev_labels)

# Run training
print(f"\n=== Training {MODEL_TO_TRAIN} ===")
score = train_model(MODEL_CONFIGS[MODEL_TO_TRAIN],
                   (train_texts, train_labels_enc),
                   (dev_texts, dev_labels_enc),
                   all_labels)
print(f"\nFinal {MODEL_TO_TRAIN} F1: {score:.4f}")