## 1. Setup and Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
from pathlib import Path
import sys

# Add nejm_b2txt_utils to path
sys.path.append('../nejm_b2txt_utils')
from general_utils import calculate_aggregate_error_rate, calculate_error_rate

sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Libraries loaded successfully")

## 2. Load Baseline Results

Load existing baseline results from the RNN model with different language models.

In [None]:
# Define result paths
baseline_path = Path('../data/t15_pretrained_rnn_baseline')

# Find all detailed results files
result_files = list(baseline_path.glob('detailed_results_*.csv'))
print(f"Found {len(result_files)} result files:")
for f in result_files:
    print(f"  - {f.name}")

In [None]:
# Load the latest validation and test results
def load_latest_results(split='val'):
    pattern = baseline_path / f'detailed_results_{split}_*.csv'
    files = list(baseline_path.glob(f'detailed_results_{split}_*.csv'))
    if not files:
        print(f"No {split} results found!")
        return None
    
    latest = max(files, key=os.path.getctime)
    print(f"Loading {split} results: {latest.name}")
    df = pd.read_csv(latest)
    return df

# Load validation results
df_val = load_latest_results('val')
if df_val is not None:
    print(f"\nValidation set: {len(df_val)} trials")
    print(f"Columns: {list(df_val.columns)}")
    print(f"\nFirst few rows:")
    display(df_val.head())

In [None]:
# Load test results
df_test = load_latest_results('test')
if df_test is not None:
    print(f"\nTest set: {len(df_test)} trials")

## 3. Calculate Error Rates

Calculate Word Error Rate (WER) and Phoneme Error Rate (PER) for the baseline model.

In [None]:
def calculate_metrics(df):
    """Calculate WER and PER from detailed results."""
    
    # Prepare data for WER calculation
    true_sentences = []
    pred_sentences = []
    
    for idx, row in df.iterrows():
        if pd.notna(row['true_sentence']):
            true_sentences.append(row['true_sentence'].split())
            pred_sentences.append(row['pred_sentence'].split() if pd.notna(row['pred_sentence']) else [])
    
    # Calculate aggregate WER
    if true_sentences:
        wer_result = calculate_aggregate_error_rate(true_sentences, pred_sentences)
        wer_agg, wer_ci_low, wer_ci_high, wer_ind = wer_result
    else:
        wer_agg, wer_ci_low, wer_ci_high = None, None, None
        wer_ind = []
    
    # Prepare data for PER calculation
    true_phonemes = []
    pred_phonemes = []
    
    for idx, row in df.iterrows():
        if pd.notna(row['true_phonemes']):
            true_ph = [p.strip() for p in row['true_phonemes'].split('|') if p.strip()]
            pred_ph = [p.strip() for p in row['pred_phonemes'].split('|') if p.strip()] if pd.notna(row['pred_phonemes']) else []
            true_phonemes.append(true_ph)
            pred_phonemes.append(pred_ph)
    
    # Calculate aggregate PER
    if true_phonemes:
        per_result = calculate_aggregate_error_rate(true_phonemes, pred_phonemes)
        per_agg, per_ci_low, per_ci_high, per_ind = per_result
    else:
        per_agg, per_ci_low, per_ci_high = None, None, None
        per_ind = []
    
    return {
        'wer': wer_agg * 100 if wer_agg is not None else None,
        'wer_ci_low': wer_ci_low * 100 if wer_ci_low is not None else None,
        'wer_ci_high': wer_ci_high * 100 if wer_ci_high is not None else None,
        'per': per_agg * 100 if per_agg is not None else None,
        'per_ci_low': per_ci_low * 100 if per_ci_low is not None else None,
        'per_ci_high': per_ci_high * 100 if per_ci_high is not None else None,
        'wer_individual': wer_ind,
        'per_individual': per_ind
    }

# Calculate metrics for validation set
if df_val is not None:
    metrics_val = calculate_metrics(df_val)
    print("\n=== Baseline Validation Metrics ===")
    if metrics_val['wer'] is not None:
        print(f"WER: {metrics_val['wer']:.2f}% (95% CI: [{metrics_val['wer_ci_low']:.2f}%, {metrics_val['wer_ci_high']:.2f}%])")
    if metrics_val['per'] is not None:
        print(f"PER: {metrics_val['per']:.2f}% (95% CI: [{metrics_val['per_ci_low']:.2f}%, {metrics_val['per_ci_high']:.2f}%])")

## 4. Corpus-Specific Analysis

Analyze performance by corpus type (Switchboard vs other corpora).

In [None]:
if df_val is not None and 'corpus' in df_val.columns:
    # Get unique corpora
    corpora = df_val['corpus'].unique()
    print(f"Corpora in dataset: {corpora}\n")
    
    # Calculate metrics per corpus
    corpus_metrics = {}
    for corpus in corpora:
        df_corpus = df_val[df_val['corpus'] == corpus]
        metrics = calculate_metrics(df_corpus)
        corpus_metrics[corpus] = metrics
        print(f"{corpus}: WER = {metrics['wer']:.2f}%, PER = {metrics['per']:.2f}%")
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    corpora_list = list(corpus_metrics.keys())
    wer_values = [corpus_metrics[c]['wer'] for c in corpora_list]
    per_values = [corpus_metrics[c]['per'] for c in corpora_list]
    
    # WER by corpus
    ax1.bar(corpora_list, wer_values, color='steelblue', alpha=0.7)
    ax1.set_ylabel('Word Error Rate (%)', fontsize=12)
    ax1.set_title('WER by Corpus', fontsize=14)
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(axis='y', alpha=0.3)
    
    # PER by corpus
    ax2.bar(corpora_list, per_values, color='coral', alpha=0.7)
    ax2.set_ylabel('Phoneme Error Rate (%)', fontsize=12)
    ax2.set_title('PER by Corpus', fontsize=14)
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 5. Language Model Comparison Plan

To run experiments with different LMs, you'll need to:

### Required LM Models
1. **1-gram** (included): `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`
2. **3-gram** (download): From Dryad → `languageModel.tar.gz` (~60GB RAM)
3. **5-gram** (download): From Dryad → `languageModel_5gram.tar.gz` (~300GB RAM)

### Experiment Configurations

| Config | LM | Neural Rescore | Alpha | Expected WER |
|--------|----|----|-------|------|
| baseline-1gram | 1-gram | No | 0.0 | Highest |
| 1gram-opt | 1-gram | OPT-6.7b | 0.55 | - |
| 3gram | 3-gram | No | 0.0 | Lower |
| 3gram-opt | 3-gram | OPT-6.7b | 0.55 | Best |
| 5gram | 5-gram | No | 0.0 | Lowest |
| 5gram-opt | 5-gram | OPT-6.7b | 0.55 | Best |

### Commands to Run

**Terminal 1: Start Redis**
```bash
redis-server
```

**Terminal 2: Start LM Server** (example for 1-gram)
```bash
conda activate b2txt25_lm
python language_model/language-model-standalone.py \
    --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil \
    --nbest 100 \
    --acoustic_scale 0.325 \
    --blank_penalty 90 \
    --alpha 0.0 \
    --redis_ip localhost \
    --gpu_number 0
```

**Terminal 3: Run Evaluation**
```bash
conda activate b2txt25
python model_training/evaluate_model.py \
    --model_path ./data/t15_pretrained_rnn_baseline \
    --data_dir ./data/hdf5_data_final \
    --eval_type val \
    --gpu_number 1
```

## 6. Decoding Parameter Analysis

Analyze the effect of key decoding parameters on rule compliance and performance.

In [None]:
# Placeholder for parameter sensitivity analysis
# This will be populated after running experiments with different parameters

parameter_configs = {
    'baseline': {'acoustic_scale': 0.325, 'blank_penalty': 90, 'alpha': 0.55, 'beam': 17.0},
    'high_acoustic': {'acoustic_scale': 0.5, 'blank_penalty': 90, 'alpha': 0.55, 'beam': 17.0},
    'low_acoustic': {'acoustic_scale': 0.1, 'blank_penalty': 90, 'alpha': 0.55, 'beam': 17.0},
    'high_lm': {'acoustic_scale': 0.325, 'blank_penalty': 90, 'alpha': 0.8, 'beam': 17.0},
    'low_lm': {'acoustic_scale': 0.325, 'blank_penalty': 90, 'alpha': 0.2, 'beam': 17.0},
}

print("Parameter configurations to test:")
for name, params in parameter_configs.items():
    print(f"\n{name}:")
    for k, v in params.items():
        print(f"  {k}: {v}")

## 7. Linguistic Improvement Analysis

Analyze specific types of linguistic improvements from different LMs.

In [None]:
def analyze_error_types(df):
    """Analyze types of errors made by the model."""
    
    if df is None or 'true_sentence' not in df.columns:
        return None
    
    error_analysis = {
        'substitutions': [],
        'insertions': [],
        'deletions': [],
        'correct': []
    }
    
    for idx, row in df.iterrows():
        if pd.notna(row['true_sentence']) and pd.notna(row['pred_sentence']):
            true_words = row['true_sentence'].lower().split()
            pred_words = row['pred_sentence'].lower().split()
            
            # Simple analysis (can be made more sophisticated)
            if true_words == pred_words:
                error_analysis['correct'].append(row)
            elif len(pred_words) > len(true_words):
                error_analysis['insertions'].append((true_words, pred_words))
            elif len(pred_words) < len(true_words):
                error_analysis['deletions'].append((true_words, pred_words))
            else:
                error_analysis['substitutions'].append((true_words, pred_words))
    
    return error_analysis

if df_val is not None:
    error_analysis = analyze_error_types(df_val)
    if error_analysis:
        print("\n=== Error Type Distribution ===")
        print(f"Correct: {len(error_analysis['correct'])}")
        print(f"Substitutions: {len(error_analysis['substitutions'])}")
        print(f"Insertions: {len(error_analysis['insertions'])}")
        print(f"Deletions: {len(error_analysis['deletions'])}")
        
        # Visualize error distribution
        error_counts = [
            len(error_analysis['correct']),
            len(error_analysis['substitutions']),
            len(error_analysis['insertions']),
            len(error_analysis['deletions'])
        ]
        labels = ['Correct', 'Substitutions', 'Insertions', 'Deletions']
        colors = ['green', 'orange', 'red', 'purple']
        
        plt.figure(figsize=(10, 6))
        plt.pie(error_counts, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
        plt.title('Error Type Distribution', fontsize=14)
        plt.show()

## 8. Example Predictions

Show examples of predictions to understand model behavior.

In [None]:
if df_val is not None:
    print("=== Example Predictions ===")
    print("\nShowing 10 random examples:\n")
    
    # Sample random trials
    sample_df = df_val[df_val['true_sentence'].notna()].sample(min(10, len(df_val)))
    
    for idx, row in sample_df.iterrows():
        print(f"Trial {row['trial']} ({row['corpus']})")
        print(f"  True:  {row['true_sentence']}")
        print(f"  Pred:  {row['pred_sentence']}")
        
        # Calculate WER for this trial
        true_words = row['true_sentence'].split()
        pred_words = row['pred_sentence'].split() if pd.notna(row['pred_sentence']) else []
        wer = calculate_error_rate(true_words, pred_words) / len(true_words) * 100
        print(f"  WER:   {wer:.1f}%")
        print()

## 9. Summary and Next Steps

### Current Baseline Performance
- **Model**: RNN + 1-gram LM (or as configured)
- **WER**: [To be filled after running experiments]
- **PER**: [To be filled after running experiments]

### Next Experiments to Run
1. ✓ Baseline 1-gram (already available)
2. ⏳ 3-gram LM (download required)
3. ⏳ 3-gram + OPT-6.7b
4. ⏳ 5-gram LM (download required, high RAM)
5. ⏳ Parameter tuning for rule compliance

### Expected Improvements
- **3-gram vs 1-gram**: Better grammar, reduced WER by ~5-10%
- **Neural LM rescoring**: Better long-range context, reduced WER by ~3-5%
- **5-gram**: Best n-gram performance, but high memory cost

### Documentation to Create
- Parameter tuning guide
- Rule compliance validation
- Linguistic improvement examples
- Speed/memory trade-off analysis