In [1]:

import os
import pandas as pd
import numpy as np
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from pydantic import BaseModel, Field
from openai import OpenAI
from dotenv import load_dotenv
from typing import Literal
import json
from tqdm import tqdm

# Load environment variables
load_dotenv()

dataset = load_dataset("wylupek/mrda-corpus")

print(f"Dataset splits:")
for split_name, split_data in dataset.items():
    print(f"  {split_name}: {len(split_data):,} samples")
total_samples = sum(len(split) for split in dataset.values())
print(f"  Total: {total_samples:,} samples\n")

print(f"Sample data: {dataset['test'][0]}")


# train_labels = [sample['general_da'] for sample in dataset['train']]
train_labels = [sample['basic_da'] for sample in dataset['test']]
unique_labels = list(set(train_labels))
unique_labels.sort()
print(f"Unique general_da labels: {len(unique_labels)}")
print(f"Labels: {unique_labels}\n")


label_counts = pd.Series(train_labels).value_counts().sort_index()
print(f"Label Distribution (Testing Set):")
for label, count in label_counts.items():
    percentage = (count / len(train_labels)) * 100
    print(f"  {label}: {count:,} samples ({percentage:.1f}%)")

  from .autonotebook import tqdm as notebook_tqdm


Dataset splits:
  train: 75,067 samples
  test: 16,702 samples
  validation: 16,433 samples
  Total: 108,202 samples

Sample data: {'speaker': 'mn015', 'text': 'okay.', 'basic_da': 'F', 'general_da': 'fg', 'full_da': 'fg'}
Unique general_da labels: 5
Labels: ['B', 'D', 'F', 'Q', 'S']

Label Distribution (Testing Set):
  B: 2,152 samples (12.9%)
  D: 2,339 samples (14.0%)
  F: 1,409 samples (8.4%)
  Q: 1,231 samples (7.4%)
  S: 9,571 samples (57.3%)


In [2]:
# Define Pydantic model for structured output
class DialogueActClassification(BaseModel):
    """Classification of dialogue act into one of 5 basic categories"""
    
    dialogue_act: Literal["S", "Q", "B", "D", "F"] = Field(
        description="""Classify the utterance into one of these dialogue acts:
        - S: Statement - Declarative utterances that convey information, opinions, or facts
        - Q: Question - Interrogative utterances seeking information or clarification  
        - B: BackChannel - Brief responses that show engagement (e.g., "uh-huh", "yeah", "mm-hmm")
        - D: Disruption - Interruptions, incomplete utterances, or speech repairs
        - F: FloorGrabber - Utterances attempting to gain speaking turn (e.g., "well", "so", "but")"""
    )
    reasoning: str = Field(
        description="Brief explanation for why this classification was chosen"
    )

class BatchDialogueActClassification(BaseModel):
    """Batch classification of multiple dialogue acts"""
    
    classifications: list[DialogueActClassification] = Field(
        description="List of dialogue act classifications, one for each input utterance in the same order"
    )

# Initialize OpenAI client
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def classify_utterance(text: str, model: str = "gpt-4o-mini") -> DialogueActClassification:
    """Classify a single utterance using OpenAI with structured output"""
    
    system_prompt = """You are an expert in dialogue act classification. 
    You will classify utterances from meeting conversations into one of 5 basic dialogue acts:
    
    - S (Statement): Declarative utterances conveying information, opinions, or facts
    - Q (Question): Interrogative utterances seeking information or clarification
    - B (BackChannel): Brief responses showing engagement (e.g., "uh-huh", "yeah", "mm-hmm")  
    - D (Disruption): Interruptions, incomplete utterances, or speech repairs
    - F (FloorGrabber): Utterances attempting to gain speaking turn (e.g., "well", "so", "but")
    
    Consider the conversational context and purpose of each utterance type."""
    
    user_prompt = f"Classify this utterance: '{text}'"
    
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            response_format={"type": "json_schema", "json_schema": {
                "name": "dialogue_act_classification",
                "schema": DialogueActClassification.model_json_schema()
            }},
            temperature=0.1,
            max_tokens=500
        )
        
        result_json = json.loads(response.choices[0].message.content)
        return DialogueActClassification(**result_json)
        
    except Exception as e:
        print(f"Error classifying utterance '{text[:50]}...': {e}")
        # Raise exception to be caught by calling function
        raise e

def classify_utterances_batch(texts: list[str], model: str = "gpt-4o-mini") -> list[DialogueActClassification]:
    """Classify multiple utterances in a single API call for efficiency"""
    
    system_prompt = """You are an expert in dialogue act classification. 
    You will classify utterances from meeting conversations into one of 5 basic dialogue acts:
    
    - S (Statement): Declarative utterances conveying information, opinions, or facts
    - Q (Question): Interrogative utterances seeking information or clarification
    - B (BackChannel): Brief responses showing engagement (e.g., "uh-huh", "yeah", "mm-hmm")  
    - D (Disruption): Interruptions, incomplete utterances, or speech repairs
    - F (FloorGrabber): Utterances attempting to gain speaking turn (e.g., "well", "so", "but")
    
    Classify each utterance independently. Return classifications in the same order as the input."""
    
    # Format utterances with numbers for clarity
    numbered_utterances = [f"{i+1}. '{text}'" for i, text in enumerate(texts)]
    user_prompt = f"Classify these {len(texts)} utterances:\n\n" + "\n".join(numbered_utterances)
    
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            response_format={"type": "json_schema", "json_schema": {
                "name": "batch_dialogue_act_classification",
                "schema": BatchDialogueActClassification.model_json_schema()
            }},
            temperature=0.1,
            max_tokens=10000  # Increased for batch processing
        )
        
        result_json = json.loads(response.choices[0].message.content)
        batch_result = BatchDialogueActClassification(**result_json)
        return batch_result.classifications
        
    except Exception as e:
        print(f"Error classifying batch of {len(texts)} utterances: {e}")
        raise e

print(classify_utterance(dataset['test'][0]['text']))
print(dataset['test'][0]['basic_da'])

dialogue_act='B' reasoning="The utterance 'okay.' is a brief response that shows engagement or acknowledgment in the conversation, which fits the definition of a BackChannel."
F


In [3]:
# Main classification pipeline
def classify_test_dataset(max_samples: int = None, batch_size: int = 500, model: str = "gpt-4o-mini"):
    """
    Classify test dataset and calculate metrics using batch processing for efficiency
    
    Args:
        max_samples: Limit number of samples to classify (for testing). None = all samples
        batch_size: Number of utterances to process in each batch (default: 500)
    """
    
    test_data = dataset['test']
    
    # Limit samples if specified (useful for testing)
    if max_samples:
        test_data = test_data.select(range(max_samples))
    
    print(f"Classifying {len(test_data):,} test samples in batches of {batch_size}...")
    
    true_labels = []
    predicted_labels = []
    reasonings = []
    failed_classifications = 0
    
    # Process in batches
    total_batches = (len(test_data) + batch_size - 1) // batch_size  # Ceiling division
    
    for batch_idx in tqdm(range(total_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(test_data))
        
        # Extract batch data
        batch_samples = test_data.select(range(start_idx, end_idx))
        batch_texts = [sample['text'] for sample in batch_samples]
        batch_true_labels = [sample['basic_da'] for sample in batch_samples]
        
        # Classify batch with OpenAI
        try:
            batch_classifications = classify_utterances_batch(batch_texts, model=model)
            
            # Verify we got the expected number of results
            if len(batch_classifications) != len(batch_texts):
                print(f"Warning: Expected {len(batch_texts)} classifications but got {len(batch_classifications)} for batch {batch_idx + 1}")
                failed_classifications += len(batch_texts) - len(batch_classifications)
                # Use the classifications we got and skip the rest
                batch_classifications = batch_classifications[:len(batch_texts)]
                batch_true_labels = batch_true_labels[:len(batch_classifications)]
            
            # Add successful classifications to results
            true_labels.extend(batch_true_labels[:len(batch_classifications)])
            predicted_labels.extend([cls.dialogue_act for cls in batch_classifications])
            reasonings.extend([cls.reasoning for cls in batch_classifications])
            
        except Exception as e:
            print(f"Error classifying batch {batch_idx + 1} (samples {start_idx}-{end_idx-1}): {e}")
            failed_classifications += len(batch_texts)
            # Skip this entire batch
            continue
        
        # Print progress
        samples_processed = len(true_labels)
        if samples_processed > 0:
            current_accuracy = accuracy_score(true_labels, predicted_labels)
            print(f"Progress: {samples_processed:,}/{len(test_data):,} samples | Batch {batch_idx + 1}/{total_batches} | Current accuracy: {current_accuracy:.3f}")
    
    return true_labels, predicted_labels, reasonings, failed_classifications


In [4]:
# =============================================================================
# EXPERIMENT LOGGING & COMPREHENSIVE EVALUATION  
# =============================================================================
def setup_experiment_logging(experiment_name, hyperparams):
    """Setup timestamped logging directory for experiments"""
    results_dir = f"ac_results/{experiment_name}"
    os.makedirs(results_dir, exist_ok=True)
    
    # Save hyperparameters for reproducibility
    with open(f"{results_dir}/hyperparams.json", 'w') as f:
        json.dump(hyperparams, f, indent=2)
    
    return results_dir

def comprehensive_evaluation_llm(true_labels, predicted_labels, results_dir, hyperparams):
    """
    Complete evaluation for LLM classification with:
    - Overall metrics (accuracy, macro/weighted F1)
    - Per-class metrics and confusion matrix
    - Detailed class analysis with TP/FP/FN/TN breakdown
    """
    
    # ===== OVERALL METRICS =====
    metrics = {
        'accuracy': round(accuracy_score(true_labels, predicted_labels), 4),
        'macro_f1': round(f1_score(true_labels, predicted_labels, average='macro'), 4),
        'micro_f1': round(f1_score(true_labels, predicted_labels, average='micro'), 4),
        'weighted_f1': round(f1_score(true_labels, predicted_labels, average='weighted'), 4),
    }
    
    # ===== PER-CLASS METRICS =====
    unique_labels = sorted(list(set(true_labels + predicted_labels)))
    class_report = classification_report(
        true_labels, predicted_labels,
        labels=unique_labels,
        target_names=unique_labels,
        output_dict=True,
        zero_division=0
    )
    
    def round_nested_dict(d, decimals=4):
        """Recursively round all numeric values in nested dictionary"""
        if isinstance(d, dict):
            return {k: round_nested_dict(v, decimals) for k, v in d.items()}
        elif isinstance(d, (int, float)):
            return round(d, decimals) if isinstance(d, float) else d
        else:
            return d
    
    class_report = round_nested_dict(class_report)
    
    # ===== CONFUSION MATRIX =====
    cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
    
    # ===== DETAILED CLASS ANALYSIS =====
    print("\n" + "="*80)
    print("PER-CLASS CONFUSION ANALYSIS (5 CLASSES)")
    print("="*80)
    
    detailed_analysis = {}
    for i, class_name in enumerate(unique_labels):
        # Calculate TP/FP/FN/TN for this class
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        tn = cm.sum() - tp - fp - fn
        
        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        # Calculate distribution percentages
        true_count = true_labels.count(class_name)
        pred_count = predicted_labels.count(class_name)
        true_pct = (true_count / len(true_labels)) * 100
        pred_pct = (pred_count / len(predicted_labels)) * 100
        
        # Store detailed analysis
        detailed_analysis[class_name] = {
            'tp': int(tp), 'fp': int(fp), 'fn': int(fn), 'tn': int(tn),
            'precision': round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4),
            'true_pct': round(true_pct, 4), 'pred_pct': round(pred_pct, 4), 
            'diff_pct': round(pred_pct - true_pct, 4)
        }
        
        # Print formatted class analysis
        print(f"\n📊 CLASS {i}: '{class_name}' Analysis")
        print("-" * 50)
        print(f"  Confusion:    TP={tp:4d} | FP={fp:4d}")
        print(f"                FN={fn:4d} | TN={tn:4d}")
        print(f"  Metrics:      Prec={precision:.3f} | Rec={recall:.3f} | F1={f1:.3f}")
        print(f"  Distribution: True={true_pct:5.1f}% | Pred={pred_pct:5.1f}% | Diff={pred_pct-true_pct:+5.1f}%")
    
    # ===== CLASS DISTRIBUTION ANALYSIS =====
    true_dist = pd.Series(true_labels).value_counts().sort_index()
    pred_dist = pd.Series(predicted_labels).value_counts().sort_index()
    
    # ===== SAVE COMPREHENSIVE RESULTS =====
    results = {
        'overall_metrics': metrics,
        'per_class_metrics': class_report,
        'detailed_class_analysis': detailed_analysis,
        'confusion_matrix': cm.tolist(),
        'true_distribution': true_dist.to_dict(),
        'pred_distribution': pred_dist.to_dict(),
        'hyperparameters': hyperparams
    }
    
    with open(f"{results_dir}/evaluation.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # ===== SUMMARY OUTPUT =====
    print(f"\n🎯 OVERALL: Acc={metrics['accuracy']:.3f} | MacroF1={metrics['macro_f1']:.3f} | WeightedF1={metrics['weighted_f1']:.3f}")
    print(f"💾 Results saved to: {results_dir}")
    
    return results

print("Classification pipeline ready!")

Classification pipeline ready!


In [5]:
hyperparams = {
    "model": "gpt-4o",
    "max_samples": 1000,
    "batch_size": 200,
}

true_labels, predicted_labels, reasonings, failed_classifications = classify_test_dataset(
    model=hyperparams["model"], 
    max_samples=hyperparams["max_samples"], 
    batch_size=hyperparams["batch_size"]
)

results_dir = setup_experiment_logging("LLM_classifier_1", hyperparams)
comprehensive_evaluation_llm(true_labels, predicted_labels, results_dir, hyperparams)


Classifying 1,000 test samples in batches of 200...


Processing batches:  20%|██        | 1/5 [00:37<02:30, 37.52s/it]

Progress: 198/1,000 samples | Batch 1/5 | Current accuracy: 0.369


Processing batches:  40%|████      | 2/5 [02:45<04:31, 90.46s/it]

Progress: 394/1,000 samples | Batch 2/5 | Current accuracy: 0.475


Processing batches:  60%|██████    | 3/5 [04:40<03:23, 101.70s/it]

Progress: 585/1,000 samples | Batch 3/5 | Current accuracy: 0.446


Processing batches:  80%|████████  | 4/5 [06:39<01:48, 108.52s/it]

Progress: 782/1,000 samples | Batch 4/5 | Current accuracy: 0.490


Processing batches: 100%|██████████| 5/5 [08:39<00:00, 103.86s/it]

Progress: 959/1,000 samples | Batch 5/5 | Current accuracy: 0.477

PER-CLASS CONFUSION ANALYSIS (5 CLASSES)

📊 CLASS 0: 'B' Analysis
--------------------------------------------------
  Confusion:    TP=  69 | FP= 181
                FN=  63 | TN= 646
  Metrics:      Prec=0.276 | Rec=0.523 | F1=0.361
  Distribution: True= 13.8% | Pred= 26.1% | Diff=+12.3%

📊 CLASS 1: 'D' Analysis
--------------------------------------------------
  Confusion:    TP=  43 | FP=  91
                FN=  84 | TN= 741
  Metrics:      Prec=0.321 | Rec=0.339 | F1=0.330
  Distribution: True= 13.2% | Pred= 14.0% | Diff= +0.7%

📊 CLASS 2: 'F' Analysis
--------------------------------------------------
  Confusion:    TP=   4 | FP=  52
                FN=  44 | TN= 859
  Metrics:      Prec=0.071 | Rec=0.083 | F1=0.077
  Distribution: True=  5.0% | Pred=  5.8% | Diff= +0.8%

📊 CLASS 3: 'Q' Analysis
--------------------------------------------------
  Confusion:    TP=  44 | FP=  69
                FN=  41 | TN= 80




{'overall_metrics': {'accuracy': 0.4765,
  'macro_f1': 0.3645,
  'micro_f1': 0.4765,
  'weighted_f1': 0.4975},
 'per_class_metrics': {'B': {'precision': 0.276,
   'recall': 0.5227,
   'f1-score': 0.3613,
   'support': 132.0},
  'D': {'precision': 0.3209,
   'recall': 0.3386,
   'f1-score': 0.3295,
   'support': 127.0},
  'F': {'precision': 0.0714,
   'recall': 0.0833,
   'f1-score': 0.0769,
   'support': 48.0},
  'Q': {'precision': 0.3894,
   'recall': 0.5176,
   'f1-score': 0.4444,
   'support': 85.0},
  'S': {'precision': 0.7315,
   'recall': 0.5238,
   'f1-score': 0.6105,
   'support': 567.0},
  'accuracy': 0.4765,
  'macro avg': {'precision': 0.3578,
   'recall': 0.3972,
   'f1-score': 0.3645,
   'support': 959.0},
  'weighted avg': {'precision': 0.5511,
   'recall': 0.4765,
   'f1-score': 0.4975,
   'support': 959.0}},
 'detailed_class_analysis': {'B': {'tp': 69,
   'fp': 181,
   'fn': 63,
   'tn': 646,
   'precision': np.float64(0.276),
   'recall': np.float64(0.5227),
   'f1': n

In [7]:
for result in zip(true_labels, predicted_labels, reasonings):
    print(result)

('F', 'B', "The utterance 'okay.' is a brief response showing engagement, typical of a backchannel.")
('S', 'S', "The utterance 'some some introductions are in order.' is a declarative statement conveying information.")
('S', 'B', "The utterance 'oh okay.' is a brief response showing engagement, typical of a backchannel.")
('S', 'D', "The utterance 'sorry.' is an interruption or speech repair, indicating a disruption.")
('F', 'B', "The utterance 'okay.' is a brief response showing engagement, typical of a backchannel.")
('S', 'D', "The utterance 'getting ahead of myself.' is an incomplete thought or self-correction, indicating a disruption.")
('F', 'F', "The utterance 'so' is an attempt to gain the speaking turn, typical of a floor grabber.")
('D', 'F', "The utterance 'um for those who don't know' is an attempt to gain the speaking turn, typical of a floor grabber.")
('S', 'S', "The utterance 'everyone knows me.' is a declarative statement conveying information.")
('S', 'S', "The utter