# MRDA Dialogue Act Classification Pipeline

## Multi-Stage Training Approach:
1. **Stage 1:** Train 12-class General DA classifier 
2. **Stage 2:** Map to binary content/non-content classification

**Target Model Repository:** `wylupek/distilbert-mrda-dialogue-acts`

---


In [1]:
import torch
import platform
import psutil
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from huggingface_hub import login, whoami

import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(f"Platform: {platform.system()} {platform.release()}")
print(f"Architecture: {platform.machine()}")
print(f"CPU Cores: {psutil.cpu_count()}")
print(f"RAM: {psutil.virtual_memory().total / (1024**3):.1f} GB")


def detect_device():
    """Detect best available device with fallback strategy"""
    if torch.cuda.is_available():
        device = "cuda"
        device_name = torch.cuda.get_device_name(0)
        memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        print(f"CUDA Device: {device_name}")
        print(f"GPU Memory: {memory_gb:.1f} GB")
    elif torch.backends.mps.is_available():
        device = "mps" 
        device_name = "Apple Silicon (MPS)"
        print(f"MPS Device: {device_name}")
        print(f"Unified Memory Available")
    else:
        device = "cpu"
        device_name = "CPU"
        print(f"CPU Device: {device_name}")
        print(f"Using CPU cores: {psutil.cpu_count()}")
    
    return device, device_name

device, device_name = detect_device()

# Test device
try:
    test_tensor = torch.randn(10, 10).to(device)
    result = torch.matmul(test_tensor, test_tensor.T)
    del test_tensor, result
except Exception as e:
    print(f"Device Test Failed: {e}")
    print("Falling back to CPU...")
    device = "cpu"

Platform: Linux 5.15.0-151-generic
Architecture: x86_64
CPU Cores: 24
RAM: 31.2 GB
CUDA Device: NVIDIA GeForce RTX 3060 Ti
GPU Memory: 7.8 GB


In [13]:

dataset = load_dataset("wylupek/mrda-corpus")

print(f"Dataset splits:")
for split_name, split_data in dataset.items():
    print(f"  {split_name}: {len(split_data):,} samples")
total_samples = sum(len(split) for split in dataset.values())
print(f"  Total: {total_samples:,} samples\n")

print(f"Sample data: {dataset['train'][0]}")


train_labels = [sample['general_da'] for sample in dataset['validation']]
unique_labels = list(set(train_labels))
unique_labels.sort()
print(f"Unique general_da labels: {len(unique_labels)}")
print(f"Labels: {unique_labels}\n")


label_counts = pd.Series(train_labels).value_counts().sort_index()
print(f"Label Distribution (Training Set):")
for label, count in label_counts.items():
    percentage = (count / len(train_labels)) * 100
    print(f"  {label}: {count:,} samples ({percentage:.1f}%)")

Dataset splits:
  train: 75,067 samples
  test: 16,702 samples
  validation: 16,433 samples
  Total: 108,202 samples

Sample data: {'speaker': 'fe016', 'text': 'okay.', 'basic_da': 'F', 'general_da': 'fg', 'full_da': 'fg'}
Unique general_da labels: 12
Labels: ['%', 'b', 'fg', 'fh', 'h', 'qh', 'qo', 'qr', 'qrr', 'qw', 'qy', 's']

Label Distribution (Training Set):
  %: 440 samples (2.7%)
  b: 2,342 samples (14.3%)
  fg: 527 samples (3.2%)
  fh: 1,225 samples (7.5%)
  h: 184 samples (1.1%)
  qh: 36 samples (0.2%)
  qo: 25 samples (0.2%)
  qr: 39 samples (0.2%)
  qrr: 73 samples (0.4%)
  qw: 287 samples (1.7%)
  qy: 806 samples (4.9%)
  s: 10,449 samples (63.6%)


In [4]:
CONTENT_LABELS = {
    's',      # Statement (65.2% - main content)
    'qy',     # Yes-No-question (4.4%)
    'qw',     # Wh-Question (1.5%)
    'qh',     # Rhetorical Question (0.3%)
    'qrr',    # Or-Clause (0.3%)
    'qr',     # Or Question (0.2%)
    'qo'      # Open-ended Question (0.2%)
}

NON_CONTENT_LABELS = {
    'b',      # Continuer (14.1% - backchannels)
    'fh',     # Floor Holder (7.5% - floor management)
    'fg',     # Floor Grabber (2.8% - floor management)
    '%',      # Interrupted/Abandoned (2.9% - disruptions)
    'h'       # Hold Before Answer (0.6% - hesitations)
}

def calculate_content_distribution(labels):
    """Calculate content vs non-content percentages"""
    content_count = sum(1 for label in labels if label in CONTENT_LABELS)
    non_content_count = sum(1 for label in labels if label in NON_CONTENT_LABELS)
    total = len(labels)
    
    content_pct = (content_count / total) * 100
    non_content_pct = (non_content_count / total) * 100
    
    return content_count, non_content_count, content_pct, non_content_pct

def map_to_binary(general_da_label):
    """Map general DA label to binary content/non-content"""
    if general_da_label in CONTENT_LABELS:
        return 1  # Content
    elif general_da_label in NON_CONTENT_LABELS:
        return 0  # Non-content
    else:
        raise ValueError(f"Unknown label: {general_da_label}")

def map_to_text(general_da_label):
    """Map general DA label to text description"""
    if general_da_label in CONTENT_LABELS:
        return "content"
    elif general_da_label in NON_CONTENT_LABELS:
        return "non-content"
    else:
        raise ValueError(f"Unknown label: {general_da_label}")


for split_name in ['train', 'validation', 'test']:
    split_labels = [sample['general_da'] for sample in dataset[split_name]]
    content_count, non_content_count, content_pct, non_content_pct = calculate_content_distribution(split_labels)
    
    print(f"{split_name.capitalize()} split:")
    print(f"  Content: {content_count:,} samples ({content_pct:.1f}%)")
    print(f"  Non-content: {non_content_count:,} samples ({non_content_pct:.1f}%)")

Train split:
  Content: 54,123 samples (72.1%)
  Non-content: 20,944 samples (27.9%)
Validation split:
  Content: 11,715 samples (71.3%)
  Non-content: 4,718 samples (28.7%)
Test split:
  Content: 11,848 samples (70.9%)
  Non-content: 4,854 samples (29.1%)


In [5]:
model_name = "distilbert-base-uncased"
num_labels = 12

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type="single_label_classification"
)

model = model.to(device)

# Test tokenization
sample_text = "okay so um i was going to try to get out of here"
encoded = tokenizer(sample_text, return_tensors="pt", padding=True, truncation=True)
print(f"\nSample text: '{sample_text}'")
print(f"Tokenized shape: {encoded['input_ids'].shape}")
print(f"Tokens: {encoded['input_ids'][0].tolist()}\n")


encoded = encoded.to(device)
with torch.no_grad():
    outputs = model(**encoded)
    logits = outputs.logits
    predictions = torch.softmax(logits, dim=-1)
    
print(f"Logits shape: {logits.shape}")
print(f"Predictions: {predictions.tolist()[0]}")

del encoded, outputs, logits, predictions

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Sample text: 'okay so um i was going to try to get out of here'
Tokenized shape: torch.Size([1, 15])
Tokens: [101, 3100, 2061, 8529, 1045, 2001, 2183, 2000, 3046, 2000, 2131, 2041, 1997, 2182, 102]

Logits shape: torch.Size([1, 12])
Predictions: [0.08611761033535004, 0.07528205960988998, 0.09400619566440582, 0.07494442164897919, 0.07367940247058868, 0.08856073021888733, 0.08872058242559433, 0.0765603557229042, 0.09287962317466736, 0.08037149161100388, 0.07952561974525452, 0.08935189247131348]


In [6]:
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print("Label mapping:")
for label, idx in label2id.items():
    print(f"  {label} -> {idx}")

def preprocess_function(examples):
    """Tokenize text and encode labels"""
    tokens = tokenizer(
        examples["text"],
        truncation=True,
        padding=False,  # Will pad later in DataCollator
        max_length=128,  # Actual max length is 96
        return_tensors=None
    )
    tokens["labels"] = [label2id[label] for label in examples["general_da"]]
    return tokens

# Apply preprocessing to all splits
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing"
)

# Verify no data loss
if len(dataset["train"]) != len(tokenized_datasets["train"]):
    print(f"Not all samples were processed\n")
    print(f"Original train size: {len(dataset['train']):,}")
    print(f"Processed train size: {len(tokenized_datasets['train']):,}")
if len(dataset["validation"]) != len(tokenized_datasets["validation"]):
    print(f"Not all samples were processed\n")
    print(f"Original validation size: {len(dataset['validation']):,}")
    print(f"Processed validation size: {len(tokenized_datasets['validation']):,}")
if len(dataset["test"]) != len(tokenized_datasets["test"]):
    print(f"Not all samples were processed\n")
    print(f"Original test size: {len(dataset['test']):,}")
    print(f"Processed test size: {len(tokenized_datasets['test']):,}")

# Check max length
print("\nMax length:", max(max(len(s["input_ids"]) for s in tokenized_datasets[split]) for split in ["train","validation","test"]))

# Create data collator for batching
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Label mapping:
  % -> 0
  b -> 1
  fg -> 2
  fh -> 3
  h -> 4
  qh -> 5
  qo -> 6
  qr -> 7
  qrr -> 8
  qw -> 9
  qy -> 10
  s -> 11


Tokenizing: 100%|██████████| 75067/75067 [00:00<00:00, 88378.53 examples/s]
Tokenizing: 100%|██████████| 16702/16702 [00:00<00:00, 65932.55 examples/s]
Tokenizing: 100%|██████████| 16433/16433 [00:00<00:00, 131969.18 examples/s]



Max length: 96


In [7]:
# 1. Calculate class weights (sklearn balanced method)
train_labels = [sample["labels"] for sample in tokenized_datasets["train"]]
unique_classes_in_subset = sorted(set(train_labels))

# Compute weights only for classes present in subset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array(unique_classes_in_subset),
    y=np.array(train_labels)
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# max weight = 50, min weight = 0.5
class_weights_tensor[class_weights_tensor > 50] = 50
class_weights_tensor[class_weights_tensor < 0.1] = 0.1

print(f"Class weights (balanced method):")
for class_id in unique_classes_in_subset:
    weight = class_weights_tensor[class_id].item()
    label_name = id2label[class_id]
    count = train_labels.count(class_id)
    print(f"  {class_id} ({label_name}): {weight:.3f} (count: {count})")

Class weights (balanced method):
  0 (%): 2.881 (count: 2171)
  1 (b): 0.590 (count: 10606)
  2 (fg): 3.013 (count: 2076)
  3 (fh): 1.114 (count: 5617)
  4 (h): 13.197 (count: 474)
  5 (qh): 24.060 (count: 260)
  6 (qo): 50.000 (count: 116)
  7 (qr): 47.753 (count: 131)
  8 (qrr): 25.638 (count: 244)
  9 (qw): 5.636 (count: 1110)
  10 (qy): 1.890 (count: 3310)
  11 (s): 0.128 (count: 48952)


In [8]:
# Focal loss with balanced cross entropy
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Custom Weighted Trainer
class AdvancedTrainer(Trainer):
    def __init__(self, loss_type="focal", class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = loss_type
        self.class_weights = class_weights
        
        if loss_type == "focal":
            self.loss_fn = FocalLoss(alpha=class_weights, gamma=2.0)
        elif loss_type == "weighted_ce":
            self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# Metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    macro_f1 = f1_score(labels, predictions, average='macro')
    micro_f1 = f1_score(labels, predictions, average='micro')
    weighted_f1 = f1_score(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)
    
    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "weighted_f1": weighted_f1
    }


In [9]:
# TODO the trainer isn't using stratified sampling
# 4. Stratified Batch Sampling
def create_stratified_sampler(dataset):
    # Get labels from dataset
    labels = [sample["labels"] for sample in dataset]
    
    # Calculate sample weights for stratified sampling - handle non-contiguous labels
    unique_labels_in_data = sorted(set(labels))
    class_sample_count = {label: labels.count(label) for label in unique_labels_in_data}
    
    # Create weight mapping for each class
    weight_mapping = {}
    for label, count in class_sample_count.items():
        weight_mapping[label] = 1.0 / count
    
    # Assign weight to each sample
    samples_weight = torch.tensor([weight_mapping[label] for label in labels])
    print(weight_mapping)
    # Create sampler
    sampler = WeightedRandomSampler(
        weights=samples_weight,
        num_samples=len(samples_weight),
        replacement=True
    )
    return sampler

# Create stratified sampler
train_sampler = create_stratified_sampler(tokenized_datasets["train"])

{0: 0.00046061722708429296, 1: 9.428625306430322e-05, 2: 0.0004816955684007707, 3: 0.00017803097739006588, 4: 0.002109704641350211, 5: 0.0038461538461538464, 6: 0.008620689655172414, 7: 0.007633587786259542, 8: 0.004098360655737705, 9: 0.0009009009009009009, 10: 0.00030211480362537764, 11: 2.0428174538323256e-05}


### Parameters tuning
**LoRA**
- lora_dropout – Regular dropout regularization, prevents overfitting. (0.0; 0.3)
- lora_r – The size of LoRA adapters, means how much new information the model can learn. High values cause better learning capacity, especially usefull for imbalanced data  (8; 64)
- lora_alpha – How strong the LoRA adapters influence the original model. Affects effective scaling (`lora_alpha/lora_r`) (scale between 0.5 and 8)

**Traning arguments**

NOTE: Step is calculated by `ceil(total_samples / batch_size)`

Proper Warmup Guidelines:
- Simple tasks: 5-10% of total steps
- Complex tasks: 10-15% of total steps
- Imbalanced/Difficult: 15-20% of total steps


In [None]:
### LoRA setup ###
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    target_modules=["q_lin", "v_lin", "k_lin", "out_lin"],
    ### TUNABLE PARAMETERS ###
    lora_dropout=0.15,
    r=32,
    lora_alpha=64,
)
advanced_peft_model = get_peft_model(model, lora_config)

### Training setup ###
no_epochs = 10
batch_size = 24
total_steps = no_epochs * np.ceil(len(tokenized_datasets["train"]) / batch_size)
advanced_training_args = TrainingArguments(
    output_dir="./advanced_checkpoints",
    save_strategy="steps",
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",  # Optimize for macro F1, perfect for imbalanced data
    greater_is_better=True,
    report_to=None,
    remove_unused_columns=False, # Required for custom trainer
    dataloader_num_workers=0,  # Important for MPS compatibility
    ### TUNABLE PARAMETERS ###
    learning_rate=0.0002,
    num_train_epochs=no_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=int(0.15 * total_steps), # Prevents overfitting, should be ~15% of total steps for imbalanced data
    save_steps=total_steps // no_epochs // 2, # 2 times per epoch
    eval_steps=total_steps // no_epochs // 8, # 8 times per epoch
    logging_steps=total_steps // no_epochs // 10, # 10 times per epoch
)
advanced_trainer = AdvancedTrainer(
    loss_type="focal",  # Focal loss for imbalanced data
    class_weights=class_weights_tensor, # Weights for imbalanced data
    model=advanced_peft_model,
    args=advanced_training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print(f"Total steps: {total_steps}")
print(f"Steps per epoch: {total_steps / no_epochs}")
print(f"Steps per validation: {total_steps / no_epochs / 8}")
print(f"Steps per logging: {total_steps / no_epochs / 10}")
print(f"Steps per saving: {total_steps / no_epochs / 2}")

In [11]:
print(f"\nStarting advanced training with:")
print(f"  Loss: Focal Loss (gamma=2.0, class weights)")
print(f"  Evaluation: Macro F1, Micro F1, Weighted F1, Accuracy")
print(f"  Model: DistilBERT + LoRA + Advanced Loss")

# Run advanced training
advanced_result = advanced_trainer.train()

print(f"\nAdvanced training completed!")
print(f"Final train loss: {advanced_result.training_loss:.4f}")


Starting advanced training with:
  Loss: Focal Loss (gamma=2.0, class weights)
  Evaluation: Macro F1, Micro F1, Weighted F1, Accuracy
  Model: DistilBERT + LoRA + Advanced Loss


Step,Training Loss,Validation Loss,Accuracy,Macro F1,Micro F1,Weighted F1
391,2.1071,2.326983,0.231364,0.107099,0.231364,0.126118
782,1.9855,1.76758,0.268302,0.245786,0.268302,0.202222
1173,1.507,1.398862,0.316619,0.300849,0.316619,0.26635
1564,1.1645,1.368319,0.41228,0.340447,0.41228,0.431602
1955,1.2208,1.164877,0.421895,0.381829,0.421895,0.427245
2346,1.1261,1.011669,0.480071,0.390435,0.480071,0.516566
2737,0.9763,1.144527,0.54135,0.412392,0.54135,0.565047
3128,1.0025,1.027956,0.467839,0.386774,0.467839,0.508069
3519,1.0507,1.22901,0.485243,0.397888,0.485243,0.504832
3910,1.072,1.324551,0.521755,0.389648,0.521755,0.541535



Advanced training completed!
Final train loss: 0.7731


In [None]:
### IMPROVED TRAINING CONFIGURATION ###
print("=== IMPROVED CONFIGURATION ===")

# 1. More Conservative Class Weights (cap at 10.0, minimum 0.2)
improved_class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array(unique_classes_in_subset),
    y=np.array(train_labels)
)
improved_class_weights_tensor = torch.tensor(improved_class_weights, dtype=torch.float).to(device)

# Conservative capping - less extreme
improved_class_weights_tensor[improved_class_weights_tensor > 10] = 10.0  # Reduced from 50
improved_class_weights_tensor[improved_class_weights_tensor < 0.2] = 0.2   # Increased from 0.1

print(f"Improved class weights (conservative):")
for class_id in unique_classes_in_subset:
    weight = improved_class_weights_tensor[class_id].item()
    label_name = id2label[class_id]
    count = train_labels.count(class_id)
    print(f"  {class_id} ({label_name}): {weight:.3f} (count: {count})")

# 2. Improved LoRA Configuration
improved_lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    target_modules=["q_lin", "v_lin", "k_lin", "out_lin"],
    # More conservative LoRA settings
    lora_dropout=0.2,    # Increased dropout for regularization
    r=16,                # Reduced rank - less parameters to overfit
    lora_alpha=32,       # Reduced influence (alpha/r = 2.0)
)

# 3. Fresh model for improved training
improved_model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type="single_label_classification"
).to(device)

improved_peft_model = get_peft_model(improved_model, improved_lora_config)

# 4. Improved Training Arguments with Regularization
no_epochs_improved = 6
batch_size_improved = 32
total_steps_improved = no_epochs_improved * np.ceil(len(tokenized_datasets["train"]) / batch_size_improved)

improved_training_args = TrainingArguments(
    output_dir="./improved_checkpoints",
    save_strategy="steps",
    eval_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="weighted_f1",  # Better for imbalanced data than macro_f1
    greater_is_better=True,
    report_to=None,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    
    # REGULARIZATION IMPROVEMENTS
    learning_rate=0.0001,              # Reduced learning rate
    num_train_epochs=no_epochs_improved,
    per_device_train_batch_size=batch_size_improved,
    per_device_eval_batch_size=batch_size_improved,
    weight_decay=0.01,                 # Added weight decay for regularization
    warmup_steps=int(0.1 * total_steps_improved), # 10% warmup
    
    # More frequent evaluation to catch overfitting early
    save_steps=int(total_steps_improved // no_epochs_improved // 4), # 4 times per epoch
    eval_steps=int(total_steps_improved // no_epochs_improved // 8), # 8 times per epoch
    logging_steps=int(total_steps_improved // no_epochs_improved // 16), # 16 times per epoch
)

# 5. Improved Trainer with Weighted CrossEntropy (instead of Focal Loss)
improved_trainer = AdvancedTrainer(
    loss_type="weighted_ce",  # Simpler, more stable than focal loss
    class_weights=improved_class_weights_tensor,
    model=improved_peft_model,
    args=improved_training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print(f"\nImproved configuration ready:")
print(f"  Loss: Weighted CrossEntropy (more stable)")
print(f"  Class weights: Conservative (max 10.0, min 0.2)")
print(f"  LoRA: Reduced rank (r=16) with higher dropout (0.2)")
print(f"  Regularization: Weight decay 0.01, reduced LR")
print(f"  Epochs: {no_epochs_improved} (instead of 10) to prevent overfitting")
print(f"  Batch size: {batch_size_improved} (larger for stability)")
print(f"  Total steps: {int(total_steps_improved)}")


In [None]:
# Optional: Test the improved configuration
# Uncomment and run this cell to train with improved settings

# print(f"\nStarting IMPROVED training with:")
# print(f"  Loss: Weighted CrossEntropy (stable)")
# print(f"  Class weights: Conservative (0.2-10.0 range)")
# print(f"  Regularization: Weight decay + higher dropout")
# print(f"  Model: DistilBERT + Conservative LoRA")

# # Run improved training
# improved_result = improved_trainer.train()

# print(f"\nImproved training completed!")
# print(f"Final train loss: {improved_result.training_loss:.4f}")

# # Evaluate improved model
# improved_eval = improved_trainer.evaluate()
# print(f"\nImproved evaluation results:")
# print(f"  Accuracy: {improved_eval['eval_accuracy']:.4f}")
# print(f"  Macro F1: {improved_eval['eval_macro_f1']:.4f}")
# print(f"  Weighted F1: {improved_eval['eval_weighted_f1']:.4f}")

# # Check if class distribution is more balanced
# improved_predictions = improved_trainer.predict(tokenized_datasets["validation"])
# improved_predicted_classes = np.argmax(improved_predictions.predictions, axis=1)
# improved_predicted_counts = pd.Series(improved_predicted_classes).value_counts().sort_index()

# print(f"\nImproved prediction distribution:")
# for label_id, count in improved_predicted_counts.items():
#     label_name = id2label[label_id]
#     percentage = (count / len(improved_predicted_classes)) * 100
#     print(f"    {label_id} ({label_name}): {count} predictions ({percentage:.1f}%)")

print("Improved configuration ready. Uncomment above code to train with better settings.")


In [12]:
# Comprehensive evaluation
advanced_eval = advanced_trainer.evaluate()
print(f"\nAdvanced evaluation results:")
print(f"  Accuracy: {advanced_eval['eval_accuracy']:.4f}")
print(f"  Macro F1: {advanced_eval['eval_macro_f1']:.4f}")
print(f"  Micro F1: {advanced_eval['eval_micro_f1']:.4f}")
print(f"  Weighted F1: {advanced_eval['eval_weighted_f1']:.4f}")

# Check prediction diversity
print(f"\nPrediction diversity analysis:")
advanced_predictions = advanced_trainer.predict(tokenized_datasets["validation"])
predicted_classes = np.argmax(advanced_predictions.predictions, axis=1)
unique_predictions = len(set(predicted_classes))
predicted_counts = pd.Series(predicted_classes).value_counts().sort_index()

print(f"  Unique classes predicted: {unique_predictions}/12")
print(f"  Prediction distribution:")
for label_id, count in predicted_counts.items():
    label_name = id2label[label_id]
    percentage = (count / len(predicted_classes)) * 100
    print(f"    {label_id} ({label_name}): {count} predictions ({percentage:.1f}%)")

# Compare with baseline (Step 6 results)
print(f"\nComparison with baseline:")
print(f"  Baseline accuracy: 65.0% (predicting only 1 class)")
print(f"  Advanced accuracy: {advanced_eval['eval_accuracy']*100:.1f}%")
print(f"  Baseline macro F1: ~0.08 (random)")
print(f"  Advanced macro F1: {advanced_eval['eval_macro_f1']:.3f}")
print(f"  Prediction diversity: {unique_predictions}/12 classes vs 1/12 baseline")

if advanced_eval['eval_macro_f1'] > 0.3:
    print("✅ Advanced training SUCCESS! Macro F1 > 0.3")
elif unique_predictions > 5:
    print("✅ Good prediction diversity! Learning multiple classes")
else:
    print("⚠️  May need further tuning - try weighted_ce loss or adjust gamma")


Advanced evaluation results:
  Accuracy: 0.5292
  Macro F1: 0.5010
  Micro F1: 0.5292
  Weighted F1: 0.5673

Prediction diversity analysis:
  Unique classes predicted: 12/12
  Prediction distribution:
    0 (%): 1202 predictions (7.3%)
    1 (b): 3487 predictions (21.2%)
    2 (fg): 3216 predictions (19.6%)
    3 (fh): 2811 predictions (17.1%)
    4 (h): 281 predictions (1.7%)
    5 (qh): 99 predictions (0.6%)
    6 (qo): 22 predictions (0.1%)
    7 (qr): 69 predictions (0.4%)
    8 (qrr): 92 predictions (0.6%)
    9 (qw): 263 predictions (1.6%)
    10 (qy): 874 predictions (5.3%)
    11 (s): 4017 predictions (24.4%)

Comparison with baseline:
  Baseline accuracy: 65.0% (predicting only 1 class)
  Advanced accuracy: 52.9%
  Baseline macro F1: ~0.08 (random)
  Advanced macro F1: 0.501
  Prediction diversity: 12/12 classes vs 1/12 baseline
✅ Advanced training SUCCESS! Macro F1 > 0.3
