# MRDA Model Loading and Inference

This notebook loads the fine-tuned DistilBERT model with LoRA adapters for 12-class dialogue act classification.

**Model Details:**
- Base Model: `distilbert-base-uncased`
- Fine-tuned with LoRA adapters 
- Task: 12-class MRDA dialogue act classification
- Classes: `['%', 'b', 'fg', 'fh', 'h', 'qh', 'qo', 'qr', 'qrr', 'qw', 'qy', 's']`

---


In [None]:
import json
import os
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F

from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification
)
from peft import (
    PeftModel,
    LoraConfig,
    TaskType
)

import warnings
warnings.filterwarnings('ignore')

In [None]:
def detect_device():
    """Detect best available device with fallback strategy"""
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps" 
    else:
        device = "cpu"
    return device

device = detect_device()

In [None]:
CHECKPOINT_PATH = "../advanced_checkpoints/checkpoint-11730"
MODEL_NAME = "distilbert-base-uncased"
# NUM_LABELS = 12
NUM_LABELS = 5

# Label mappings from your training
# unique_labels = ['%', 'b', 'fg', 'fh', 'h', 'qh', 'qo', 'qr', 'qrr', 'qw', 'qy', 's']
unique_labels = ['S', 'Q', 'D', 'F', 'B']
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}


tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)

base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    problem_type="single_label_classification",
    id2label=id2label,
    label2id=label2id
)

model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
model = model.to(device)
model.eval()

In [None]:
# CONTENT_LABELS = {'s', 'qy', 'qw', 'qh', 'qrr', 'qr', 'qo'}
# NON_CONTENT_LABELS = {'b', 'fh', 'fg', '%', 'h'}
# LABEL_DESCRIPTIONS = {
#     '%': 'Interrupted/Abandoned utterance',
#     'b': 'Continuer (backchannel)',
#     'fg': 'Floor Grabber (taking the floor)',
#     'fh': 'Floor Holder (keeping the floor)', 
#     'h': 'Hold Before Answer (hesitation)',
#     'qh': 'Rhetorical Question',
#     'qo': 'Open-ended Question',
#     'qr': 'Or Question',
#     'qrr': 'Or-Clause (part of question)',
#     'qw': 'Wh-Question (what, where, when, etc.)',
#     'qy': 'Yes-No Question',
#     's': 'Statement'
# }
CONTENT_LABELS = {'S', 'Q'}
NON_CONTENT_LABELS = {'B', 'D', 'F'}
LABEL_DESCRIPTIONS = {
    'S': 'Statement',
    'Q': 'Question',
    'B': 'Backchannel',
    'D': 'Disruptions',
    'F': 'Floor Grabber'
}

def map_to_binary(general_da_label):
    """Map general DA label to binary content/non-content"""
    if general_da_label in CONTENT_LABELS:
        return 1, "content"
    elif general_da_label in NON_CONTENT_LABELS:
        return 0, "non-content"
    else:
        return -1, "unknown"

def predict_single(text, return_probabilities=False):
    """
    Predict dialogue act for a single text
    
    Args:
        text (str): Input text
        return_probabilities (bool): Whether to return class probabilities
    
    Returns:
        dict: Prediction results
    """
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    ).to(device)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = F.softmax(logits, dim=-1)
        predicted_class_id = torch.argmax(logits, dim=-1).item()
    
    # Get predicted label
    predicted_label = id2label[predicted_class_id]
    confidence = probabilities[0][predicted_class_id].item()
    
    # Binary classification
    binary_label, binary_text = map_to_binary(predicted_label)
    
    result = {
        'text': text,
        'predicted_class_id': predicted_class_id,
        'predicted_label': predicted_label,
        'confidence': confidence,
        'binary_label': binary_label,
        'binary_text': binary_text
    }
    
    if return_probabilities:
        all_probs = {id2label[i]: prob.item() for i, prob in enumerate(probabilities[0])}
        result['all_probabilities'] = all_probs
    
    return result

def predict_batch_real(texts, return_probabilities=False):
    """
    TRUE batch prediction - processes all texts in single forward pass
    
    Args:
        texts (list): List of input texts
        return_probabilities (bool): Whether to return class probabilities
    
    Returns:
        list: List of prediction results
    """
    if not texts:
        return []
    
    # Tokenize all texts together - this is the key for real batching
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    ).to(device)
    
    # Single forward pass for all texts
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = F.softmax(logits, dim=-1)
        predicted_class_ids = torch.argmax(logits, dim=-1)
    
    # Process results for each text
    results = []
    for i, text in enumerate(texts):
        predicted_class_id = predicted_class_ids[i].item()
        predicted_label = id2label[predicted_class_id]
        confidence = probabilities[i][predicted_class_id].item()
        
        # Binary classification
        binary_label, binary_text = map_to_binary(predicted_label)
        
        result = {
            'text': text,
            'predicted_class_id': predicted_class_id,
            'predicted_label': predicted_label,
            'confidence': confidence,
            'binary_label': binary_label,
            'binary_text': binary_text
        }
        
        if return_probabilities:
            all_probs = {id2label[j]: prob.item() for j, prob in enumerate(probabilities[i])}
            result['all_probabilities'] = all_probs
        
        results.append(result)
    
    return results

def predict_batch(texts, return_probabilities=False):
    """
    FAKE batch prediction - just loops through predict_single (kept for compatibility)
    Use predict_batch_real() for true batching
    """
    results = []
    for text in texts:
        result = predict_single(text, return_probabilities)
        results.append(result)
    return results

def display_prediction(result, true_label=None):
    """
    Display prediction results in a nice format
    
    Args:
        result (dict): Prediction result from predict_single or predict_batch_real
        true_label (str, optional): Ground truth label for comparison
    """
    print(f"📝 Text: '{result['text']}'")
    print(f"🏷️ Predicted: {result['predicted_label']} (ID: {result['predicted_class_id']})")
    # Show ground truth if provided
    if true_label is not None:
        is_correct = result['predicted_label'] == true_label
        status_icon = "✅" if is_correct else "❌"
        print(f"🎯 True Label: {true_label} {status_icon}")
        if not is_correct:
            print(f"   True Description: {LABEL_DESCRIPTIONS.get(true_label, 'Unknown')}")
            true_binary = map_to_binary(true_label)[1]
            pred_binary = result['binary_text']
            binary_correct = "✅" if true_binary == pred_binary else "❌"
            print(f"   Binary: {true_binary} vs {pred_binary} {binary_correct}")
    
    print(f"📊 Confidence: {result['confidence']:.4f}")
    print(f"🔄 Binary Classification: {result['binary_text']}")
    print(f"💡 Description: {LABEL_DESCRIPTIONS[result['predicted_label']]}")

    if 'all_probabilities' in result:
        print(f"📈 Top 5 Probabilities:")
        sorted_probs = sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True)
        for label, prob in sorted_probs[:5]:
            print(f"   {label}: {prob:.4f}")
    print("-" * 60)

def display_batch_predictions(results, true_labels=None):
    """
    Display batch prediction results
    
    Args:
        results (list): List of prediction results
        true_labels (list, optional): List of ground truth labels for comparison
    """
    correct_count = 0
    total_count = len(results)
    
    for i, result in enumerate(results):
        true_label = true_labels[i] if true_labels and i < len(true_labels) else None
        display_prediction(result, true_label)
        
        if true_label and result['predicted_label'] == true_label:
            correct_count += 1
    
    # Show summary if true labels provided
    if true_labels:
        accuracy = correct_count / total_count if total_count > 0 else 0
        print(f"📊 BATCH SUMMARY: {correct_count}/{total_count} correct ({accuracy:.1%})")


In [None]:
x = predict_single("I was fine-tuning BERT", return_probabilities=True)
display_prediction(x, true_label="s")

In [None]:
batch = [
    "okay",  # Should be 'fg' (Floor Grabber) - non-content
    "i think that's a good idea",  # Should be 's' (Statement) - content
    "uh-huh",  # Should be 'b' (Continuer) - non-content  
    "what do you think about this?",  # Should be 'qy' (Yes-No Question) - content
    "how are you doing?",  # Should be 'qw' (Wh-Question) - content
    "um",  # Should be 'fh' (Floor Holder) - non-content
]
labels = [
    "fg",
    "s",
    "b",
    "qy",
    "qw",
    "fh",
]


predictions = predict_batch_real(batch, return_probabilities=True)
display_batch_predictions(predictions,  true_labels=labels)


In [None]:
# Test the model with example utterances
test_utterances = [
    "okay",  # Should be 'fg' (Floor Grabber) - non-content
    "i think that's a good idea",  # Should be 's' (Statement) - content
    "uh-huh",  # Should be 'b' (Continuer) - non-content  
    "what do you think about this?",  # Should be 'qy' (Yes-No Question) - content
    "how are you doing?",  # Should be 'qw' (Wh-Question) - content
    "um",  # Should be 'fh' (Floor Holder) - non-content
]

print("🧪 Testing model with example utterances:")
print("=" * 60)

for utterance in test_utterances:
    result = predict_single(utterance, return_probabilities=True)
    display_prediction(result)


In [None]:
# Demonstrate TRUE batch processing vs FAKE batch processing
import time

# Test data
test_texts = [
    "okay let's start",
    "i think this is important", 
    "uh-huh",
    "what should we do next?",
    "how about this approach?",
    "um let me think",
    "that sounds good",
    "yes exactly",
    "are you sure about that?",
    "maybe we should consider"
] * 5  # 50 texts total
test_labels = ["fg", "s", "b", "qw", "qw", "fh", "s", "b", "qy", "s"] * 5

print(f"🧪 Comparing batch processing methods with {len(test_texts)} texts:")
print("=" * 70)

# FAKE batch processing (loops through predict_single)
start_time = time.time()
fake_results = predict_batch(test_texts, return_probabilities=False)
fake_time = time.time() - start_time

# TRUE batch processing (single forward pass)  
start_time = time.time()
real_results = predict_batch_real(test_texts, return_probabilities=False)
real_time = time.time() - start_time

print(f"⚡ FAKE Batch (loops): {fake_time:.3f}s")
print(f"🚀 REAL Batch (vectorized): {real_time:.3f}s")
print(f"💨 Speedup: {fake_time/real_time:.1f}x faster!")

# Verify results are identical
predictions_match = all(
    fake['predicted_label'] == real['predicted_label'] 
    for fake, real in zip(fake_results, real_results)
)
print(f"✅ Results identical: {predictions_match}")

# Show accuracy for both methods
fake_correct = sum(1 for i, result in enumerate(fake_results) 
                   if result['predicted_label'] == test_labels[i])
real_correct = sum(1 for i, result in enumerate(real_results) 
                   if result['predicted_label'] == test_labels[i])

print(f"📊 Accuracy: {fake_correct}/{len(test_texts)} = {fake_correct/len(test_texts):.1%}")
print("\n" + "="*70)
print("🎯 Always use predict_batch_real() for better performance!")


In [None]:
with open("../transcript_en.txt", "r") as file:
    transcript = file.read()

transcript = transcript.split("\n")
transcript = [line.split(";") for line in transcript]
transcript_texts = [line[1] for line in transcript]

start_time = time.time()
real_results = predict_batch_real(transcript_texts, return_probabilities=False)
real_time = time.time() - start_time

print(f"Time: {real_time:.3f}s")

In [None]:
for result, text in zip(real_results, transcript_texts):
    print(text[0])
    print(result)
    print(text[1])