In [None]:
import torch
import numpy as np
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DistilBertForSequenceClassification
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix

class WeightedLossModel(DistilBertForSequenceClassification):
    """
    Custom DistilBERT model with weighted loss for handling class imbalance.
    """
    def __init__(self, pretrained_model_name_or_path, num_labels, class_weights):
        """
        Initialize the weighted loss model.

        Args:
            pretrained_model_name_or_path (str): Pretrained model name or path
            num_labels (int): Number of output labels
            class_weights (tensor): Class weights for loss function
        """
        super().__init__.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Compute weighted loss during training.

        Args:
            model: The model being trained
            inputs: Model inputs
            return_outputs (bool): Whether to return model outputs

        Returns:
            tensor or tuple: Loss tensor or (loss, outputs) tuple
        """
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")

        loss_fct = torch.nn.CrossEntropyLoss(
            weight=self.class_weights.to(logits.device)
        )
        loss = loss_fct(logits, labels)

        return (loss, outputs) if return_outputs else loss

class PoliticalBiasTransformer:
    """
    Transformer-based model for political bias classification using DistilBERT.
    """

    def __init__(self, model_name="distilbert-base-uncased", num_labels=3, max_length=512):
        """
        Initialize the transformer model.

        Args:
            model_name (str): Pretrained model name
            num_labels (int): Number of output labels
            max_length (int): Maximum sequence length for tokenizer
        """
        self.model_name = model_name
        self.num_labels = num_labels
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = None
        self.trainer = None
        self.training_args = None

    def tokenize_function(self, examples):
        """
        Tokenize input texts for transformer model.

        Args:
            examples (dict): Dictionary containing text examples

        Returns:
            dict: Tokenized examples
        """
        return self.tokenizer(
            examples["content"],
            truncation=True,
            padding='max_length',
            max_length=self.max_length
        )

    def encode_labels(self, example):
        """
        Encode label field for the transformer model.

        Args:
            example (dict): Example dictionary

        Returns:
            dict: Example with encoded labels
        """
        example["label"] = example["bias"]
        return example

    def prepare_dataset(self, dataset):
        """
        Prepare a dataset for transformer model training.

        Args:
            dataset: Input dataset

        Returns:
            dataset: Processed dataset ready for training
        """
        # Tokenize the text
        tokenized_dataset = dataset.map(self.tokenize_function, batched=True)

        # Encode labels
        return tokenized_dataset.map(self.encode_labels)

    def train(self, train_dataset, eval_dataset=None, output_dir="./bert_results",
              batch_size=16, learning_rate=2e-5, num_epochs=3, weight_decay=0.01):
        """
        Train the transformer model.

        Args:
            train_dataset: Training dataset
            eval_dataset: Evaluation dataset
            output_dir (str): Directory to save model outputs
            batch_size (int): Training batch size
            learning_rate (float): Learning rate
            num_epochs (int): Number of training epochs
            weight_decay (float): Weight decay for regularization

        Returns:
            self: Trained model instance
        """
        # Prepare datasets
        train_data = self.prepare_dataset(train_dataset)
        eval_data = None if eval_dataset is None else self.prepare_dataset(eval_dataset)

        # Compute class weights
        labels = train_data['label']
        class_weights = compute_class_weight(
            class_weight='balanced',
            classes=np.unique(labels),
            y=labels
        )
        weights_tensor = torch.tensor(class_weights, dtype=torch.float)

        # Initialize the model with class weights
        self.model = WeightedLossModel(
            pretrained_model_name_or_path=self.model_name,
            num_labels=self.num_labels,
            class_weights=weights_tensor
        )

        # Training arguments
        self.training_args = TrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            logging_steps=100,
            save_steps=1000,
            eval_steps=500,
            seed=42
        )

        # Initialize trainer
        self.trainer = Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=train_data,
            eval_dataset=eval_data,
            tokenizer=self.tokenizer,
        )

        # Train the model
        self.trainer.train()

        return self

    def evaluate(self, test_dataset, target_names=["Left", "Center", "Right"]):
        """
        Evaluate the model and print classification report.

        Args:
            test_dataset: Test dataset
            target_names (list): Names of target classes

        Returns:
            dict: Dictionary containing evaluation metrics
        """
        if self.trainer is None:
            raise ValueError("Model has not been trained yet. Call train() first.")

        test_data = self.prepare_dataset(test_dataset)
        predictions = self.trainer.predict(test_data)

        # Get predictions and true labels
        preds = predictions.predictions.argmax(-1)
        labels = predictions.label_ids

        # Generate report
        report = classification_report(labels, preds, target_names=target_names, output_dict=True)
        print(f"📊 DistilBERT Results:")
        print(classification_report(labels, preds, target_names=target_names))

        return {
            'preds': preds,
            'labels': labels,
            'probabilities': predictions.predictions,  # Raw logits, not probabilities
            'confusion_matrix': confusion_matrix(labels, preds),
            'report': report,
            'predictions': predictions,  # Full prediction object
            'accuracy': report['accuracy'],
            'f1': report['macro avg']['f1-score'],
            'precision': report['macro avg']['precision'],
            'recall': report['macro avg']['recall']
        }

    def save_model(self, output_dir):
        """
        Save the trained model and tokenizer.

        Args:
            output_dir (str): Directory to save model

        Returns:
            str: Path to saved model
        """
        if self.model is None:
            raise ValueError("No model to save. Train the model first.")

        self.model.save_pretrained(output_dir)
        self.tokenizer.save_pretrained(output_dir)

        return output_dir

    @classmethod
    def load_model(cls, model_dir, num_labels=3, max_length=512):
        """
        Load a saved model and tokenizer.

        Args:
            model_dir (str): Directory containing the saved model
            num_labels (int): Number of output labels
            max_length (int): Maximum sequence length for tokenizer

        Returns:
            PoliticalBiasTransformer: Loaded model instance
        """
        # Create instance
        instance = cls(num_labels=num_labels, max_length=max_length)

        # Load tokenizer
        instance.tokenizer = AutoTokenizer.from_pretrained(model_dir)

        # Load model without class weights (they're not needed for inference)
        instance.model = DistilBertForSequenceClassification.from_pretrained(model_dir)

        return instance

    def predict(self, texts):
        """
        Make predictions on new texts.

        Args:
            texts (list): List of text strings to classify

        Returns:
            tuple: (predictions, probabilities) where predictions are class indices
                  and probabilities are softmax of logits
        """
        if self.model is None:
            raise ValueError("No model available. Train or load a model first.")

        # Tokenize texts
        inputs = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Move inputs to the same device as model
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Make predictions
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Get predicted classes and probabilities
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        preds = torch.argmax(logits, dim=-1)

        return preds.cpu().numpy(), probs.cpu().numpy()

    def convert_logits_to_probs(self, logits):
        """
        Convert model logits to probabilities.

        Args:
            logits (array): Raw logits from model

        Returns:
            array: Probability distributions
        """
        # Convert to PyTorch tensor if it's a numpy array
        if isinstance(logits, np.ndarray):
            logits_tensor = torch.tensor(logits)
        else:
            logits_tensor = logits

        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(logits_tensor, dim=-1)

        # Convert back to numpy if input was numpy
        if isinstance(logits, np.ndarray):
            return probs.numpy()
        else:
            return probs