In [None]:
# %% [markdown]
# # Uncertainty Quantification Pipeline for Neural Text Generation

# This notebook implements uncertainty quantification analysis for neural text generation,
# ensuring proper data selection, prompt handling, and multifaceted evaluation metrics.
#
# **Specifications:**
# - 500 prompts with exactly 10 distinct human continuations each
# - 4 models × 5 decoding strategies × 10 samples = 1,000 generations per prompt
# - Proper prompt removal from all model outputs
# - Comprehensive evaluation metrics

# %% [markdown]
# ## 1. Installation and Setup

# %%
print("Installing required packages...")
print("This will take a few minutes...")

# Core packages
!pip install -q --upgrade pip
!pip install -q datasets pandas numpy matplotlib seaborn
!pip install -q torch transformers accelerate sentencepiece protobuf
!pip install -q sentence-transformers bert-score nltk rouge-score
!pip install -q scikit-learn scipy tqdm
!pip install -q bitsandbytes  # For 4-bit quantization
!pip install -U bitsandbytes

# Additional packages for advanced metrics
!pip install -q textstat  # Readability metrics
!pip install -q lexical-diversity  # Vocabulary diversity
!pip install -q evaluate  # HuggingFace evaluate library

print("All packages installed")

# %% [markdown]
# ## 2. Imports and Configuration with NLTK Setup

# %%
import os
import sys
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
from datasets import load_dataset
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
import json
import time
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional, Set
from scipy import stats
import random
from datetime import datetime
import hashlib

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    BitsAndBytesConfig,
    set_seed,
    StoppingCriteria,
    StoppingCriteriaList
)

# Evaluation libraries
from sentence_transformers import SentenceTransformer
from bert_score import score as bert_score
import nltk
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer
import textstat
from lexical_diversity import lex_div as ld
import evaluate

print("All libraries imported")

# NLTK DATA DOWNLOAD WITH ERROR HANDLING
print("\nSetting up NLTK resources...")

def setup_nltk_resources():
    """Download all required NLTK resources with proper error handling"""
    required_resources = [
        'punkt',
        'punkt_tab',
        'averaged_perceptron_tagger',
        'averaged_perceptron_tagger_eng',
        'wordnet',
        'stopwords',
        'tagsets',
        'maxent_ne_chunker',
        'words',
        'brown',
        'omw-1.4'
    ]

    failed_resources = []

    for resource in required_resources:
        try:
            nltk.data.find(f'tokenizers/{resource}')
            print(f"   {resource} already present")
        except LookupError:
            try:
                print(f"   Downloading {resource}...")
                nltk.download(resource, quiet=True)
                print(f"   {resource} downloaded")
            except Exception as e:
                print(f"   Failed to download {resource}: {str(e)[:50]}")
                failed_resources.append(resource)

    # Try alternative downloads for critical resources
    if failed_resources:
        print("\n   Attempting alternative downloads...")
        try:
            nltk.download('all-corpora', quiet=True)
            nltk.download('all-nltk', quiet=True)
        except:
            pass

        still_failed = []
        for resource in failed_resources:
            try:
                nltk.data.find(f'tokenizers/{resource}')
            except LookupError:
                still_failed.append(resource)

        if still_failed:
            print(f"\n   Some resources unavailable: {still_failed}")
            print("   The pipeline will use fallback methods for these.")

    return len(failed_resources) == 0

# Download NLTK resources
nltk_success = setup_nltk_resources()

# Validate critical NLTK functionality
print("\nValidating NLTK functionality...")
validation_passed = True

try:
    test_text = "This is a test sentence."
    tokens = nltk.word_tokenize(test_text)
    assert len(tokens) > 0, "Tokenization failed"
    print("   Tokenization working")
except Exception as e:
    print(f"   Tokenization issue: {e}")
    validation_passed = False

try:
    tokens = nltk.word_tokenize("The cat sat on the mat.")
    pos_tags = nltk.pos_tag(tokens)
    assert len(pos_tags) > 0, "POS tagging failed"
    print("   POS tagging working")
except Exception as e:
    print(f"   POS tagging issue: {e}")
    print("   Installing fallback...")
    try:
        nltk.download('averaged_perceptron_tagger', quiet=True)
        nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    except:
        pass
    validation_passed = False

try:
    sentences = nltk.sent_tokenize("This is sentence one. This is sentence two.")
    assert len(sentences) == 2, "Sentence tokenization failed"
    print("   Sentence tokenization working")
except Exception as e:
    print(f"   Sentence tokenization issue: {e}")
    validation_passed = False

if not validation_passed:
    print("\nSome NLTK functions may not work properly. The pipeline will use fallback methods.")
else:
    print("\nAll NLTK functions validated successfully")

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
set_seed(SEED)

# Enable TF32 for A100 optimization
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
if device.type == "cuda":
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"   Memory: {gpu_memory:.2f} GB")

    gpu_name = torch.cuda.get_device_name(0).lower()
    if 'a100' in gpu_name:
        print("   A100 detected - enabling optimizations")
        USE_FLASH_ATTENTION = True
        BATCH_SIZE_MULTIPLIER = 2
    elif 'v100' in gpu_name or 'p100' in gpu_name:
        print("   V100/P100 detected - using standard settings")
        USE_FLASH_ATTENTION = False
        BATCH_SIZE_MULTIPLIER = 1
    else:
        print(f"   GPU type: {torch.cuda.get_device_name(0)}")
        USE_FLASH_ATTENTION = False
        BATCH_SIZE_MULTIPLIER = 1
else:
    USE_FLASH_ATTENTION = False
    BATCH_SIZE_MULTIPLIER = 0.5

# %% [markdown]
# ## 3. HuggingFace Authentication

# %%
def get_hf_token():
    """Get HuggingFace token from Colab secrets"""
    try:
        from google.colab import userdata
        token = userdata.get('HF_TOKEN')
        print("Token loaded from Colab secrets")
        return token
    except Exception as e:
        print(f"Could not load token from secrets: {e}")
        token = input("Please enter your HuggingFace token: ").strip()
        if token:
            return token
        else:
            raise ValueError("No HuggingFace token provided")

HF_TOKEN = get_hf_token()
os.environ["HF_TOKEN"] = HF_TOKEN

from huggingface_hub import login
try:
    login(token=HF_TOKEN, add_to_git_credential=False)
    print("Logged in to HuggingFace")
except Exception as e:
    print(f"Failed to login to HuggingFace: {e}")
    raise

# %% [markdown]
# ## 4. Setup Output Directory

# %%
# Try to use Google Drive for persistent storage
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    SAVE_DIR = '/content/drive/MyDrive/UQ_Analysis'
    print("Google Drive mounted")
except:
    SAVE_DIR = '/content/UQ_Analysis'
    print("Using local storage (will be lost on session end)")

# Create directory structure
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'results'), exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'figures'), exist_ok=True)

# Test write permissions
test_file = os.path.join(SAVE_DIR, 'test.txt')
try:
    with open(test_file, 'w') as f:
        f.write('test')
    os.remove(test_file)
    print(f"Save directory ready: {SAVE_DIR}")
except:
    print(f"Cannot write to {SAVE_DIR}")
    SAVE_DIR = './UQ_Analysis'
    os.makedirs(SAVE_DIR, exist_ok=True)
    print(f"Using fallback directory: {SAVE_DIR}")

# %% [markdown]
# ## 5. Load and Select High-Quality Prompts (500 with 10 Unique Stories Each)

# %%
print("\nLoading WritingPrompts dataset...")

# Load dataset
dataset = load_dataset("euclaise/writingprompts", token=HF_TOKEN)
df_train = pd.DataFrame(dataset['train'])
print(f"Dataset loaded. Total entries: {len(df_train):,}")

# Initialize tokenizer for length analysis
from transformers import GPT2Tokenizer
length_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
length_tokenizer.pad_token = length_tokenizer.eos_token

print("\nFinding prompts with exactly 10 unique human stories...")

# Group stories by prompt and filter for uniqueness
prompt_to_stories = defaultdict(list)

for idx, row in tqdm(df_train.iterrows(), total=len(df_train), desc="Grouping stories"):
    prompt = row['prompt']
    story = row['story']

    # Quality checks
    if not prompt or not story:
        continue
    if len(prompt) < 20 or len(prompt) > 500:
        continue
    if len(story) < 200 or len(story) > 10000:
        continue

    prompt_to_stories[prompt].append(story)

print(f"\nFound {len(prompt_to_stories)} unique prompts")

# Find prompts with at least 10 UNIQUE stories
prompts_with_10_unique = {}
story_length_stats = []

for prompt, stories in tqdm(prompt_to_stories.items(), desc="Finding unique stories"):
    unique_stories = list(set(stories))

    if len(unique_stories) >= 10:
        story_data = []
        for story in unique_stories[:10]:
            try:
                tokens = length_tokenizer.encode(story, truncation=False, add_special_tokens=False)
                story_data.append({
                    'text': story,
                    'length': len(tokens),
                    'hash': hashlib.md5(story.encode()).hexdigest()
                })
            except:
                continue

        if len(story_data) == 10:
            prompts_with_10_unique[prompt] = story_data
            story_length_stats.extend([s['length'] for s in story_data])

print(f"\nFound {len(prompts_with_10_unique)} prompts with 10+ unique stories")

# Select 500 best prompts (prioritize by diversity of story lengths)
def calculate_prompt_quality(prompt, stories):
    """Score prompts based on story diversity and quality"""
    lengths = [s['length'] for s in stories]
    length_diversity = np.std(lengths)
    mean_length = np.mean(lengths)

    if 100 <= mean_length <= 500:
        length_score = 1.0
    else:
        length_score = 0.5

    quality_score = length_diversity * length_score
    return quality_score

# Score and rank prompts
prompt_scores = []
for prompt, stories in prompts_with_10_unique.items():
    score = calculate_prompt_quality(prompt, stories)
    prompt_scores.append((prompt, stories, score))

prompt_scores.sort(key=lambda x: x[2], reverse=True)
selected_prompts = prompt_scores[:500]

print(f"\nSelected {len(selected_prompts)} highest quality prompts")

# Prepare final dataset
selected_data = {}
all_story_lengths = []
hash_verification = set()

for i, (prompt, stories, score) in enumerate(selected_prompts):
    story_hashes = [s['hash'] for s in stories]
    assert len(story_hashes) == len(set(story_hashes)), f"Duplicate stories found in prompt {i}"

    selected_data[str(i)] = {
        'prompt': prompt,
        'human_stories': [s['text'] for s in stories],
        'human_story_lengths': [s['length'] for s in stories],
        'human_story_hashes': story_hashes,
        'quality_score': score,
        'mean_length': np.mean([s['length'] for s in stories]),
        'std_length': np.std([s['length'] for s in stories])
    }

    all_story_lengths.extend([s['length'] for s in stories])
    hash_verification.update(story_hashes)

# Verify global uniqueness
print(f"\nVerification:")
print(f"   Total unique story hashes: {len(hash_verification)}")
print(f"   Expected (500 * 10): {500 * 10}")
assert len(hash_verification) == 500 * 10, "Some stories are duplicated across prompts"

# Save selected prompts
save_path = os.path.join(SAVE_DIR, 'selected_prompts_verified.json')
with open(save_path, 'w') as f:
    json.dump(selected_data, f, indent=2)

print(f"\nSaved {len(selected_data)} prompts with verified unique stories")

# Statistics
print(f"\nDataset Statistics:")
print(f"   Total prompts: {len(selected_data)}")
print(f"   Total unique human stories: {len(selected_data) * 10}")
print(f"   Mean story length: {np.mean(all_story_lengths):.1f} ± {np.std(all_story_lengths):.1f} tokens")
print(f"   Median length: {np.median(all_story_lengths):.1f} tokens")
print(f"   Length range: {min(all_story_lengths)}-{max(all_story_lengths)} tokens")
print(f"   All stories verified unique")

# %% [markdown]
# ## 6. Model Configurations (4 Models)

# %%
print("\nConfiguring models...")

# 4-bit quantization configuration for large models
bnb_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Model configurations - 4 models
MODEL_CONFIGS = {
    'GPT2-XL': {
        'name': 'gpt2-xl',
        'model_class': GPT2LMHeadModel,
        'tokenizer_class': GPT2Tokenizer,
        'quantization_config': None,
        'torch_dtype': torch.float16,
        'requires_auth': False,
        'batch_size': int(20 * BATCH_SIZE_MULTIPLIER),
        'remove_prompt': False,
        'max_length': 512
    },
    'Mistral-7B-Instruct': {
        'name': 'mistralai/Mistral-7B-Instruct-v0.2',
        'model_class': AutoModelForCausalLM,
        'tokenizer_class': AutoTokenizer,
        'quantization_config': bnb_config_4bit,
        'torch_dtype': torch.float16,
        'requires_auth': True,
        'batch_size': int(10 * BATCH_SIZE_MULTIPLIER),
        'remove_prompt': True,
        'max_length': 512
    },
    'Llama-3.1-8B-Instruct': {
        'name': 'meta-llama/Llama-3.1-8B-Instruct',
        'model_class': AutoModelForCausalLM,
        'tokenizer_class': AutoTokenizer,
        'quantization_config': bnb_config_4bit,
        'torch_dtype': torch.float16,
        'requires_auth': True,
        'batch_size': int(8 * BATCH_SIZE_MULTIPLIER),
        'remove_prompt': True,
        'max_length': 512
    },
    'Gemma-2B': {
        'name': 'google/gemma-2b',
        'model_class': AutoModelForCausalLM,
        'tokenizer_class': AutoTokenizer,
        'quantization_config': None,
        'torch_dtype': torch.float16,
        'requires_auth': True,
        'batch_size': int(25 * BATCH_SIZE_MULTIPLIER),
        'remove_prompt': True,
        'max_length': 512
    }
}

# Decoding strategies - 5 strategies
DECODING_STRATEGIES = {
    'temperature_0.7': {
        'do_sample': True,
        'temperature': 0.7,
        'top_k': 0,
        'top_p': 1.0,
        'repetition_penalty': 1.1
    },
    'temperature_1.2': {
        'do_sample': True,
        'temperature': 1.2,
        'top_k': 0,
        'top_p': 1.0,
        'repetition_penalty': 1.1
    },
    'top_p_0.9': {
        'do_sample': True,
        'temperature': 1.0,
        'top_p': 0.9,
        'top_k': 0,
        'repetition_penalty': 1.1
    },
    'top_k_40': {
        'do_sample': True,
        'temperature': 1.0,
        'top_k': 40,
        'top_p': 1.0,
        'repetition_penalty': 1.1
    },
    'typical_0.95': {
        'do_sample': True,
        'temperature': 1.0,
        'typical_p': 0.95,
        'top_k': 0,
        'top_p': 1.0,
        'repetition_penalty': 1.1
    }
}

total_generations = len(selected_data) * len(MODEL_CONFIGS) * len(DECODING_STRATEGIES) * 10
print(f"\nGeneration Plan:")
print(f"   Prompts: {len(selected_data)}")
print(f"   Models: {len(MODEL_CONFIGS)} ({', '.join(MODEL_CONFIGS.keys())})")
print(f"   Strategies: {len(DECODING_STRATEGIES)} ({', '.join(DECODING_STRATEGIES.keys())})")
print(f"   Samples per config: 10")
print(f"   Total generations: {total_generations:,}")

estimated_time = total_generations * 0.5 / 60
print(f"   Estimated time: {estimated_time:.1f} minutes ({estimated_time/60:.1f} hours)")

# %% [markdown]
# ## 7. Generation Functions with Proper Prompt Handling

# %%
def load_model_and_tokenizer(model_config: Dict):
    """Load model and tokenizer with optimizations"""
    model_name = model_config['name']

    try:
        print(f"   Loading {model_name}...")

        # Load tokenizer
        tokenizer = model_config['tokenizer_class'].from_pretrained(
            model_name,
            token=HF_TOKEN if model_config.get('requires_auth', False) else None,
            trust_remote_code=True
        )

        # Set padding token
        if tokenizer.pad_token is None:
            if tokenizer.eos_token:
                tokenizer.pad_token = tokenizer.eos_token
            else:
                tokenizer.pad_token = tokenizer.unk_token

        # Model loading arguments
        model_kwargs = {
            'token': HF_TOKEN if model_config.get('requires_auth', False) else None,
            'torch_dtype': model_config['torch_dtype'],
            'device_map': "auto",
            'trust_remote_code': True
        }

        # Add quantization if specified
        if model_config.get('quantization_config'):
            model_kwargs['quantization_config'] = model_config['quantization_config']
            print(f"      Using 4-bit quantization")

        # Try Flash Attention 2 for supported models
        if USE_FLASH_ATTENTION and 'gpt2' not in model_name.lower():
            try:
                model_kwargs['use_flash_attention_2'] = True
                model = model_config['model_class'].from_pretrained(
                    model_name,
                    **model_kwargs
                )
                print(f"      Loaded with Flash Attention 2")
            except:
                del model_kwargs['use_flash_attention_2']
                model = model_config['model_class'].from_pretrained(
                    model_name,
                    **model_kwargs
                )
                print(f"      Loaded without Flash Attention")
        else:
            model = model_config['model_class'].from_pretrained(
                model_name,
                **model_kwargs
            )
            print(f"      Model loaded successfully")

        # Try torch.compile for speedup (PyTorch 2.0+)
        if torch.__version__ >= "2.0.0" and 'gpt2' in model_name.lower():
            try:
                model = torch.compile(model, mode="reduce-overhead")
                print(f"      Model compiled with torch.compile")
            except:
                pass

        return model, tokenizer

    except Exception as e:
        print(f"   Error loading {model_name}: {str(e)[:200]}")
        return None, None

def format_prompt_for_generation(prompt: str, model_name: str) -> str:
    """Format prompt appropriately for each model"""
    prompt = prompt.strip()

    if 'instruct' in model_name.lower():
        if 'mistral' in model_name.lower():
            return f"[INST] Write a creative story based on this prompt: {prompt} [/INST]"
        elif 'llama' in model_name.lower():
            return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWrite a creative story based on this prompt: {prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        elif 'gemma' in model_name.lower():
            return f"<start_of_turn>user\nWrite a creative story based on this prompt: {prompt}<end_of_turn>\n<start_of_turn>model\n"

    return f"Story prompt: {prompt}\n\nStory:"

def clean_generated_text(generated_text: str, prompt: str, formatted_prompt: str, model_name: str) -> str:
    """Remove prompt from generated text - critical for proper evaluation"""

    if formatted_prompt in generated_text:
        generated_text = generated_text.replace(formatted_prompt, "").strip()

    patterns_to_remove = [
        formatted_prompt,
        f"Write a creative story based on this prompt: {prompt}",
        f"Story prompt: {prompt}\n\nStory:",
        f"Story prompt: {prompt}",
        prompt
    ]

    for pattern in patterns_to_remove:
        if generated_text.startswith(pattern):
            generated_text = generated_text[len(pattern):].strip()

    if 'mistral' in model_name.lower():
        if '[INST]' in generated_text and '[/INST]' in generated_text:
            start = generated_text.find('[/INST]')
            if start != -1:
                generated_text = generated_text[start + 7:].strip()

    elif 'llama' in model_name.lower():
        markers = ['<|eot_id|>', '<|start_header_id|>', '<|end_header_id|>']
        for marker in markers:
            generated_text = generated_text.replace(marker, '')

    elif 'gemma' in model_name.lower():
        if '<start_of_turn>model' in generated_text:
            start = generated_text.find('<start_of_turn>model')
            if start != -1:
                generated_text = generated_text[start + 21:].strip()
        generated_text = generated_text.replace('<end_of_turn>', '').strip()

    if generated_text.startswith(prompt):
        generated_text = generated_text[len(prompt):].strip()

    generated_text = generated_text.lstrip(':').strip()

    return generated_text

def generate_stories_batch(
    prompts: List[str],
    prompt_ids: List[str],
    model,
    tokenizer,
    model_config: Dict,
    decoding_strategy: Dict,
    num_samples: int = 10
) -> Dict[str, List[str]]:
    """Generate stories with proper prompt handling and batching"""

    results = {}
    model_name = model_config['name']
    batch_size = model_config.get('batch_size', 10)
    max_length = model_config.get('max_length', 512)
    remove_prompt = model_config.get('remove_prompt', True)

    for batch_start in tqdm(range(0, len(prompts), batch_size),
                            desc=f"      Generating", leave=False):
        batch_end = min(batch_start + batch_size, len(prompts))
        batch_prompts = prompts[batch_start:batch_end]
        batch_prompt_ids = prompt_ids[batch_start:batch_end]

        formatted_prompts = [format_prompt_for_generation(p, model_name) for p in batch_prompts]

        batch_results = {p: [] for p in batch_prompts}

        samples_per_pass = min(5, num_samples)
        num_passes = (num_samples + samples_per_pass - 1) // samples_per_pass

        for pass_idx in range(num_passes):
            samples_this_pass = min(samples_per_pass, num_samples - pass_idx * samples_per_pass)

            try:
                inputs = tokenizer(
                    formatted_prompts,
                    return_tensors="pt",
                    truncation=True,
                    max_length=max_length // 2,
                    padding=True
                )
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                gen_params = {
                    **{k: v for k, v in decoding_strategy.items() if k != 'typical_p'},
                    'max_new_tokens': max_length // 2,
                    'min_new_tokens': 50,
                    'num_return_sequences': samples_this_pass,
                    'pad_token_id': tokenizer.pad_token_id,
                    'eos_token_id': tokenizer.eos_token_id,
                    'return_dict_in_generate': False,
                }

                if 'typical_p' in decoding_strategy:
                    try:
                        gen_params['typical_p'] = decoding_strategy['typical_p']
                    except:
                        pass

                set_seed(SEED + batch_start + pass_idx * 1000)

                with torch.no_grad():
                    with torch.cuda.amp.autocast(dtype=torch.float16):
                        outputs = model.generate(**inputs, **gen_params)

                generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

                for i, original_prompt in enumerate(batch_prompts):
                    formatted_prompt = formatted_prompts[i]

                    for j in range(samples_this_pass):
                        output_idx = i * samples_this_pass + j
                        if output_idx < len(generated_texts):
                            generated_text = generated_texts[output_idx]

                            if remove_prompt:
                                story = clean_generated_text(
                                    generated_text,
                                    original_prompt,
                                    formatted_prompt,
                                    model_name
                                )
                            else:
                                story = generated_text
                                if story.startswith(formatted_prompt):
                                    story = story[len(formatted_prompt):].strip()

                            if len(story.split()) > 10:
                                batch_results[original_prompt].append(story)

                del outputs
                torch.cuda.empty_cache()

            except torch.cuda.OutOfMemoryError:
                print(f"         OOM, reducing batch size")
                torch.cuda.empty_cache()

                for i, original_prompt in enumerate(batch_prompts):
                    formatted_prompt = formatted_prompts[i]

                    for sample_idx in range(samples_this_pass):
                        try:
                            single_input = tokenizer(
                                formatted_prompt,
                                return_tensors="pt",
                                truncation=True,
                                max_length=max_length // 2
                            )
                            single_input = {k: v.to(model.device) for k, v in single_input.items()}

                            with torch.no_grad():
                                output = model.generate(
                                    **single_input,
                                    **{k: v for k, v in gen_params.items() if k != 'num_return_sequences'},
                                    num_return_sequences=1
                                )

                            generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

                            if remove_prompt:
                                story = clean_generated_text(
                                    generated_text,
                                    original_prompt,
                                    formatted_prompt,
                                    model_name
                                )
                            else:
                                story = generated_text
                                if story.startswith(formatted_prompt):
                                    story = story[len(formatted_prompt):].strip()

                            if len(story.split()) > 10:
                                batch_results[original_prompt].append(story)

                        except Exception as e:
                            print(f"         Error generating sample: {str(e)[:100]}")
                            continue

                torch.cuda.empty_cache()

            except Exception as e:
                print(f"         Generation error: {str(e)[:100]}")
                continue

        for prompt in batch_prompts:
            stories = batch_results[prompt]

            while len(stories) < num_samples:
                if stories:
                    stories.append(stories[len(stories) % len(stories)])
                else:
                    stories.append("Once upon a time, there was a story that began with this prompt.")

            results[prompt] = stories[:num_samples]

    return results

# %% [markdown]
# ## 8. Comprehensive Evaluation Metrics

# %%
class ComprehensiveMetrics:
    """Calculate comprehensive evaluation metrics"""

    def __init__(self, device='cuda'):
        self.device = device

        print("   Loading evaluation models...")
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)

        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

        try:
            self.bleurt = evaluate.load("bleurt", "BLEURT-20")
            self.has_bleurt = True
        except:
            self.has_bleurt = False
            print("      BLEURT not available, skipping")

    def safe_nltk_tokenize(self, text: str) -> List[str]:
        """Safe tokenization with fallback"""
        try:
            return nltk.word_tokenize(text)
        except:
            return text.split()

    def safe_nltk_sent_tokenize(self, text: str) -> List[str]:
        """Safe sentence tokenization with fallback"""
        try:
            return nltk.sent_tokenize(text)
        except:
            sentences = text.split('.')
            return [s.strip() for s in sentences if s.strip()]

    def safe_pos_tag(self, tokens: List[str]) -> List[Tuple[str, str]]:
        """Safe POS tagging with fallback"""
        try:
            return nltk.pos_tag(tokens)
        except:
            return [(token, 'NN') for token in tokens]

    def calculate_lexical_diversity(self, texts: List[str]) -> Dict:
        """Calculate various lexical diversity metrics with error handling"""
        metrics = {}

        try:
            combined_text = ' '.join(texts)
            tokens = combined_text.split()

            metrics['type_token_ratio'] = len(set(tokens)) / len(tokens) if tokens else 0

            try:
                metrics['mtld'] = ld.mtld(tokens)
                metrics['hdd'] = ld.hdd(tokens)
            except:
                metrics['mtld'] = 0
                metrics['hdd'] = 0

            metrics['vocab_size'] = len(set(tokens))

            word_freq = Counter(tokens)
            metrics['hapax_ratio'] = sum(1 for count in word_freq.values() if count == 1) / len(tokens) if tokens else 0

        except Exception as e:
            print(f"      Error in lexical diversity: {str(e)[:50]}")
            metrics = {
                'type_token_ratio': 0,
                'mtld': 0,
                'hdd': 0,
                'vocab_size': 0,
                'hapax_ratio': 0
            }

        return metrics

    def calculate_syntactic_diversity(self, texts: List[str]) -> Dict:
        """Calculate syntactic diversity metrics with error handling"""
        metrics = {}

        try:
            all_pos_sequences = []
            for text in texts[:10]:
                tokens = self.safe_nltk_tokenize(text[:500])
                pos_tags = self.safe_pos_tag(tokens)
                pos_sequence = ' '.join([tag for _, tag in pos_tags[:50]])
                all_pos_sequences.append(pos_sequence)

            metrics['unique_pos_patterns'] = len(set(all_pos_sequences))

            sentence_lengths = []
            for text in texts[:10]:
                sentences = self.safe_nltk_sent_tokenize(text)
                sentence_lengths.extend([len(s.split()) for s in sentences[:10]])

            if sentence_lengths:
                metrics['sentence_length_mean'] = np.mean(sentence_lengths)
                metrics['sentence_length_std'] = np.std(sentence_lengths)
            else:
                metrics['sentence_length_mean'] = 0
                metrics['sentence_length_std'] = 0

        except Exception as e:
            print(f"      Error in syntactic diversity: {str(e)[:50]}")
            metrics = {
                'unique_pos_patterns': 0,
                'sentence_length_mean': 0,
                'sentence_length_std': 0
            }

        return metrics

    def calculate_semantic_coherence(self, texts: List[str]) -> Dict:
        """Calculate semantic coherence and similarity metrics with error handling"""
        metrics = {}

        try:
            with torch.no_grad():
                embeddings = self.sentence_model.encode(texts, convert_to_tensor=True, show_progress_bar=False)
                embeddings_np = embeddings.cpu().numpy()

            similarities = cosine_similarity(embeddings_np)

            upper_tri = np.triu_indices(len(texts), k=1)
            pairwise_sims = similarities[upper_tri]

            metrics['semantic_coherence'] = np.mean(pairwise_sims) if len(pairwise_sims) > 0 else 0
            metrics['semantic_diversity'] = 1 - metrics['semantic_coherence']
            metrics['semantic_coherence_std'] = np.std(pairwise_sims) if len(pairwise_sims) > 0 else 0

            if len(pairwise_sims) > 0:
                metrics['min_similarity'] = np.min(pairwise_sims)
                metrics['max_similarity'] = np.max(pairwise_sims)
            else:
                metrics['min_similarity'] = 0
                metrics['max_similarity'] = 0

        except Exception as e:
            print(f"      Error in semantic coherence: {str(e)[:50]}")
            metrics = {
                'semantic_coherence': 0,
                'semantic_diversity': 0,
                'semantic_coherence_std': 0,
                'min_similarity': 0,
                'max_similarity': 0
            }

        return metrics

    def calculate_readability(self, texts: List[str]) -> Dict:
        """Calculate readability metrics with error handling"""
        metrics = {}

        try:
            readability_scores = []
            for text in texts[:10]:
                try:
                    fre = textstat.flesch_reading_ease(text)
                    readability_scores.append(fre)
                except:
                    continue

            if readability_scores:
                metrics['flesch_reading_ease_mean'] = np.mean(readability_scores)
                metrics['flesch_reading_ease_std'] = np.std(readability_scores)
            else:
                metrics['flesch_reading_ease_mean'] = 0
                metrics['flesch_reading_ease_std'] = 0

            grade_levels = []
            for text in texts[:10]:
                try:
                    grade = textstat.flesch_kincaid_grade(text)
                    grade_levels.append(grade)
                except:
                    continue

            if grade_levels:
                metrics['grade_level_mean'] = np.mean(grade_levels)
                metrics['grade_level_std'] = np.std(grade_levels)
            else:
                metrics['grade_level_mean'] = 0
                metrics['grade_level_std'] = 0

        except Exception as e:
            print(f"      Error in readability: {str(e)[:50]}")
            metrics = {
                'flesch_reading_ease_mean': 0,
                'flesch_reading_ease_std': 0,
                'grade_level_mean': 0,
                'grade_level_std': 0
            }

        return metrics

    def calculate_repetition(self, texts: List[str]) -> Dict:
        """Calculate repetition metrics with error handling"""
        metrics = {}

        try:
            for n in [2, 3, 4]:
                all_ngrams = []
                for text in texts[:10]:
                    tokens = text.split()[:100]
                    ngrams = [' '.join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
                    all_ngrams.extend(ngrams)

                if all_ngrams:
                    unique_ratio = len(set(all_ngrams)) / len(all_ngrams)
                    metrics[f'{n}gram_unique_ratio'] = unique_ratio
                else:
                    metrics[f'{n}gram_unique_ratio'] = 1.0

        except Exception as e:
            print(f"      Error in repetition: {str(e)[:50]}")
            metrics = {
                '2gram_unique_ratio': 1.0,
                '3gram_unique_ratio': 1.0,
                '4gram_unique_ratio': 1.0
            }

        return metrics

    def calculate_all_metrics(self, texts: List[str]) -> Dict:
        """Calculate all metrics for a set of texts with comprehensive error handling"""
        all_metrics = {}

        try:
            all_metrics['num_texts'] = len(texts)
            all_metrics['avg_length_words'] = np.mean([len(t.split()) for t in texts])
            all_metrics['std_length_words'] = np.std([len(t.split()) for t in texts])

            all_metrics.update(self.calculate_lexical_diversity(texts))
            all_metrics.update(self.calculate_syntactic_diversity(texts))
            all_metrics.update(self.calculate_semantic_coherence(texts))
            all_metrics.update(self.calculate_readability(texts))
            all_metrics.update(self.calculate_repetition(texts))

        except Exception as e:
            print(f"      Error in calculate_all_metrics: {str(e)[:50]}")
            all_metrics = {
                'num_texts': len(texts),
                'avg_length_words': 0,
                'std_length_words': 0,
                'semantic_diversity': 0
            }

        return all_metrics

# Initialize metrics calculator
print("\nInitializing comprehensive metrics calculator...")
metrics_calculator = ComprehensiveMetrics(device=device)
print("Metrics calculator ready")

# Test metrics calculator
print("\nTesting metrics calculator...")
test_texts = ["This is a test sentence.", "Another test sentence here."]
test_metrics = metrics_calculator.calculate_all_metrics(test_texts)
if test_metrics['num_texts'] == 2:
    print("Metrics calculator test passed")
else:
    print("Metrics calculator may have issues, but continuing...")

# %% [markdown]
# ## 9. Main Generation Pipeline (Metrics Calculated After Generation)

# %%
print("\nStarting main generation pipeline...")
print(f"   This will take several hours. Starting at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Initialize results storage
all_results = []
generation_metadata = []

# Track start time
pipeline_start_time = time.time()

# Process each model
for model_idx, (model_key, model_config) in enumerate(MODEL_CONFIGS.items()):
    print(f"\n{'='*60}")
    print(f"[{model_idx+1}/{len(MODEL_CONFIGS)}] Processing {model_key}")
    print(f"{'='*60}")

    model_start = time.time()

    # Load model
    model, tokenizer = load_model_and_tokenizer(model_config)

    if model is None:
        print(f"   Could not load {model_key}, skipping...")
        continue

    # Process each decoding strategy
    for strategy_idx, (strategy_name, strategy_params) in enumerate(DECODING_STRATEGIES.items()):
        print(f"\n   Strategy [{strategy_idx+1}/{len(DECODING_STRATEGIES)}]: {strategy_name}")
        source_key = f"{model_key}_{strategy_name}"
        strategy_start = time.time()

        # Prepare prompts
        prompt_ids = list(selected_data.keys())
        prompts = [selected_data[pid]['prompt'] for pid in prompt_ids]

        # Generate stories in batches
        print(f"      Generating {len(prompts)} prompts × 10 samples = {len(prompts)*10} stories")
        generated_stories = generate_stories_batch(
            prompts,
            prompt_ids,
            model,
            tokenizer,
            model_config,
            strategy_params,
            num_samples=10
        )

        # Store results
        stories_generated = 0
        for prompt_id, prompt in zip(prompt_ids, prompts):
            stories = generated_stories.get(prompt, [])

            assert len(stories) == 10, f"Expected 10 stories, got {len(stories)} for prompt {prompt_id}"

            for story_idx, story in enumerate(stories):
                story_length = len(tokenizer.encode(story, add_special_tokens=False))

                result = {
                    'prompt_id': prompt_id,
                    'prompt': prompt,
                    'source': source_key,
                    'model': model_key,
                    'strategy': strategy_name,
                    'story_index': story_idx,
                    'story': story,
                    'story_length_tokens': story_length,
                    'story_length_words': len(story.split()),
                    'timestamp': datetime.now().isoformat()
                }
                all_results.append(result)
                stories_generated += 1

        # Store basic metadata
        generation_metadata.append({
            'source': source_key,
            'model': model_key,
            'strategy': strategy_name,
            'stories_generated': stories_generated,
            'generation_time': time.time() - strategy_start
        })

        # Save checkpoint
        checkpoint_df = pd.DataFrame(all_results)
        checkpoint_path = os.path.join(SAVE_DIR, 'checkpoints', f'checkpoint_{source_key}.parquet')
        checkpoint_df.to_parquet(checkpoint_path)
        print(f"      Checkpoint saved: {checkpoint_path}")

        strategy_time = time.time() - strategy_start
        print(f"      Completed in {strategy_time/60:.1f} minutes")
        print(f"      Generated {stories_generated} stories")

    # Model completion
    model_time = time.time() - model_start
    print(f"\n   {model_key} completed in {model_time/60:.1f} minutes")

    # Clear model from memory
    del model
    del tokenizer
    torch.cuda.empty_cache()
    gc.collect()

    # Estimate remaining time
    elapsed = time.time() - pipeline_start_time
    models_done = model_idx + 1
    if models_done < len(MODEL_CONFIGS):
        avg_per_model = elapsed / models_done
        remaining = avg_per_model * (len(MODEL_CONFIGS) - models_done)
        print(f"   Estimated remaining: {remaining/60:.1f} minutes ({remaining/3600:.1f} hours)")

print(f"\nAll models completed. Total generation time: {(time.time() - pipeline_start_time)/60:.1f} minutes")

# %% [markdown]
# ## 10. Add Human Stories to Results

# %%
print("\nAdding human stories to results...")

human_stories_added = 0
for prompt_id, prompt_data in selected_data.items():
    prompt_text = prompt_data['prompt']
    human_stories = prompt_data['human_stories']
    human_lengths = prompt_data['human_story_lengths']

    for story_idx, (story, length) in enumerate(zip(human_stories, human_lengths)):
        result = {
            'prompt_id': prompt_id,
            'prompt': prompt_text,
            'source': 'human',
            'model': 'human',
            'strategy': 'human',
            'story_index': story_idx,
            'story': story,
            'story_length_tokens': length,
            'story_length_words': len(story.split()),
            'timestamp': datetime.now().isoformat()
        }
        all_results.append(result)
        human_stories_added += 1

print(f"Added {human_stories_added} human stories")
print(f"   Total stories in dataset: {len(all_results)}")

# Convert to DataFrame
results_df = pd.DataFrame(all_results)

# Save complete results
results_path = os.path.join(SAVE_DIR, 'results', 'all_stories_complete.parquet')
results_df.to_parquet(results_path)
print(f"Saved complete results to: {results_path}")

# Also save as CSV for easier inspection
csv_path = os.path.join(SAVE_DIR, 'results', 'all_stories_complete.csv')
results_df.to_csv(csv_path, index=False)
print(f"Saved CSV version to: {csv_path}")

# %% [markdown]
# ## 11. Calculate Comprehensive Metrics for All Sources

# %%
print("\nCalculating comprehensive metrics for all sources...")

metrics_results = []

for source in tqdm(results_df['source'].unique(), desc="Calculating metrics"):
    source_stories = results_df[results_df['source'] == source]['story'].tolist()

    sample_stories = source_stories[:100]

    print(f"   Processing {source}: {len(sample_stories)} stories")

    try:
        source_metrics = metrics_calculator.calculate_all_metrics(sample_stories)
        source_metrics['source'] = source
        source_metrics['model'] = source.split('_')[0] if '_' in source else source
        source_metrics['strategy'] = '_'.join(source.split('_')[1:]) if '_' in source else 'human'
        metrics_results.append(source_metrics)
    except Exception as e:
        print(f"   Error calculating metrics for {source}: {str(e)[:100]}")
        metrics_results.append({
            'source': source,
            'model': source.split('_')[0] if '_' in source else source,
            'strategy': '_'.join(source.split('_')[1:]) if '_' in source else 'human',
            'semantic_diversity': 0,
            'num_texts': len(sample_stories)
        })

# Convert to DataFrame
metrics_df = pd.DataFrame(metrics_results)

# Save metrics
metrics_path = os.path.join(SAVE_DIR, 'results', 'comprehensive_metrics.csv')
metrics_df.to_csv(metrics_path, index=False)
print(f"Saved comprehensive metrics to: {metrics_path}")

# %% [markdown]
# ## 12. Calculate Uncertainty Metrics

# %%
print("\nCalculating uncertainty metrics...")

def calculate_diversity_for_source_safe(stories: List[str], metric_type: str = 'semantic') -> float:
    """Calculate diversity for a set of stories with error handling"""
    if len(stories) < 2:
        return 0.0

    try:
        if metric_type == 'semantic':
            with torch.no_grad():
                embeddings = metrics_calculator.sentence_model.encode(
                    stories[:50],
                    convert_to_tensor=True,
                    show_progress_bar=False
                )
                embeddings_np = embeddings.cpu().numpy()
                similarities = cosine_similarity(embeddings_np)
                upper_tri = np.triu_indices(len(embeddings_np), k=1)
                diversity = 1 - np.mean(similarities[upper_tri])
                return diversity

        elif metric_type == 'lexical':
            vocabs = [set(story.lower().split()[:200]) for story in stories[:50]]
            overlaps = []
            for i in range(len(vocabs)):
                for j in range(i+1, len(vocabs)):
                    intersection = len(vocabs[i] & vocabs[j])
                    union = len(vocabs[i] | vocabs[j])
                    overlap = intersection / union if union > 0 else 0
                    overlaps.append(overlap)
            diversity = 1 - np.mean(overlaps) if overlaps else 0
            return diversity

        elif metric_type == 'syntactic':
            pos_patterns = []
            for story in stories[:20]:
                try:
                    tokens = metrics_calculator.safe_nltk_tokenize(story[:500])
                    pos_tags = metrics_calculator.safe_pos_tag(tokens[:50])
                    pattern = ' '.join([tag for _, tag in pos_tags[:30]])
                    pos_patterns.append(pattern)
                except:
                    pos_patterns.append("")

            unique_patterns = len(set(pos_patterns))
            diversity = unique_patterns / len(pos_patterns) if pos_patterns else 0
            return diversity

    except Exception as e:
        print(f"      Error in diversity calculation ({metric_type}): {str(e)[:50]}")
        return 0.0

    return 0.0

# Calculate uncertainty for each prompt and source
uncertainty_results = []

for prompt_id in tqdm(selected_data.keys(), desc="Calculating uncertainties"):
    prompt_data = results_df[results_df['prompt_id'] == prompt_id]

    human_stories = prompt_data[prompt_data['source'] == 'human']['story'].tolist()

    if len(human_stories) != 10:
        continue

    human_semantic_div = calculate_diversity_for_source_safe(human_stories, 'semantic')
    human_lexical_div = calculate_diversity_for_source_safe(human_stories, 'lexical')
    human_syntactic_div = calculate_diversity_for_source_safe(human_stories, 'syntactic')

    for source in prompt_data['source'].unique():
        if source == 'human':
            continue

        model_stories = prompt_data[prompt_data['source'] == source]['story'].tolist()

        if len(model_stories) != 10:
            continue

        model_semantic_div = calculate_diversity_for_source_safe(model_stories, 'semantic')
        model_lexical_div = calculate_diversity_for_source_safe(model_stories, 'lexical')
        model_syntactic_div = calculate_diversity_for_source_safe(model_stories, 'syntactic')

        combined_stories = human_stories[:5] + model_stories[:5]
        cross_semantic_div = calculate_diversity_for_source_safe(combined_stories, 'semantic')

        semantic_calibration = abs(model_semantic_div - human_semantic_div)
        lexical_calibration = abs(model_lexical_div - human_lexical_div)
        syntactic_calibration = abs(model_syntactic_div - human_syntactic_div)

        uncertainty_results.append({
            'prompt_id': prompt_id,
            'source': source,
            'model': source.split('_')[0],
            'strategy': '_'.join(source.split('_')[1:]),
            'human_semantic_diversity': human_semantic_div,
            'human_lexical_diversity': human_lexical_div,
            'human_syntactic_diversity': human_syntactic_div,
            'model_semantic_diversity': model_semantic_div,
            'model_lexical_diversity': model_lexical_div,
            'model_syntactic_diversity': model_syntactic_div,
            'cross_semantic_diversity': cross_semantic_div,
            'semantic_calibration_error': semantic_calibration,
            'lexical_calibration_error': lexical_calibration,
            'syntactic_calibration_error': syntactic_calibration,
            'overall_calibration_error': (semantic_calibration + lexical_calibration + syntactic_calibration) / 3
        })

uncertainty_df = pd.DataFrame(uncertainty_results)

# Save uncertainty analysis
uncertainty_path = os.path.join(SAVE_DIR, 'results', 'uncertainty_analysis.csv')
uncertainty_df.to_csv(uncertainty_path, index=False)
print(f"Saved uncertainty analysis to: {uncertainty_path}")

# %% [markdown]
# ## 13. Statistical Analysis and Key Findings

# %%
print("\nSTATISTICAL ANALYSIS")
print("="*70)

# Aggregate by model and strategy
model_performance = uncertainty_df.groupby(['model', 'strategy']).agg({
    'semantic_calibration_error': ['mean', 'std'],
    'lexical_calibration_error': ['mean', 'std'],
    'syntactic_calibration_error': ['mean', 'std'],
    'overall_calibration_error': ['mean', 'std'],
    'model_semantic_diversity': ['mean', 'std'],
    'model_lexical_diversity': ['mean', 'std'],
    'model_syntactic_diversity': ['mean', 'std']
}).round(4)

print("\nMODEL PERFORMANCE RANKING (by overall calibration error):")
print("-"*70)

# Best configurations
best_configs = uncertainty_df.groupby('source')['overall_calibration_error'].mean().sort_values()
for rank, (source, error) in enumerate(best_configs.head(10).items(), 1):
    model = source.split('_')[0]
    strategy = '_'.join(source.split('_')[1:])
    print(f"{rank:2d}. {model:20s} + {strategy:15s} : {error:.4f}")

# Human baseline
human_data = results_df[results_df['source'] == 'human']
human_diversity_stats = {
    'semantic': [],
    'lexical': [],
    'syntactic': []
}

for prompt_id in selected_data.keys():
    human_stories = human_data[human_data['prompt_id'] == prompt_id]['story'].tolist()
    if len(human_stories) == 10:
        human_diversity_stats['semantic'].append(calculate_diversity_for_source_safe(human_stories, 'semantic'))
        human_diversity_stats['lexical'].append(calculate_diversity_for_source_safe(human_stories, 'lexical'))
        human_diversity_stats['syntactic'].append(calculate_diversity_for_source_safe(human_stories, 'syntactic'))

print(f"\nHUMAN BASELINE DIVERSITY:")
print(f"   Semantic:  {np.mean(human_diversity_stats['semantic']):.4f} ± {np.std(human_diversity_stats['semantic']):.4f}")
print(f"   Lexical:   {np.mean(human_diversity_stats['lexical']):.4f} ± {np.std(human_diversity_stats['lexical']):.4f}")
print(f"   Syntactic: {np.mean(human_diversity_stats['syntactic']):.4f} ± {np.std(human_diversity_stats['syntactic']):.4f}")

# Model comparison
print(f"\nMODEL COMPARISON:")
for model_name in MODEL_CONFIGS.keys():
    model_data = uncertainty_df[uncertainty_df['model'] == model_name]
    if len(model_data) > 0:
        print(f"\n   {model_name}:")
        print(f"      Semantic diversity:  {model_data['model_semantic_diversity'].mean():.4f}")
        print(f"      Calibration error:   {model_data['overall_calibration_error'].mean():.4f}")

        model_stories_all = results_df[results_df['model'] == model_name]
        if len(model_stories_all) > 0:
            model_divs = model_data['model_semantic_diversity'].values[:100]
            human_divs = human_diversity_stats['semantic'][:100]

            if len(model_divs) > 0 and len(human_divs) > 0:
                t_stat, p_value = stats.ttest_ind(model_divs, human_divs)
                print(f"      vs Human (p-value): {p_value:.6f} {'(ns)' if p_value > 0.05 else '(*)'}")

