In [1]:
import os
os.environ['WANDB_DISABLED'] = 'true'

In [2]:
# 1. Installation Cell
%pip install transformers>=4.40.0 torch>=2.0.0 accelerate scikit-learn pandas numpy torch-lr-finder

In [12]:
# 2. Imports and Setup Cell
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
    AutoTokenizer,
    MambaConfig,
    MambaForCausalLM,
    MambaModel,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    DataCollatorWithPadding, # Added this import
    TrainerCallback
)
from torch.utils.data import Dataset
import warnings
from tqdm.auto import tqdm
import os
import glob
import json

warnings.filterwarnings('ignore')

# Mount drive
try:
    from google.colab import drive
    drive.mount('/content/gdrive')
except ImportError:
    print("Not in Colab, skipping drive mount.")

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Using device: cuda


In [13]:
# 3. Data Loading Cell
def load_multilingual_data(data_dir, languages=None, split='train'):
    """
    Load data from multiple language files

    Args:
        data_dir: Path to directory (e.g., '/content/gdrive/MyDrive/subtask1/train')
        languages: List of language codes (e.g., ['eng', 'arb', 'deu']) or None for all
        split: 'train' or 'dev'

    Returns:
        combined_df: Combined DataFrame with all languages
        language_counts: Dict with counts per language
    """
    import glob
    import os
    import pandas as pd

    # Language mapping (expanded to include all languages from both versions)
    lang_map = {
        'amh': 'Amharic',
        'arb': 'Arabic',
        'bul': 'Bulgarian',
        'deu': 'German',
        'eng': 'English',
        'spa': 'Spanish',
        'fra': 'French',
        'hau': 'Hausa',
        'hin': 'Hindi',
        'ita': 'Italian',
        'por': 'Portuguese',
        'urd': 'Urdu',
        'zho': 'Chinese'
    }

    all_data = []
    language_counts = {}

    print(f"{'='*70}")
    print(f"LOADING {split.upper()} DATA - MULTILINGUAL")
    print(f"{'='*70}")
    print(f"Data directory: {data_dir}")
    print()

    # If no languages specified, search for all available files
    if languages is None:
        # Try both naming patterns
        pattern1 = os.path.join(data_dir, f'subtask1_*_{split}.csv')
        pattern2 = os.path.join(data_dir, '*.csv')

        files = glob.glob(pattern1)
        if not files:
            files = glob.glob(pattern2)
            print(f"Using pattern: *.csv")
        else:
            print(f"Using pattern: subtask1_*_{split}.csv")
    else:
        files = []
        # Try multiple naming conventions
        for lang in languages:
            possible_names = [
                os.path.join(data_dir, f'subtask1_{lang}_{split}.csv'),
                os.path.join(data_dir, f'{lang}.csv')
            ]
            for path in possible_names:
                if os.path.exists(path):
                    files.append(path)
                    break

    print(f"Found {len(files)} files for {split} split\\n")

    for file_path in files:
        filename = os.path.basename(file_path)

        # Extract language code - try multiple patterns
        try:
            if filename.startswith('subtask1_'):
                # Pattern: subtask1_{lang}_{split}.csv
                lang_code = filename.split('_')[1]
            else:
                # Pattern: {lang}.csv
                lang_code = filename.split('.')[0]
        except IndexError:
            print(f"⚠️  Skipping file with unexpected name: {filename}")
            continue

        if not os.path.exists(file_path):
            print(f"⚠️  Warning: File not found: {file_path}, skipping...")
            continue

        # Load CSV
        df = pd.read_csv(file_path)
        df['language'] = lang_code
        df['language_name'] = lang_map.get(lang_code, lang_code)

        all_data.append(df)
        language_counts[lang_code] = len(df)

        lang_name = lang_map.get(lang_code, lang_code)
        print(f"✓ Loaded {lang_code} ({lang_name:12s}): {len(df):5d} samples from {filename}")

    if not all_data:
        print(f"\\n⚠️  ERROR: No data loaded. Check path: {data_dir}")
        print(f"    Make sure files exist and match expected naming patterns:")
        print(f"    - Pattern 1: subtask1_{{lang}}_{{split}}.csv")
        print(f"    - Pattern 2: {{lang}}.csv")
        return pd.DataFrame(columns=['text', 'label', 'language', 'language_name']), {}

    combined_df = pd.concat(all_data, ignore_index=True)

    print(f"\\n{'='*70}")
    print(f"TOTAL: {len(combined_df)} samples across {len(language_counts)} languages")
    print(f"{'='*70}")

    # Check for both 'label' and 'polarization' columns
    if 'label' in combined_df.columns:
        print(f"\\nLabel distribution:\\n{combined_df['label'].value_counts(normalize=True)}")
    elif 'polarization' in combined_df.columns:
        print("\\nClass Distribution:")
        for lang_code, count in language_counts.items():
            lang_df = combined_df[combined_df['language'] == lang_code]
            polarized = (lang_df['polarization'] == 1).sum()
            non_polarized = (lang_df['polarization'] == 0).sum()
            print(f"  {lang_code}: Polarized={polarized}, Non-Polarized={non_polarized}")
    else:
        print("\\n⚠️  Warning: No 'label' or 'polarization' column found in loaded data.")

    return combined_df, language_counts

In [14]:
# 4. Mamba Model for Sequence Classification Cell
class MambaForSequenceClassification(nn.Module):
    """
    Mamba model adapted for sequence classification
    Uses the efficient state-space architecture instead of attention
    """
    def __init__(self, model_name="state-spaces/mamba-370m-hf", num_labels=2, dropout=0.1):
        super().__init__()
        self.num_labels = num_labels

        # Load Mamba backbone
        print(f"Loading Mamba model: {model_name}")
        self.mamba = MambaModel.from_pretrained(model_name)

        # Get hidden size from config
        self.hidden_size = self.mamba.config.hidden_size

        # Classification head
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.hidden_size, num_labels)

        # Initialize classifier weights
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass for classification

        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len] (optional, may not be used by Mamba)
            labels: Ground truth labels [batch_size] (optional)

        Returns:
            SequenceClassifierOutput with loss and logits
        """
        # Get Mamba outputs
        # Note: Mamba processes sequences differently than transformers
        # It doesn't use attention_mask the same way, but we pass it for compatibility
        outputs = self.mamba(input_ids=input_ids)

        # Get last hidden state
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]

        # Use mean pooling over sequence length (alternative to [CLS] token)
        # This is more appropriate for Mamba which doesn't have special tokens
        if attention_mask is not None:
            # Mask out padding tokens
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            sum_hidden = torch.sum(hidden_states * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
            pooled_output = sum_hidden / sum_mask
        else:
            # Simple mean pooling
            pooled_output = hidden_states.mean(dim=1)

        # Apply dropout and classification layer
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )

    def freeze_backbone(self, freeze=True):
        """Freeze/unfreeze the Mamba backbone"""
        for param in self.mamba.parameters():
            param.requires_grad = not freeze
        print(f"Mamba backbone {'frozen' if freeze else 'unfrozen'}")

    def unfreeze_last_n_layers(self, n=4):
        """Unfreeze last n layers of Mamba for fine-tuning"""
        # Freeze all first
        self.freeze_backbone(freeze=True)

        # Unfreeze last n layers
        total_layers = len(self.mamba.layers)
        layers_to_unfreeze = list(range(max(0, total_layers - n), total_layers))

        for idx in layers_to_unfreeze:
            for param in self.mamba.layers[idx].parameters():
                param.requires_grad = True

        print(f"Unfroze last {n} layers of Mamba (layers {layers_to_unfreeze}")

In [15]:
# 5. Dataset Class Cell
class MultilingualDataset(Dataset):
    """Dataset for multilingual text classification with Mamba"""
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        # Tokenize
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [16]:
# 6. Focal Loss Cell
class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance
    """
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [17]:
# 7. Custom Trainer Cell
class FocalLossTrainer(Trainer):
    """Custom Trainer that uses Focal Loss"""
    def __init__(self, *args, use_focal_loss=False, focal_alpha=0.25, focal_gamma=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_focal_loss = use_focal_loss
        self.focal_loss = None
        if use_focal_loss:
            self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Added num_items_in_batch
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        logits = outputs.logits

        if self.use_focal_loss and self.focal_loss is not None:
            loss = self.focal_loss(logits, labels)
        else:
            # Default to model's internal loss (if labels were passed)
            if outputs.loss is not None:
                loss = outputs.loss
            else:
                # Or recompute CE loss if model didn't return it
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return (loss, outputs) if return_outputs else loss

In [18]:
# 8. Metrics Computation Cell
def compute_metrics(eval_pred):
    """Compute F1, accuracy metrics"""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    f1_macro = f1_score(labels, predictions, average='macro')
    f1_binary = f1_score(labels, predictions, average='binary')
    accuracy = accuracy_score(labels, predictions)

    return {
        'f1_macro': f1_macro,
        'f1_binary': f1_binary,
        'accuracy': accuracy
    }

In [19]:
# 9. Training Setup Cell
torch.cuda.empty_cache()
# Configuration
MODEL_NAME = "state-spaces/mamba-370m-hf"  # 370M params, fits in 6-8GB VRAM
MAX_LENGTH = 512
BATCH_SIZE = 4  # Adjusted batch size
GRADIENT_ACCUMULATION_STEPS = 8  # Adjusted to maintain effective batch size (4 * 8 = 32)
LEARNING_RATE = 2e-5
NUM_EPOCHS = 5
WARMUP_RATIO = 0.1

# Paths (adjust to your setup)
# Ensure this path is correct in your Google Drive
BASE_DIR = '/content/gdrive/MyDrive/subtask1'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
DEV_DIR = os.path.join(BASE_DIR, 'dev')
OUTPUT_DIR = './mamba_multilingual_output'

# Create output dir
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load data
print("Loading training data...")
train_df, train_counts = load_multilingual_data(TRAIN_DIR, split='train')
print("\nLoading validation data...")
val_df, val_counts = load_multilingual_data(DEV_DIR, split='dev')

# Check if data loaded
if train_df.empty or val_df.empty:
    print("ERROR: Data not loaded. Check paths:")
    print(f"TRAIN_DIR: {TRAIN_DIR}")
    print(f"DEV_DIR: {DEV_DIR}")
    # Stop execution if data is missing
    raise FileNotFoundError("Could not load training or validation data.")

# Initialize tokenizer (Mamba uses GPT-NeoX tokenizer)
print(f"\nLoading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set padding token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer pad_token set to eos_token")

# Create datasets
print("\nCreating datasets...")
train_dataset = MultilingualDataset(
    texts=train_df['text'].values,
    labels=train_df['polarization'].values,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH
)
val_dataset = MultilingualDataset(
    texts=val_df['text'].values,
    labels=val_df['polarization'].values,
    tokenizer=tokenizer,
    max_length=MAX_LENGTH
)
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Initialize model
print(f"\nInitializing Mamba model...")
model = MambaForSequenceClassification(
    model_name=MODEL_NAME,
    num_labels=2,
    dropout=0.1
)

# Optional: Freeze backbone and only train last layers + classifier
# Uncomment if you want faster training with less memory
model.unfreeze_last_n_layers(n=4)

# Move to GPU if available
# device is defined in the import cell
model = model.to(device)
print(f"Model loaded on {device}")

print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Loading training data...
LOADING TRAIN DATA - MULTILINGUAL
Data directory: /content/gdrive/MyDrive/subtask1/train

Using pattern: *.csv
Found 13 files for train split\n
✓ Loaded eng (English     ):  2676 samples from eng.csv
✓ Loaded urd (Urdu        ):  2849 samples from urd.csv
✓ Loaded arb (Arabic      ):  3380 samples from arb.csv
✓ Loaded hau (Hausa       ):  3651 samples from hau.csv
✓ Loaded ita (Italian     ):  3334 samples from ita.csv
✓ Loaded nep (nep         ):  2005 samples from nep.csv
✓ Loaded deu (German      ):  3180 samples from deu.csv
✓ Loaded fas (fas         ):  3295 samples from fas.csv
✓ Loaded zho (Chinese     ):  4280 samples from zho.csv
✓ Loaded amh (Amharic     ):  3332 samples from amh.csv
✓ Loaded tur (tur         ):  2364 samples from tur.csv
✓ Loaded hin (Hindi       ):  2744 samples from hin.csv
✓ Loaded spa (Spanish     ):  3305 samples from spa.csv
TOTAL: 40395 samples across 13 languages
\nClass Distribution:
  eng: Polarized=1002, Non-Polarized=167

In [20]:
# 10. Training Arguments and Training Cell

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_ratio=WARMUP_RATIO,
    logging_dir=f'{OUTPUT_DIR}/logs',
    logging_steps=10, # Changed from 50 to 10 for more verbose output
    eval_steps=150,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=150,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    dataloader_num_workers=0, # Changed from 2 to 0 to prevent potential hanging issues
    remove_unused_columns=False,
    report_to="none",
    disable_tqdm=False,  # Add this line - explicitly enable progress bar
    logging_first_step=True,  # Add this line - shows progress immediately
)

# Initialize trainer
trainer = FocalLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    use_focal_loss=False,  # Set to True if you want focal loss
    focal_alpha=0.25,
    focal_gamma=2.0,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

# Train
print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70)
trainer.train()

# Save final model
print("\nSaving final model...")
final_model_path = os.path.join(OUTPUT_DIR, 'final_model')
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"Model saved to {final_model_path}!")


STARTING TRAINING


KeyboardInterrupt: 

In [None]:
# 11. Threshold Finding Cell

def find_optimal_threshold(model, dataset, device, metric='f1_macro'):
    """
    Find optimal classification threshold for binary classification
    """
    print(f"\n{'='*70}")
    print(f"FINDING OPTIMAL THRESHOLD FOR {metric.upper()}")
    print(f"{'='*70}")

    model.eval()
    all_probs = []
    all_labels = []

    from torch.utils.data import DataLoader
    dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

    print(f"Validation samples: {len(dataset)}")
    print("Generating predictions...")

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = F.softmax(outputs.logits, dim=-1)[:, 1]  # Probability of class 1

            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)

    # Test different thresholds
    thresholds = np.arange(0.1, 0.9, 0.01)
    best_threshold = 0.5
    best_score = 0

    print(f"Testing {len(thresholds)} thresholds...")

    for threshold in thresholds:
        preds = (all_probs >= threshold).astype(int)

        if metric == 'f1_macro':
            score = f1_score(all_labels, preds, average='macro')
        elif metric == 'f1_binary':
            score = f1_score(all_labels, preds, average='binary')
        else:
            score = accuracy_score(all_labels, preds)

        if score > best_score:
            best_score = score
            best_threshold = threshold

    print(f"\nOptimal threshold: {best_threshold:.3f}")
    print(f"Best {metric}: {best_score:.4f}")

    # Final predictions with optimal threshold
    final_preds = (all_probs >= best_threshold).astype(int)

    print("\nClassification Report:")
    print(classification_report(all_labels, final_preds, digits=4))

    return best_threshold, best_score

# Find optimal thresholds
print("\n" + "="*70)
print("FINDING OPTIMAL THRESHOLD")
print("="*70)

threshold_f1_macro = 0.5
score_f1_macro = 0.0
threshold_f1_binary = 0.5
score_f1_binary = 0.0

# Ensure val_dataset and device are available from the previous cell
if 'val_dataset' in globals() and 'device' in globals() and 'model' in globals():
    threshold_f1_macro, score_f1_macro = find_optimal_threshold(
        model, val_dataset, device, metric='f1_macro'
    )
    threshold_f1_binary, score_f1_binary = find_optimal_threshold(
        model, val_dataset, device, metric='f1_binary'
    )

    # Save thresholds
    thresholds = {
        'f1_macro': {'threshold': float(threshold_f1_macro), 'score': float(score_f1_macro)},
        'f1_binary': {'threshold': float(threshold_f1_binary), 'score': float(score_f1_binary)}
    }

    threshold_path = os.path.join(OUTPUT_DIR, 'optimal_thresholds.json')
    with open(threshold_path, 'w') as f:
        json.dump(thresholds, f, indent=2)

    print(f"\nThresholds saved to {threshold_path}")
else:
    print("Could not find val_dataset or device. Skipping threshold finding.")

In [None]:
# 12. Inference Function Cell

def predict_with_mamba(texts, model, tokenizer, device, threshold=0.5, batch_size=16):
    """
    Make predictions on new texts using Mamba model

    Args:
        texts: List of text strings or a single text string
        model: Trained MambaForSequenceClassification model
        tokenizer: Tokenizer
        device: torch device
        threshold: Classification threshold
        batch_size: Batch size for inference

    Returns:
        predictions: List of predicted labels (0 or 1)
        probabilities: List of probabilities for class 1
    """
    model.eval()
    all_probs = []

    from torch.utils.data import DataLoader, TensorDataset

    # Handle single text input
    if isinstance(texts, str):
        texts = [texts]

    # Tokenize all texts
    encodings = tokenizer(
        texts,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'])
    dataloader = DataLoader(dataset, batch_size=batch_size)

    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = F.softmax(outputs.logits, dim=-1)[:, 1]

            all_probs.extend(probs.cpu().numpy())

    all_probs = np.array(all_probs)
    predictions = (all_probs >= threshold).astype(int)

    return predictions.tolist(), all_probs.tolist()

# Example usage
print("\n" + "="*70)
print("INFERENCE EXAMPLE")
print("="*70)

test_texts = [
    "This is a great product! I love it!",
    "Terrible experience, would not recommend."
]

# Ensure model, tokenizer, device, and threshold are available
if 'model' in globals() and 'tokenizer' in globals() and 'device' in globals() and 'threshold_f1_macro' in globals():
    predictions, probabilities = predict_with_mamba(
        test_texts, model, tokenizer, device, threshold=threshold_f1_macro
    )

    for text, pred, prob in zip(test_texts, predictions, probabilities):
        print(f"Text: {text}")
        print(f"Prediction: {pred} (prob: {prob:.4f})\n")
else:
    print("Could not run example: model, tokenizer, or threshold not found.")
    print("Please run the training cells first.")