# Import and setups

In [None]:
!pip install mistral-common

In [None]:
import os
import json
import pickle
import logging
import argparse
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from transformers import (
    AutoTokenizer, 
    AutoModel, 
    AutoConfig,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import classification_report, f1_score
from tqdm import tqdm
import numpy as np

from huggingface_hub import login


HF_TOKEN = ""
login(token=HF_TOKEN)


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Configuration

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration optimized for Kaggle."""
    model_name: str = "mistralai/Ministral-8B-Instruct-2410"
    max_length: int = 96
    batch_size: int = 12
    learning_rate: float = 1e-4
    num_epochs: int = 1
    warmup_steps: int = 150
    weight_decay: float = 0.01
    gradient_accumulation_steps: int = 3
    max_grad_norm: float = 0.5
    save_steps: int = 500
    eval_steps: int = 250
    logging_steps: int = 50
    output_dir: str = "./ministral_token_classifier"
    seed: int = 42


config = TrainingConfig()
print(f"📋 Configuration:")
print(f"   Model: {config.model_name}")
print(f"   Batch size: {config.batch_size}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Max length: {config.max_length}")

# Simplefied Dataset class

In [6]:
class FastPIITokenDataset(Dataset):
    """Ultra-fast PyTorch Dataset - NO tokenization during training!"""
    
    def __init__(self, dataset: Dict[str, Any], max_length: int = 96):
        self.texts = dataset['texts']
        self.token_ids = dataset['token_ids']
        self.label_ids = dataset['label_ids']
        self.label_to_id = dataset['label_to_id']
        self.id_to_label = dataset['id_to_label']
        self.max_length = max_length
        print(f"FastPIITokenDataset: {len(self.texts)} examples")
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        token_ids = self.token_ids[idx]
        label_ids = self.label_ids[idx]
        
        # Pad/truncate
        input_ids = self._pad_or_truncate(token_ids, self.max_length, 0)
        labels = self._pad_or_truncate(label_ids, self.max_length, -100)
        attention_mask = [1 if tid != 0 else 0 for tid in input_ids]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }
    
    def _pad_or_truncate(self, seq, max_len, pad_val):
        return seq[:max_len] + [pad_val] * max(0, max_len - len(seq))

# Model definition

In [None]:
class MinistralTokenClassifier(nn.Module):
    """Ministral-8B model with token classification head - Optimized version."""
    
    def __init__(self, model_name: str, num_labels: int, freeze_backbone: bool = True):
        """
        Initialize the model.
        
        Args:
            model_name: Pre-trained model name
            num_labels: Number of classification labels
            freeze_backbone: Whether to freeze the backbone model
        """
        super().__init__()
        
        print(f"Loading model: {model_name}")
        
        self.config = AutoConfig.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
            print("Backbone model frozen")
        
        hidden_size = self.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_labels)
        
        self.classifier = self.classifier.to(dtype=torch.bfloat16)
        
        nn.init.normal_(self.classifier.weight, std=0.02)
        nn.init.zeros_(self.classifier.bias)
        
        self.num_labels = num_labels
        
        print(f"Model initialized with {num_labels} labels")
        print(f"Backbone dtype: {next(self.backbone.parameters()).dtype}")
        print(f"Classifier dtype: {self.classifier.weight.dtype}")
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        """Forward pass."""
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            use_cache=False
        )
        
        hidden_states = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        hidden_states = self.dropout(hidden_states)
        
        hidden_states = hidden_states.to(self.classifier.weight.dtype)
        
        logits = self.classifier(hidden_states)  # [batch_size, seq_len, num_labels]
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits
        }

print("Optimized model class defined")

# Load dataset


In [None]:

train_dataset_path = "/kaggle/input/mistral-token-classif-english/train_dataset.pkl"
val_dataset_path = "/kaggle/input/mistral-token-classif-english/val_dataset.pkl"

print("Loading datasets...")

with open(train_dataset_path, 'rb') as f:
    train_data = pickle.load(f)

train_dataset = FastPIITokenDataset(train_data, config.max_length)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True,
    num_workers=2
)

val_dataloader = None
if os.path.exists(val_dataset_path):
    with open(val_dataset_path, 'rb') as f:
        val_data = pickle.load(f)
    
    val_dataset = FastPIITokenDataset(val_data, config.max_length)
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False,
        num_workers=2
    )
    print(f"Loaded {len(val_dataset)} validation examples")

label_to_id = train_data['label_to_id']
id_to_label = train_data['id_to_label']
num_labels = train_data['num_labels']

print(f"Loaded {len(train_dataset)} training examples")
print(f"Number of labels: {num_labels}")
print(f"Labels: {list(label_to_id.keys())}")

# Reduce dataset

In [None]:
import random
from collections import Counter

def reduce_dataset_size(train_data, reduction_strategy="random", target_size=10000, 
                       min_examples_per_label=50, seed=42):
    """
    Reduce training dataset size with different strategies.
    
    Args:
        train_data: Dictionary with 'texts', 'token_ids', 'label_ids', etc.
        reduction_strategy: 'random', 'balanced', 'stratified'
        target_size: Target number of examples
        min_examples_per_label: Minimum examples per label (for balanced)
        seed: Random seed for reproducibility
    
    Returns:
        Reduced train_data dictionary
    """
    
    random.seed(seed)
    
    original_size = len(train_data['texts'])
    print(f"Original dataset size: {original_size:,} examples")
    
    if target_size >= original_size:
        print("Target size >= original size, no reduction needed")
        return train_data
    
    all_indices = list(range(original_size))
    
    if reduction_strategy == "random":
        selected_indices = random.sample(all_indices, target_size)
        print(f"Random sampling: {target_size:,} examples")
        
    elif reduction_strategy == "balanced":
        print(f"Balanced sampling with min {min_examples_per_label} per label...")
        
        label_counts = Counter()
        label_to_indices = {}
        
        for idx in all_indices:
            example_labels = [label for label in train_data['label_ids'][idx] 
                            if label != train_data['label_to_id'].get('O', -1) and label != -100]
            
            for label in set(example_labels):
                if label not in label_to_indices:
                    label_to_indices[label] = []
                label_to_indices[label].append(idx)
                label_counts[label] += 1
        
        print(f"Found {len(label_to_indices)} unique labels")
        
        selected_indices = set()
        
        for label, indices in label_to_indices.items():
            if len(indices) >= min_examples_per_label:
                selected_indices.update(random.sample(indices, min_examples_per_label))
            else:
                selected_indices.update(indices)
        
        remaining_slots = target_size - len(selected_indices)
        if remaining_slots > 0:
            remaining_indices = [idx for idx in all_indices if idx not in selected_indices]
            if remaining_indices:
                additional_indices = random.sample(
                    remaining_indices, 
                    min(remaining_slots, len(remaining_indices))
                )
                selected_indices.update(additional_indices)
        
        selected_indices = list(selected_indices)[:target_size]
        
    elif reduction_strategy == "stratified":
        print(f"📊 Stratified sampling by text length and label diversity...")
        
        features = []
        for idx in all_indices:
            text_length = len(train_data['texts'][idx])
            unique_labels = len(set([label for label in train_data['label_ids'][idx] 
                                   if label != train_data['label_to_id'].get('O', -1) and label != -100]))
            features.append((idx, text_length, unique_labels))
        
        features.sort(key=lambda x: (x[1], x[2]))
        step = len(features) // target_size
        selected_indices = [features[i][0] for i in range(0, len(features), max(1, step))][:target_size]
        
    else:
        raise ValueError(f"Unknown reduction strategy: {reduction_strategy}")
    
    reduced_data = {}
    for key in train_data.keys():
        if key in ['texts', 'token_ids', 'label_ids']:
            reduced_data[key] = [train_data[key][idx] for idx in selected_indices]
        else:
            reduced_data[key] = train_data[key]
    
    print(f"Reduced dataset size: {len(reduced_data['texts']):,} examples")
    print(f"Reduction ratio: {len(reduced_data['texts'])/original_size:.1%}")
    
    if 'label_to_id' in train_data:
        reduced_label_counts = Counter()
        for label_seq in reduced_data['label_ids']:
            for label in label_seq:
                if label != train_data['label_to_id'].get('O', -1) and label != -100:
                    reduced_label_counts[label] += 1
        
        print(f"Reduced dataset has {len(reduced_label_counts)} active labels")
        
        if reduced_label_counts:
            id_to_label = train_data['id_to_label']
            print("🔝 Top 10 labels in reduced dataset:")
            for label_id, count in reduced_label_counts.most_common(10):
                label_name = id_to_label.get(label_id, f"ID_{label_id}")
                print(f"   {label_name}: {count:,} tokens")
    
    return reduced_data

REDUCTION_CONFIG = {
    'strategy': 'balanced',
    'target_size': 30000,
    'min_examples_per_label': 100,
    'seed': 42
}

print("Reducing training dataset size...")
print(f"Strategy: {REDUCTION_CONFIG['strategy']}")
print(f"Target size: {REDUCTION_CONFIG['target_size']:,}")

train_data = reduce_dataset_size(
    train_data, 
    reduction_strategy=REDUCTION_CONFIG['strategy'],
    target_size=REDUCTION_CONFIG['target_size'],
    min_examples_per_label=REDUCTION_CONFIG['min_examples_per_label'],
    seed=REDUCTION_CONFIG['seed']
)

print("Recreating dataset and dataloader...")

train_dataset = FastPIITokenDataset(train_data, config.max_length)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True,
    num_workers=2
)

print(f"New training dataset: {len(train_dataset):,} examples")
print(f"New dataloader: {len(train_dataloader):,} batches")

total_steps = len(train_dataloader) * config.num_epochs // config.gradient_accumulation_steps
print(f"Updated total training steps: {total_steps:,}")

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=total_steps
)

print("Ready to train with reduced dataset!")

# Initialize model and optimizer

In [None]:

torch.manual_seed(config.seed)
np.random.seed(config.seed)

model = MinistralTokenClassifier(
    model_name=config.model_name,
    num_labels=57,
    freeze_backbone=True
)

model.to(device)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable ratio: {trainable_params/total_params:.2%}")

optimizer_params = [p for p in model.parameters() if p.requires_grad]

optimizer = AdamW(
    optimizer_params,
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

total_steps = len(train_dataloader) * config.num_epochs // config.gradient_accumulation_steps

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=total_steps
)

print(f"Total training steps: {total_steps}")
print(f"Ready to train!")

# Training functions


In [None]:
def train_epoch(model, train_dataloader, optimizer, scheduler, epoch, config, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
    
    for step, batch in enumerate(progress_bar):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs['loss']
        
        loss = loss / config.gradient_accumulation_steps
        loss.backward()
        
        total_loss += loss.item()
        num_batches += 1
        
        if (step + 1) % config.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / num_batches
        progress_bar.set_postfix({'loss': f'{avg_loss:.4f}'})
        
        if step % config.logging_steps == 0:
            print(f"Epoch {epoch+1}, Step {step}, Loss: {avg_loss:.4f}")
    
    return total_loss / num_batches

def evaluate(model, val_dataloader, id_to_label, device):
    """Evaluate on validation set."""
    if not val_dataloader:
        print("No validation dataset available")
        return {}
    
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(**batch)
            loss = outputs['loss']
            logits = outputs['logits']
            
            total_loss += loss.item()
            
            predictions = torch.argmax(logits, dim=-1)
            
            batch_labels = batch['labels'].cpu().numpy().flatten()
            batch_predictions = predictions.cpu().numpy().flatten()
            
            mask = batch_labels != -100
            all_labels.extend(batch_labels[mask])
            all_predictions.extend(batch_predictions[mask])
    
    avg_loss = total_loss / len(val_dataloader)
    f1 = f1_score(all_labels, all_predictions, average='weighted')
    
    target_names = [id_to_label[i] for i in range(len(id_to_label))]
    report = classification_report(
        all_labels, 
        all_predictions, 
        target_names=target_names,
        output_dict=True,
        zero_division=0
    )
    
    print(f"Validation Loss: {avg_loss:.4f}")
    print(f"Validation F1-Score: {f1:.4f}")
    
    return {
        'loss': avg_loss,
        'f1_score': f1,
        'classification_report': report
    }

def save_model(model, config, label_to_id, id_to_label, output_dir):
    """Save the trained model."""
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    model_state = {
        'classifier_state_dict': model.classifier.state_dict(),
        'config': config.__dict__,
        'label_to_id': label_to_id,
        'id_to_label': id_to_label,
        'num_labels': len(label_to_id),
        'model_name': config.model_name
    }
    
    torch.save(model_state, output_path / "pytorch_model.bin")
    
    with open(output_path / "training_config.json", 'w') as f:
        json.dump(config.__dict__, f, indent=2)
    
    print(f"💾 Model saved to {output_path}")

print("✅ Training functions defined")

# Training loop

In [None]:
print("Starting training...")

best_f1 = 0
output_dir = config.output_dir

for epoch in range(config.num_epochs):
    print(f"\nStarting epoch {epoch + 1}/{config.num_epochs}")
    
    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, epoch, config, device)
    print(f"Epoch {epoch + 1} - Training Loss: {train_loss:.4f}")
    
    eval_results = evaluate(model, val_dataloader, id_to_label, device)
    
    if eval_results and eval_results['f1_score'] > best_f1:
        best_f1 = eval_results['f1_score']
        save_model(model, config, label_to_id, id_to_label, output_dir)
        print(f"New best F1-Score: {best_f1:.4f} - Model saved")

print(f"\nTraining completed!")
print(f"Best F1-Score: {best_f1:.4f}")
print(f"Model saved to: {output_dir}")

# Evaluation and cleanup

In [None]:
# Final evaluation on validation set
if val_dataloader:
    print("\nFinal evaluation:")
    final_results = evaluate(model, val_dataloader, id_to_label, device)
    
    if final_results:
        print("\nClassification Report:")
        report = final_results['classification_report']
        for label, metrics in report.items():
            if isinstance(metrics, dict) and 'precision' in metrics:
                print(f"{label:15} - P: {metrics['precision']:.3f}, R: {metrics['recall']:.3f}, F1: {metrics['f1-score']:.3f}")

In [None]:
torch.cuda.empty_cache()
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print("\n🧹 Memory cleaned up")