# MRDA Dialogue Act Classification Pipeline

## Multi-Stage Training Approach:
1. **Stage 1:** Train 12-class General DA classifier 
2. **Stage 2:** Map to binary content/non-content classification

**Target Model Repository:** `wylupek/distilbert-mrda-dialogue-acts`

---


In [1]:
import torch
import platform
import psutil
from datasets import load_dataset
from collections import Counter
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from huggingface_hub import login, whoami

import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(f"Platform: {platform.system()} {platform.release()}")
print(f"Architecture: {platform.machine()}")
print(f"CPU Cores: {psutil.cpu_count()}")
print(f"RAM: {psutil.virtual_memory().total / (1024**3):.1f} GB")


def detect_device():
    """Detect best available device with fallback strategy"""
    if torch.cuda.is_available():
        device = "cuda"
        device_name = torch.cuda.get_device_name(0)
        memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        print(f"CUDA Device: {device_name}")
        print(f"GPU Memory: {memory_gb:.1f} GB")
    elif torch.backends.mps.is_available():
        device = "mps" 
        device_name = "Apple Silicon (MPS)"
        print(f"MPS Device: {device_name}")
        print(f"Unified Memory Available")
    else:
        device = "cpu"
        device_name = "CPU"
        print(f"CPU Device: {device_name}")
        print(f"Using CPU cores: {psutil.cpu_count()}")
    
    return device, device_name

device, device_name = detect_device()

# Test device
try:
    test_tensor = torch.randn(10, 10).to(device)
    result = torch.matmul(test_tensor, test_tensor.T)
    del test_tensor, result
except Exception as e:
    print(f"Device Test Failed: {e}")
    print("Falling back to CPU...")
    device = "cpu"

Platform: Linux 5.15.0-151-generic
Architecture: x86_64
CPU Cores: 24
RAM: 31.2 GB
CUDA Device: NVIDIA GeForce RTX 3060 Ti
GPU Memory: 7.8 GB


In [3]:

dataset = load_dataset("wylupek/mrda-corpus")

print(f"Dataset splits:")
for split_name, split_data in dataset.items():
    print(f"  {split_name}: {len(split_data):,} samples")
total_samples = sum(len(split) for split in dataset.values())
print(f"  Total: {total_samples:,} samples\n")

print(f"Sample data: {dataset['train'][0]}")


train_labels = [sample['general_da'] for sample in dataset['validation']]
unique_labels = list(set(train_labels))
unique_labels.sort()
print(f"Unique general_da labels: {len(unique_labels)}")
print(f"Labels: {unique_labels}\n")


label_counts = pd.Series(train_labels).value_counts().sort_index()
print(f"Label Distribution (Training Set):")
for label, count in label_counts.items():
    percentage = (count / len(train_labels)) * 100
    print(f"  {label}: {count:,} samples ({percentage:.1f}%)")

Dataset splits:
  train: 75,067 samples
  test: 16,702 samples
  validation: 16,433 samples
  Total: 108,202 samples

Sample data: {'speaker': 'fe016', 'text': 'okay.', 'basic_da': 'F', 'general_da': 'fg', 'full_da': 'fg'}
Unique general_da labels: 12
Labels: ['%', 'b', 'fg', 'fh', 'h', 'qh', 'qo', 'qr', 'qrr', 'qw', 'qy', 's']

Label Distribution (Training Set):
  %: 440 samples (2.7%)
  b: 2,342 samples (14.3%)
  fg: 527 samples (3.2%)
  fh: 1,225 samples (7.5%)
  h: 184 samples (1.1%)
  qh: 36 samples (0.2%)
  qo: 25 samples (0.2%)
  qr: 39 samples (0.2%)
  qrr: 73 samples (0.4%)
  qw: 287 samples (1.7%)
  qy: 806 samples (4.9%)
  s: 10,449 samples (63.6%)


In [4]:
CONTENT_LABELS = {
    's',      # Statement (65.2% - main content)
    'qy',     # Yes-No-question (4.4%)
    'qw',     # Wh-Question (1.5%)
    'qh',     # Rhetorical Question (0.3%)
    'qrr',    # Or-Clause (0.3%)
    'qr',     # Or Question (0.2%)
    'qo'      # Open-ended Question (0.2%)
}

NON_CONTENT_LABELS = {
    'b',      # Continuer (14.1% - backchannels)
    'fh',     # Floor Holder (7.5% - floor management)
    'fg',     # Floor Grabber (2.8% - floor management)
    '%',      # Interrupted/Abandoned (2.9% - disruptions)
    'h'       # Hold Before Answer (0.6% - hesitations)
}

def calculate_content_distribution(labels):
    """Calculate content vs non-content percentages"""
    content_count = sum(1 for label in labels if label in CONTENT_LABELS)
    non_content_count = sum(1 for label in labels if label in NON_CONTENT_LABELS)
    total = len(labels)
    
    content_pct = (content_count / total) * 100
    non_content_pct = (non_content_count / total) * 100
    
    return content_count, non_content_count, content_pct, non_content_pct

def map_to_binary(general_da_label):
    """Map general DA label to binary content/non-content"""
    if general_da_label in CONTENT_LABELS:
        return 1  # Content
    elif general_da_label in NON_CONTENT_LABELS:
        return 0  # Non-content
    else:
        raise ValueError(f"Unknown label: {general_da_label}")

def map_to_text(general_da_label):
    """Map general DA label to text description"""
    if general_da_label in CONTENT_LABELS:
        return "content"
    elif general_da_label in NON_CONTENT_LABELS:
        return "non-content"
    else:
        raise ValueError(f"Unknown label: {general_da_label}")


for split_name in ['train', 'validation', 'test']:
    split_labels = [sample['general_da'] for sample in dataset[split_name]]
    content_count, non_content_count, content_pct, non_content_pct = calculate_content_distribution(split_labels)
    
    print(f"{split_name.capitalize()} split:")
    print(f"  Content: {content_count:,} samples ({content_pct:.1f}%)")
    print(f"  Non-content: {non_content_count:,} samples ({non_content_pct:.1f}%)")

Train split:
  Content: 54,123 samples (72.1%)
  Non-content: 20,944 samples (27.9%)
Validation split:
  Content: 11,715 samples (71.3%)
  Non-content: 4,718 samples (28.7%)
Test split:
  Content: 11,848 samples (70.9%)
  Non-content: 4,854 samples (29.1%)


In [5]:
MODEL_NAME = "distilbert-base-uncased"
NUM_LABELS = 12

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    problem_type="single_label_classification"
)

model = model.to(device)

# Test tokenization
sample_text = "okay so um i was going to try to get out of here"
encoded = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True)
print(f"\nSample text: '{sample_text}'")
print(f"Tokenized shape: {encoded['input_ids'].shape}")
print(f"Tokens: {encoded['input_ids'][0].tolist()}\n")


encoded = encoded.to(device)
with torch.no_grad():
    outputs = model(**encoded)
    logits = outputs.logits
    predictions = torch.softmax(logits, dim=-1)
    
print(f"Logits shape: {logits.shape}")
print(f"Predictions: {predictions.tolist()[0]}")

del encoded, outputs, logits, predictions

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Sample text: 'okay so um i was going to try to get out of here'
Tokenized shape: torch.Size([1, 15])
Tokens: [101, 3100, 2061, 8529, 1045, 2001, 2183, 2000, 3046, 2000, 2131, 2041, 1997, 2182, 102]

Logits shape: torch.Size([1, 12])
Predictions: [0.07466944307088852, 0.09711909294128418, 0.08354572206735611, 0.08508571237325668, 0.08700844645500183, 0.0856059193611145, 0.07023141533136368, 0.0793866217136383, 0.07778478413820267, 0.09030162543058395, 0.08108097314834595, 0.08818024396896362]


In [6]:
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print("Label mapping:")
for label, idx in label2id.items():
    print(f"  {label} -> {idx}")

def preprocess_function(examples):
    """Tokenize text and encode labels"""
    tokens = tokenizer(
        examples["text"],
        truncation=True,
        padding=False,  # Will pad later in DataCollator
        max_length=128,  # Actual max length is 96
        return_tensors=None
    )
    tokens["labels"] = [label2id[label] for label in examples["general_da"]]
    return tokens

# Apply preprocessing to all splits
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing"
)

# Verify no data loss
if len(dataset["train"]) != len(tokenized_datasets["train"]):
    print(f"Not all samples were processed\n")
    print(f"Original train size: {len(dataset['train']):,}")
    print(f"Processed train size: {len(tokenized_datasets['train']):,}")
if len(dataset["validation"]) != len(tokenized_datasets["validation"]):
    print(f"Not all samples were processed\n")
    print(f"Original validation size: {len(dataset['validation']):,}")
    print(f"Processed validation size: {len(tokenized_datasets['validation']):,}")
if len(dataset["test"]) != len(tokenized_datasets["test"]):
    print(f"Not all samples were processed\n")
    print(f"Original test size: {len(dataset['test']):,}")
    print(f"Processed test size: {len(tokenized_datasets['test']):,}")

# Check max length
print("\nMax length:", max(max(len(s["input_ids"]) for s in tokenized_datasets[split]) for split in ["train","validation","test"]))

# Create data collator for batching
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Label mapping:
  % -> 0
  b -> 1
  fg -> 2
  fh -> 3
  h -> 4
  qh -> 5
  qo -> 6
  qr -> 7
  qrr -> 8
  qw -> 9
  qy -> 10
  s -> 11

Max length: 96


In [7]:
MAX_WEIGHT = 10
MIN_WEIGHT = 0.3

raw_train_labels = [sample["labels"] for sample in tokenized_datasets["train"]]
total_train_examples = len(raw_train_labels)
train_label_counts = dict(sorted(Counter(raw_train_labels).items(), key=lambda x: x[0]))

# Compute weights
class_weights = [1 / x for x in train_label_counts.values()]
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# Scale weights to MAX_WEIGHT and MIN_WEIGHT
curr_max_weight = class_weights_tensor.max()
curr_min_weight = class_weights_tensor.min()
class_weights_tensor = (class_weights_tensor - curr_min_weight) / (curr_max_weight - curr_min_weight) * (MAX_WEIGHT - MIN_WEIGHT) + MIN_WEIGHT

print(f"Class weights (balanced method):")
for class_id in range(NUM_LABELS):
    print(f"  {class_id} ({id2label[class_id]}): {class_weights_tensor[class_id].item() :.3f} (count: {train_label_counts[class_id]})")

Class weights (balanced method):
  0 (%): 0.796 (count: 2171)
  1 (b): 0.383 (count: 10606)
  2 (fg): 0.820 (count: 2076)
  3 (fh): 0.478 (count: 5617)
  4 (h): 2.656 (count: 474)
  5 (qh): 4.615 (count: 260)
  6 (qo): 10.000 (count: 116)
  7 (qr): 8.887 (count: 131)
  8 (qrr): 4.899 (count: 244)
  9 (qw): 1.293 (count: 1110)
  10 (qy): 0.618 (count: 3310)
  11 (s): 0.300 (count: 48952)


In [8]:
# Focal loss with balanced cross entropy
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Custom Weighted Trainer
class AdvancedTrainer(Trainer):
    def __init__(self, loss_type="focal", class_weights=None, focal_gamma=2.0, loss_reduction="mean", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = loss_type
        self.class_weights = class_weights
        
        if loss_type == "focal":
            self.loss_fn = FocalLoss(alpha=class_weights, gamma=focal_gamma, reduction=loss_reduction)
        elif loss_type == "weighted_ce":
            self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    macro_f1 = f1_score(labels, predictions, average='macro')
    micro_f1 = f1_score(labels, predictions, average='micro')
    weighted_f1 = f1_score(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)
    
    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "weighted_f1": weighted_f1
    }


In [9]:
# TODO the trainer isn't using stratified sampling
# 4. Stratified Batch Sampling
def create_stratified_sampler(dataset):
    # Get labels from dataset
    labels = [sample["labels"] for sample in dataset]
    
    # Calculate sample weights for stratified sampling - handle non-contiguous labels
    unique_labels_in_data = sorted(set(labels))
    class_sample_count = {label: labels.count(label) for label in unique_labels_in_data}
    
    # Create weight mapping for each class
    weight_mapping = {}
    for label, count in class_sample_count.items():
        weight_mapping[label] = 1.0 / count
    
    # Assign weight to each sample
    samples_weight = torch.tensor([weight_mapping[label] for label in labels])
    print(weight_mapping)
    # Create sampler
    sampler = WeightedRandomSampler(
        weights=samples_weight,
        num_samples=len(samples_weight),
        replacement=True
    )
    return sampler

# Create stratified sampler
train_sampler = create_stratified_sampler(tokenized_datasets["train"])

{0: 0.00046061722708429296, 1: 9.428625306430322e-05, 2: 0.0004816955684007707, 3: 0.00017803097739006588, 4: 0.002109704641350211, 5: 0.0038461538461538464, 6: 0.008620689655172414, 7: 0.007633587786259542, 8: 0.004098360655737705, 9: 0.0009009009009009009, 10: 0.00030211480362537764, 11: 2.0428174538323256e-05}


### Parameters tuning
**LoRA**
- lora_dropout – Regular dropout regularization, prevents overfitting. (0.0; 0.3)
- lora_r – The size of LoRA adapters, means how much new information the model can learn. High values cause better learning capacity, especially usefull for imbalanced data  (8; 64)
- lora_alpha – How strong the LoRA adapters influence the original model. Affects effective scaling (`lora_alpha/lora_r`) (scale between 0.5 and 8)

**Traning arguments**

NOTE: Step is calculated by `ceil(total_samples / batch_size)`

Proper Warmup Guidelines:
- Simple tasks: 5-10% of total steps
- Complex tasks: 10-15% of total steps
- Imbalanced/Difficult: 15-20% of total steps


In [None]:
LORA_DROPOUT = 0.15
LORA_R = 64
LORA_ALPHA = 128

LOSS_TYPE = "focal"
NO_EPOCHS = 6
BATCH_SIZE = 32
LEARNING_RATE = 0.0001
WARMUP_STEPS = 0.15
WEIGHT_DECAY = 0.04

FOCAL_GAMMA=2.0
LOSS_REDUCTION="mean"

### LoRA setup ###
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    target_modules=["q_lin", "v_lin", "k_lin", "out_lin"],
    modules_to_save=["pre_classifier", "classifier"],  # Keep classification head trainable!
    ### TUNABLE PARAMETERS ###
    lora_dropout=LORA_DROPOUT,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
)
advanced_peft_model = get_peft_model(model, lora_config)

# Check which parameters are trainable
print("\n=== Trainable Parameters Check ===")
for name, param in advanced_peft_model.named_parameters():
    if param.requires_grad:
        print(f"✅ {name}: {param.shape}")
    elif "classifier" in name or "pre_classifier" in name:
        print(f"❌ FROZEN: {name}: {param.shape}")

# Print summary
print(f"\nTotal trainable parameters: {sum(p.numel() for p in advanced_peft_model.parameters() if p.requires_grad):,}")
print(f"Classifier head status: {'✅ TRAINABLE' if any('classifier' in name for name, p in advanced_peft_model.named_parameters() if p.requires_grad) else '❌ FROZEN'}")

### Training setup ###
total_steps = NO_EPOCHS * np.ceil(len(tokenized_datasets["train"]) / BATCH_SIZE)
eval_steps = total_steps // NO_EPOCHS // 6 # 6 times per epoch
save_steps = eval_steps * 3 # 3 times per epoch
logging_steps = total_steps // NO_EPOCHS // 12 # 12 times per epoch
advanced_training_args = TrainingArguments(
    output_dir="./advanced_checkpoints",
    save_strategy="steps",
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",  # Optimize for macro F1, perfect for imbalanced data
    greater_is_better=True,
    report_to=None,
    remove_unused_columns=False, # Required for custom trainer
    dataloader_num_workers=0,  # Important for MPS compatibility
    eval_steps=eval_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    ### TUNABLE PARAMETERS ###
    learning_rate=LEARNING_RATE,
    num_train_epochs=NO_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=int(WARMUP_STEPS * total_steps), # Prevents overfitting, should be ~15% of total steps for imbalanced data
    weight_decay=WEIGHT_DECAY,
)
advanced_trainer = AdvancedTrainer(
    loss_type=LOSS_TYPE,  # Focal loss for imbalanced data
    class_weights=class_weights_tensor, # Weights for imbalanced data
    model=advanced_peft_model,
    args=advanced_training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    ### TUNABLE PARAMETERS ###
    focal_gamma=FOCAL_GAMMA,
    loss_reduction=LOSS_REDUCTION,
)

print(f"Total steps: {int(total_steps)}")
print(f"Steps per epoch: {int(total_steps // NO_EPOCHS)}")
print(f"Steps per eval: {int(eval_steps)}")
print(f"Steps per logging: {int(logging_steps)}")
print(f"Steps per saving: {int(save_steps)}")

Total steps: 14076
Steps per epoch: 2346
Steps per eval: 391
Steps per logging: 195
Steps per saving: 1173


In [11]:
hyperparams = {
    "max_weight": MAX_WEIGHT,
    "min_weight": MIN_WEIGHT,

    "lora_dropout": LORA_DROPOUT,
    "lora_r": LORA_R,
    "lora_alpha": LORA_ALPHA, 

    "loss_type": LOSS_TYPE,
    "no_epochs": NO_EPOCHS,
    "batch_size": BATCH_SIZE,
    "learning_rate": LEARNING_RATE,
    "warmup_steps": WARMUP_STEPS,
    "weight_decay": WEIGHT_DECAY,

    "focal_gamma": FOCAL_GAMMA,
    "loss_reduction": LOSS_REDUCTION,
}

print(f"\nStarting advanced training with:")
print(hyperparams)

# Run advanced training
advanced_result = advanced_trainer.train()

print(f"\nAdvanced training completed!")
print(f"Final train loss: {advanced_result.training_loss:.4f}")


Starting advanced training with:
{'max_weight': 10, 'min_weight': 0.3, 'lora_dropout': 0.15, 'lora_r': 64, 'lora_alpha': 128, 'loss_type': 'focal', 'no_epochs': 6, 'batch_size': 32, 'learning_rate': 0.0001, 'warmup_steps': 0.15, 'weight_decay': 0.04, 'max_grad_norm': 1.0, 'focal_gamma': 2.0, 'loss_reduction': 'mean'}


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Micro F1,Weighted F1
391,0.4485,0.456442,0.775817,0.248688,0.775817,0.741483
782,0.3037,0.305745,0.755979,0.416966,0.755979,0.763666
1173,0.1991,0.248026,0.759021,0.437196,0.759021,0.772824
1564,0.2218,0.203688,0.757744,0.476281,0.757744,0.772239
1955,0.1979,0.191236,0.769367,0.519548,0.769367,0.781363
2346,0.1917,0.19669,0.716911,0.480123,0.716911,0.743466
2737,0.1824,0.19328,0.750259,0.523584,0.750259,0.770414
3128,0.1761,0.202966,0.756587,0.51794,0.756587,0.77564
3519,0.1947,0.184291,0.776304,0.554355,0.776304,0.789998
3910,0.1655,0.19257,0.731394,0.52742,0.731394,0.758007



Advanced training completed!
Final train loss: 0.1602


In [18]:
# =============================================================================
# EXPERIMENT LOGGING & COMPREHENSIVE EVALUATION
# =============================================================================

import json
import os
from datetime import datetime


def setup_experiment_logging(experiment_name, hyperparams):
    """Setup timestamped logging directory for experiments"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results/{experiment_name}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    # Save hyperparameters for reproducibility
    with open(f"{results_dir}/hyperparams.json", 'w') as f:
        json.dump(hyperparams, f, indent=2)
    
    return results_dir


def comprehensive_evaluation(trainer, tokenized_datasets, id2label, results_dir):
    """
    Complete evaluation with:
    - Overall metrics (accuracy, macro/weighted F1)
    - Per-class metrics and confusion matrix
    - Training progress table with all checkpoints
    - Detailed 12-class analysis with TP/FP/FN/TN breakdown
    """
    
    # ===== GET PREDICTIONS =====
    predictions = trainer.predict(tokenized_datasets["validation"])
    y_true = predictions.label_ids
    y_pred = np.argmax(predictions.predictions, axis=1)
    
    # ===== OVERALL METRICS =====
    metrics = {
        'accuracy': round(accuracy_score(y_true, y_pred), 4),
        'macro_f1': round(f1_score(y_true, y_pred, average='macro'), 4),
        'weighted_f1': round(f1_score(y_true, y_pred, average='weighted'), 4),
    }
    
    # ===== PER-CLASS METRICS =====
    class_report = classification_report(
        y_true, y_pred,
        target_names=[id2label[i] for i in range(len(id2label))],
        output_dict=True
    )
    
    def round_nested_dict(d, decimals=4):
        """Recursively round all numeric values in nested dictionary"""
        if isinstance(d, dict):
            return {k: round_nested_dict(v, decimals) for k, v in d.items()}
        elif isinstance(d, (int, float)):
            return round(d, decimals) if isinstance(d, float) else d
        else:
            return d
    
    class_report = round_nested_dict(class_report)
    
    # ===== CONFUSION MATRIX =====
    cm = confusion_matrix(y_true, y_pred)
    
    # ===== TRAINING PROGRESS EXTRACTION =====
    training_progress = []
    if hasattr(trainer, 'state') and hasattr(trainer.state, 'log_history'):
        # Separate evaluation and training logs
        eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
        train_logs = [log for log in trainer.state.log_history if 'loss' in log and 'eval_loss' not in log]
        
        # Map training steps to training loss
        train_loss_map = {}
        for log in train_logs:
            step = log.get('step', 0)
            if step > 0:
                train_loss_map[step] = log.get('loss', 0)
        
        # Match evaluation checkpoints with training loss
        for log in eval_logs:
            eval_step = log.get('step', 0)
            
            # Find closest training loss (≤ eval_step)
            training_loss = 0
            if train_loss_map:
                valid_train_steps = [s for s in train_loss_map.keys() if s <= eval_step]
                if valid_train_steps:
                    closest_train_step = max(valid_train_steps)
                    training_loss = train_loss_map[closest_train_step]
            
            progress_entry = {
                'step': eval_step,
                'training_loss': round(training_loss, 4),
                'validation_loss': round(log.get('eval_loss', 0), 4),
                'accuracy': round(log.get('eval_accuracy', 0), 4),
                'macro_f1': round(log.get('eval_macro_f1', 0), 4),
                'micro_f1': round(log.get('eval_micro_f1', 0), 4),
                'weighted_f1': round(log.get('eval_weighted_f1', 0), 4)
            }
            training_progress.append(progress_entry)
    
    # ===== DETAILED 12-CLASS ANALYSIS =====
    print("\n" + "="*80)
    print("PER-CLASS CONFUSION ANALYSIS (12 TABLES)")
    print("="*80)
    
    detailed_analysis = {}
    for class_id in range(len(id2label)):
        class_name = id2label[class_id]
        
        # Calculate TP/FP/FN/TN for this class
        tp = cm[class_id, class_id]
        fp = cm[:, class_id].sum() - tp
        fn = cm[class_id, :].sum() - tp
        tn = cm.sum() - tp - fp - fn
        
        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        # Calculate distribution percentages
        true_count = (y_true == class_id).sum()
        pred_count = (y_pred == class_id).sum()
        true_pct = (true_count / len(y_true)) * 100
        pred_pct = (pred_count / len(y_pred)) * 100
        
        # Store detailed analysis
        detailed_analysis[class_name] = {
            'tp': int(tp), 'fp': int(fp), 'fn': int(fn), 'tn': int(tn),
            'precision': round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4),
            'true_pct': round(true_pct, 4), 'pred_pct': round(pred_pct, 4), 
            'diff_pct': round(pred_pct - true_pct, 4)
        }
        
        # Print formatted class analysis
        print(f"\n📊 CLASS {class_id}: '{class_name}' Analysis")
        print("-" * 50)
        print(f"  Confusion:    TP={tp:4d} | FP={fp:4d}")
        print(f"                FN={fn:4d} | TN={tn:4d}")
        print(f"  Metrics:      Prec={precision:.3f} | Rec={recall:.3f} | F1={f1:.3f}")
        print(f"  Distribution: True={true_pct:5.1f}% | Pred={pred_pct:5.1f}% | Diff={pred_pct-true_pct:+5.1f}%")
    
    # ===== CLASS DISTRIBUTION ANALYSIS =====
    true_dist = pd.Series(y_true).value_counts().sort_index()
    pred_dist = pd.Series(y_pred).value_counts().sort_index()
    
    # ===== SAVE COMPREHENSIVE RESULTS =====
    results = {
        'overall_metrics': metrics,
        'per_class_metrics': class_report,
        'detailed_class_analysis': detailed_analysis,
        'training_progress': training_progress,
        'confusion_matrix': cm.tolist(),
        'true_distribution': true_dist.to_dict(),
        'pred_distribution': pred_dist.to_dict()
    }
    
    with open(f"{results_dir}/evaluation.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # ===== SUMMARY OUTPUT =====
    print(f"\n🎯 OVERALL: Acc={metrics['accuracy']:.3f} | MacroF1={metrics['macro_f1']:.3f} | WeightedF1={metrics['weighted_f1']:.3f}")
    print(f"💾 Training progress saved with {len(training_progress)} checkpoints")
    
    return results


# =============================================================================
# RUN COMPREHENSIVE EVALUATION
# =============================================================================

results_dir = setup_experiment_logging("no_1", hyperparams)
results = comprehensive_evaluation(advanced_trainer, tokenized_datasets, id2label, results_dir)



PER-CLASS CONFUSION ANALYSIS (12 TABLES)

📊 CLASS 0: '%' Analysis
--------------------------------------------------
  Confusion:    TP= 343 | FP= 562
                FN=  97 | TN=15431
  Metrics:      Prec=0.379 | Rec=0.780 | F1=0.510
  Distribution: True=  2.7% | Pred=  5.5% | Diff= +2.8%

📊 CLASS 1: 'b' Analysis
--------------------------------------------------
  Confusion:    TP=2094 | FP=1097
                FN= 248 | TN=12994
  Metrics:      Prec=0.656 | Rec=0.894 | F1=0.757
  Distribution: True= 14.3% | Pred= 19.4% | Diff= +5.2%

📊 CLASS 2: 'fg' Analysis
--------------------------------------------------
  Confusion:    TP= 146 | FP= 506
                FN= 381 | TN=15400
  Metrics:      Prec=0.224 | Rec=0.277 | F1=0.248
  Distribution: True=  3.2% | Pred=  4.0% | Diff= +0.8%

📊 CLASS 3: 'fh' Analysis
--------------------------------------------------
  Confusion:    TP= 660 | FP= 413
                FN= 565 | TN=14795
  Metrics:      Prec=0.615 | Rec=0.539 | F1=0.574
  Distri