In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/NLP

/content/drive/MyDrive/NLP


In [None]:
import pandas as pd

from sklearn.metrics import recall_score, precision_score, f1_score
from sklearn.model_selection import train_test_split, StratifiedKFold
import numpy as np

import torch

import matplotlib.pyplot as plt
import seaborn as sns

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    TrainerCallback,
    EarlyStoppingCallback
)
from torch.utils.data import Dataset

import os
import zipfile
import json


In [None]:
import wandb

# Disable wandb logging for this script
wandb.init(mode="disabled")

  | |_| | '_ \/ _` / _` |  _/ -_)


In [None]:
my_pal = ['1f77b4', 'ff8c1a', '2ca02c', 'd62728', '9467bd', 'c5b300', 'e377c2', '17becf']
my_pal = [f"#{c}" for c in my_pal]

sns.reset_defaults() # useful when adjusting style a lot
plt.rcParams['font.family']=['sans-serif']
sns.set_theme(context="paper", style="ticks",
              # palette="Set2",
              palette=my_pal,
              rc={
              "pdf.fonttype": 42,  # embed font in output
              "svg.fonttype": "none",  # embed font in output
              "figure.facecolor": "white",
              "figure.dpi": 150,
              "axes.facecolor": "None",
              "axes.spines.left": True,
              "axes.spines.bottom": True,
              "axes.spines.right": False,
              "axes.spines.top": False,
          },
          )

In [None]:
train = pd.read_csv('subtask3/train/swa.csv')
test = pd.read_csv('subtask3/dev/swa.csv')
# fix this for this subtask
train.head()
train['stratify_key'] = train[['polarization']].apply(
    lambda row: '_'.join(row.astype(str)), axis=1
)

train, val = train_test_split(
    train,
    test_size=0.2
)

train.head()

Unnamed: 0,id,text,political,racial/ethnic,religious,gender/sexual,other,stratify_key
2356,swa_a0a8a7d3fadf4d25b8e732b78dd1f4f3,baba tosha pongeziazimio uchaguzi mkuu Kenya,0,0,0,0,0,0_0_0_0_0
2891,swa_c2782f9e935e68e486062ffa71919077,matunda ya punda wa ghetto yalikuja kwenye bou...,0,0,0,0,0,0_0_0_0_0
2833,swa_ac33ea5f12a175f4e6bbf0333a610d06,kwa hivyo nyani huyu anaweza kupata mcm lakini...,0,0,0,0,0,0_0_0_0_0
67,swa_806334758786daada0bef3b6406e28c7,wee kuna rais wa marekani mhaya oscar award wi...,0,1,0,0,0,0_0_0_1_0
1235,swa_c11f130a0b352280bbfdaa336256879c,bubu bubu achana nami,0,0,0,0,1,0_0_0_0_1


# new large


In [None]:
import os
import zipfile
import json
# ============================================
# Configuration
# ============================================
CONFIG = {
    'english': {
        'model_name': 'cardiffnlp/twitter-roberta-base-hate-multiclass-latest',
        'train_file': 'subtask1/train/swa.csv',
        'test_file': 'subtask1/dev/swa.csv',
        'output_name': 'eng'
    },
    'swahili': {
        'model_name': 'metabloit/swahBERT',
        'train_file': 'subtask1/train/swa.csv',
        'test_file': 'subtask1/dev/swa.csv',
        'output_name': 'swa'
    }
}
# cardiffnlp/twitter-roberta-base-hate-multiclass-latest
# metabloit/swahBERT
# Tadesse/AfroXLMR-Social
# Davlan/afro-xlmr-base
# FacebookAI/xlm-roberta-base
# castorini/afriteva_base
# microsoft/deberta-v3-base

# test for english
# cardiffnlp/twitter-roberta-base-2022-154m
LABEL_COLUMNS = ["polarization"]

# Select language
LANGUAGE = 'swahili'
config = CONFIG[LANGUAGE]


In [None]:
class MultiLabelDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding=False,
            max_length=self.max_length,
            return_tensors='pt'
        )

        item = {key: encoding[key].squeeze() for key in encoding.keys()}
        item['labels'] = torch.tensor(label, dtype=torch.float)
        return item

class MultiLabelTestDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx]) if pd.notna(self.texts[idx]) else ""
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding=False,
            max_length=self.max_length,
            return_tensors='pt'
        )

        item = {key: encoding[key].squeeze() for key in encoding.keys()}
        return item

# ============================================
# Custom Trainer with Combined Metric
# ============================================
class CombinedMetricTrainer(Trainer):
    """Custom Trainer that optimizes for both F1 score and loss reduction"""
    def __init__(self, *args, f1_weight=0.7, loss_weight=0.3, **kwargs):
        super().__init__(*args, **kwargs)
        self.f1_weight = f1_weight
        self.loss_weight = loss_weight
        self.train_losses = []

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Track training loss"""
        outputs = model(**inputs)
        loss = outputs.loss
        self.train_losses.append(loss.item())
        return (loss, outputs) if return_outputs else loss

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Add combined score metric"""
        metrics = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)

        f1 = metrics.get(f'{metric_key_prefix}_f1_macro', 0)
        loss = metrics.get(f'{metric_key_prefix}_loss', 0)

        normalized_loss = min(loss / 2.0, 1.0)
        combined_score = (self.f1_weight * f1) - (self.loss_weight * normalized_loss)
        metrics[f'{metric_key_prefix}_combined_score'] = combined_score

        if self.train_losses:
            avg_train_loss = np.mean(self.train_losses[-100:])
            metrics[f'{metric_key_prefix}_train_loss_recent'] = avg_train_loss

        return metrics

# ============================================
# Enhanced Monitoring Callback
# ============================================
class EnhancedMonitoringCallback(TrainerCallback):
    """Callback to monitor and log F1, train loss, and eval loss"""
    def __init__(self):
        self.best_f1 = -float('inf')
        self.best_eval_loss = float('inf')
        self.best_train_loss = float('inf')
        self.best_combined = -float('inf')
        self.history = {
            'epoch': [],
            'train_loss': [],
            'eval_loss': [],
            'eval_f1': [],
            'combined_score': []
        }

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and 'loss' in logs:
            if logs['loss'] < self.best_train_loss:
                self.best_train_loss = logs['loss']

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None:
            epoch = state.epoch
            train_loss = metrics.get('train_loss', state.log_history[-1].get('loss', 0) if state.log_history else 0)
            eval_loss = metrics.get('eval_loss', 0)
            eval_f1 = metrics.get('eval_f1_macro', 0)
            combined_score = metrics.get('eval_combined_score', 0)

            self.history['epoch'].append(epoch)
            self.history['train_loss'].append(train_loss)
            self.history['eval_loss'].append(eval_loss)
            self.history['eval_f1'].append(eval_f1)
            self.history['combined_score'].append(combined_score)

            if eval_f1 > self.best_f1:
                self.best_f1 = eval_f1
            if eval_loss < self.best_eval_loss:
                self.best_eval_loss = eval_loss
            if combined_score > self.best_combined:
                self.best_combined = combined_score

            print(f"\n{'='*70}")
            print(f"Epoch {epoch:.1f} Metrics:")
            print(f"{'='*70}")
            print(f"  F1 Score:        {eval_f1:.4f} (Best: {self.best_f1:.4f})")
            print(f"  Eval Loss:       {eval_loss:.4f} (Best: {self.best_eval_loss:.4f})")
            print(f"  Train Loss:      {train_loss:.4f} (Best: {self.best_train_loss:.4f})")
            print(f"  Combined Score:  {combined_score:.4f} (Best: {self.best_combined:.4f})")

            if eval_f1 == self.best_f1:
                print(f"  ✓ NEW BEST F1!")
            if eval_loss == self.best_eval_loss:
                print(f"  ✓ NEW BEST EVAL LOSS!")
            if combined_score == self.best_combined:
                print(f"  ✓ NEW BEST COMBINED SCORE!")
            print(f"{'='*70}\n")

# ============================================
# Early Stopping with Combined Metric
# ============================================
class CombinedEarlyStoppingCallback(TrainerCallback):
    """Early stopping based on combined F1 and loss metric"""
    def __init__(self, patience=5):
        self.patience = patience
        self.best_combined_score = -float('inf')
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics is not None:
            combined_score = metrics.get('eval_combined_score', -float('inf'))

            if combined_score > self.best_combined_score:
                self.best_combined_score = combined_score
                self.patience_counter = 0
            else:
                self.patience_counter += 1

            if self.patience_counter >= self.patience:
                print(f"\n⚠ Early stopping triggered after {self.patience} epochs without improvement")
                print(f"Best combined score: {self.best_combined_score:.4f}")
                control.should_training_stop = True

# ============================================
# Compute Metrics for Multi-Label
# ============================================
def compute_metrics_multilabel(p):
    """Compute macro F1 for multi-label classification"""
    probs = torch.sigmoid(torch.from_numpy(p.predictions))
    preds = (probs > 0.5).int().numpy()
    return {'f1_macro': f1_score(p.label_ids, preds, average='macro', zero_division=0)}

# ============================================
# Create Stratification Key for Multi-Label
# ============================================
def create_stratify_key(df, label_columns):
    """Create a stratification key for multi-label data"""
    return df[label_columns].apply(lambda row: '_'.join(row.astype(str)), axis=1)

# ============================================
# K-Fold Cross-Validation with Enhanced Training
# ============================================
def train_kfold_cv(train_df, model_name, label_columns, n_splits=5, seed=42):
    """Train model using K-Fold Cross-Validation for multi-label classification"""

    # Create stratification key
    train_df['stratify_key'] = create_stratify_key(train_df, label_columns)

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    fold_models = []
    fold_scores = []
    fold_metrics = []

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    print(f"\n{'='*70}")
    print(f"Starting {n_splits}-Fold Cross-Validation")
    print(f"Model: {model_name}")
    print(f"Total samples: {len(train_df)}")
    print(f"Task: Multi-Label Classification ({len(label_columns)} labels)")
    print(f"Optimization: Maximize F1 + Minimize Loss")
    print(f"{'='*70}\n")

    for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['stratify_key'])):
        print(f"\n{'='*70}")
        print(f"FOLD {fold + 1}/{n_splits}")
        print(f"{'='*70}")

        train_fold = train_df.iloc[train_idx]
        val_fold = train_df.iloc[val_idx]

        print(f"Train size: {len(train_fold)}, Val size: {len(val_fold)}")

        # Create datasets
        train_dataset = MultiLabelDataset(
            train_fold['text'].tolist(),
            train_fold[label_columns].values.tolist(),
            tokenizer
        )
        val_dataset = MultiLabelDataset(
            val_fold['text'].tolist(),
            val_fold[label_columns].values.tolist(),
            tokenizer
        )

        # Load config and set dropout parameters (BASIC - Fast Training)
        model_config = AutoConfig.from_pretrained(model_name)
        model_config.hidden_dropout_prob = 0.1  # Lower for faster training
        model_config.attention_probs_dropout_prob = 0.1
        model_config.classifier_dropout = 0.1
        model_config.num_labels = len(label_columns)
        model_config.problem_type = "multi_label_classification"

        # Initialize model with dropout
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=model_config,
            ignore_mismatched_sizes=True
        )

        # BASIC training arguments - Fast and Simple
        training_args = TrainingArguments(
            output_dir=f"/content/outputs/fold_{fold+1}",
            num_train_epochs=4,  # Reduced to 4 epochs
            learning_rate=3e-5,  # Standard learning rate
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            weight_decay=0.01,  # Basic regularization
            eval_strategy="epoch",
            save_strategy="epoch",
            logging_steps=50,
            save_total_limit=1,  # Only keep best model
            load_best_model_at_end=True,
            metric_for_best_model="f1_macro",  # Simple F1 metric
            greater_is_better=True,
            fp16=True,  # Speed up training
            disable_tqdm=False,
            report_to="none"
        )

        # Initialize callbacks (BASIC - Simple monitoring)
        monitoring_callback = EnhancedMonitoringCallback()

        # Initialize STANDARD trainer (simpler, faster)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=compute_metrics_multilabel,
            data_collator=DataCollatorWithPadding(tokenizer),
            callbacks=[monitoring_callback]
        )

        # Train
        print(f"\nTraining Fold {fold + 1}...")
        trainer.train()

        # Final evaluation
        eval_results = trainer.evaluate()
        fold_score = eval_results['eval_f1_macro']
        fold_eval_loss = eval_results['eval_loss']

        fold_scores.append(fold_score)
        fold_metrics.append({
            'fold': fold + 1,
            'f1': fold_score,
            'eval_loss': fold_eval_loss,
            'combined_score': fold_score,  # Use F1 as combined score for simplicity
            'best_f1': monitoring_callback.best_f1,
            'best_eval_loss': monitoring_callback.best_eval_loss,
            'best_train_loss': monitoring_callback.best_train_loss
        })

        print(f"\n{'='*70}")
        print(f"Fold {fold + 1} Final Results:")
        print(f"{'='*70}")
        print(f"  F1 Score:       {fold_score:.4f}")
        print(f"  Eval Loss:      {fold_eval_loss:.4f}")
        print(f"{'='*70}")

        # Save model info
        fold_models.append({
            'model': model,
            'tokenizer': tokenizer,
            'score': fold_score,
            'eval_loss': fold_eval_loss,
            'combined_score': fold_score,  # Use F1 as combined score
            'fold': fold + 1,
            'history': monitoring_callback.history
        })

    # Print summary
    print(f"\n{'='*70}")
    print(f"CROSS-VALIDATION SUMMARY")
    print(f"{'='*70}")
    print(f"\nPer-Fold Results:")
    print(f"{'Fold':<8} {'F1 Score':<12} {'Eval Loss':<12} {'Combined':<12}")
    print(f"{'-'*50}")
    for metrics in fold_metrics:
        print(f"{metrics['fold']:<8} {metrics['f1']:<12.4f} {metrics['eval_loss']:<12.4f} {metrics['combined_score']:<12.4f}")

    print(f"\nAggregate Statistics:")
    print(f"  Mean F1:       {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")
    print(f"  Mean Eval Loss:{np.mean([m['eval_loss'] for m in fold_metrics]):.4f} ± {np.std([m['eval_loss'] for m in fold_metrics]):.4f}")
    print(f"  Mean Combined: {np.mean([m['combined_score'] for m in fold_metrics]):.4f} ± {np.std([m['combined_score'] for m in fold_metrics]):.4f}")
    print(f"{'='*70}\n")

    # Plot results
    plot_all_folds(fold_models, config['model_name'], LANGUAGE)

    return fold_models, fold_metrics

# ============================================
# Plotting Function
# ============================================
def plot_all_folds(fold_models, model_name, language):
    """Plot training history for all folds with complementary colors"""
    import matplotlib.pyplot as plt

    cfg = CONFIG[language]
    # Extract short model name from full path
    short_model_name = cfg['model_name'].split('/')[-1]

    n_folds = len(fold_models)
    colors = plt.cm.tab10(np.linspace(0, 0.9, n_folds))

    fig = plt.figure(figsize=(16, 6))
    mosaic = [['losses', 'f1']]
    axes = fig.subplot_mosaic(mosaic)

    # Left plot: Losses for all folds
    for idx, fold_model in enumerate(fold_models):
        history = fold_model['history']
        fold_num = fold_model['fold']
        color = colors[idx]

        axes['losses'].plot(history['epoch'], history['train_loss'],
                           label=f'Fold {fold_num} Train',
                           alpha=0.7,
                           linestyle='--',
                           color=color,
                           linewidth=2)
        axes['losses'].plot(history['epoch'], history['eval_loss'],
                           label=f'Fold {fold_num} Val',
                           alpha=0.9,
                           color=color,
                           linewidth=2)

    axes['losses'].set_xlabel('Epoch', fontsize=11)
    axes['losses'].set_ylabel('Loss', fontsize=11)
    axes['losses'].set_title(f'Training & Validation Loss - {short_model_name}', fontsize=12, fontweight='bold')
    axes['losses'].legend(fontsize=9, framealpha=0.9)
    axes['losses'].grid(True, alpha=0.3, linestyle='--')

    # Right plot: F1 scores for all folds
    for idx, fold_model in enumerate(fold_models):
        history = fold_model['history']
        fold_num = fold_model['fold']
        color = colors[idx]

        axes['f1'].plot(history['epoch'], history['eval_f1'],
                       label=f'Fold {fold_num}',
                       marker='o',
                       alpha=0.85,
                       color=color,
                       linewidth=2,
                       markersize=6)

    axes['f1'].set_xlabel('Epoch', fontsize=11)
    axes['f1'].set_ylabel('F1 Score', fontsize=11)
    axes['f1'].set_title(f'F1 Macro Score - {short_model_name}', fontsize=12, fontweight='bold')
    axes['f1'].legend(fontsize=9, framealpha=0.9)
    axes['f1'].grid(True, alpha=0.3, linestyle='--')

    plt.tight_layout()
    # Save with model name in filename
    plot_filename = f'multilabel_{cfg["output_name"]}_{short_model_name}.pdf'
    plt.savefig(plot_filename, dpi=200, format='pdf')
    print(f"\n✓ Plot saved to: {plot_filename}")
    plt.show()

# ============================================
# Ensemble Prediction for Multi-Label
# ============================================
def ensemble_predict_multilabel(fold_models, test_dataset, method='weighted', threshold=0.5):
    """Make ensemble predictions for multi-label classification"""
    all_predictions = []
    weights = []

    for i, fold_model in enumerate(fold_models):
        print(f"Getting predictions from Fold {i+1} (F1: {fold_model['score']:.4f})...")

        model = fold_model['model']
        tokenizer = fold_model['tokenizer']

        trainer = Trainer(
            model=model,
            data_collator=DataCollatorWithPadding(tokenizer)
        )

        predictions = trainer.predict(test_dataset)
        probs = torch.sigmoid(torch.tensor(predictions.predictions))
        all_predictions.append(probs.numpy())

        if method == 'weighted':
            weights.append(fold_model['score'])
        elif method == 'weighted_combined':
            weights.append(fold_model['combined_score'])
        else:
            weights.append(1.0)

    # Normalize weights
    weights = np.array(weights)
    weights = weights / weights.sum()

    print(f"\nEnsemble weights: {weights}")

    # Weighted average of probabilities
    weighted_probs = np.zeros_like(all_predictions[0])
    for pred, weight in zip(all_predictions, weights):
        weighted_probs += pred * weight

    # Apply threshold
    final_predictions = (weighted_probs > threshold).astype(int)

    return final_predictions

# ============================================
# Save Models in Compressed Format
# ============================================
def save_fold_models(fold_models, fold_metrics, language):
    """Save all fold models and metrics in a single compressed file"""
    cfg = CONFIG[language]

    # Extract short model name from full path
    short_model_name = cfg['model_name'].split('/')[-1]

    # Create temporary directory for models
    temp_dir = f"/content/temp_models_{cfg['output_name']}"
    os.makedirs(temp_dir, exist_ok=True)

    # Save best model
    best_fold_idx = np.argmax([m['combined_score'] for m in fold_metrics])
    best_fold_model = fold_models[best_fold_idx]

    best_model_dir = os.path.join(temp_dir, "best_model")
    os.makedirs(best_model_dir, exist_ok=True)
    best_fold_model['model'].save_pretrained(best_model_dir)
    best_fold_model['tokenizer'].save_pretrained(best_model_dir)

    print(f"\n✓ Best model prepared (Fold {best_fold_model['fold']})")
    print(f"✓ F1 Score: {best_fold_model['score']:.4f}")
    print(f"✓ Combined Score: {best_fold_model['combined_score']:.4f}")

    # Save all fold models
    for fold_model in fold_models:
        fold_dir = os.path.join(temp_dir, f"fold_{fold_model['fold']}")
        os.makedirs(fold_dir, exist_ok=True)
        fold_model['model'].save_pretrained(fold_dir)
        fold_model['tokenizer'].save_pretrained(fold_dir)
        print(f"✓ Prepared Fold {fold_model['fold']}")

    # Save metrics
    metrics_path = os.path.join(temp_dir, "fold_metrics.json")
    with open(metrics_path, 'w') as f:
        json.dump(fold_metrics, f, indent=2)

    # Save model info
    model_info = {
        'language': language,
        'model_name': cfg['model_name'],
        'short_model_name': short_model_name,
        'best_fold': best_fold_model['fold'],
        'best_f1': best_fold_model['score'],
        'best_combined_score': best_fold_model['combined_score'],
        'n_folds': len(fold_models),
        'label_columns': LABEL_COLUMNS
    }
    info_path = os.path.join(temp_dir, "model_info.json")
    with open(info_path, 'w') as f:
        json.dump(model_info, f, indent=2)

    # Compress everything with model name in filename
    zip_filename = f"multilabel_{cfg['output_name']}_{short_model_name}.zip"
    final_path = f"/content/drive/MyDrive/NLP/{zip_filename}"

    print(f"\nCompressing all models...")
    with zipfile.ZipFile(final_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(temp_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, temp_dir)
                zipf.write(file_path, arcname)

    # Clean up temp directory
    import shutil
    shutil.rmtree(temp_dir)

    print(f"\n{'='*70}")
    print(f"✓ All models saved to compressed file:")
    print(f"  {final_path}")
    print(f"  Model: {short_model_name}")
    print(f"  Language: {language}")
    print(f"{'='*70}\n")

# ============================================
# LOAD SAVED MODEL FOR PREDICTION
# ============================================
def load_saved_model(language='english', model_name=None):
    """
    Load the best saved model from compressed file

    Args:
        language: 'english' or 'swahili'
        model_name: Optional specific model name (e.g., 'xlm-roberta-base').
                   If None, uses the model_name from CONFIG

    Returns:
        model, tokenizer, model_info
    """
    cfg = CONFIG[language]

    # Use provided model_name or extract from config
    if model_name is None:
        short_model_name = cfg['model_name'].split('/')[-1]
    else:
        short_model_name = model_name.split('/')[-1] if '/' in model_name else model_name

    zip_path = f"/content/drive/MyDrive/NLP/multilabel_{cfg['output_name']}_{short_model_name}.zip"

    print(f"Loading model from: {zip_path}")

    # Extract to temp directory
    temp_extract_dir = f"/content/temp_extract_{cfg['output_name']}"

    with zipfile.ZipFile(zip_path, 'r') as zipf:
        zipf.extractall(temp_extract_dir)

    # Load model info
    with open(os.path.join(temp_extract_dir, "model_info.json"), 'r') as f:
        model_info = json.load(f)

    # Load best model
    best_model_dir = os.path.join(temp_extract_dir, "best_model")
    model = AutoModelForSequenceClassification.from_pretrained(best_model_dir)
    tokenizer = AutoTokenizer.from_pretrained(best_model_dir)

    print(f"✓ Model loaded successfully!")
    print(f"  Model: {model_info['short_model_name']}")
    print(f"  Best Fold: {model_info['best_fold']}")
    print(f"  F1 Score: {model_info['best_f1']:.4f}")

    return model, tokenizer, model_info

# ============================================
# LOAD ALL FOLD MODELS FOR ENSEMBLE
# ============================================
def load_all_fold_models(language='english', model_name=None):
    """
    Load all saved fold models from compressed file for ensemble prediction

    Args:
        language: 'english' or 'swahili'
        model_name: Optional specific model name (e.g., 'xlm-roberta-base').
                   If None, uses the model_name from CONFIG

    Returns:
        fold_models: List of model dictionaries
    """
    cfg = CONFIG[language]

    # Use provided model_name or extract from config
    if model_name is None:
        short_model_name = cfg['model_name'].split('/')[-1]
    else:
        short_model_name = model_name.split('/')[-1] if '/' in model_name else model_name

    zip_path = f"/content/drive/MyDrive/NLP/multilabel_{cfg['output_name']}_{short_model_name}.zip"

    print(f"\nLoading all fold models from: {zip_path}")

    # Extract to temp directory
    temp_extract_dir = f"/content/temp_extract_{cfg['output_name']}"

    with zipfile.ZipFile(zip_path, 'r') as zipf:
        zipf.extractall(temp_extract_dir)

    # Load model info and metrics
    with open(os.path.join(temp_extract_dir, "model_info.json"), 'r') as f:
        model_info = json.load(f)

    with open(os.path.join(temp_extract_dir, "fold_metrics.json"), 'r') as f:
        fold_metrics = json.load(f)

    n_folds = model_info['n_folds']
    fold_models = []

    print(f"Loading {n_folds} fold models for {language} ({short_model_name})...")

    for fold_num in range(1, n_folds + 1):
        fold_dir = os.path.join(temp_extract_dir, f"fold_{fold_num}")

        print(f"  Loading Fold {fold_num}...")

        model = AutoModelForSequenceClassification.from_pretrained(fold_dir)
        tokenizer = AutoTokenizer.from_pretrained(fold_dir)

        fold_metric = fold_metrics[fold_num - 1]

        fold_models.append({
            'model': model,
            'tokenizer': tokenizer,
            'fold': fold_num,
            'score': fold_metric['f1'],
            'eval_loss': fold_metric['eval_loss'],
            'combined_score': fold_metric['combined_score']
        })

        print(f"  ✓ Fold {fold_num} loaded (F1: {fold_metric['f1']:.4f}, Loss: {fold_metric['eval_loss']:.4f})")

    print(f"\n✓ All {n_folds} models loaded successfully!")

    return fold_models

# ============================================
# PREDICT WITH SAVED MODEL
# ============================================
def predict_with_saved_model(test_csv_path, language='english', model_name=None,
                             output_csv='predictions.csv', threshold=0.5):
    """
    Make predictions using saved best model

    Args:
        test_csv_path: Path to test CSV file
        language: 'english' or 'swahili'
        model_name: Optional specific model name. If None, uses CONFIG model_name
        output_csv: Output CSV filename
        threshold: Prediction threshold for multi-label
    """
    # Load saved model
    model, tokenizer, model_info = load_saved_model(language, model_name)

    # Load test data
    test = pd.read_csv(test_csv_path)
    print(f"Test size: {len(test)}")

    # Create test dataset
    test_dataset = MultiLabelTestDataset(test['text'].tolist(), tokenizer)

    # Create trainer for prediction
    trainer = Trainer(
        model=model,
        data_collator=DataCollatorWithPadding(tokenizer)
    )

    # Make predictions
    print("Making predictions...")
    predictions = trainer.predict(test_dataset)
    probs = torch.sigmoid(torch.tensor(predictions.predictions))
    predicted_labels = (probs > threshold).int().numpy()

    # Save results
    label_cols = model_info['label_columns']
    results_df = pd.DataFrame({'id': test['id']})
    for i, col in enumerate(label_cols):
        results_df[col] = predicted_labels[:, i]

    results_df.to_csv(output_csv, index=False)

    print(f"✓ Predictions saved to: {output_csv}")
    print(f"\nSample predictions:")
    print(results_df.head())
    print(f"\nPrediction distribution:")
    print(results_df[label_cols].sum())

    return results_df

# ============================================
# ENSEMBLE PREDICT WITH SAVED MODELS
# ============================================
def ensemble_predict_with_saved_models(test_csv_path, language='english', model_name=None,
                                       method='weighted_combined',
                                       output_csv='ensemble_predictions.csv',
                                       threshold=0.5):
    """
    Make ensemble predictions using all saved fold models

    Args:
        test_csv_path: Path to test CSV file
        language: 'english' or 'swahili'
        model_name: Optional specific model name. If None, uses CONFIG model_name
        method: 'average', 'weighted', or 'weighted_combined'
        output_csv: Output CSV filename
        threshold: Prediction threshold for multi-label
    """
    # Load all fold models
    fold_models = load_all_fold_models(language, model_name)

    # Load test data
    test = pd.read_csv(test_csv_path)
    print(f"\nTest size: {len(test)}")

    # Create test dataset
    tokenizer = fold_models[0]['tokenizer']
    test_dataset = MultiLabelTestDataset(test['text'].tolist(), tokenizer)

    # Make ensemble predictions
    print(f"\nMaking ensemble predictions with method: {method}")
    final_predictions = ensemble_predict_multilabel(fold_models, test_dataset, method=method, threshold=threshold)

    # Save results
      results_df = pd.DataFrame({
        'id': test['id'],
        'polarization': final_predictions[:, 0],
    })


    results_df.to_csv(output_csv, index=False)

    print(f"\n✓ Ensemble predictions saved to: {output_csv}")
    print(f"\nSample predictions:")
    print(results_df.head(10))
    print(f"\nPrediction distribution:")
    print(results_df[LABEL_COLUMNS].sum())

    return results_df

# ============================================
# CONTINUE TRAINING SAVED MODEL
# ============================================
def continue_training(train_csv_path, language='english', model_name=None, additional_epochs=5):
    """
    Continue training a saved model

    Args:
        train_csv_path: Path to training CSV
        language: 'english' or 'swahili'
        model_name: Optional specific model name. If None, uses CONFIG model_name
        additional_epochs: Number of additional epochs to train
    """
    from sklearn.model_selection import train_test_split

    # Load saved model
    model, tokenizer, model_info = load_saved_model(language, model_name)

    # Load training data
    train = pd.read_csv(train_csv_path)
    print(f"Train size: {len(train)}")

    # Create stratification key
    train['stratify_key'] = create_stratify_key(train, LABEL_COLUMNS)

    # Split into train/val
    train_split, val_split = train_test_split(
        train,
        test_size=0.2,
        stratify=train['stratify_key'],
        random_state=42
    )

    # Create datasets
    train_dataset = MultiLabelDataset(
        train_split['text'].tolist(),
        train_split[LABEL_COLUMNS].values.tolist(),
        tokenizer
    )
    val_dataset = MultiLabelDataset(
        val_split['text'].tolist(),
        val_split[LABEL_COLUMNS].values.tolist(),
        tokenizer
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"/content/outputs/continued_training",
        num_train_epochs=additional_epochs,
        learning_rate=3e-5,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        lr_scheduler_type="linear",
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=25,
        load_best_model_at_end=True,
        metric_for_best_model="f1_macro",
        greater_is_better=True,
        fp16=True,
        dataloader_num_workers=2,
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_multilabel,
        data_collator=DataCollatorWithPadding(tokenizer)
    )

    # Continue training
    print(f"\nContinuing training for {additional_epochs} epochs...")
    trainer.train()

    # Evaluate
    eval_results = trainer.evaluate()
    print(f"\n✓ Final F1 Score: {eval_results['eval_f1_macro']:.4f}")

    # Save updated model
    cfg = CONFIG[language]
    short_model_name = model_info['short_model_name']
    updated_model_dir = f"/content/temp_updated_model"
    os.makedirs(updated_model_dir, exist_ok=True)
    trainer.save_model(updated_model_dir)
    tokenizer.save_pretrained(updated_model_dir)

    # Compress updated model
    zip_filename = f"multilabel_{cfg['output_name']}_{short_model_name}_updated.zip"
    final_path = f"/content/drive/MyDrive/NLP{zip_filename}"

    with zipfile.ZipFile(final_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(updated_model_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, updated_model_dir)
                zipf.write(file_path, arcname)

    print(f"✓ Updated model saved to: {final_path}")

    return trainer

# ============================================
# Main Training Pipeline
# ============================================
def main():
    print(f"Loading data for {LANGUAGE}...")
    train = pd.read_csv(config['train_file'])
    test = pd.read_csv(config['test_file'])

    print(f"Train size: {len(train)}")
    print(f"Test size: {len(test)}")
    print(f"\nLabel distribution:")
    print(train[LABEL_COLUMNS].sum())

    # Train with K-Fold CV
    fold_models, fold_metrics = train_kfold_cv(
        train,
        config['model_name'],
        LABEL_COLUMNS,
        n_splits=5,
        seed=42
    )

    # Save models
    save_fold_models(fold_models, fold_metrics, LANGUAGE)

    # Create test dataset
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    test_dataset = MultiLabelTestDataset(test['text'].tolist(), tokenizer)

    # Ensemble predictions
    print("\n" + "="*70)
    print("Making Weighted Ensemble Predictions")
    print("="*70 + "\n")

    final_predictions = ensemble_predict_multilabel(
        fold_models,
        test_dataset,
        method='weighted_combined',
        threshold=0.5
    )

    # Save predictions
     results_df = pd.DataFrame({
        'id': test['id'],
        'polarization': final_predictions[:, 0],
    })


    os.makedirs('subtask_1', exist_ok=True)
    csv_path = f'subtask_1/pred_{config["output_name"]}.csv'
    results_df.to_csv(csv_path, index=False)

    print(f"\n✓ Saved predictions to {csv_path}")
    print(f"\nSample predictions:")
    print(results_df.head(10))
    print(f"\nPrediction distribution:")
    print(results_df[LABEL_COLUMNS].sum())

    # Compress
    zip_filename = f'subtask_1_{config["output_name"]}_ensemble.zip'
    with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk('subtask_2'):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.join(os.path.basename(root), file)
                zipf.write(file_path, arcname)

    print(f"✓ Created compressed file: {zip_filename}")

if __name__ == "__main__":
    main()

In [None]:
!rm -rf /content/outputs