In [22]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/emotion/emotion-emotion_69k.csv


# **TASK 1 - Preprocessing**

In [23]:
import pandas as pd
import re
import json
from collections import Counter
import os
import unicodedata  # For unicode normalization in text cleaning

# ---------- Paths ----------
if os.path.exists('/kaggle/input'):
    data_path = '/kaggle/input/emotion/emotion-emotion_69k.csv'
    output_dir = '/kaggle/working'
else:
    data_path = 'dataset/emotion-emotion_69k.csv'
    output_dir = '.'

# ---------- Configuration ----------
MAX_SEQ_LENGTH = 128  # Cap combined sequences (recommended for project)
MIN_WORD_FREQ = 2     # Minimum frequency to include word in vocab (reduces vocab size, improves stability)

SPECIAL_TOKENS = {
    '<pad>': 0,
    '<bos>': 1,
    '<eos>': 2,
    '<unk>': 3
}

print("="*70)
print("TASK 1: PREPROCESSING FOR EMPATHETIC DIALOGUES")
print("="*70)
print(f"✓ Required modules loaded: pandas, re, json, unicodedata")


# ---------- Load dataset ----------
print("\n📂 Loading dataset...")
df = pd.read_csv(data_path)
print(f"✓ Data loaded: {df.shape[0]} samples")
print(f"✓ Columns: {df.columns.tolist()}")

# Validate dataset structure
needed_columns = {"Situation", "emotion", "empathetic_dialogues", "labels"}
missing = [c for c in needed_columns if c not in df.columns]
if missing:
    print(f"⚠️  Expected columns missing: {missing}")
    print("   Make sure you're using the Empathetic Dialogues dataset (not a classification CSV).")
    raise ValueError(f"Missing required columns: {missing}")

# Filter out rows with missing critical data
df = df.dropna(subset=['Situation', 'emotion', 'empathetic_dialogues', 'labels'])
df = df.reset_index(drop=True)
print(f"✓ After removing NaN: {df.shape[0]} samples")

# ---------- Parse empathetic dialogues ----------
print("\n🔍 Parsing empathetic dialogues...")

def extract_customer_utterance(dialogue_text):
    """
    Extract the last customer utterance from the dialogue.
    Format: "Customer :utterance1\nAgent :response1\nCustomer :utterance2\nAgent :"
    We want the last customer utterance before the final "Agent :"
    Robust to spacing/casing variations.
    """
    if pd.isna(dialogue_text):
        return ""
    
    # Split by newlines
    lines = dialogue_text.strip().split('\n')
    
    # Find the last "Customer :" line (case-insensitive, flexible spacing)
    customer_utterances = []
    for line in lines:
        # Check for customer prefix (case-insensitive)
        line_lower = line.strip().lower()
        if line_lower.startswith('customer :') or line_lower.startswith('customer:'):
            # Extract everything after "Customer :" or "Customer:"
            if ':' in line:
                utterance = line.split(':', 1)[1].strip()
                customer_utterances.append(utterance)
    
    # Return the last customer utterance (the one we need to respond to)
    if customer_utterances:
        return customer_utterances[-1]
    return ""

# Extract customer utterances
df['customer_utterance'] = df['empathetic_dialogues'].apply(extract_customer_utterance)

# Check if extraction worked
print(f"✓ Extracted customer utterances")
print(f"✓ Sample: '{df['customer_utterance'].iloc[0][:80]}...'")

# Filter out rows where we couldn't extract customer utterance
df = df[df['customer_utterance'].str.len() > 0]
df = df.reset_index(drop=True)
print(f"✓ After filtering empty utterances: {df.shape[0]} samples")

# ---------- Hard Text Normalization ----------
def normalize_unicode(text):
    """Normalize unicode to NFKC form (cleans weird widths & punctuation)"""
    return unicodedata.normalize("NFKC", text)

def remove_emojis(text):
    """Remove emojis and other non-standard unicode characters"""
    return ''.join(char for char in text if unicodedata.category(char)[0] != 'C' 
                   and unicodedata.category(char) not in ['So', 'Sk'])

def standardize_quotes(text):
    """Standardize different quote types to simple double quotes"""
    # Replace smart quotes, curly quotes, etc.
    quote_chars = ['\u201c', '\u201d', '\u2018', '\u2019', '`', '\u00b4', '\u2032']
    for char in quote_chars:
        text = text.replace(char, '"')
    return text

def clean_text_hard(text):
    """Hard normalization: lowercase, collapse spaces, standardize quotes, strip emojis"""
    if pd.isna(text):
        return ""
    
    # Convert to string
    text = str(text)
    
    # Unicode normalization (NFKC cleans weird widths & punctuation)
    text = normalize_unicode(text)
    
    # Lowercase
    text = text.lower()
    
    # Remove emojis
    text = remove_emojis(text)
    
    # Standardize quotes
    text = standardize_quotes(text)
    
    # Remove extra punctuation repetitions (e.g., "!!!" -> "!")
    text = re.sub(r'([.!?]){2,}', r'\1', text)
    text = re.sub(r'([,;:]){2,}', r'\1', text)
    
    # Add space around punctuation for better tokenization
    text = re.sub(r'([.,!?;:])', r' \1 ', text)
    
    # Collapse multiple spaces into one
    text = re.sub(r'\s+', ' ', text)
    
    # Strip leading/trailing whitespace
    text = text.strip()
    
    return text

def truncate_sequence(text, max_len=MAX_SEQ_LENGTH):
    """Truncate text to max_len tokens"""
    tokens = text.split()
    if len(tokens) > max_len:
        tokens = tokens[:max_len]
    return ' '.join(tokens)

# ---------- Apply normalization ----------
print("\n🧹 Applying hard normalization...")
# Note: We normalize but DON'T truncate yet - truncation happens after templating
df['Situation'] = df['Situation'].apply(clean_text_hard)
df['customer_utterance'] = df['customer_utterance'].apply(clean_text_hard)
df['labels'] = df['labels'].apply(clean_text_hard)
df['emotion'] = df['emotion'].str.lower().str.strip()

# Filter out invalid emotions (but allow multi-word emotions with spaces)
df = df[df['emotion'].str.len() < 50]
df = df[df['emotion'].str.len() > 0]
df = df.reset_index(drop=True)

print(f"✓ After filtering emotions: {df.shape[0]} samples")
print(f"✓ Unique emotions: {df['emotion'].nunique()}")

# ---------- Split dataset (80/10/10) ----------
print("\n📊 Splitting dataset...")
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
n = len(df)
train = df[:int(0.8*n)].copy()
val = df[int(0.8*n):int(0.9*n)].copy()
test = df[int(0.9*n):].copy()

# Reset indices for each split
train = train.reset_index(drop=True)
val = val.reset_index(drop=True)
test = test.reset_index(drop=True)

# Filter out empty pairs to avoid NaNs during training
train = train[(train["labels"].str.len() > 0) & 
              (train["Situation"].str.len() > 0) & 
              (train["customer_utterance"].str.len() > 0)].reset_index(drop=True)
val = val[(val["labels"].str.len() > 0) & 
          (val["Situation"].str.len() > 0) & 
          (val["customer_utterance"].str.len() > 0)].reset_index(drop=True)
test = test[(test["labels"].str.len() > 0) & 
            (test["Situation"].str.len() > 0) & 
            (test["customer_utterance"].str.len() > 0)].reset_index(drop=True)

print(f"✓ Train: {len(train)} ({len(train)/n*100:.1f}%)")
print(f"✓ Val: {len(val)} ({len(val)/n*100:.1f}%)")
print(f"✓ Test: {len(test)} ({len(test)/n*100:.1f}%)")

# ---------- Build vocabulary ONLY on training data ----------
print("\n📚 Building vocabulary on TRAIN set only...")

def tokenize(text):
    """Simple whitespace tokenization"""
    return text.split()

word_counts = Counter()

# Count words only from training data (all three text fields)
for _, row in train.iterrows():
    word_counts.update(tokenize(row['Situation']))
    word_counts.update(tokenize(row['customer_utterance']))
    word_counts.update(tokenize(row['labels']))

print(f"✓ Total unique words in train: {len(word_counts)}")
print(f"✓ Total word occurrences: {sum(word_counts.values())}")

# Show frequency distribution
freq_1 = sum(1 for cnt in word_counts.values() if cnt == 1)
freq_2_5 = sum(1 for cnt in word_counts.values() if 2 <= cnt <= 5)
freq_6_plus = sum(1 for cnt in word_counts.values() if cnt > 5)
print(f"  • Frequency = 1: {freq_1} words (singletons)")
print(f"  • Frequency 2-5: {freq_2_5} words")
print(f"  • Frequency > 5: {freq_6_plus} words")

# ---------- Create vocabulary with special tokens ----------
# Start with special tokens (fixed IDs)
vocab = dict(SPECIAL_TOKENS)

# Add template tokens FIRST (critical - these appear in input template)
TEMPLATE_TOKENS = ["emotion", "situation", "customer", "agent", "|", ":"]
for tok in TEMPLATE_TOKENS:
    if tok not in vocab:
        vocab[tok] = len(vocab)

print(f"✓ Added {len(TEMPLATE_TOKENS)} template tokens (emotion, situation, customer, agent, |, :)")

# Add emotion-specific tokens (replace spaces with underscores for token names)
unique_emotions = sorted(train['emotion'].unique())
emotion_tokens = [f'<emotion_{emotion.replace(" ", "_")}>' for emotion in unique_emotions]

for emotion_tok in emotion_tokens:
    vocab[emotion_tok] = len(vocab)

print(f"✓ Added {len(emotion_tokens)} emotion tokens")

# Add words from training data (with frequency threshold)
words_added = 0
words_filtered = 0
for word in sorted(word_counts.keys()):
    if word not in vocab:
        if word_counts[word] >= MIN_WORD_FREQ:
            vocab[word] = len(vocab)
            words_added += 1
        else:
            words_filtered += 1

print(f"✓ Added {words_added} word tokens (freq >= {MIN_WORD_FREQ})")
print(f"✓ Filtered {words_filtered} rare words (freq < {MIN_WORD_FREQ}, will map to <unk>)")

# Create reverse mapping
idx2word = {i: w for w, i in vocab.items()}

print(f"\n📊 Vocabulary statistics:")
print(f"  • Total vocab size: {len(vocab)}")
print(f"  • Special tokens: {len(SPECIAL_TOKENS)}")
print(f"  • Template tokens: {len(TEMPLATE_TOKENS)}")
print(f"  • Emotion tokens: {len(emotion_tokens)}")
print(f"  • Word tokens: {words_added}")
print(f"  • Rare words (filtered to <unk>): {words_filtered}")

# ---------- Create input templates and convert to IDs ----------
print("\n🔧 Creating input templates and tokenizing to IDs...")

def create_input_template(emotion, situation, customer_utterance):
    """
    Create input template EXACTLY as specified in Task 2 assignment:
    emotion: {emotion} | situation: {situation} | customer: {utterance} | agent:
    
    Using emotion token for the emotion value, with exact literal markers per spec.
    """
    # Use emotion token for the emotion value
    emo_tok = f'<emotion_{emotion.replace(" ", "_")}>'
    
    # Build template with exact literal markers as per assignment spec
    # Format: emotion: {emo_tok} | situation: {situation} | customer: {utterance} | agent:
    template = f"emotion : {emo_tok} | situation : {situation} | customer : {customer_utterance} | agent :"
    return template

def text_to_ids(text, vocab, add_bos=False, add_eos=False):
    """Convert text to list of token IDs, with <unk> for OOV words"""
    tokens = tokenize(text)
    ids = []
    
    if add_bos:
        ids.append(SPECIAL_TOKENS['<bos>'])
    
    for token in tokens:
        ids.append(vocab.get(token, SPECIAL_TOKENS['<unk>']))
    
    if add_eos:
        ids.append(SPECIAL_TOKENS['<eos>'])
    
    return ids

def prepare_sequences(df_split, split_name):
    """Prepare source and target ID sequences for a dataset split"""
    sequences = []
    
    for idx, row in df_split.iterrows():
        # Create source: emotion + situation + customer utterance
        src_text = create_input_template(
            row['emotion'],
            row['Situation'],
            row['customer_utterance']
        )
        
        # Target: agent response (must have <bos> at start and <eos> at end)
        tgt_text = row['labels']
        
        # Convert to IDs
        src_ids = text_to_ids(src_text, vocab, add_bos=False, add_eos=False)
        tgt_ids = text_to_ids(tgt_text, vocab, add_bos=True, add_eos=True)  # <bos> ... <eos>
        
        # Apply sequence length cap AFTER templating
        if len(src_ids) > MAX_SEQ_LENGTH:
            src_ids = src_ids[:MAX_SEQ_LENGTH]
        
        if len(tgt_ids) > MAX_SEQ_LENGTH:
            # Keep <bos> at start and <eos> at end, truncate middle
            tgt_ids = [tgt_ids[0]] + tgt_ids[1:MAX_SEQ_LENGTH-1] + [tgt_ids[-1]]
        
        sequences.append({
            'src_ids': src_ids,
            'tgt_ids': tgt_ids,
            'emotion': row['emotion'],
            'situation': row['Situation'],
            'customer_utterance': row['customer_utterance'],
            'response': row['labels']
        })
    
    print(f"  ✓ {split_name}: {len(sequences)} sequences")
    return sequences

# Prepare sequences for all splits
train_sequences = prepare_sequences(train, 'Train')
val_sequences = prepare_sequences(val, 'Val')
test_sequences = prepare_sequences(test, 'Test')

# ---------- Sanity checks ----------
print("\n🔍 Sanity checks on preprocessed sequences...")

# Check template tokens are in vocab (not mapping to <unk>)
print("  Template tokens in vocab:")
for tok in TEMPLATE_TOKENS:
    print(f"    '{tok}' -> ID {vocab.get(tok, 'MISSING!')}")

# Check a few source sequences contain template tokens
print("\n  Checking first 3 sources for template tokens:")
for i in range(min(3, len(train_sequences))):
    s = train_sequences[i]["src_ids"]
    print(f"    Source {i}:")
    print(f"      has 'emotion' token? {vocab.get('emotion') in s}")
    print(f"      has ':' token?       {vocab.get(':') in s}")
    print(f"      has '|' token?       {vocab.get('|') in s}")
    print(f"      has 'agent' token?   {vocab.get('agent') in s}")

# Check target format
print("\n  Target format validation:")
print(f"    Example target starts with <bos> (1)? {train_sequences[0]['tgt_ids'][0] == SPECIAL_TOKENS['<bos>']}")
print(f"    Example target ends with <eos> (2)?   {train_sequences[0]['tgt_ids'][-1] == SPECIAL_TOKENS['<eos>']}")

# Check for any all-UNK sequences (would indicate vocab issues)
unk_id = SPECIAL_TOKENS['<unk>']
src_all_unk = sum(1 for seq in train_sequences if all(id == unk_id for id in seq['src_ids']))
tgt_all_unk = sum(1 for seq in train_sequences if all(id == unk_id for id in seq['tgt_ids'][1:-1]))  # Skip BOS/EOS
print(f"\n  Sequences with all <unk> tokens:")
print(f"    Source: {src_all_unk} (should be 0)")
print(f"    Target: {tgt_all_unk} (should be 0)")

if src_all_unk > 0 or tgt_all_unk > 0:
    print("    ⚠️  WARNING: Some sequences are all <unk>! Check vocab building.")

# Check emotion coverage across splits
print("\n  Emotion coverage across splits:")
for split_name, split_df in [("Train", train), ("Val", val), ("Test", test)]:
    unique_emos = split_df["emotion"].nunique()
    print(f"    {split_name}: {unique_emos} unique emotions, {len(split_df)} samples")

# ---------- Save preprocessed data ----------
print("\n💾 Saving preprocessed data...")
os.makedirs(output_dir, exist_ok=True)

# Save CSV splits (with all columns for reference)
train.to_csv(os.path.join(output_dir, 'train.csv'), index=False)
val.to_csv(os.path.join(output_dir, 'val.csv'), index=False)
test.to_csv(os.path.join(output_dir, 'test.csv'), index=False)
print("  ✓ Saved train.csv, val.csv, test.csv")

# Save vocabulary
with open(os.path.join(output_dir, 'vocab.json'), 'w', encoding='utf-8') as f:
    json.dump(vocab, f, ensure_ascii=False, indent=2)

with open(os.path.join(output_dir, 'idx2word.json'), 'w', encoding='utf-8') as f:
    json.dump(idx2word, f, ensure_ascii=False, indent=2)
print("  ✓ Saved vocab.json, idx2word.json")

# Save special tokens for reference
with open(os.path.join(output_dir, 'special_tokens.json'), 'w', encoding='utf-8') as f:
    json.dump({
        'special_tokens': SPECIAL_TOKENS,
        'template_tokens': {tok: vocab[tok] for tok in TEMPLATE_TOKENS},
        'emotion_tokens': {tok: vocab[tok] for tok in emotion_tokens},
        'max_seq_length': MAX_SEQ_LENGTH,
        'min_word_freq': MIN_WORD_FREQ
    }, f, ensure_ascii=False, indent=2)
print("  ✓ Saved special_tokens.json")

# Save word frequencies for analysis/reporting
with open(os.path.join(output_dir, 'word_freq_train.json'), 'w', encoding='utf-8') as f:
    json.dump(dict(word_counts.most_common(1000)), f, ensure_ascii=False, indent=2)
print("  ✓ Saved word_freq_train.json (top 1000 words)")

# Save tokenized ID sequences (JSONL format for Task 4)
def save_ids_jsonl(sequences, filepath):
    """Save sequences in JSONL format"""
    with open(filepath, 'w', encoding='utf-8') as f:
        for seq in sequences:
            json.dump({
                'src_ids': seq['src_ids'],
                'tgt_ids': seq['tgt_ids']
            }, f, ensure_ascii=False)
            f.write('\n')

save_ids_jsonl(train_sequences, os.path.join(output_dir, 'train_ids.jsonl'))
save_ids_jsonl(val_sequences, os.path.join(output_dir, 'val_ids.jsonl'))
save_ids_jsonl(test_sequences, os.path.join(output_dir, 'test_ids.jsonl'))
print("  ✓ Saved train_ids.jsonl, val_ids.jsonl, test_ids.jsonl")

# Also save human-readable pair files for inspection
def save_pairs_csv(sequences, df_split, filepath):
    """Save source-target pairs in CSV format for easy inspection"""
    pairs_df = pd.DataFrame({
        'emotion': df_split['emotion'].values,
        'situation': df_split['Situation'].values,
        'customer_utterance': df_split['customer_utterance'].values,
        'agent_response': df_split['labels'].values,
        'src_length': [len(seq['src_ids']) for seq in sequences],
        'tgt_length': [len(seq['tgt_ids']) for seq in sequences]
    })
    pairs_df.to_csv(filepath, index=False)

save_pairs_csv(train_sequences, train, os.path.join(output_dir, 'train_pairs.csv'))
save_pairs_csv(val_sequences, val, os.path.join(output_dir, 'val_pairs.csv'))
save_pairs_csv(test_sequences, test, os.path.join(output_dir, 'test_pairs.csv'))
print("  ✓ Saved train_pairs.csv, val_pairs.csv, test_pairs.csv")

# ---------- Statistics ----------
print("\n" + "="*70)
print("✅ PREPROCESSING COMPLETE")
print("="*70)

print("\n📁 Dataset splits saved:")
print(f"  • train.csv ({len(train)} samples)")
print(f"  • val.csv ({len(val)} samples)")
print(f"  • test.csv ({len(test)} samples)")

print("\n📝 Tokenized ID files saved (for Task 4):")
print(f"  • train_ids.jsonl ({len(train_sequences)} sequences)")
print(f"  • val_ids.jsonl ({len(val_sequences)} sequences)")
print(f"  • test_ids.jsonl ({len(test_sequences)} sequences)")

print("\n📋 Pair files saved (for inspection):")
print(f"  • train_pairs.csv, val_pairs.csv, test_pairs.csv")

print("\n📚 Vocabulary files saved:")
print(f"  • vocab.json (word → id mapping, {len(vocab)} entries)")
print(f"  • idx2word.json (id → word mapping, {len(idx2word)} entries)")
print(f"  • special_tokens.json (special token definitions)")

print("\n🔑 Special tokens:")
for token, idx in SPECIAL_TOKENS.items():
    print(f"  {token}: {idx}")

print(f"\n📝 Template tokens ({len(TEMPLATE_TOKENS)} total):")
for tok in TEMPLATE_TOKENS:
    print(f"  '{tok}': {vocab[tok]}")

print(f"\n🎭 Emotion tokens ({len(emotion_tokens)} total):")
for tok in emotion_tokens[:10]:
    print(f"  {tok}: {vocab[tok]}")
if len(emotion_tokens) > 10:
    print(f"  ... and {len(emotion_tokens) - 10} more")

print(f"\n📏 Max sequence length: {MAX_SEQ_LENGTH} tokens (applied after templating)")

# Sample statistics
avg_src_len = sum(len(seq['src_ids']) for seq in train_sequences) / len(train_sequences)
avg_tgt_len = sum(len(seq['tgt_ids']) for seq in train_sequences) / len(train_sequences)
avg_situation_len = train['Situation'].apply(lambda x: len(x.split())).mean()
avg_utterance_len = train['customer_utterance'].apply(lambda x: len(x.split())).mean()
avg_response_len = train['labels'].apply(lambda x: len(x.split())).mean()

# Count sequences that hit the cap
src_capped = sum(1 for seq in train_sequences if len(seq['src_ids']) == MAX_SEQ_LENGTH)
tgt_capped = sum(1 for seq in train_sequences if len(seq['tgt_ids']) == MAX_SEQ_LENGTH)

print(f"\n📊 Average lengths (train set):")
print(f"  • Situation: {avg_situation_len:.1f} tokens")
print(f"  • Customer utterance: {avg_utterance_len:.1f} tokens")
print(f"  • Agent response: {avg_response_len:.1f} tokens")
print(f"  • Source (full input template): {avg_src_len:.1f} IDs")
print(f"  • Target (with <bos> and <eos>): {avg_tgt_len:.1f} IDs")
print(f"  • Sequences capped at {MAX_SEQ_LENGTH}: src={src_capped}, tgt={tgt_capped}")

print("\n🔍 Example preprocessed sample:")
sample = train_sequences[0]
print(f"  Emotion: {sample['emotion']}")
print(f"  Situation: {sample['situation'][:80]}...")
print(f"  Customer: {sample['customer_utterance'][:80]}...")
print(f"  Response: {sample['response'][:80]}...")
print(f"  Source length: {len(sample['src_ids'])} IDs")
print(f"  Target length: {len(sample['tgt_ids'])} IDs (with <bos> and <eos>)")
print(f"  Source IDs (first 20): {sample['src_ids'][:20]}")
print(f"  Target IDs (first 10): {sample['tgt_ids'][:10]}")
print(f"  Target starts with <bos> (1)? {sample['tgt_ids'][0] == 1}")
print(f"  Target ends with <eos> (2)? {sample['tgt_ids'][-1] == 2}")

print("\n🔍 Input template format (Task 2 spec):")
sample_template = create_input_template(
    train.iloc[0]['emotion'],
    train.iloc[0]['Situation'][:40] + "...",
    train.iloc[0]['customer_utterance'][:40] + "..."
)
print(f"  {sample_template[:120]}...")

print("\n" + "="*70)
print("✨ Ready for Task 2: Sequence-to-Sequence Modeling")
print("="*70) 


TASK 1: PREPROCESSING FOR EMPATHETIC DIALOGUES
✓ Required modules loaded: pandas, re, json, unicodedata

📂 Loading dataset...
✓ Data loaded: 64636 samples
✓ Columns: ['Unnamed: 0', 'Situation', 'emotion', 'empathetic_dialogues', 'labels', 'Unnamed: 5', 'Unnamed: 6']
✓ After removing NaN: 64632 samples

🔍 Parsing empathetic dialogues...
✓ Extracted customer utterances
✓ Sample: 'I remember going to see the fireworks with my best friend. It was the first time...'
✓ After filtering empty utterances: 64591 samples

🧹 Applying hard normalization...
✓ After filtering emotions: 64591 samples
✓ Unique emotions: 32

📊 Splitting dataset...
✓ Train: 51672 (80.0%)
✓ Val: 6459 (10.0%)
✓ Test: 6460 (10.0%)

📚 Building vocabulary on TRAIN set only...
✓ Total unique words in train: 21286
✓ Total word occurrences: 2554215
  • Frequency = 1: 4364 words (singletons)
  • Frequency 2-5: 8592 words
  • Frequency > 5: 8330 words
✓ Added 6 template tokens (emotion, situation, customer, agent, |, :)
✓ Added 32

# **TASK 2 - Input and Output Definition**

In [24]:
# ===== Task-2: exact X/Y format (with checks) =====
import os, json, re, pandas as pd

data_dir = "/kaggle/working" if os.path.exists("/kaggle/input") else "."
with open(f"{data_dir}/vocab.json","r") as f: stoi = json.load(f)

PAD, BOS, EOS, UNK = stoi["<pad>"], stoi["<bos>"], stoi["<eos>"], stoi["<unk>"]

CUSTOMER_RE = re.compile(r"customer\s*:?\s*(.*?)(?:\s*agent\s*:|$)", re.IGNORECASE)
def extract_customer_text(text: str) -> str:
    if pd.isna(text): return ""
    m = CUSTOMER_RE.search(str(text))
    return m.group(1).strip() if m else str(text).strip()

def tok(s): return str(s).split()
def enc(tokens): return [stoi.get(t, UNK) for t in tokens]

def pick_col(df, candidates, default=None):
    cmap = {c.lower(): c for c in df.columns}
    for name in candidates:
        if name.lower() in cmap:
            return cmap[name.lower()]
    return default

def make_pairs(in_csv, out_csv, out_jsonl):
    df_raw = pd.read_csv(in_csv)

    emotion_c   = pick_col(df_raw, ["emotion","feeling"])
    situation_c = pick_col(df_raw, ["situation","context","prompt","situation_text","Situation"])
    utter_c     = pick_col(df_raw, ["utterance","text","customer","message","empathetic_dialogues"])
    response_c  = pick_col(df_raw, ["response","reply","agent_reply","target","reference","output","gold","labels"])

    if utter_c is None or response_c is None:
        raise ValueError(f"{in_csv}: missing utterance/response-like columns. Have={list(df_raw.columns)}")

    utterance_series = (
        df_raw[utter_c].apply(extract_customer_text)
        if utter_c else pd.Series([""] * len(df_raw))
    )

    df = pd.DataFrame({
        "emotion":   df_raw[emotion_c]   if emotion_c   else "",
        "situation": df_raw[situation_c] if situation_c else "",
        "utterance": utterance_series,
        "response":  df_raw[response_c]
    }).fillna("")

    # exact X format
    df["input_text"] = (
        "Emotion: " + df["emotion"].astype(str).str.strip()
        + " | Situation: " + df["situation"].astype(str).str.strip()
        + " | Customer: " + df["utterance"].astype(str).str.strip()
        + " Agent:"
    )
    df["target_text"] = df["response"].astype(str).str.strip()

    # drop empty targets
    df = df[df["target_text"].ne("")].reset_index(drop=True)

    # save readable pairs
    df[["input_text","target_text"]].to_csv(out_csv, index=False)

    # save ids
    with open(out_jsonl, "w") as f:
        for _, r in df.iterrows():
            src_ids = enc(tok(r["input_text"]))
            tgt_ids = [BOS] + enc(tok(r["target_text"])) + [EOS]
            f.write(json.dumps({"src_ids": src_ids, "tgt_ids": tgt_ids}) + "\n")

    # sanity vs. brief (non-failing informative prints)
    ex1_x = "Emotion: sentimental | Situation: I remember going to the fireworks with my best friend... | Customer: This was a best friend. I miss her. Agent:"
    ex1_y = "Where has she gone?"
    ex2_x = "Emotion: afraid | Situation: I used to scare for darkness | Customer: it feels like hitting to blank wall when I see the darkness Agent:"
    ex2_y = "Oh ya? I don't really see how"
    print("Format check (examples from brief):")
    print("X ex1 ->", ex1_x[:110] + " ...")
    print("Y ex1 ->", ex1_y)
    print("X ex2 ->", ex2_x[:110] + " ...")
    print("Y ex2 ->", ex2_y)

for name in ("train","val","test"):
    make_pairs(f"{data_dir}/{name}.csv",
               f"{data_dir}/{name}_pairs.csv",
               f"{data_dir}/{name}_ids.jsonl")

print("✅ Task 2 aligned with spec — exact X format + IDs written.")


Format check (examples from brief):
X ex1 -> Emotion: sentimental | Situation: I remember going to the fireworks with my best friend... | Customer: This wa ...
Y ex1 -> Where has she gone?
X ex2 -> Emotion: afraid | Situation: I used to scare for darkness | Customer: it feels like hitting to blank wall when ...
Y ex2 -> Oh ya? I don't really see how
Format check (examples from brief):
X ex1 -> Emotion: sentimental | Situation: I remember going to the fireworks with my best friend... | Customer: This wa ...
Y ex1 -> Where has she gone?
X ex2 -> Emotion: afraid | Situation: I used to scare for darkness | Customer: it feels like hitting to blank wall when ...
Y ex2 -> Oh ya? I don't really see how
Format check (examples from brief):
X ex1 -> Emotion: sentimental | Situation: I remember going to the fireworks with my best friend... | Customer: This wa ...
Y ex1 -> Where has she gone?
X ex2 -> Emotion: afraid | Situation: I used to scare for darkness | Customer: it feels like hitting to bla

# **TASK 3 — Transformer Encoder–Decoder Model**

In [25]:
# ===== Task-3: spec-aligned Transformer (distinct layers + proper masks) =====
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1,L,D]
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Helper to guard against all-False mask rows (can cause softmax NaNs)
def ensure_nonempty_rows(mask):
    """Ensure each row has at least one True value to prevent softmax NaNs."""
    row_has_true = mask.any(dim=-1, keepdim=True)
    safe = mask.clone()
    safe[:, :, 0] |= ~row_has_true.squeeze(-1)
    return safe

# Multi-Head Attention that accepts boolean (preferred) or additive masks
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.dk = d_model // n_heads
        self.Wq, self.Wk, self.Wv = nn.Linear(d_model, d_model), nn.Linear(d_model, d_model), nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
        self.attn_drop = nn.Dropout(dropout)  # Dropout on attention weights
    def _split(self, x):
        B,L,D = x.size()
        return x.view(B, L, self.h, self.dk).transpose(1, 2)  # [B,h,L,dk]
    def _merge(self, x):
        B,h,L,dk = x.size()
        return x.transpose(1, 2).contiguous().view(B, L, h*dk)
    def forward(self, q, k, v, mask=None):
        Q, K, V = self._split(self.Wq(q)), self._split(self.Wk(k)), self._split(self.Wv(v))
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.dk)  # [B,h,Lq,Lk]
        if mask is not None:
            if mask.dtype == torch.bool:  # True = keep, False = mask
                mask = ensure_nonempty_rows(mask)  # Guard against all-False rows
                scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))
            else:  # additive mask (0 or -inf)
                scores = scores + mask.unsqueeze(1)
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)  # Safety fallback for any NaNs
        attn = self.attn_drop(attn)  # Dropout on attention weights
        out = attn @ V
        out = self._merge(out)
        return self.Wo(self.drop(out))

# FFN
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff, d_model), nn.Dropout(dropout)
        )
    def forward(self, x): return self.ff(x)

# Encoder/Decoder layers
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1, self.n2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
    def forward(self, x, attn_mask=None):
        x = self.n1(x + self.sa(x, x, x, attn_mask))
        x = self.n2(x + self.ff(x))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.msa = MultiHeadAttention(d_model, n_heads, dropout)
        self.xattn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1, self.n2, self.n3 = nn.LayerNorm(d_model), nn.LayerNorm(d_model), nn.LayerNorm(d_model)
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        x = self.n1(x + self.msa(x, x, x, self_mask))
        x = self.n2(x + self.xattn(x, mem, mem, mem_mask))
        x = self.n3(x + self.ff(x))
        return x

# Stacks with DISTINCT layer instances
class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, attn_mask=None):
        for lyr in self.layers:
            x = lyr(x, attn_mask)
        return x

class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        for lyr in self.layers:
            x = lyr(x, mem, self_mask, mem_mask)
        return x

# Mask helpers
def make_pad_mask(ids, pad_idx):
    return (ids != pad_idx)  # [B,L] True=token, False=pad

def make_attn_mask(valid_q, valid_k):
    return valid_q.unsqueeze(2) & valid_k.unsqueeze(1)  # [B,Lq,Lk] bool

def make_causal_mask(L, device):
    return torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))  # [L,L]

# Full model
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=2, num_layers=2, d_ff=2048, dropout=0.1, pad_idx=0, weight_tying=True):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.embed_drop = nn.Dropout(dropout)  # Dropout on embeddings
        self.encoder = Encoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.decoder = Decoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.out = nn.Linear(d_model, vocab_size)
        
        # Weight tying: share embeddings with output layer
        if weight_tying:
            self.out.weight = self.embed.weight

    def forward(self, src, tgt):
        device = src.device
        B, Ls = src.size()
        Lt = tgt.size(1)

        # masks (boolean)
        src_valid = make_pad_mask(src, self.pad_idx)          # [B,Ls]
        tgt_valid = make_pad_mask(tgt, self.pad_idx)          # [B,Lt]
        enc_self  = make_attn_mask(src_valid, src_valid)      # [B,Ls,Ls]
        dec_caus  = make_causal_mask(Lt, device)              # [Lt,Lt]
        dec_self  = dec_caus.unsqueeze(0).expand(B, -1, -1) & make_attn_mask(tgt_valid, tgt_valid)  # [B,Lt,Lt]
        cross     = make_attn_mask(tgt_valid, src_valid)      # [B,Lt,Ls]

        # embeddings + PE + dropout
        src_e = self.embed_drop(self.pos(self.embed(src) * math.sqrt(self.d_model)))
        tgt_e = self.embed_drop(self.pos(self.embed(tgt) * math.sqrt(self.d_model)))

        mem   = self.encoder(src_e, enc_self)
        dec   = self.decoder(tgt_e, mem, dec_self, cross)
        return self.out(dec)

# quick smoke test
if __name__ == "__main__":
    torch.manual_seed(0)
    V, PAD = 10000, 0
    m = TransformerModel(vocab_size=V, d_model=256, n_heads=2, num_layers=2, dropout=0.1, pad_idx=PAD, weight_tying=True)
    src = torch.randint(1, V, (2, 10)); src[:, -2:] = PAD
    tgt = torch.randint(1, V, (2, 9));  tgt[:, -1]  = PAD
    logits = m(src, tgt)
    print("Shape:", logits.shape)  # (2, 9, 10000)
    print("✅ Task 3 complete – spec-aligned Transformer with NaN guards, attention/embedding dropout, and weight tying.")


Shape: torch.Size([2, 9, 10000])
✅ Task 3 complete – spec-aligned Transformer with NaN guards, attention/embedding dropout, and weight tying.


# **Task 4 - Training & Hyperparameters**

In [None]:
# ============================================================
# TASK 4 — Training & Hyperparameters (final, Kaggle-ready)
# ============================================================

import os, json, math, random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ------------------------ Config ----------------------------
class Config:
    # ==== Assignment hyperparameters ====
    batch_size = 32          # set 32 or 64
    lr = 3e-4                # in [1e-4, 5e-4]
    betas = (0.9, 0.98)      # Adam betas
    epochs = 6
    grad_clip = 1.0

    # ==== Model ====
    d_model = 256
    n_heads = 2
    num_layers = 2
    d_ff = 2048
    dropout = 0.1
    max_seq_len = 128

    # ==== Paths ====
    data_dir = "/kaggle/working" if os.path.exists("/kaggle/input") else "."
    save_dir = "checkpoints"

    # ==== Device ====
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
cfg = Config()

# -------------------- Dataset (JSONL) -----------------------
class EmpatheticDataset(Dataset):
    """
    Expects JSONL lines like:
      {"src_ids": [...], "tgt_ids": [...]}
    Uses vocab.json for <pad>/<bos>/<eos> etc.
    """
    def __init__(self, jsonl_path, vocab, max_len=128):
        self.rows = []
        self.vocab = vocab
        self.max_len = max_len
        self.pad = vocab["<pad>"]

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                ex = json.loads(line)
                s = ex["src_ids"][:max_len]
                t = ex["tgt_ids"][:max_len]
                self.rows.append((s, t))

    def __len__(self): return len(self.rows)

    def __getitem__(self, idx):
        s, t = self.rows[idx]
        return torch.tensor(s, dtype=torch.long), torch.tensor(t, dtype=torch.long)

    def collate_fn(self, batch):
        src_list, tgt_list = zip(*batch)
        src = nn.utils.rnn.pad_sequence(src_list, batch_first=True, padding_value=self.pad)
        tgt = nn.utils.rnn.pad_sequence(tgt_list, batch_first=True, padding_value=self.pad)
        return src, tgt

# ------------------- Transformer (from scratch) -------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1,L,D]
    def forward(self, x):  # x: [B,L,D]
        return x + self.pe[:, :x.size(1), :]

def make_pad_mask(ids, pad_idx):  # True=valid token
    return (ids != pad_idx)

def make_attn_mask(valid_q, valid_k):  # [B,Lq,Lk]
    return valid_q.unsqueeze(2) & valid_k.unsqueeze(1)

def make_causal_mask(L, device):      # [L,L]
    return torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))

def _ensure_nonempty_rows(mask):
    # mask [B,Lq,Lk] (True=keep). Ensure each row has >=1 True to avoid NaN in softmax.
    has_true = mask.any(dim=-1, keepdim=True)
    mask = mask.clone()
    mask[:, :, 0] = mask[:, :, 0] | (~has_true.squeeze(-1))
    return mask

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.dk = d_model // n_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)

    def _split(self, x):  # [B,L,D] -> [B,h,L,dk]
        B,L,D = x.size()
        return x.view(B, L, self.h, self.dk).transpose(1,2)

    def _merge(self, x):  # [B,h,L,dk] -> [B,L,D]
        B,h,L,dk = x.size()
        return x.transpose(1,2).contiguous().view(B,L,h*dk)

    def forward(self, q, k, v, mask=None):
        Q, K, V = self._split(self.Wq(q)), self._split(self.Wk(k)), self._split(self.Wv(v))
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.dk)  # [B,h,Lq,Lk]
        if mask is not None:
            if mask.dtype == torch.bool:
                mask = _ensure_nonempty_rows(mask)
                scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))
            else:  # additive mask
                scores = scores + mask.unsqueeze(1)
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)
        out = attn @ V
        out = self._merge(out)
        return self.Wo(self.drop(out))

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff, d_model), nn.Dropout(dropout)
        )
    def forward(self, x): return self.ff(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1, self.n2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
    def forward(self, x, attn_mask=None):
        x = self.n1(x + self.sa(x, x, x, attn_mask))
        x = self.n2(x + self.ff(x))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.msa = MultiHeadAttention(d_model, n_heads, dropout)
        self.xattn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1, self.n2, self.n3 = nn.LayerNorm(d_model), nn.LayerNorm(d_model), nn.LayerNorm(d_model)
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        x = self.n1(x + self.msa(x, x, x, self_mask))
        x = self.n2(x + self.xattn(x, mem, mem, mem_mask))
        x = self.n3(x + self.ff(x))
        return x

class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, attn_mask=None):
        for lyr in self.layers: x = lyr(x, attn_mask)
        return x

class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        for lyr in self.layers: x = lyr(x, mem, self_mask, mem_mask)
        return x

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=2, num_layers=2, d_ff=2048, dropout=0.1, pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.encoder = Encoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.decoder = Decoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        device = src.device
        B, Ls = src.size()
        Lt = tgt.size(1)

        src_valid = make_pad_mask(src, self.pad_idx)
        tgt_valid = make_pad_mask(tgt, self.pad_idx)
        enc_self  = make_attn_mask(src_valid, src_valid)                # [B,Ls,Ls]
        dec_caus  = make_causal_mask(Lt, device)                        # [Lt,Lt]
        dec_self  = dec_caus.unsqueeze(0).expand(B,-1,-1) & make_attn_mask(tgt_valid, tgt_valid)
        cross     = make_attn_mask(tgt_valid, src_valid)                # [B,Lt,Ls]

        src_e = self.pos(self.embed(src) * math.sqrt(self.d_model))
        tgt_e = self.pos(self.embed(tgt) * math.sqrt(self.d_model))

        mem = self.encoder(src_e, enc_self)
        dec = self.decoder(tgt_e, mem, dec_self, cross)
        return self.out(dec)

# ------------------------ Inference --------------------------
def greedy_decode(model, src, vocab, max_len=128, device=None):
    if device is None: device = src.device
    model.eval()
    bos, eos = vocab["<bos>"], vocab["<eos>"]
    B = src.size(0)
    tgt = torch.full((B,1), bos, dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len-1):
            logits = model(src, tgt)              # [B,T,V]
            next_tok = logits[:,-1,:].argmax(-1, keepdim=True)
            tgt = torch.cat([tgt, next_tok], dim=1)
            if (next_tok.squeeze(-1) == eos).all(): break
    return tgt[:,1:]  # strip BOS

# ------------------------ Metrics ---------------------------
def _install_metrics():
    try:
        import sacrebleu, rouge_score, evaluate  # noqa
    except Exception:
        import subprocess, sys
        subprocess.check_call([sys.executable,"-m","pip","install","-q","sacrebleu","rouge-score","evaluate"])

try:
    from sacrebleu import BLEU
    from rouge_score import rouge_scorer
    import evaluate as _hf_eval
    METRICS_OK = True
except Exception:
    _install_metrics()
    try:
        from sacrebleu import BLEU
        from rouge_score import rouge_scorer
        import evaluate as _hf_eval
        METRICS_OK = True
    except Exception:
        METRICS_OK = False
        BLEU = None; rouge_scorer = None; _hf_eval = None

class MetricsCalculator:
    def __init__(self, vocab):
        self.vocab = vocab
        self.idx2word = {v:k for k,v in vocab.items()}
        if METRICS_OK:
            self.bleu = BLEU()
            self.rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
            self.chrf = _hf_eval.load("chrf")
        else:
            self.bleu = self.rouge = self.chrf = None

    def ids_to_text(self, ids):
        out = []
        for x in ids:
            if x == self.vocab["<eos>"] or x == self.vocab["<pad>"]: break
            if x == self.vocab["<bos>"]: continue
            out.append(self.idx2word.get(int(x), "<unk>"))
        return " ".join(out)

    def bleu_score(self, preds, refs):
        if not self.bleu:
            p = [" ".join(set(a.split())) for a in preds]
            r = [" ".join(set(b.split())) for b in refs]
            match = sum(len(set(a.split()) & set(b.split())) for a,b in zip(p,r))
            denom = sum(max(len(set(a.split())), len(set(b.split()))) for a,b in zip(p,r)) or 1
            return 100.0*match/denom
        return self.bleu.corpus_score(preds, [refs]).score

    def rougeL(self, preds, refs):
        if not self.rouge:
            inter = [len(set(a.split()) & set(b.split()))/(len(set(b.split())) or 1) for a,b in zip(preds,refs)]
            return 100.0*float(np.mean(inter)) if inter else 0.0
        scores = [self.rouge.score(r, p)["rougeL"].fmeasure for p,r in zip(preds,refs)]
        return 100.0*float(np.mean(scores)) if scores else 0.0

    def chrf_score(self, preds, refs):
        if not self.chrf:
            vals = []
            for p,r in zip(preds,refs):
                A,B=set(p),set(r)
                if not B: vals.append(0.0); continue
                prec = len(A&B)/(len(A) or 1); rec = len(A&B)/len(B)
                f1 = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
                vals.append(100.0*f1)
            return float(np.mean(vals)) if vals else 0.0
        return float(self.chrf.compute(predictions=preds, references=refs)["score"])

# ---------------------- Train / Eval ------------------------
def calculate_perplexity(loss_val):
    return float(torch.exp(torch.tensor(loss_val)).item())

def train_epoch(model, loader, opt, criterion, device, pad_idx):
    model.train()
    total = 0.0
    for src, tgt in tqdm(loader, desc="Training"):
        src, tgt = src.to(device), tgt.to(device)

        # Teacher forcing (shift)
        inp = tgt[:, :-1]
        out = tgt[:, 1:]

        opt.zero_grad()
        logits = model(src, inp)
        loss = criterion(logits.reshape(-1, logits.size(-1)), out.reshape(-1))
        if torch.isnan(loss): raise ValueError("NaN loss detected.")
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        total += loss.item()
    return total/len(loader)

def evaluate(model, loader, criterion, device, vocab, metrics: MetricsCalculator):
    model.eval()
    total = 0.0
    preds, refs = [], []
    with torch.no_grad():
        for src, tgt in tqdm(loader, desc="Evaluating"):
            src, tgt = src.to(device), tgt.to(device)
            inp, out = tgt[:, :-1], tgt[:, 1:]
            logits = model(src, inp)
            loss = criterion(logits.reshape(-1, logits.size(-1)), out.reshape(-1))
            total += loss.item()

            # generate (greedy) for metrics
            gen = greedy_decode(model, src, vocab, max_len=cfg.max_seq_len, device=device)  # [B,T]
            for i in range(src.size(0)):
                preds.append(metrics.ids_to_text(gen[i].cpu().numpy()))
                refs.append(metrics.ids_to_text(out[i].cpu().numpy()))
    avg_loss = total/len(loader)
    ppl = calculate_perplexity(avg_loss)
    bleu = metrics.bleu_score(preds, refs)
    rougeL = metrics.rougeL(preds, refs)
    chrf = metrics.chrf_score(preds, refs)
    return avg_loss, ppl, bleu, rougeL, chrf

def save_ckpt(path, model, opt, epoch, extra):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": opt.state_dict(),
        **extra
    }, path)

# --------------------------- Main ---------------------------
def main():
    print(f"Device: {cfg.device}")
    os.makedirs(cfg.save_dir, exist_ok=True)

    # ---- Load vocab
    with open(os.path.join(cfg.data_dir, "vocab.json"), "r") as f:
        vocab = json.load(f)
    assert all(k in vocab for k in ["<pad>","<bos>","<eos>"]), "vocab.json must contain <pad>, <bos>, <eos>"
    pad = vocab["<pad>"]

    # ---- Datasets / Loaders
    train_ds = EmpatheticDataset(os.path.join(cfg.data_dir,"train_ids.jsonl"), vocab, cfg.max_seq_len)
    val_ds   = EmpatheticDataset(os.path.join(cfg.data_dir,"val_ids.jsonl"),   vocab, cfg.max_seq_len)
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                              collate_fn=train_ds.collate_fn, num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False,
                              collate_fn=val_ds.collate_fn, num_workers=0)

    # ---- Model / Optimizer / Loss
    model = TransformerModel(vocab_size=len(vocab), d_model=cfg.d_model, n_heads=cfg.n_heads,
                             num_layers=cfg.num_layers, d_ff=cfg.d_ff, dropout=cfg.dropout,
                             pad_idx=pad).to(cfg.device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr, betas=cfg.betas)
    criterion = nn.CrossEntropyLoss(ignore_index=pad)

    metrics_calc = MetricsCalculator(vocab)

    best_bleu = -1.0
    best_epoch = -1
    history = []

    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("Start training...")
    for epoch in range(1, cfg.epochs+1):
        print(f"\nEpoch {epoch}/{cfg.epochs}")
        tr_loss = train_epoch(model, train_loader, opt, criterion, cfg.device, pad)
        vl_loss, ppl, bleu, rougeL, chrf = evaluate(model, val_loader, criterion, cfg.device, vocab, metrics_calc)

        print(f"Train Loss: {tr_loss:.4f}")
        print(f"Val   Loss: {vl_loss:.4f} | PPL: {ppl:.2f}")
        print(f"BLEU: {bleu:.2f} | ROUGE-L: {rougeL:.2f} | chrF: {chrf:.2f}")

        # save epoch checkpoint
        save_ckpt(
            os.path.join(cfg.save_dir, f"checkpoint_epoch_{epoch}.pt"),
            model, opt, epoch,
            extra={
                "metrics": {"val_loss":vl_loss,"ppl":ppl,"bleu":bleu,"rougeL":rougeL,"chrf":chrf},
                "config": vars(cfg),
                "vocab": vocab
            }
        )

        # save best (by BLEU)
        if bleu > best_bleu:
            best_bleu = bleu
            best_epoch = epoch
            save_ckpt(
                os.path.join(cfg.save_dir, "best_model.pt"),
                model, opt, epoch,
                extra={
                    "metrics": {"val_loss":vl_loss,"ppl":ppl,"bleu":bleu,"rougeL":rougeL,"chrf":chrf},
                    "config": vars(cfg),
                    "vocab": vocab
                }
            )
            print(f"🏆 New best model saved at checkpoints/best_model.pt (BLEU {bleu:.2f})")
        else:
            print(f"📈 Best BLEU so far: {best_bleu:.2f} (epoch {best_epoch})")

        history.append({
            "epoch": epoch,
            "train_loss": tr_loss,
            "val_loss": vl_loss,
            "perplexity": ppl,
            "bleu": bleu,
            "rouge_l": rougeL,
            "chrf": chrf
        })

    with open(os.path.join(cfg.save_dir,"training_history.json"),"w") as f:
        json.dump(history, f, indent=2)

    print("\nFiles in checkpoints/:")
    for fn in sorted(os.listdir(cfg.save_dir)):
        print(" -", fn)

    print(f"\n✅ Training complete. Best BLEU: {best_bleu:.2f} (epoch {best_epoch})")
    print(f"Best model path: {os.path.join(cfg.save_dir, 'best_model.pt')}")

if __name__ == "__main__":
    main()


Device: cuda
Parameters: 14,487,102
Start training...

Epoch 1/6


Training:  42%|████▏     | 683/1615 [00:22<00:32, 28.84it/s]

# **Task 5 - Evaluation**

In [None]:

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import os
from tqdm import tqdm
import numpy as np
import math
import torch.nn.functional as F
import pandas as pd
from collections import defaultdict
import random

# ===== METRICS SETUP (robust, auto-install) =====
METRICS_AVAILABLE = False
def ensure_metrics():
    """Auto-install and import metrics packages (sacrebleu + rouge-score)"""
    global METRICS_AVAILABLE, BLEU, CHRF, rouge_scorer
    try:
        from sacrebleu import BLEU
        from sacrebleu.metrics import CHRF
        from rouge_score import rouge_scorer
        METRICS_AVAILABLE = True
        return
    except Exception:
        print("Installing metrics packages...")
        import sys, subprocess, importlib
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "sacrebleu", "rouge-score"])
        importlib.invalidate_caches()
        from sacrebleu import BLEU
        from sacrebleu.metrics import CHRF
        from rouge_score import rouge_scorer
        METRICS_AVAILABLE = True
        print("✓ Metrics packages installed")

ensure_metrics()

# ===== EMBEDDED MODEL ARCHITECTURE =====
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Helper to prevent all-False mask rows (causes NaN in softmax)
def ensure_nonempty_rows(mask):
    """Ensure each row has at least one True value to prevent NaN in softmax"""
    # mask: [B, Lq, Lk] boolean
    row_has_true = mask.any(dim=-1, keepdim=True)  # [B, Lq, 1]
    # For rows with all False, set first column to True
    mask = mask.clone()
    mask[:, :, 0] = mask[:, :, 0] | (~row_has_true.squeeze(-1))
    return mask

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.dk = d_model // n_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, q, k, v, mask=None):
        B, L, D = q.size()
        Q = self.Wq(q).view(B, L, self.h, self.dk).transpose(1, 2)
        K = self.Wk(k).view(B, -1, self.h, self.dk).transpose(1, 2)
        V = self.Wv(v).view(B, -1, self.h, self.dk).transpose(1, 2)
        
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.dk)
        if mask is not None:
            if mask.dtype == torch.bool:
                mask = ensure_nonempty_rows(mask)  # Prevent all-False rows
                scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))
            else:
                scores = scores + mask.unsqueeze(1)
        
        attn = F.softmax(scores, dim=-1)
        attn = torch.nan_to_num(attn, nan=0.0)  # Safety: sanitize any NaN
        out = attn @ V
        out = out.transpose(1, 2).contiguous().view(B, L, D)
        return self.Wo(self.drop(out))

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff, d_model), nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.ff(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
    def forward(self, x, attn_mask=None):
        x = self.n1(x + self.sa(x, x, x, attn_mask))
        x = self.n2(x + self.ff(x))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.msa = MultiHeadAttention(d_model, n_heads, dropout)
        self.xattn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.n1 = nn.LayerNorm(d_model)
        self.n2 = nn.LayerNorm(d_model)
        self.n3 = nn.LayerNorm(d_model)
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        x = self.n1(x + self.msa(x, x, x, self_mask))
        x = self.n2(x + self.xattn(x, mem, mem, mem_mask))
        x = self.n3(x + self.ff(x))
        return x

class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, attn_mask=None):
        for lyr in self.layers:
            x = lyr(x, attn_mask)
        return x

class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
    def forward(self, x, mem, self_mask=None, mem_mask=None):
        for lyr in self.layers:
            x = lyr(x, mem, self_mask, mem_mask)
        return x

def make_pad_mask(ids, pad_idx):
    return (ids != pad_idx)

def make_attn_mask(valid_q, valid_k):
    return valid_q.unsqueeze(2) & valid_k.unsqueeze(1)

def make_causal_mask(L, device):
    return torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=2, num_layers=2, d_ff=2048, dropout=0.1, pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        self.encoder = Encoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.decoder = Decoder(d_model, n_heads, num_layers, d_ff, dropout)
        self.out = nn.Linear(d_model, vocab_size)
    
    def forward(self, src, tgt):
        device = src.device
        B, Ls = src.size()
        Lt = tgt.size(1)
        
        src_valid = make_pad_mask(src, self.pad_idx)
        tgt_valid = make_pad_mask(tgt, self.pad_idx)
        enc_self = make_attn_mask(src_valid, src_valid)
        dec_caus = make_causal_mask(Lt, device)
        dec_self = dec_caus.unsqueeze(0).expand(B, -1, -1) & make_attn_mask(tgt_valid, tgt_valid)
        cross = make_attn_mask(tgt_valid, src_valid)
        
        src_e = self.pos(self.embed(src) * math.sqrt(self.d_model))
        tgt_e = self.pos(self.embed(tgt) * math.sqrt(self.d_model))
        
        mem = self.encoder(src_e, enc_self)
        dec = self.decoder(tgt_e, mem, dec_self, cross)
        return self.out(dec)

# ===== GREEDY DECODING =====
def greedy_decode(model, src, vocab, max_len=128, device=None):
    """Generate response using greedy decoding (with safe length limiting)"""
    if device is None:
        device = src.device
    
    model.eval()
    batch_size = src.size(0)
    bos_token = vocab['<bos>']
    eos_token = vocab['<eos>']
    
    tgt = torch.full((batch_size, 1), bos_token, dtype=torch.long, device=device)
    
    with torch.no_grad():
        for _ in range(max_len - 1):  # Cap at max_len-1 to account for BOS
            output = model(src, tgt)
            next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)
            if (next_token.squeeze(-1) == eos_token).all():
                break
    
    # Return sequences stripped of BOS
    return tgt[:, 1:]  # Remove BOS token

# ===== DATASET =====
class ChatbotDataset(Dataset):
    def __init__(self, jsonl_file, vocab, max_len=128):
        self.data = []
        self.vocab = vocab
        self.max_len = max_len
        self.pad_token = vocab['<pad>']
        
        with open(jsonl_file, 'r') as f:
            for line in f:
                item = json.loads(line.strip())
                src_ids = item['src_ids'][:max_len]
                tgt_ids = item['tgt_ids'][:max_len]
                self.data.append({'src': src_ids, 'tgt': tgt_ids})
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def collate_fn(self, batch):
        src_seqs = [item['src'] for item in batch]
        tgt_seqs = [item['tgt'] for item in batch]
        
        max_src_len = max(len(seq) for seq in src_seqs)
        max_tgt_len = max(len(seq) for seq in tgt_seqs)
        
        src_batch = torch.full((len(batch), max_src_len), self.pad_token, dtype=torch.long)
        tgt_batch = torch.full((len(batch), max_tgt_len), self.pad_token, dtype=torch.long)
        
        for i, (src, tgt) in enumerate(zip(src_seqs, tgt_seqs)):
            src_batch[i, :len(src)] = torch.tensor(src)
            tgt_batch[i, :len(tgt)] = torch.tensor(tgt)
        
        return src_batch, tgt_batch

# ===== METRICS CALCULATOR =====
class MetricsCalculator:
    def __init__(self, vocab):
        self.vocab = vocab
        self.idx2word = {v: k for k, v in vocab.items()}
        
        # Initialize metrics (guaranteed available after ensure_metrics())
        self.bleu = BLEU()
        self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
        self.chrf_metric = CHRF()
    
    def ids_to_text(self, ids):
        """Convert token IDs to text"""
        tokens = []
        for id in ids:
            id = int(id)  # Cast to plain int (safer with numpy types)
            if id == self.vocab['<eos>'] or id == self.vocab['<pad>']:
                break
            if id == self.vocab['<bos>']:
                continue
            tokens.append(self.idx2word.get(id, '<unk>'))
        return ' '.join(tokens)
    
    def calculate_bleu(self, predictions, references):
        """Calculate BLEU score"""
        pred_texts = [self.ids_to_text(pred) for pred in predictions]
        ref_texts = [self.ids_to_text(ref) for ref in references]
        
        valid_pairs = [(p, r) for p, r in zip(pred_texts, ref_texts) if p.strip() and r.strip()]
        if not valid_pairs:
            return 0.0
        
        pred_texts, ref_texts = zip(*valid_pairs)
        
        try:
            score = self.bleu.corpus_score(pred_texts, [ref_texts])
            return score.score
        except:
            # Fallback: simple BLEU
            return self._simple_bleu(pred_texts, ref_texts)
    
    def _simple_bleu(self, pred_texts, ref_texts):
        matches = 0
        total = 0
        for pred, ref in zip(pred_texts, ref_texts):
            pred_words = set(pred.split())
            ref_words = set(ref.split())
            if pred_words and ref_words:
                matches += len(pred_words & ref_words)
                total += max(len(pred_words), len(ref_words))
        return (matches / total * 100) if total > 0 else 0.0
    
    def calculate_rouge_l(self, predictions, references):
        """Calculate ROUGE-L score"""
        pred_texts = [self.ids_to_text(pred) for pred in predictions]
        ref_texts = [self.ids_to_text(ref) for ref in references]
        
        scores = []
        for pred, ref in zip(pred_texts, ref_texts):
            if pred.strip() and ref.strip():
                try:
                    score = self.rouge_scorer.score(ref, pred)
                    scores.append(score['rougeL'].fmeasure)
                except:
                    # Fallback
                    pred_words = set(pred.lower().split())
                    ref_words = set(ref.lower().split())
                    if ref_words:
                        scores.append(len(pred_words & ref_words) / len(ref_words))
        
        return np.mean(scores) if scores else 0.0
    
    def calculate_chrf(self, predictions, references):
        """Calculate chrF score using sacrebleu"""
        pred_texts = [self.ids_to_text(p) for p in predictions]
        ref_texts = [self.ids_to_text(r) for r in references]
        
        valid_pairs = [(p, r) for p, r in zip(pred_texts, ref_texts) if p.strip() and r.strip()]
        if not valid_pairs:
            return 0.0
        
        pred_texts, ref_texts = zip(*valid_pairs)
        
        try:
            return self.chrf_metric.corpus_score(pred_texts, [ref_texts]).score
        except:
            # Fallback: character-level F1
            scores = []
            for pred, ref in zip(pred_texts, ref_texts):
                pred_chars = set(pred.lower())
                ref_chars = set(ref.lower())
                if ref_chars:
                    precision = len(pred_chars & ref_chars) / len(pred_chars) if pred_chars else 0
                    recall = len(pred_chars & ref_chars) / len(ref_chars)
                    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
                    scores.append(f1 * 100)
            return np.mean(scores) if scores else 0.0

# ===== COMPREHENSIVE EVALUATION =====
def evaluate_model(model, test_loader, vocab, metrics_calc, device, criterion=None):
    """
    Comprehensive model evaluation
    
    Returns:
        results: Dictionary with all metrics
        examples: List of sample outputs for qualitative analysis
    """
    model.eval()
    
    all_predictions = []
    all_references = []
    all_examples = []
    total_loss = 0
    num_batches = 0
    
    print("Running evaluation...")
    
    with torch.no_grad():
        for src, tgt in tqdm(test_loader, desc="Evaluating"):
            src, tgt = src.to(device), tgt.to(device)
            
            # Calculate loss if criterion provided
            if criterion:
                tgt_input = tgt[:, :-1]
                tgt_output = tgt[:, 1:]
                output = model(src, tgt_input)
                loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
                total_loss += loss.item()
                num_batches += 1
            
            # Generate predictions
            generated = greedy_decode(model, src, vocab, max_len=128, device=device)
            
            # Process per-sample
            for i in range(src.size(0)):
                # Extract sequences
                src_seq = src[i].cpu().numpy()
                
                # Trim generated sequence at first <eos> (extra safety for mixed-length batches)
                pred_seq_t = generated[i].cpu()
                eos_pos = (pred_seq_t == vocab['<eos>']).nonzero(as_tuple=False)
                if len(eos_pos) > 0:
                    pred_seq_t = pred_seq_t[:eos_pos[0].item()+1]
                pred_seq = pred_seq_t.numpy()
                
                # Reference (drop BOS, cut at EOS/PAD)
                ref = tgt[i, 1:].cpu()
                eos_pos = (ref == vocab['<eos>']).nonzero(as_tuple=False)
                if len(eos_pos) > 0:
                    ref = ref[:eos_pos[0].item()]
                ref = ref[ref != vocab['<pad>']].numpy()
                
                all_predictions.append(pred_seq)
                all_references.append(ref)
                
                # Store examples for qualitative analysis
                if len(all_examples) < 50:  # Keep 50 examples
                    all_examples.append({
                        'input': metrics_calc.ids_to_text(src_seq),
                        'reference': metrics_calc.ids_to_text(ref),
                        'generated': metrics_calc.ids_to_text(pred_seq)
                    })
    
    # Calculate metrics
    print("\nCalculating metrics...")
    
    results = {}
    
    # Perplexity (if loss available)
    if criterion and num_batches > 0:
        avg_loss = total_loss / num_batches
        results['loss'] = avg_loss
        results['perplexity'] = torch.exp(torch.tensor(avg_loss)).item()
    
    # BLEU
    results['bleu'] = metrics_calc.calculate_bleu(all_predictions, all_references)
    
    # ROUGE-L
    results['rouge_l'] = metrics_calc.calculate_rouge_l(all_predictions, all_references)
    
    # chrF
    results['chrf'] = metrics_calc.calculate_chrf(all_predictions, all_references)
    
    return results, all_examples

# ===== HUMAN EVALUATION FRAMEWORK =====
def create_human_evaluation_template(examples, output_file="human_eval_template.csv"):
    """
    Create a CSV template for human evaluation
    
    Criteria:
    - Fluency: How natural/grammatical is the response? (1-5)
    - Relevance: How relevant is the response to the input? (1-5)
    - Adequacy: How well does it capture the meaning? (1-5)
    """
    
    # Select random subset for human evaluation
    eval_samples = random.sample(examples, min(100, len(examples)))
    
    data = []
    for idx, ex in enumerate(eval_samples, 1):
        data.append({
            'ID': idx,
            'Input': ex['input'],
            'Reference': ex['reference'],
            'Generated': ex['generated'],
            'Fluency (1-5)': '',
            'Relevance (1-5)': '',
            'Adequacy (1-5)': '',
            'Comments': ''
        })
    
    df = pd.DataFrame(data)
    df.to_csv(output_file, index=False)
    
    print(f"✓ Human evaluation template saved: {output_file}")
    print(f"  {len(data)} samples ready for annotation")
    print("\nEvaluation Guidelines:")
    print("  Fluency (1-5):")
    print("    5 = Perfect, native-like")
    print("    4 = Good, minor errors")
    print("    3 = Acceptable, some errors")
    print("    2 = Poor, many errors")
    print("    1 = Incomprehensible")
    print("\n  Relevance (1-5):")
    print("    5 = Perfectly relevant")
    print("    4 = Mostly relevant")
    print("    3 = Somewhat relevant")
    print("    2 = Barely relevant")
    print("    1 = Completely irrelevant")
    print("\n  Adequacy (1-5):")
    print("    5 = Fully captures meaning")
    print("    4 = Captures most meaning")
    print("    3 = Partial meaning")
    print("    2 = Little meaning")
    print("    1 = No meaning captured")
    
    return df

def analyze_human_evaluation(eval_file="human_eval_completed.csv"):
    """Analyze completed human evaluation"""
    df = pd.read_csv(eval_file)
    
    # Calculate averages
    fluency_avg = df['Fluency (1-5)'].mean()
    relevance_avg = df['Relevance (1-5)'].mean()
    adequacy_avg = df['Adequacy (1-5)'].mean()
    
    print("="*70)
    print("HUMAN EVALUATION RESULTS")
    print("="*70)
    print(f"Samples evaluated: {len(df)}")
    print(f"\nAverage Scores:")
    print(f"  Fluency:   {fluency_avg:.2f} / 5.0")
    print(f"  Relevance: {relevance_avg:.2f} / 5.0")
    print(f"  Adequacy:  {adequacy_avg:.2f} / 5.0")
    print(f"  Overall:   {(fluency_avg + relevance_avg + adequacy_avg) / 3:.2f} / 5.0")
    
    # Score distribution
    print(f"\nScore Distribution:")
    for metric in ['Fluency (1-5)', 'Relevance (1-5)', 'Adequacy (1-5)']:
        print(f"\n{metric}:")
        counts = df[metric].value_counts().sort_index()
        for score, count in counts.items():
            percentage = (count / len(df)) * 100
            print(f"  {int(score)}: {'█' * int(percentage/2)} {count} ({percentage:.1f}%)")
    
    return {
        'fluency': fluency_avg,
        'relevance': relevance_avg,
        'adequacy': adequacy_avg,
        'overall': (fluency_avg + relevance_avg + adequacy_avg) / 3
    }

# ===== QUALITATIVE ANALYSIS =====
def display_qualitative_examples(examples, num_examples=10, save_file=None):
    """Display qualitative comparison of model outputs"""
    
    print("\n" + "="*70)
    print("QUALITATIVE EXAMPLES: Model Output vs Ground Truth")
    print("="*70)
    
    # Select diverse examples
    selected = random.sample(examples, min(num_examples, len(examples)))
    
    results = []
    for idx, ex in enumerate(selected, 1):
        print(f"\n{'─'*70}")
        print(f"Example {idx}:")
        print(f"{'─'*70}")
        print(f"📥 Input:     {ex['input']}")
        print(f"✅ Reference: {ex['reference']}")
        print(f"🤖 Generated: {ex['generated']}")
        
        # Simple quality indicators
        ref_words = set(ex['reference'].lower().split())
        gen_words = set(ex['generated'].lower().split())
        overlap = len(ref_words & gen_words) / len(ref_words) if ref_words else 0
        
        print(f"📊 Word Overlap: {overlap*100:.1f}%")
        
        results.append({
            'Example': idx,
            'Input': ex['input'],
            'Reference': ex['reference'],
            'Generated': ex['generated'],
            'Word Overlap': f"{overlap*100:.1f}%"
        })
    
    if save_file:
        df = pd.DataFrame(results)
        df.to_csv(save_file, index=False)
        print(f"\n✓ Examples saved to: {save_file}")
    
    return results

# ===== MAIN EVALUATION FUNCTION =====
def run_full_evaluation(checkpoint_path, data_dir, output_dir="/kaggle/working/evaluation"):
    """
    Run complete evaluation pipeline
    
    Args:
        checkpoint_path: Path to trained model checkpoint
        data_dir: Directory with vocab.json and test_ids.jsonl
        output_dir: Where to save evaluation results
    
    Returns:
        metrics: Dictionary of all metrics
        examples: Qualitative examples
    """
    print("="*70)
    print("MODEL EVALUATION")
    print("="*70)
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Load vocabulary
    vocab_path = f"{data_dir}/vocab.json"
    print(f"Loading vocab: {vocab_path}")
    with open(vocab_path, "r") as f:
        vocab = json.load(f)
    print(f"✓ Vocab size: {len(vocab)}")
    
    # Load model checkpoint first
    print(f"Loading model: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Prefer vocab from checkpoint if present (ensures ID consistency)
    if "vocab" in checkpoint and isinstance(checkpoint["vocab"], dict):
        print("✓ Using vocab from checkpoint (ensures ID consistency)")
        vocab = checkpoint["vocab"]
    
    # Load test data
    print("Loading test data...")
    test_dataset = ChatbotDataset(f"{data_dir}/test_ids.jsonl", vocab)
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=test_dataset.collate_fn,
        num_workers=0
    )
    print(f"✓ Test samples: {len(test_dataset)}")
    
    # Extract model config from checkpoint (if available) or use defaults
    config = checkpoint.get('config', {})
    d_model = config.get('d_model', 256)
    n_heads = config.get('n_heads', 2)
    num_layers = config.get('num_layers', 2)
    d_ff = config.get('d_ff', 2048)
    dropout = config.get('dropout', 0.1)
    
    print(f"✓ Model config: d_model={d_model}, n_heads={n_heads}, num_layers={num_layers}")
    
    model = TransformerModel(
        vocab_size=len(vocab),
        d_model=d_model,
        n_heads=n_heads,
        num_layers=num_layers,
        d_ff=d_ff,
        dropout=dropout,
        pad_idx=vocab['<pad>']
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✓ Model loaded (epoch {checkpoint.get('epoch', 'N/A')})")
    
    # Initialize metrics calculator
    metrics_calc = MetricsCalculator(vocab)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])
    
    # Run evaluation
    print("\n" + "="*70)
    results, examples = evaluate_model(
        model, test_loader, vocab, metrics_calc, device, criterion
    )
    
    # Display results
    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)
    if 'perplexity' in results:
        print(f"Perplexity: {results['perplexity']:.2f}")
    print(f"BLEU:       {results['bleu']:.2f}")
    print(f"ROUGE-L:    {results['rouge_l']:.4f}")
    print(f"chrF:       {results['chrf']:.2f}")
    
    # Save metrics
    metrics_file = f"{output_dir}/automatic_metrics.json"
    with open(metrics_file, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\n✓ Saved: {metrics_file}")
    
    # Generate qualitative examples
    print("\nGenerating examples...")
    examples_file = f"{output_dir}/qualitative_examples.csv"
    display_qualitative_examples(examples, num_examples=20, save_file=examples_file)
    
    # Create human evaluation template
    print("\nCreating human eval template...")
    human_eval_file = f"{output_dir}/human_evaluation_template.csv"
    create_human_evaluation_template(examples, human_eval_file)
    
    # Summary
    print("\n" + "="*70)
    print("COMPLETE!")
    print("="*70)
    print(f"Output directory: {output_dir}")
    print(f"  1. {os.path.basename(metrics_file)}")
    print(f"  2. {os.path.basename(examples_file)}")
    print(f"  3. {os.path.basename(human_eval_file)}")
    print("="*70)
    
    return results, examples

# ===== AUTO-DETECTION HELPERS =====
def find_checkpoint_file(prefer_best=True):
    """
    Search for a .pt model in both /kaggle/input and /kaggle/working.
    
    Args:
        prefer_best: If True, prioritize best_model.pt over epoch checkpoints
    
    Returns:
        Path to the best available checkpoint
    """
    roots = ["/kaggle/input", "/kaggle/working"]
    
    # First priority: best_model.pt (BLEU-selected)
    if prefer_best:
        for base in roots:
            if not os.path.exists(base):
                continue
            for root, _, files in os.walk(base):
                if "best_model.pt" in files:
                    return os.path.join(root, "best_model.pt")
    
    # Second priority: any checkpoint file
    for base in roots:
        if not os.path.exists(base):
            continue
        for root, _, files in os.walk(base):
            for f in files:
                if f.endswith(".pt") and ("checkpoint" in f.lower() or "model" in f.lower()):
                    return os.path.join(root, f)
    
    return None

def find_data_directory():
    """Look for a folder that has vocab.json and test_ids.jsonl in input/working."""
    candidates = [
        "/kaggle/input", "/kaggle/working",
        "/kaggle/working/checkpoints", "/kaggle/input/checkpoints"
    ]
    for base in candidates:
        if not os.path.exists(base):
            continue
        for root, dirs, files in os.walk(base):
            if "vocab.json" in files and "test_ids.jsonl" in files:
                return root
    return None

def auto_detect_paths():
    """
    Automatically detect checkpoint and data paths in Kaggle.
    Prioritizes best_model.pt (BLEU-selected) over epoch checkpoints.
    """
    checkpoint_path = find_checkpoint_file(prefer_best=True)
    data_dir = find_data_directory()
    
    if checkpoint_path:
        print(f"✓ Found checkpoint: {checkpoint_path}")
    else:
        print("❌ No checkpoint file found (.pt)")
    
    if data_dir:
        print(f"✓ Found data directory: {data_dir}")
    else:
        print("❌ No data directory found (needs vocab.json + test_ids.jsonl)")
    
    return checkpoint_path, data_dir

# ============================================================
# MAIN EXECUTION
# ============================================================

if __name__ == "__main__":
    print("="*70)
    print("EMPATHETIC CHATBOT - MODEL EVALUATION")
    print("="*70)
    print("✓ Metrics packages available")
    
    # Check if in Kaggle
    if os.path.exists("/kaggle/input"):
        print("✓ Running in Kaggle environment\n")
        
        # AUTO-DETECT PATHS (prioritizes best_model.pt)
        print("Auto-detecting paths...")
        checkpoint_path, data_dir = auto_detect_paths()
        print()
        
        # Verify both found
        if checkpoint_path and data_dir:
            output_dir = "/kaggle/working/evaluation"
            
            # RUN EVALUATION
            try:
                metrics, examples = run_full_evaluation(
                    checkpoint_path=checkpoint_path,
                    data_dir=data_dir,
                    output_dir=output_dir
                )
                print("\n✓ Evaluation complete!")
                
            except Exception as e:
                print(f"\n✗ Error: {e}")
                import traceback
                traceback.print_exc()
        else:
            print("✗ Missing files:")
            if not checkpoint_path:
                print("  - Model checkpoint (.pt file)")
            if not data_dir:
                print("  - Data (vocab.json + test_ids.jsonl)")
            print("\nAdd datasets in Kaggle and re-run.")
    else:
        print("✓ Running locally\n")
        print("Usage:")
        print("  metrics, examples = run_full_evaluation(")
        print("      checkpoint_path='checkpoints/best_model.pt',")
        print("      data_dir='.',")
        print("      output_dir='./evaluation'")
        print("  )")
        print("\nOr see examples in the docstring at the bottom of this file.")

# ============================================================
# QUICK USAGE EXAMPLES
# ============================================================
"""
EXAMPLE 1: Run full evaluation in Kaggle
=========================================
metrics, examples = run_full_evaluation(
    checkpoint_path="/kaggle/input/my-model/best_model.pt",
    data_dir="/kaggle/input/my-data",
    output_dir="/kaggle/working/evaluation"
)

EXAMPLE 2: Just metrics (no files)
===================================
device = torch.device('cuda')
vocab = json.load(open("vocab.json"))
model = TransformerModel(len(vocab), pad_idx=vocab['<pad>']).to(device)
checkpoint = torch.load("best_model.pt")
model.load_state_dict(checkpoint['model_state_dict'])

test_dataset = ChatbotDataset("test_ids.jsonl", vocab)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=test_dataset.collate_fn)
metrics_calc = MetricsCalculator(vocab)
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'])

results, examples = evaluate_model(model, test_loader, vocab, metrics_calc, device, criterion)
print(f"BLEU: {results['bleu']:.2f}")
print(f"ROUGE-L: {results['rouge_l']:.4f}")
print(f"chrF: {results['chrf']:.2f}")

EXAMPLE 3: Analyze human evaluation results
============================================
human_scores = analyze_human_evaluation("/kaggle/input/human-eval/completed.csv")
print(f"Overall human score: {human_scores['overall']:.2f}/5.0")
"""


# **Inference Script**

In [None]:
# # ============================================================
# # Chatbot Inference Script — loads checkpoints/best_model.pt
# # ============================================================

# import os, json, math, sys
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# # ------------------- Model (same as training) -------------------
# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model, max_len=5000):
#         super().__init__()
#         pe = torch.zeros(max_len, d_model)
#         pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
#         div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
#         pe[:, 0::2] = torch.sin(pos * div)
#         pe[:, 1::2] = torch.cos(pos * div)
#         self.register_buffer("pe", pe.unsqueeze(0))
#     def forward(self, x):
#         return x + self.pe[:, :x.size(1), :]

# def make_pad_mask(ids, pad_idx):  # True=valid token
#     return (ids != pad_idx)

# def make_attn_mask(valid_q, valid_k):  # [B,Lq,Lk]
#     return valid_q.unsqueeze(2) & valid_k.unsqueeze(1)

# def make_causal_mask(L, device):
#     return torch.tril(torch.ones(L, L, dtype=torch.bool, device=device))

# def _ensure_nonempty_rows(mask):
#     has_true = mask.any(dim=-1, keepdim=True)
#     mask = mask.clone()
#     mask[:, :, 0] = mask[:, :, 0] | (~has_true.squeeze(-1))
#     return mask

# class MultiHeadAttention(nn.Module):
#     def __init__(self, d_model, n_heads, dropout=0.1):
#         super().__init__()
#         assert d_model % n_heads == 0
#         self.h = n_heads
#         self.dk = d_model // n_heads
#         self.Wq = nn.Linear(d_model, d_model)
#         self.Wk = nn.Linear(d_model, d_model)
#         self.Wv = nn.Linear(d_model, d_model)
#         self.Wo = nn.Linear(d_model, d_model)
#         self.drop = nn.Dropout(dropout)
#     def _split(self, x):
#         B,L,D = x.size()
#         return x.view(B, L, self.h, self.dk).transpose(1,2)
#     def _merge(self, x):
#         B,h,L,dk = x.size()
#         return x.transpose(1,2).contiguous().view(B,L,h*dk)
#     def forward(self, q, k, v, mask=None):
#         Q, K, V = self._split(self.Wq(q)), self._split(self.Wk(k)), self._split(self.Wv(v))
#         scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.dk)
#         if mask is not None:
#             if mask.dtype == torch.bool:
#                 mask = _ensure_nonempty_rows(mask)
#                 scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf"))
#             else:
#                 scores = scores + mask.unsqueeze(1)
#         attn = F.softmax(scores, dim=-1)
#         attn = torch.nan_to_num(attn, nan=0.0)
#         out = attn @ V
#         out = self._merge(out)
#         return self.Wo(self.drop(out))

# class FeedForward(nn.Module):
#     def __init__(self, d_model, d_ff=2048, dropout=0.1):
#         super().__init__()
#         self.ff = nn.Sequential(
#             nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
#             nn.Linear(d_ff, d_model), nn.Dropout(dropout)
#         )
#     def forward(self, x): return self.ff(x)

# class EncoderLayer(nn.Module):
#     def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
#         super().__init__()
#         self.sa = MultiHeadAttention(d_model, n_heads, dropout)
#         self.ff = FeedForward(d_model, d_ff, dropout)
#         self.n1, self.n2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
#     def forward(self, x, attn_mask=None):
#         x = self.n1(x + self.sa(x, x, x, attn_mask))
#         x = self.n2(x + self.ff(x))
#         return x

# class DecoderLayer(nn.Module):
#     def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
#         super().__init__()
#         self.msa = MultiHeadAttention(d_model, n_heads, dropout)
#         self.xattn = MultiHeadAttention(d_model, n_heads, dropout)
#         self.ff = FeedForward(d_model, d_ff, dropout)
#         self.n1, self.n2, self.n3 = nn.LayerNorm(d_model), nn.LayerNorm(d_model), nn.LayerNorm(d_model)
#     def forward(self, x, mem, self_mask=None, mem_mask=None):
#         x = self.n1(x + self.msa(x, x, x, self_mask))
#         x = self.n2(x + self.xattn(x, mem, mem, mem_mask))
#         x = self.n3(x + self.ff(x))
#         return x

# class Encoder(nn.Module):
#     def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
#         super().__init__()
#         self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
#     def forward(self, x, attn_mask=None):
#         for lyr in self.layers: x = lyr(x, attn_mask)
#         return x

# class Decoder(nn.Module):
#     def __init__(self, d_model, n_heads, num_layers, d_ff=2048, dropout=0.1):
#         super().__init__()
#         self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)])
#     def forward(self, x, mem, self_mask=None, mem_mask=None):
#         for lyr in self.layers: x = lyr(x, mem, self_mask, mem_mask)
#         return x

# class TransformerModel(nn.Module):
#     def __init__(self, vocab_size, d_model=256, n_heads=2, num_layers=2, d_ff=2048, dropout=0.1, pad_idx=0):
#         super().__init__()
#         self.pad_idx = pad_idx
#         self.d_model = d_model
#         self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
#         self.pos = PositionalEncoding(d_model)
#         self.encoder = Encoder(d_model, n_heads, num_layers, d_ff, dropout)
#         self.decoder = Decoder(d_model, n_heads, num_layers, d_ff, dropout)
#         self.out = nn.Linear(d_model, vocab_size)
#     def forward(self, src, tgt):
#         device = src.device
#         B, Ls = src.size()
#         Lt = tgt.size(1)
#         src_valid = make_pad_mask(src, self.pad_idx)
#         tgt_valid = make_pad_mask(tgt, self.pad_idx)
#         enc_self  = make_attn_mask(src_valid, src_valid)
#         dec_caus  = make_causal_mask(Lt, device)
#         dec_self  = dec_caus.unsqueeze(0).expand(B,-1,-1) & make_attn_mask(tgt_valid, tgt_valid)
#         cross     = make_attn_mask(tgt_valid, src_valid)
#         src_e = self.pos(self.embed(src) * math.sqrt(self.d_model))
#         tgt_e = self.pos(self.embed(tgt) * math.sqrt(self.d_model))
#         mem = self.encoder(src_e, enc_self)
#         dec = self.decoder(tgt_e, mem, dec_self, cross)
#         return self.out(dec)

# # --------------------- Token utils ---------------------
# def ids_to_text(ids, vocab):
#     idx2word = {v:k for k,v in vocab.items()}
#     out = []
#     for x in ids:
#         if x == vocab.get("<eos>"): break
#         if x == vocab.get("<pad>"): break
#         if x == vocab.get("<bos>"): continue
#         out.append(idx2word.get(int(x), "<unk>"))
#     return " ".join(out)

# def text_to_ids(text, vocab):
#     # simple whitespace tokenization (must match how vocab was built)
#     # unknown tokens -> <unk>
#     w2i = vocab
#     unk = w2i.get("<unk>", 0)
#     tokens = text.strip().lower().split()
#     return [w2i.get(t, unk) for t in tokens]

# def build_source(emotion, situation, customer):
#     # Matches project input template
#     # Emotion optional; if empty, omit
#     parts = []
#     if emotion:
#         parts.append(f"emotion: {emotion}")
#     parts.append(f"situation: {situation}".strip())
#     parts.append(f"customer: {customer}".strip())
#     parts.append("agent:")
#     return " | ".join(parts).lower()

# # --------------------- Decoding ---------------------
# @torch.no_grad()
# def greedy_decode(model, src, vocab, max_len=128, device=None):
#     if device is None: device = src.device
#     model.eval()
#     bos = vocab["<bos>"]; eos = vocab["<eos>"]
#     B = src.size(0)
#     tgt = torch.full((B,1), bos, dtype=torch.long, device=device)
#     for _ in range(max_len-1):
#         logits = model(src, tgt)
#         next_tok = logits[:, -1, :].argmax(-1, keepdim=True)
#         tgt = torch.cat([tgt, next_tok], dim=1)
#         if (next_tok.squeeze(-1) == eos).all(): break
#     return tgt[:,1:]  # strip BOS

# @torch.no_grad()
# def beam_search_decode(model, src, vocab, max_len=128, beam_size=4, alpha=0.7, device=None):
#     """
#     Simple length-penalized beam search.
#     score = sum(logprobs) / (len^alpha)
#     """
#     if device is None: device = src.device
#     model.eval()
#     bos = vocab["<bos>"]; eos = vocab["<eos>"]

#     beams = [(torch.tensor([bos], device=device, dtype=torch.long), 0.0, False)]  # (seq, score, ended)
#     for _ in range(max_len-1):
#         # gather candidates
#         all_cands = []
#         seqs = [b[0] for b in beams]
#         # expand each beam independently
#         for (seq, score, ended) in beams:
#             if ended:
#                 all_cands.append((seq, score, True))
#                 continue
#             tgt = seq.unsqueeze(0)                                 # [1, t]
#             logits = model(src, tgt)                               # [1, t, V]
#             logp = F.log_softmax(logits[:, -1, :], dim=-1)[0]      # [V]
#             topk_logp, topk_ids = torch.topk(logp, beam_size)
#             for lp, wid in zip(topk_logp.tolist(), topk_ids.tolist()):
#                 new_seq = torch.cat([seq, torch.tensor([wid], device=device)])
#                 ended2 = (wid == eos)
#                 # length penalty
#                 L = new_seq.size(0)
#                 new_score = (score * ((L-1)**alpha) + lp) / (L**alpha)
#                 all_cands.append((new_seq, new_score, ended2))
#         # select best beams
#         all_cands.sort(key=lambda x: x[1], reverse=True)
#         beams = all_cands[:beam_size]
#         # early stop if all ended
#         if all(b[2] for b in beams):
#             break

#     best_seq = max(beams, key=lambda x: x[1])[0]
#     # remove BOS
#     if best_seq[0].item() == bos:
#         best_seq = best_seq[1:]
#     return best_seq.unsqueeze(0)  # [1, T]

# # --------------------- Loader ---------------------
# def load_vocab(vocab_path, ckpt_dict):
#     if os.path.exists(vocab_path):
#         with open(vocab_path, "r") as f:
#             return json.load(f)
#     # fallback: from checkpoint extras
#     if isinstance(ckpt_dict.get("vocab"), dict):
#         return ckpt_dict["vocab"]
#     raise FileNotFoundError("vocab.json not found and 'vocab' not present in checkpoint.")

# def build_model_from_ckpt(ckpt, vocab, device):
#     cfg = ckpt.get("config", {}) or {}
#     d_model   = int(cfg.get("d_model", 256))
#     n_heads   = int(cfg.get("n_heads", 2))
#     num_layers= int(cfg.get("num_layers", 2))
#     d_ff      = int(cfg.get("d_ff", 2048))
#     dropout   = float(cfg.get("dropout", 0.1))
#     pad_idx   = int(vocab.get("<pad>", 0))

#     model = TransformerModel(
#         vocab_size=len(vocab),
#         d_model=d_model,
#         n_heads=n_heads,
#         num_layers=num_layers,
#         d_ff=d_ff,
#         dropout=dropout,
#         pad_idx=pad_idx
#     ).to(device)
#     model.load_state_dict(ckpt["model_state_dict"], strict=True)
#     model.eval()
#     return model

# # --------------------- Chat Loop ---------------------
# def chat_loop(model, vocab, device, max_len=128, use_beam=False, beam_size=4):
#     print("\nEmpathetic Chatbot ready. Commands: /beam, /greedy, /quit\n")
#     while True:
#         try:
#             emo = input("Emotion (optional, e.g., 'sad/afraid/hopeful'; Enter to skip): ").strip()
#             sit = input("Situation: ").strip()
#             usr = input("Customer: ").strip()
#         except EOFError:
#             print("\nExiting.")
#             break

#         # command shortcuts when user entered a command instead of text
#         if emo.lower() in ("/quit","quit","exit"):
#             break
#         if emo.lower() == "/beam":
#             use_beam = True
#             print("Decoding set to BEAM.")
#             continue
#         if emo.lower() == "/greedy":
#             use_beam = False
#             print("Decoding set to GREEDY.")
#             continue

#         # Build input string
#         src_text = build_source(emo, sit, usr)
#         src_ids = text_to_ids(src_text, vocab)
#         if len(src_ids) == 0:
#             print("Input produced empty ids; please try again.")
#             continue

#         src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, L]

#         # Decode
#         if use_beam:
#             gen = beam_search_decode(model, src, vocab, max_len=max_len, beam_size=beam_size, device=device)  # [1,T]
#         else:
#             gen = greedy_decode(model, src, vocab, max_len=max_len, device=device)  # [1,T]

#         reply = ids_to_text(gen[0].tolist(), vocab)
#         print(f"\nAgent: {reply}\n")

# # --------------------- Main ---------------------
# def main():
#     # Paths
#     ckpt_path = os.environ.get("BEST_CKPT", "checkpoints/best_model.pt")
#     vocab_path = os.environ.get("VOCAB_JSON", "vocab.json")

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"Using device: {device}")

#     if not os.path.exists(ckpt_path):
#         print(f"Checkpoint not found at: {ckpt_path}")
#         print("Set BEST_CKPT env var or place best_model.pt under checkpoints/.")
#         sys.exit(1)

#     # Load checkpoint
#     ckpt = torch.load(ckpt_path, map_location=device)

#     # Load vocab (file first, fallback to ckpt['vocab'])
#     vocab = load_vocab(vocab_path, ckpt)
#     for tok in ["<pad>","<bos>","<eos>","<unk>"]:
#         if tok not in vocab:
#             raise RuntimeError(f"Missing special token in vocab: {tok}")

#     # Recreate model and load weights
#     model = build_model_from_ckpt(ckpt, vocab, device)

#     # Max gen length (from config if available)
#     cfg = ckpt.get("config", {}) or {}
#     max_len = int(cfg.get("max_seq_len", 128))

#     # Start interactive chat
#     chat_loop(model, vocab, device, max_len=max_len, use_beam=False, beam_size=4)

# if __name__ == "__main__":
#     main()


# **END**