In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from transformers import (
    AutoTokenizer, 
    AutoModel, 
    Trainer, 
    TrainingArguments,
    EarlyStoppingCallback,
    PreTrainedModel,
    PretrainedConfig
)
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 42 #for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Model Config
MODEL_NAME = "answerdotai/ModernBERT-base"
MAX_LENGTH = 256  # Increased to fit both texts
HIDDEN_DIM = 256
DROPOUT = 0.3

# Embedding dimensions
STAGE_EMBED_DIM = 32

# Training configuration
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 5e-5
NUM_EPOCHS = 12
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
PATIENCE = 3  # Early stopping patience

# Label mappings for Layer 2 multi-class classification
id2label = {
    0: "answer_submission",
    1: "clarification_request",
    2: "process_inquiry",
    3: "challenge_assessment",
    4: "off_topic",
    5: "small_talk"
}
label2id = {v: k for k, v in id2label.items()}
NUM_LABELS = len(id2label)

# Stage mappings (4 stages)
stage2id = {
    "opening": 0,
    "technical_depth": 1,
    "challenge": 2,
    "closing": 3
}
id2stage = {v: k for k, v in stage2id.items()}
NUM_STAGES = len(stage2id)

print(f"Classes: {NUM_LABELS} intents, {NUM_STAGES} stages")

MODEL_CONFIG = {
    'model_name': MODEL_NAME,
    'num_labels': NUM_LABELS,
    'num_stages': NUM_STAGES,
    'stage_embed_dim': STAGE_EMBED_DIM,
    'hidden_dim': HIDDEN_DIM,
    'dropout': DROPOUT
}

In [None]:
# CSV columns: user_query, prev_agent_response, interview_stage, label
df = pd.read_csv('layer2_contextual_data.csv')

# Clean data
df['label'] = pd.to_numeric(df['label'], errors='coerce')
df = df.dropna(subset=['label'])
df['label'] = df['label'].astype(int)
df = df[df['label'].isin([0, 1, 2, 3, 4, 5])].copy()

# Validate stages
df = df[df['interview_stage'].isin(stage2id.keys())].copy()

print(f"Loaded {len(df)} samples")
print(f"\nLabel distribution:")
print(df['label'].value_counts().sort_index())
print(f"\nStage distribution:")
print(df['interview_stage'].value_counts())
df.head(5)

In [None]:
# First split: 70% train, 30% val/test
train_df, temp_df = train_test_split(
    df, 
    test_size=0.30, 
    random_state=RANDOM_SEED, 
    stratify=df['label']
)

# Second split: 15% val, 15% test (50-50 split of the 30%)
val_df, test_df = train_test_split(
    temp_df, 
    test_size=0.50, 
    random_state=RANDOM_SEED, 
    stratify=temp_df['label']
)

# Reset indices
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print(f"Train size: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Val size: {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test size: {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

splits = [
    ('Train', train_df),
    ('Validation', val_df),
    ('Test', test_df)
]

colors_intent = ['#ff6b6b', '#4ecdc4', '#45b7d1', '#96ceb4', '#ffeaa7', '#dfe6e9']
colors_stage = ['#a29bfe', '#fd79a8', '#00b894', '#e17055']
class_names = [id2label[i] for i in range(NUM_LABELS)]
stage_names = list(stage2id.keys())

# Row 1: Intent distribution
for idx, (split_name, split_df) in enumerate(splits):
    counts = split_df['label'].value_counts().sort_index()
    percentages = (counts / len(split_df) * 100)
    
    bars = axes[0, idx].bar(range(NUM_LABELS), counts.values, color=colors_intent)
    axes[0, idx].set_title(f'{split_name} - Intent Distribution (n={len(split_df):,})', fontsize=12, fontweight='bold')
    axes[0, idx].set_ylabel('Count')
    axes[0, idx].set_xticks(range(NUM_LABELS))
    axes[0, idx].set_xticklabels(class_names, rotation=45, ha='right', fontsize=9)
    
    for bar, count, pct in zip(bars, counts.values, percentages.values):
        label = f'{count:,}    ({pct:.1f}%)'
        axes[0, idx].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                          label, ha='center', va='bottom', fontsize=8)

# Row 2: Stage distribution
for idx, (split_name, split_df) in enumerate(splits):
    counts = split_df['interview_stage'].value_counts()
    # Reorder to match stage2id
    counts = counts.reindex(stage_names)
    percentages = (counts / len(split_df) * 100)
    
    bars = axes[1, idx].bar(range(NUM_STAGES), counts.values, color=colors_stage)
    axes[1, idx].set_title(f'{split_name} - Stage Distribution', fontsize=12, fontweight='bold')
    axes[1, idx].set_ylabel('Count')
    axes[1, idx].set_xticks(range(NUM_STAGES))
    axes[1, idx].set_xticklabels(stage_names, rotation=45, ha='right', fontsize=9)
    
    for bar, count, pct in zip(bars, counts.values, percentages.values):
        label = f'{count:,}    ({pct:.1f}%)'
        axes[1, idx].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                          label, ha='center', va='bottom', fontsize=8)

plt.suptitle(f'Data Distribution Across Splits - Total: {len(df):,} samples', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Using HuggingFace Dataset format

In [None]:
class ContextAwareLayer2Config(PretrainedConfig):
    """Custom config for ContextAwareLayer2Classifier."""
    model_type = "context_aware_layer2"
    
    def __init__(
        self,
        model_name: str = "answerdotai/ModernBERT-base",
        num_labels: int = 6,
        num_stages: int = 4,
        stage_embed_dim: int = 32,
        hidden_dim: int = 256,
        dropout: float = 0.3,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.num_labels = num_labels
        self.num_stages = num_stages
        self.stage_embed_dim = stage_embed_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout


class ContextAwareLayer2Classifier(PreTrainedModel):
    """
    Context-aware intent classifier that combines:
    - Text encoding (ModernBERT with [prev_msg] [SEP] [current_query])
    - Interview stage embedding
    
    Compatible with HuggingFace Trainer.
    """
    config_class = ContextAwareLayer2Config
    
    def __init__(self, config):
        super().__init__(config)
        
        self.num_labels = config.num_labels
        
        # ModernBERT encoder for text
        self.bert = AutoModel.from_pretrained(config.model_name)
        self.bert_hidden_size = self.bert.config.hidden_size  # 768
        
        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False
        
        # Stage embedding
        self.stage_embedding = nn.Embedding(
            num_embeddings=config.num_stages,
            embedding_dim=config.stage_embed_dim
        )
        
        # Calculate combined dimension
        # BERT (768) + stage (32) = 800
        combined_dim = self.bert_hidden_size + config.stage_embed_dim
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.num_labels)
        )
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        interview_stage: torch.Tensor,
        labels: torch.Tensor = None,
        **kwargs  # Accept other kwargs from Trainer
    ):
        """
        Forward pass compatible with HuggingFace Trainer.
        
        Args:
            input_ids: Tokenized text [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            interview_stage: Stage indices [batch_size]
            labels: Ground truth labels [batch_size] (optional)
        
        Returns:
            dict with 'loss' (if labels provided) and 'logits'
        """
        # 1. Encode text through BERT
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        # Use [CLS] token embedding
        text_embedding = bert_output.last_hidden_state[:, 0, :]  # [batch_size, 768]
        
        # 2. Embed stage
        stage_emb = self.stage_embedding(interview_stage)  # [batch_size, 32]
        
        # 3. Concatenate all embeddings
        combined = torch.cat([text_embedding, stage_emb], dim=-1)  # [batch_size, 800]
        
        # 4. Classification
        logits = self.classifier(combined)  # [batch_size, num_labels]
        
        # 5. Compute loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        # Return in format expected by Trainer
        return {
            'loss': loss,
            'logits': logits
        }
    
    def count_parameters(self):
        """Count trainable and frozen parameters."""
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
        return trainable, frozen

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

config = ContextAwareLayer2Config(**MODEL_CONFIG)
model = ContextAwareLayer2Classifier(config)

trainable, frozen = model.count_parameters()
print(f"Parameters: {trainable:,} trainable, {frozen:,} frozen")

In [None]:
def prepare_dataset(df):
    """Prepare dataframe for tokenization."""
    df['combined_text'] = df.apply(
        lambda row: f"{str(row['prev_agent_response']) if pd.notna(row['prev_agent_response']) else ''} {tokenizer.sep_token} {str(row['user_query'])}",
        axis=1
    )
    df['interview_stage'] = df['interview_stage'].map(stage2id)
    return df[['combined_text', 'interview_stage', 'label']].copy()

train_prepared = prepare_dataset(train_df)
val_prepared = prepare_dataset(val_df)
test_prepared = prepare_dataset(test_df)

train_dataset = Dataset.from_pandas(train_prepared.reset_index(drop=True))
val_dataset = Dataset.from_pandas(val_prepared.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_prepared.reset_index(drop=True))

def tokenize_function(examples):
    tokenized = tokenizer(
        examples['combined_text'],
        truncation=True,
        max_length=MAX_LENGTH,
        padding='max_length'
    )
    tokenized['interview_stage'] = examples['interview_stage']
    tokenized['labels'] = examples['label']
    return tokenized

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=['combined_text', 'label'])
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=['combined_text', 'label'])
test_dataset = test_dataset.map(tokenize_function, batched=True, remove_columns=['combined_text', 'label'])

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'interview_stage', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'interview_stage', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'interview_stage', 'labels'])

print(f"Datasets ready: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")

In [None]:
def compute_metrics_for_trainer(eval_pred):
    """
    Compute metrics for HuggingFace Trainer.
    
    Args:
        eval_pred: EvalPrediction object with predictions and label_ids
    
    Returns:
        Dictionary of metrics
    """
    predictions, labels = eval_pred
    
    # predictions are logits, convert to probabilities and predicted classes
    import scipy.special
    probs = scipy.special.softmax(predictions, axis=-1)
    preds = np.argmax(predictions, axis=-1)
    
    confidence = np.max(probs, axis=-1)
    
    # Per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, preds, average=None, labels=list(range(NUM_LABELS)), zero_division=0
    )
    
    # Macro averages
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        labels, preds, average='macro', zero_division=0
    )
    
    # Per-class AUC
    per_class_auc = []
    for i in range(NUM_LABELS):
        binary_labels = (labels == i).astype(int)
        if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels):
            auc = roc_auc_score(binary_labels, probs[:, i])
        else:
            auc = 0.0
        per_class_auc.append(auc)
    
    # Confidence stats
    correct_mask = preds == labels
    
    metrics = {
        "accuracy": accuracy_score(labels, preds),
        "macro_f1": macro_f1,
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "macro_auc": np.mean(per_class_auc),
        "confidence_mean": confidence.mean(),
        "confidence_correct": confidence[correct_mask].mean() if correct_mask.any() else 0.0,
        "confidence_wrong": confidence[~correct_mask].mean() if (~correct_mask).any() else 0.0,
    }
    
    for i, label_name in id2label.items():
        metrics[f"f1_{label_name}"] = f1[i]
        metrics[f"precision_{label_name}"] = precision[i]
        metrics[f"recall_{label_name}"] = recall[i]
        metrics[f"auc_{label_name}"] = per_class_auc[i]
    
    return metrics

In [None]:
def evaluate_with_stages(model, dataset, df_original, batch_size=64):
    """
    Evaluate model with stage-stratified metrics.
    
    Args:
        model: Trained model
        dataset: HuggingFace Dataset
        df_original: Original dataframe with stage_name column
        batch_size: Evaluation batch size
    
    Returns:
        metrics dict, labels, predictions, probabilities
    """
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    # Create temporary trainer for evaluation
    temp_trainer = Trainer(
        model=model,
        args=TrainingArguments(
            output_dir='./temp',
            per_device_eval_batch_size=batch_size,
            report_to=[]
        )
    )
    
    # Get predictions
    predictions = temp_trainer.predict(dataset)
    logits = predictions.predictions
    labels = predictions.label_ids
    
    # Convert to probabilities
    import scipy.special
    probs = scipy.special.softmax(logits, axis=-1)
    preds = np.argmax(logits, axis=-1)
    
    # Compute base metrics
    metrics = compute_metrics_for_trainer(predictions)
    
    # Add stage-stratified metrics
    stage_names = df_original['interview_stage'].values
    for stage_name in stage2id.keys():
        stage_mask = stage_names == stage_name
        if stage_mask.sum() > 0:
            stage_acc = accuracy_score(labels[stage_mask], preds[stage_mask])
            metrics[f"accuracy_stage_{stage_name}"] = stage_acc
            metrics[f"n_stage_{stage_name}"] = int(stage_mask.sum())
        else:
            metrics[f"accuracy_stage_{stage_name}"] = 0.0
            metrics[f"n_stage_{stage_name}"] = 0
    
    return metrics, labels, preds, probs

In [None]:
training_args = TrainingArguments(
    output_dir="./layer2_contextual_model",
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="macro_f1",
    greater_is_better=True,
    logging_steps=50,
    report_to=[],
    save_total_limit=2,
    seed=RANDOM_SEED,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics_for_trainer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=PATIENCE)],
)

print(f"Training: {NUM_EPOCHS} epochs, batch_size={BATCH_SIZE}, lr={LEARNING_RATE}, early_stop={PATIENCE}")

In [None]:
train_result = trainer.train()

history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_macro_f1': []}
for log in trainer.state.log_history:
    if 'loss' in log and 'epoch' in log:
        history['train_loss'].append(log['loss'])
    if 'eval_loss' in log:
        history['val_loss'].append(log['eval_loss'])
    if 'eval_accuracy' in log:
        history['val_accuracy'].append(log['eval_accuracy'])
    if 'eval_macro_f1' in log:
        history['val_macro_f1'].append(log['eval_macro_f1'])

best_val_f1 = max(history['val_macro_f1']) if history['val_macro_f1'] else 0.0

trainer.save_model(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)

import json
with open(f"{training_args.output_dir}/label_mappings.json", "w") as f:
    json.dump({
        'id2label': id2label,
        'label2id': label2id,
        'id2stage': id2stage,
        'stage2id': stage2id
    }, f, indent=2)

print(f"Training complete. Best F1: {best_val_f1:.4f}. Saved to {training_args.output_dir}")

In [None]:
import os
import json

model_dir = "./layer2_contextual_model"

with open(os.path.join(model_dir, 'label_mappings.json'), 'r') as f:
    mappings = json.load(f)
    id2label = {int(k): v for k, v in mappings['id2label'].items()}
    label2id = mappings['label2id']
    id2stage = {int(k): v for k, v in mappings['id2stage'].items()}
    stage2id = mappings['stage2id']

model = ContextAwareLayer2Classifier.from_pretrained(model_dir)
model = model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_dir)
print(f"Model loaded from {model_dir}")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(epochs_range, history['val_accuracy'], 'g-', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Validation Accuracy', fontweight='bold')
axes[1].grid(True, alpha=0.3)

# Macro F1
axes[2].plot(epochs_range, history['val_macro_f1'], 'purple', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Macro F1')
axes[2].set_title('Validation Macro F1', fontweight='bold')
axes[2].axhline(y=best_val_f1, color='red', linestyle='--', label=f'Best: {best_val_f1:.4f}')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Training History', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
test_metrics, test_labels, test_preds, test_probs = evaluate_with_stages(
    model, test_dataset, test_df, batch_size=EVAL_BATCH_SIZE
)

print("=" * 60)
print("TEST SET RESULTS")
print("=" * 60)
print(f"\nOverall Metrics:")
print(f"  Accuracy: {test_metrics['accuracy']:.4f}")
print(f"  Macro F1: {test_metrics['macro_f1']:.4f}")
print(f"  Macro Precision: {test_metrics['macro_precision']:.4f}")
print(f"  Macro Recall: {test_metrics['macro_recall']:.4f}")
print(f"  Macro AUC: {test_metrics['macro_auc']:.4f}")

print(f"\nConfidence Statistics:")
print(f"  Mean Confidence: {test_metrics['confidence_mean']:.4f}")
print(f"  Confidence (Correct): {test_metrics['confidence_correct']:.4f}")
print(f"  Confidence (Wrong): {test_metrics['confidence_wrong']:.4f}")

print(f"\nStratified by Stage:")
for stage_name in stage2id.keys():
    acc = test_metrics[f'accuracy_stage_{stage_name}']
    n = test_metrics[f'n_stage_{stage_name}']
    print(f"  {stage_name}: {acc:.4f} (n={n})")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

class_names = [id2label[i] for i in range(NUM_LABELS)]
x = np.arange(NUM_LABELS)
bar_width = 0.6
colors_intent = ['#ff6b6b', '#4ecdc4', '#45b7d1', '#96ceb4', '#ffeaa7', '#dfe6e9']

# F1 Scores
f1_scores = [test_metrics[f"f1_{id2label[i]}"] for i in range(NUM_LABELS)]
bars = axes[0].bar(x, f1_scores, bar_width, color=colors_intent)
axes[0].set_title('F1 Score per Class', fontsize=12, fontweight='bold')
axes[0].set_ylabel('F1 Score')
axes[0].set_xticks(x)
axes[0].set_xticklabels(class_names, rotation=45, ha='right', fontsize=9)
axes[0].set_ylim(0, 1.15)
axes[0].axhline(y=test_metrics["macro_f1"], color='red', linestyle='--', linewidth=2, 
                label=f'Macro F1: {test_metrics["macro_f1"]:.3f}')
axes[0].legend(loc='upper right', framealpha=0.9)
for bar, score in zip(bars, f1_scores):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{score:.3f}', ha='center', fontsize=9)

# Precision Scores
precision_scores = [test_metrics[f"precision_{id2label[i]}"] for i in range(NUM_LABELS)]
bars = axes[1].bar(x, precision_scores, bar_width, color=colors_intent)
axes[1].set_title('Precision per Class', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Precision')
axes[1].set_xticks(x)
axes[1].set_xticklabels(class_names, rotation=45, ha='right', fontsize=9)
axes[1].set_ylim(0, 1.15)
axes[1].axhline(y=test_metrics["macro_precision"], color='red', linestyle='--', linewidth=2, 
                label=f'Macro: {test_metrics["macro_precision"]:.3f}')
axes[1].legend(loc='upper right', framealpha=0.9)
for bar, score in zip(bars, precision_scores):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{score:.3f}', ha='center', fontsize=9)

# Recall Scores
recall_scores = [test_metrics[f"recall_{id2label[i]}"] for i in range(NUM_LABELS)]
bars = axes[2].bar(x, recall_scores, bar_width, color=colors_intent)
axes[2].set_title('Recall per Class', fontsize=12, fontweight='bold')
axes[2].set_ylabel('Recall')
axes[2].set_xticks(x)
axes[2].set_xticklabels(class_names, rotation=45, ha='right', fontsize=9)
axes[2].set_ylim(0, 1.15)
axes[2].axhline(y=test_metrics["macro_recall"], color='red', linestyle='--', linewidth=2, 
                label=f'Macro: {test_metrics["macro_recall"]:.3f}')
axes[2].legend(loc='upper right', framealpha=0.9)
for bar, score in zip(bars, recall_scores):
    axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, f'{score:.3f}', ha='center', fontsize=9)

plt.suptitle('Per-Class Classification Metrics on Test Set', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

# Accuracy by Interview Stage
stage_names = list(stage2id.keys())
stage_accs = [test_metrics[f'accuracy_stage_{s}'] for s in stage_names]
stage_counts = [test_metrics[f'n_stage_{s}'] for s in stage_names]
colors_stage = ['#a29bfe', '#fd79a8', '#00b894', '#e17055']

bars = ax.bar(range(len(stage_names)), stage_accs, color=colors_stage)
ax.set_title('Accuracy by Interview Stage', fontsize=14, fontweight='bold')
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_xlabel('Interview Stage', fontsize=12)
ax.set_xticks(range(len(stage_names)))
ax.set_xticklabels(stage_names, fontsize=11)
ax.set_ylim(0, 1.15)
ax.axhline(y=test_metrics['accuracy'], color='red', linestyle='--', linewidth=2,
           label=f'Overall: {test_metrics["accuracy"]:.3f}')
ax.legend(loc='upper right', fontsize=11)

for bar, acc, count in zip(bars, stage_accs, stage_counts):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
            f'{acc:.3f}\n(n={count})', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
cm = confusion_matrix(test_labels, test_preds)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

class_names = [id2label[i] for i in range(NUM_LABELS)]

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=class_names, yticklabels=class_names)
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_title('Confusion Matrix (Counts)', fontweight='bold')

# Normalized (percentages)
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', ax=axes[1],
            xticklabels=class_names, yticklabels=class_names)
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')
axes[1].set_title('Confusion Matrix (Normalized)', fontweight='bold')

plt.suptitle('Test Set Confusion Matrix', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
def predict(
    model,
    tokenizer,
    user_query: str,
    prev_agent_response: str = "",
    interview_stage: str = "opening",
    confidence_threshold: float = 0.5
):
    """
    Make a prediction with context.
    
    Args:
        model: Trained model
        tokenizer: Tokenizer
        user_query: Current user query
        prev_agent_response: Previous agent/interviewer message
        interview_stage: One of 'opening', 'technical_depth', 'challenge', 'closing'
        confidence_threshold: Threshold for flagging low confidence
    
    Returns:
        Dictionary with prediction results
    """
    model.eval()
    
    combined_text = f"{prev_agent_response} {tokenizer.sep_token} {user_query}"
    encoding = tokenizer(
        combined_text,
        truncation=True,
        max_length=MAX_LENGTH,
        padding='max_length',
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    stage_tensor = torch.tensor([stage2id[interview_stage]], dtype=torch.long).to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, stage_tensor)
        logits = outputs['logits']
        probs = torch.softmax(logits, dim=-1)[0]
    
    pred_class = probs.argmax().item()
    confidence = probs[pred_class].item()
    pred_label = id2label[pred_class]
    
    if confidence >= confidence_threshold:
        status = f"ACCEPTED - {pred_label}"
    else:
        status = f"FLAGGED (low confidence) - Predicted: {pred_label}"
    
    return {
        'prediction': pred_label,
        'prediction_id': pred_class,
        'confidence': confidence,
        'status': status,
        'all_probs': {id2label[i]: probs[i].item() for i in range(NUM_LABELS)}
    }

In [None]:
test_query = "Do you guys have free snacks in the office?"
test_prev_response = "Tell me about your most challenging technical project."
test_stage = "technical_depth"

result = predict(
    model=model,
    tokenizer=tokenizer,
    user_query=test_query,
    prev_agent_response=test_prev_response,
    interview_stage=test_stage
)

print("="*60)
print("PREDICTION ANALYSIS")
print("="*60)
print(f"\nInput:")
print(f"  Query: '{test_query}'")
print(f"  Prev Response: '{test_prev_response}'")
print(f"  Stage: {test_stage}")
print(f"\nPrediction: {result['prediction']} (label {label2id[result['prediction']]})")
print(f"Confidence: {result['confidence']:.1%}")
print(f"Status: {result['status']}")
print("\nTop 3 Predictions:")
sorted_probs = sorted(result['all_probs'].items(), key=lambda x: -x[1])
for i, (label, prob) in enumerate(sorted_probs[:3], 1):
    marker = " <- PREDICTED" if label == result['prediction'] else ""
    print(f"  {i}. {label:25s} {prob:6.1%}{marker}")


In [None]:
def load_model(model_dir: str, device: torch.device):
    """Load a saved context-aware model. Returns model, tokenizer, mappings."""
    import os
    import json
    
    with open(os.path.join(model_dir, 'label_mappings.json'), 'r') as f:
        mappings = json.load(f)
    
    model = ContextAwareLayer2Classifier.from_pretrained(model_dir)
    model = model.to(device)
    model.eval()
    
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    
    return model, tokenizer, mappings