In [None]:
# Credal Set Analysis for Uncertainty Quantification in NLG
# Implementation for analyzing uncertainty in open-ended text generation

# %% [markdown]
# # Credal Set Analysis for Uncertainty Quantification in Open-Ended Text Generation
#
# This notebook implements credal set analysis with comprehensive visualizations
# for the paper: "Disentangling Aleatoric and Epistemic Uncertainty in Open-Ended Text Generation"

# %% [markdown]
# ## 1. Setup and Imports

# %%
print("Installing required packages...")

# Core packages
!pip install -q --upgrade pip
!pip install -q pandas numpy matplotlib seaborn scipy
!pip install -q scikit-learn plotly
!pip install -q sentence-transformers umap-learn
!pip install -q shapely  # For geometric operations

import os
import sys
import json
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.spatial import ConvexHull, Delaunay
from scipy.spatial.distance import cdist, directed_hausdorff
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import plotly.graph_objects as go
import plotly.express as px
from typing import List, Dict, Tuple, Optional, Set
from collections import defaultdict, Counter
from datetime import datetime
import hashlib
from tqdm import tqdm
from itertools import combinations
import time
import pickle

# For embeddings and analysis
from sentence_transformers import SentenceTransformer
import umap

print("Package import complete")

# Set publication-quality style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 14
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False

# Define color palette for publication
COLORS = {
    'human': '#2E86AB',      # Deep blue
    'base': '#A23B72',        # Burgundy
    'instruct': '#F18F01',    # Orange
    'primary': '#2E86AB',
    'secondary': '#A23B72',
    'tertiary': '#F18F01',
    'quaternary': '#C73E1D',
    'accent': '#6A994E'
}

# Model-specific colors
MODEL_COLORS = {
    'human': '#2E86AB',
    'GPT2-XL': '#A23B72',
    'Gemma-2B': '#F18F01',
    'Mistral-7B-Instruct': '#C73E1D',
    'Llama-3.1-8B-Instruct': '#6A994E'
}


# %% [markdown]
# ## 2. Configure Paths and Load Data

# %%
# Mount storage (modify as needed for your environment)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Define paths - modify these for your setup
BASE_DIR = '/content/drive/MyDrive/analysis_data'  # Change to your directory
SAVE_DIR = os.path.join(BASE_DIR, 'uq_analysis')
RESULTS_DIR = os.path.join(SAVE_DIR, 'results')
FIGURES_DIR = os.path.join(SAVE_DIR, 'figures')
CREDAL_DIR = os.path.join(SAVE_DIR, 'credal_analysis')

# Create credal analysis directory
os.makedirs(CREDAL_DIR, exist_ok=True)

print(f"Working directory: {SAVE_DIR}")
print(f"Loading results from: {RESULTS_DIR}")

# Load the main results
try:
    # Load uncertainty analysis
    uncertainty_df = pd.read_csv(os.path.join(RESULTS_DIR, 'uncertainty_analysis.csv'))
    print(f"Loaded uncertainty analysis: {uncertainty_df.shape}")

    # Load comprehensive metrics
    metrics_df = pd.read_csv(os.path.join(RESULTS_DIR, 'comprehensive_metrics.csv'))
    print(f"Loaded comprehensive metrics: {metrics_df.shape}")

    # Load all stories
    stories_df = pd.read_parquet(os.path.join(RESULTS_DIR, 'all_stories_complete.parquet'))
    print(f"Loaded stories dataset: {stories_df.shape}")

    # Load selected prompts metadata
    with open(os.path.join(SAVE_DIR, 'selected_prompts_verified.json'), 'r') as f:
        selected_prompts = json.load(f)
    print(f"Loaded {len(selected_prompts)} prompts metadata")

except Exception as e:
    print(f"Error loading data: {e}")
    raise

# %% [markdown]
# ## 3. Pre-compute Embeddings

# %%
print("Pre-computing embeddings...")

# Check if embeddings are already cached
embeddings_cache_path = os.path.join(CREDAL_DIR, 'embeddings_cache.pkl')

if os.path.exists(embeddings_cache_path):
    print("Loading cached embeddings...")
    with open(embeddings_cache_path, 'rb') as f:
        story_embeddings = pickle.load(f)
    print(f"Loaded {len(story_embeddings)} cached embeddings")
else:
    print("Computing embeddings...")
    sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

    # Get unique stories and compute embeddings
    unique_stories = stories_df['story'].unique()
    print(f"Computing embeddings for {len(unique_stories)} unique stories...")

    # Batch encode for efficiency
    batch_size = 128
    all_embeddings = []

    for i in tqdm(range(0, len(unique_stories), batch_size), desc="Encoding batches"):
        batch = unique_stories[i:i+batch_size]
        batch_embeddings = sentence_model.encode(batch, show_progress_bar=False)
        all_embeddings.extend(batch_embeddings)

    # Create dictionary mapping
    story_embeddings = {story: emb for story, emb in zip(unique_stories, all_embeddings)}

    # Cache for future use
    with open(embeddings_cache_path, 'wb') as f:
        pickle.dump(story_embeddings, f)
    print(f"Computed and cached {len(story_embeddings)} embeddings")

# %% [markdown]
# ## 4. Credal Set Framework

# %%
class CredalSetAnalyzer:
    """
    Credal set framework with fast computation and uncertainty decomposition.
    """

    def __init__(self, story_embeddings: Dict, seed=42):
        """
        Initialize credal set analyzer.

        Args:
            story_embeddings: Pre-computed embeddings dictionary
            seed: Random seed for reproducibility
        """
        self.seed = seed
        np.random.seed(seed)
        self.story_embeddings = story_embeddings
        self._cache = {}  # Cache for repeated computations

    def compute_diversity_metrics(self, stories: List[str]) -> Dict[str, np.ndarray]:
        """
        Compute diversity metrics using pre-computed embeddings.
        """
        if len(stories) < 2:
            return {
                'semantic': np.array([0.0, 0.0, 0.0]),
                'lexical': np.array([0.0, 0.0, 0.0]),
                'syntactic': np.array([0.0, 0.0, 0.0]),
                'combined': np.zeros(9)
            }

        diversity_metrics = {}

        # 1. Semantic Diversity
        embeddings = []
        for story in stories:
            if story in self.story_embeddings:
                embeddings.append(self.story_embeddings[story])

        if len(embeddings) >= 2:
            embeddings = np.array(embeddings)

            # Fast pairwise cosine similarity
            norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
            normalized = embeddings / norms
            cos_sim_matrix = np.dot(normalized, normalized.T)

            # Extract upper triangle
            mask = np.triu(np.ones_like(cos_sim_matrix), k=1).astype(bool)
            semantic_distances = 1 - cos_sim_matrix[mask]

            diversity_metrics['semantic'] = np.array([
                np.mean(semantic_distances),
                np.std(semantic_distances),
                np.percentile(semantic_distances, 75) - np.percentile(semantic_distances, 25)
            ])
        else:
            diversity_metrics['semantic'] = np.zeros(3)

        # 2. Lexical Diversity
        vocabs = [set(story.lower().split()) for story in stories]

        all_tokens = ' '.join(stories).lower().split()
        ttr = len(set(all_tokens)) / len(all_tokens) if all_tokens else 0

        jaccard_distances = []
        for i, j in combinations(range(len(vocabs)), 2):
            intersection = len(vocabs[i] & vocabs[j])
            union = len(vocabs[i] | vocabs[j])
            if union > 0:
                jaccard_distances.append(1 - intersection / union)

        if jaccard_distances:
            diversity_metrics['lexical'] = np.array([
                ttr,
                np.mean(jaccard_distances),
                np.std(jaccard_distances)
            ])
        else:
            diversity_metrics['lexical'] = np.array([ttr, 0, 0])

        # 3. Syntactic Diversity
        all_lengths = []
        for story in stories:
            lengths = [len(s.split()) for s in story.split('.') if s.strip()]
            all_lengths.extend(lengths)

        if all_lengths:
            diversity_metrics['syntactic'] = np.array([
                np.std(all_lengths),
                np.percentile(all_lengths, 75) - np.percentile(all_lengths, 25),
                max(all_lengths) - min(all_lengths) if all_lengths else 0
            ])
        else:
            diversity_metrics['syntactic'] = np.zeros(3)

        # 4. Combined
        diversity_metrics['combined'] = np.concatenate([
            diversity_metrics['semantic'],
            diversity_metrics['lexical'],
            diversity_metrics['syntactic']
        ])

        return diversity_metrics


    def construct_credal_set(self, diversity_vectors: List[np.ndarray],
                            min_points: int = 4) -> Optional[Dict]:
        """
        Construct credal set with pre-computed statistics.
        """
        if len(diversity_vectors) < min_points:
            return None

        points = np.array(diversity_vectors)

        # Standardize
        scaler = StandardScaler()
        points_scaled = scaler.fit_transform(points)

        try:
            # Add minimal noise for numerical stability
            noise = np.random.normal(0, 1e-8, points_scaled.shape)
            points_noisy = points_scaled + noise

            hull = ConvexHull(points_noisy)

            # Store necessary information
            result = {
                'hull': hull,
                'points_scaled': points_scaled,
                'points_original': points,
                'centroid_scaled': np.mean(points_scaled, axis=0),
                'centroid_original': np.mean(points, axis=0),
                'vertices_scaled': points_scaled[hull.vertices],
                'vertices_indices': hull.vertices,
                'volume': hull.volume,
                'scaler': scaler,
                'n_points': len(points)
            }

            return result

        except Exception as e:
            print(f"Warning: Could not compute convex hull: {e}")
            return None


    def compute_calibration_metrics(self, credal1: Dict, credal2: Dict) -> Dict[str, float]:
        """
        Compute calibration metrics between two credal sets.
        """
        if credal1 is None or credal2 is None:
            return {
                'overlap': 0.0,
                'centroid_dist': float('inf'),
                'hausdorff': float('inf'),
                'volume_ratio': 0.0
            }

        metrics = {}

        # Use cached computation if available
        cache_key = (id(credal1), id(credal2))
        if cache_key in self._cache:
            return self._cache[cache_key]

        # Centroid distance (normalized)
        c1 = credal1['centroid_original']
        c2 = credal2['centroid_original']
        scale = np.mean([np.std(credal1['points_original']),
                         np.std(credal2['points_original'])])
        metrics['centroid_dist'] = np.linalg.norm(c1 - c2) / scale if scale > 0 else 0

        # Overlap approximation
        v1 = credal1['vertices_scaled'][:min(50, len(credal1['vertices_scaled']))]
        v2 = credal2['vertices_scaled'][:min(50, len(credal2['vertices_scaled']))]


        if len(v1) > 0 and len(v2) > 0:
            min_dists = np.min(cdist(v1, v2), axis=1)
            metrics['overlap'] = np.mean(min_dists < 1.0)
        else:
            metrics['overlap'] = 0.0

        # Hausdorff distance
        if len(v1) > 0 and len(v2) > 0:
            metrics['hausdorff'] = max(
                np.max(np.min(cdist(v1, v2), axis=1)),
                np.max(np.min(cdist(v2, v1), axis=1))
            )
        else:
            metrics['hausdorff'] = float('inf')


        # Volume ratio
        metrics['volume_ratio'] = credal1['volume'] / credal2['volume'] if credal2['volume'] > 0 else 0

        # Cache result
        self._cache[cache_key] = metrics

        return metrics


    def decompose_uncertainty(self, model_name: str,
                              credal_sets: Dict,
                              human_credal: Dict = None) -> Dict:
        """
        Decompose uncertainty into epistemic and aleatoric components.
        """
        model_sources = [s for s in credal_sets.keys() if s.startswith(model_name)]

        if len(model_sources) < 2:
            return {
                'epistemic_uncertainty': 0.0,
                'aleatoric_uncertainty': 0.0,
                'total_uncertainty': 0.0,
                'epistemic_ratio': 0.0,
                'n_strategies': len(model_sources)
            }

        div_type = 'combined'

        # Collect credal data
        credal_data_list = []
        for source in model_sources:
            if source in credal_sets and div_type in credal_sets[source]:
                credal_data_list.append(credal_sets[source][div_type])

        valid_credal = [c for c in credal_data_list if c is not None]

        if len(valid_credal) < 2:
            return {
                'epistemic_uncertainty': 0.0,
                'aleatoric_uncertainty': 0.0,
                'total_uncertainty': 0.0,
                'epistemic_ratio': 0.0,
                'n_strategies': len(model_sources)
            }

        # Combine all points for proper scaling
        all_original_points = []
        for c in valid_credal:
            all_original_points.append(c['points_original'])

        combined_points = np.vstack(all_original_points)
        common_scaler = StandardScaler()
        all_scaled = common_scaler.fit_transform(combined_points)

        # Compute centroids for each strategy
        centroids_scaled = []
        start_idx = 0
        for original in all_original_points:
            end_idx = start_idx + len(original)
            centroid = np.mean(all_scaled[start_idx:end_idx], axis=0)
            centroids_scaled.append(centroid)
            start_idx = end_idx


        # Epistemic: Variance across strategy centroids
        centroids_array = np.array(centroids_scaled)
        epistemic_uncertainty = np.mean(np.var(centroids_array, axis=0))

        # Aleatoric: Average within-strategy variance
        within_variances = []
        start_idx = 0
        for original in all_original_points:
            end_idx = start_idx + len(original)
            strategy_points = all_scaled[start_idx:end_idx]
            within_var = np.mean(np.var(strategy_points, axis=0))
            within_variances.append(within_var)
            start_idx = end_idx

        aleatoric_uncertainty = np.mean(within_variances)

        # Total uncertainty
        total_uncertainty = epistemic_uncertainty + aleatoric_uncertainty

        # Epistemic ratio
        epistemic_ratio = epistemic_uncertainty / total_uncertainty if total_uncertainty > 0 else 0

        result = {
            'epistemic_uncertainty': epistemic_uncertainty,
            'aleatoric_uncertainty': aleatoric_uncertainty,
            'total_uncertainty': total_uncertainty,
            'epistemic_ratio': epistemic_ratio,
            'n_strategies': len(model_sources),
            'centroid_dispersion': np.std(centroids_array),
            'mean_within_variance': aleatoric_uncertainty
        }

        # If human baseline provided, compute relative metrics
        if human_credal is not None and div_type in human_credal:
            human_points = human_credal[div_type]['points_original']
            human_variance = np.mean(np.var(human_points, axis=0))
            result['aleatoric_ratio_to_human'] = aleatoric_uncertainty / human_variance if human_variance > 0 else 0


        return result

# Initialize analyzer
credal_analyzer = CredalSetAnalyzer(story_embeddings)
print("Credal Set Analyzer initialized")


# %% [markdown]
# ## 5. Build Credal Sets

# %%
print("\nBuilding credal sets...")

start_time = time.time()

# Dictionary to store credal sets and diversity vectors
credal_sets = {}
diversity_vectors = defaultdict(lambda: defaultdict(list))

# Get unique sources
sources = stories_df['source'].unique()
print(f"Found {len(sources)} sources to analyze")

# Process each source
for source in tqdm(sources, desc="Building credal sets"):

    # Process each prompt
    for prompt_id in selected_prompts.keys():
        # Get stories for this prompt and source
        prompt_stories = stories_df[
            (stories_df['prompt_id'] == prompt_id) &
            (stories_df['source'] == source)
        ]['story'].tolist()

        if len(prompt_stories) >= 5:
            # Compute diversity vectors
            div_vectors = credal_analyzer.compute_diversity_metrics(prompt_stories[:10])

            # Store vectors
            for div_type in ['semantic', 'lexical', 'syntactic', 'combined']:
                diversity_vectors[source][div_type].append(div_vectors[div_type])

    # Construct credal sets
    credal_sets[source] = {}
    for div_type in ['semantic', 'lexical', 'syntactic', 'combined']:
        if diversity_vectors[source][div_type]:
            credal_data = credal_analyzer.construct_credal_set(
                diversity_vectors[source][div_type]
            )
            credal_sets[source][div_type] = credal_data

            if credal_data is not None:
                print(f"  {source} - {div_type}: {credal_data['n_points']} points, "
                      f"volume: {credal_data['volume']:.4f}")


elapsed = time.time() - start_time
print(f"\nBuilt all credal sets in {elapsed:.1f} seconds")


# %% [markdown]
# ## 6. Compute Calibration Metrics

# %%
print("\nComputing calibration metrics...")

start_time = time.time()

# Human baseline
human_credal_sets = credal_sets.get('human', {})
human_vectors = diversity_vectors.get('human', {})

calibration_results = []

# Process all sources
for source in tqdm(credal_sets.keys(), desc="Computing calibration"):
    if source == 'human':
        continue

    # Parse model and strategy
    if '_' in source:
        parts = source.split('_')
        model = parts[0]
        if 'temperature' in source:
            strategy = 'temperature'
            value = parts[-1]
        elif 'top_p' in source:
            strategy = 'top_p'
            value = parts[-1]
        elif 'top_k' in source:
            strategy = 'top_k'
            value = parts[-1]
        elif 'typical' in source:
            strategy = 'typical_p'
            value = parts[-1]
        else:
            strategy = '_'.join(parts[1:])
            value = ''
    else:
        model = source
        strategy = 'default'
        value = ''

    # Initialize metrics
    cal_metrics = {
        'source': source,
        'model': model,
        'strategy': strategy,
        'strategy_value': value
    }

    # Compute calibration for each diversity type
    for div_type in ['semantic', 'lexical', 'syntactic', 'combined']:
        if div_type in credal_sets[source] and div_type in human_credal_sets:
            model_credal = credal_sets[source][div_type]
            human_credal = human_credal_sets[div_type]

            if model_credal is not None and human_credal is not None:
                metrics = credal_analyzer.compute_calibration_metrics(
                    model_credal, human_credal
                )

                cal_metrics[f'{div_type}_calibration'] = metrics['overlap']
                cal_metrics[f'{div_type}_centroid_dist'] = metrics['centroid_dist']
                cal_metrics[f'{div_type}_hausdorff'] = metrics['hausdorff']
                cal_metrics[f'{div_type}_volume_ratio'] = metrics['volume_ratio']


    # Overall metrics
    cal_scores = []
    dist_scores = []
    for dt in ['semantic', 'lexical', 'syntactic']:
        if f'{dt}_calibration' in cal_metrics:
            cal_scores.append(cal_metrics[f'{dt}_calibration'])
        if f'{dt}_centroid_dist' in cal_metrics:
            dist_scores.append(cal_metrics[f'{dt}_centroid_dist'])

    cal_metrics['overall_calibration'] = np.mean(cal_scores) if cal_scores else 0.0
    cal_metrics['overall_distance'] = np.mean(dist_scores) if dist_scores else float('inf')

    calibration_results.append(cal_metrics)

# Convert to DataFrame
calibration_df = pd.DataFrame(calibration_results)
calibration_df = calibration_df.sort_values('overall_calibration', ascending=False)

elapsed = time.time() - start_time
print(f"\nCalibration computed in {elapsed:.1f} seconds")

print("\nBest calibrated configurations:")
print(calibration_df[['source', 'overall_calibration', 'overall_distance']].head(10).round(3))

# Save results
calibration_df.to_csv(os.path.join(CREDAL_DIR, 'calibration_metrics.csv'), index=False)

# %% [markdown]
# ## 7. Uncertainty Decomposition (Epistemic vs Aleatoric)

# %%
print("\nComputing uncertainty decomposition...")

uncertainty_decomp = []

# Get human baseline for aleatoric reference
human_credal = credal_sets.get('human', {})

for model_name in ['GPT2-XL', 'Gemma-2B', 'Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']:
    decomp = credal_analyzer.decompose_uncertainty(
        model_name, credal_sets, human_credal
    )
    decomp['model'] = model_name
    uncertainty_decomp.append(decomp)

uncertainty_decomp_df = pd.DataFrame(uncertainty_decomp)

print("\nUncertainty Decomposition (Epistemic vs Aleatoric):")
print(uncertainty_decomp_df[['model', 'epistemic_uncertainty', 'aleatoric_uncertainty',
                             'total_uncertainty', 'epistemic_ratio']].round(3))

uncertainty_decomp_df.to_csv(os.path.join(CREDAL_DIR, 'uncertainty_decomposition.csv'), index=False)


# %% [markdown]
# ## 8. Strategy and Model Analysis

# %%
print("\nAnalyzing best strategies and models...")

# 1. Best strategy per model
print("\n=== Best Decoding Strategy per Model ===")
best_strategy_per_model = []

for model in ['GPT2-XL', 'Gemma-2B', 'Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']:
    model_df = calibration_df[calibration_df['model'] == model]
    if len(model_df) > 0:
        best = model_df.iloc[0]
        best_strategy_per_model.append({
            'Model': model,
            'Best Strategy': f"{best['strategy']} {best['strategy_value']}".strip(),
            'Calibration': f"{best['overall_calibration']:.3f}",
            'Distance': f"{best['overall_distance']:.3f}"
        })

strategy_table = pd.DataFrame(best_strategy_per_model)
print(strategy_table.to_string(index=False))
strategy_table.to_csv(os.path.join(CREDAL_DIR, 'best_strategy_per_model.csv'), index=False)

# 2. Best model per strategy
print("\n=== Best Model per Decoding Strategy ===")
best_model_per_strategy = []

for strategy in calibration_df['strategy'].unique():
    strategy_df = calibration_df[calibration_df['strategy'] == strategy]
    if len(strategy_df) > 0:
        best = strategy_df.iloc[0]
        best_model_per_strategy.append({
            'Strategy': strategy,
            'Best Model': best['model'],
            'Calibration': f"{best['overall_calibration']:.3f}",
            'Distance': f"{best['overall_distance']:.3f}"
        })

model_table = pd.DataFrame(best_model_per_strategy)
print(model_table.to_string(index=False))
model_table.to_csv(os.path.join(CREDAL_DIR, 'best_model_per_strategy.csv'), index=False)


# 3. Strategy effectiveness across all models
print("\n=== Average Strategy Performance ===")
strategy_performance = calibration_df.groupby('strategy').agg({
    'overall_calibration': ['mean', 'std'],
    'overall_distance': ['mean', 'std']
}).round(3)
print(strategy_performance)


# %% [markdown]
# ## 9. Visualization 1: Credal Sets in PCA Space

# %%
print("\nCreating credal set visualization...")

# Select sources for visualization
viz_sources = ['human']
for model in ['GPT2-XL', 'Mistral-7B-Instruct']:
    model_sources = [s for s in credal_sets.keys() if s.startswith(model)]
    if model_sources:
        # Get best calibrated for this model
        model_df = calibration_df[calibration_df['model'] == model]
        if len(model_df) > 0:
            viz_sources.append(model_df.iloc[0]['source'])

print(f"Visualizing: {viz_sources}")

# Collect combined diversity vectors
all_vectors = []
all_labels = []

for source in viz_sources:
    if source in diversity_vectors and 'combined' in diversity_vectors[source]:
        vectors = diversity_vectors[source]['combined']
        all_vectors.extend(vectors)
        all_labels.extend([source] * len(vectors))


if all_vectors:
    # PCA for visualization
    all_vectors_array = np.array(all_vectors)
    pca = PCA(n_components=min(3, all_vectors_array.shape[1]), random_state=42)
    vectors_pca = pca.fit_transform(all_vectors_array)

    # Create figure
    fig = plt.figure(figsize=(18, 6))

    # Three 2D projections
    if vectors_pca.shape[1] >= 3:
        projections = [(0, 1, '1-2'), (0, 2, '1-3'), (1, 2, '2-3')]
    else:
        projections = [(0, 1, '1-2')]


    for idx, proj_data in enumerate(projections):
        ax = plt.subplot(1, len(projections), idx + 1)

        pc1, pc2, label = proj_data if len(proj_data) == 3 else (0, 1, '1-2')

        for i, source in enumerate(viz_sources):
            source_mask = np.array(all_labels) == source
            if np.any(source_mask):
                source_points = vectors_pca[source_mask][:, [pc1, pc2]]

                # Determine color
                if source == 'human':
                    color = COLORS['human']
                    label_text = 'Human'
                elif 'GPT2' in source:
                    color = MODEL_COLORS['GPT2-XL']
                    label_text = 'GPT2-XL'
                else:
                    color = MODEL_COLORS['Mistral-7B-Instruct']
                    label_text = 'Mistral-7B'


                # Plot points
                ax.scatter(source_points[:, 0], source_points[:, 1],
                          alpha=0.6, s=60, c=color,
                          label=label_text,
                          edgecolors='white', linewidth=0.5)

                # Add convex hull
                if len(source_points) >= 3:
                    try:
                        hull = ConvexHull(source_points)
                        for simplex in hull.simplices:
                            ax.plot(source_points[simplex, 0],
                                   source_points[simplex, 1],
                                   color=color, alpha=0.3, linewidth=2)
                        ax.fill(source_points[hull.vertices, 0],
                               source_points[hull.vertices, 1],
                               color=color, alpha=0.1)
                    except:
                        pass


        ax.set_xlabel(f'PC{pc1+1} ({pca.explained_variance_ratio_[pc1]:.1%})')
        ax.set_ylabel(f'PC{pc2+1} ({pca.explained_variance_ratio_[pc2]:.1%})')
        ax.set_title(f'Principal Components {pc1+1}-{pc2+1}')
        ax.grid(True, alpha=0.3, linewidth=0.5)
        ax.legend(frameon=True, fancybox=True, shadow=True)


    plt.suptitle('Credal Sets in Principal Component Space', fontsize=14, y=1.02)
    plt.tight_layout()

    fig_path = os.path.join(CREDAL_DIR, 'credal_sets_pca.png')
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved credal set visualization")


# %% [markdown]
# ## 10. Visualization 2: Uncertainty Decomposition

# %%
print("\nCreating uncertainty decomposition visualization...")

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Plot 1: Stacked bar chart of uncertainty components
ax1 = axes[0, 0]
models = uncertainty_decomp_df['model'].tolist()
epistemic = uncertainty_decomp_df['epistemic_uncertainty'].tolist()
aleatoric = uncertainty_decomp_df['aleatoric_uncertainty'].tolist()

x_pos = np.arange(len(models))
width = 0.6

bars1 = ax1.bar(x_pos, aleatoric, width, label='Aleatoric',
                color=COLORS['primary'], alpha=0.7)
bars2 = ax1.bar(x_pos, epistemic, width, bottom=aleatoric,
                label='Epistemic', color=COLORS['secondary'], alpha=0.7)

ax1.set_xlabel('Model', fontsize=11)
ax1.set_ylabel('Uncertainty', fontsize=11)
ax1.set_title('Uncertainty Decomposition', fontsize=12)
ax1.set_xticks(x_pos)
ax1.set_xticklabels([m.replace('-', '\n') for m in models], fontsize=9)
ax1.legend(frameon=True, fancybox=True)
ax1.grid(axis='y', alpha=0.3, linewidth=0.5)

# Add value labels
for i, (e, a) in enumerate(zip(epistemic, aleatoric)):
    ax1.text(i, a/2, f'{a:.2f}', ha='center', va='center', fontsize=9, color='white')
    ax1.text(i, a + e/2, f'{e:.2f}', ha='center', va='center', fontsize=9, color='white')


# Plot 2: Epistemic ratio
ax2 = axes[0, 1]
epistemic_ratios = uncertainty_decomp_df['epistemic_ratio'].tolist()

bars = ax2.bar(x_pos, epistemic_ratios, color=COLORS['tertiary'], alpha=0.7)
ax2.set_xlabel('Model', fontsize=11)
ax2.set_ylabel('Epistemic Ratio', fontsize=11)
ax2.set_title('Proportion of Epistemic Uncertainty', fontsize=12)
ax2.set_xticks(x_pos)
ax2.set_xticklabels([m.replace('-', '\n') for m in models], fontsize=9)
ax2.grid(axis='y', alpha=0.3, linewidth=0.5)
ax2.set_ylim([0, 1])

# Add value labels
for bar, val in zip(bars, epistemic_ratios):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.02,
            f'{val:.2%}', ha='center', va='bottom', fontsize=9)


# Plot 3: Calibration vs Model Size
ax3 = axes[1, 0]
model_sizes = {
    'GPT2-XL': 1.5,
    'Gemma-2B': 2.0,
    'Mistral-7B-Instruct': 7.0,
    'Llama-3.1-8B-Instruct': 8.0
}

# Get best calibration per model
best_cal_by_model = calibration_df.groupby('model')['overall_calibration'].max()

sizes = []
calibrations = []
colors_list = []
for model, size in model_sizes.items():
    if model in best_cal_by_model.index:
        sizes.append(size)
        calibrations.append(best_cal_by_model[model])
        colors_list.append(MODEL_COLORS[model])


scatter = ax3.scatter(sizes, calibrations, s=200, alpha=0.7, c=colors_list, edgecolors='white', linewidth=2)

# Add trend line
z = np.polyfit(sizes, calibrations, 1)
p = np.poly1d(z)
ax3.plot(sizes, p(sizes), "--", alpha=0.5, color='gray')

for i, model in enumerate(model_sizes.keys()):
    if model in best_cal_by_model.index:
        ax3.annotate(model.split('-')[0],
                    (model_sizes[model], best_cal_by_model[model]),
                    xytext=(5, 5), textcoords='offset points', fontsize=9)


ax3.set_xlabel('Model Size (Billions)', fontsize=11)
ax3.set_ylabel('Best Calibration Score', fontsize=11)
ax3.set_title('Calibration vs Model Size', fontsize=12)
ax3.grid(True, alpha=0.3, linewidth=0.5)


# Plot 4: Strategy performance heatmap
ax4 = axes[1, 1]

# Create strategy performance matrix
strategies = calibration_df['strategy'].unique()
models_list = ['GPT2-XL', 'Gemma-2B', 'Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']

strategy_matrix = np.zeros((len(strategies), len(models_list)))
for i in range(len(strategies)):
    for j in range(len(models_list)):
        mask = (calibration_df['strategy'] == strategies[i]) & (calibration_df['model'] == models_list[j])
        if mask.any():
            strategy_matrix[i, j] = calibration_df.loc[mask, 'overall_calibration'].values[0]
        else:
            strategy_matrix[i, j] = np.nan


im = ax4.imshow(strategy_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=0.6)
ax4.set_xticks(np.arange(len(models_list)))
ax4.set_yticks(np.arange(len(strategies)))
ax4.set_xticklabels([m.split('-')[0] for m in models_list], fontsize=9)
ax4.set_yticklabels(strategies, fontsize=9)
ax4.set_xlabel('Model', fontsize=11)
ax4.set_ylabel('Strategy', fontsize=11)
ax4.set_title('Strategy Performance Across Models', fontsize=12)

# Add colorbar
cbar = plt.colorbar(im, ax=ax4)
cbar.set_label('Calibration', fontsize=10)

# Add text annotations
for i in range(len(strategies)):
    for j in range(len(models_list)):
        if not np.isnan(strategy_matrix[i, j]):
            text = ax4.text(j, i, f'{strategy_matrix[i, j]:.2f}',
                          ha="center", va="center", color="black", fontsize=8)


plt.suptitle('Uncertainty Analysis and Model Performance', fontsize=14, y=1.02)
plt.tight_layout()

fig_path = os.path.join(CREDAL_DIR, 'uncertainty_analysis.png')
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved uncertainty analysis visualization")

# %% [markdown]
# ## 11. Visualization 3: Epistemic vs Aleatoric Trade-off

# %%
print("\nCreating epistemic-aleatoric trade-off visualization...")

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Scatter plot of epistemic vs aleatoric
ax1 = axes[0]

for i, row in uncertainty_decomp_df.iterrows():
    model = row['model']
    ax1.scatter(row['aleatoric_uncertainty'], row['epistemic_uncertainty'],
               s=300, alpha=0.7, c=MODEL_COLORS[model],
               edgecolors='white', linewidth=2, label=model)

    # Add text label
    ax1.annotate(model.split('-')[0],
                (row['aleatoric_uncertainty'], row['epistemic_uncertainty']),
                xytext=(5, 5), textcoords='offset points', fontsize=9)


ax1.set_xlabel('Aleatoric Uncertainty', fontsize=11)
ax1.set_ylabel('Epistemic Uncertainty', fontsize=11)
ax1.set_title('Epistemic vs Aleatoric Uncertainty Trade-off', fontsize=12)
ax1.grid(True, alpha=0.3, linewidth=0.5)
ax1.legend(frameon=True, fancybox=True, loc='best')

# Add diagonal line for equal uncertainty
max_val = max(ax1.get_xlim()[1], ax1.get_ylim()[1])
ax1.plot([0, max_val], [0, max_val], '--', alpha=0.3, color='gray', label='Equal')


# Plot 2: Radar chart of model characteristics
ax2 = axes[1]

# Prepare data for radar chart
categories = ['Calibration', 'Epistemic\nRatio', 'Volume\nVariance', 'N Strategies']
N = len(categories)

angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]

ax2 = plt.subplot(122, projection='polar')

for model in ['GPT2-XL', 'Gemma-2B', 'Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']:
    model_data = uncertainty_decomp_df[uncertainty_decomp_df['model'] == model].iloc[0]
    best_cal = calibration_df[calibration_df['model'] == model]['overall_calibration'].max()

    values = [
        best_cal,
        model_data['epistemic_ratio'],
        model_data.get('centroid_dispersion', 0) / 5,  # Normalize
        model_data['n_strategies'] / 5  # Normalize
    ]
    values += values[:1]

    ax2.plot(angles, values, 'o-', linewidth=2, label=model.split('-')[0],
            color=MODEL_COLORS[model], alpha=0.7)
    ax2.fill(angles, values, alpha=0.1, color=MODEL_COLORS[model])


ax2.set_xticks(angles[:-1])
ax2.set_xticklabels(categories, fontsize=10)
ax2.set_ylim(0, 1)
ax2.set_title('Model Characteristics Comparison', fontsize=12, pad=20)
ax2.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), frameon=True, fancybox=True)
ax2.grid(True, alpha=0.3, linewidth=0.5)


plt.suptitle('Uncertainty Components Analysis', fontsize=14, x=0.5, y=1.02)
plt.tight_layout()

fig_path = os.path.join(CREDAL_DIR, 'epistemic_aleatoric_tradeoff.png')
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved epistemic-aleatoric trade-off visualization")

# %% [markdown]
# ## 12. Statistical Analysis

# %%
print("\nStatistical Analysis...")

# Model size correlation
model_sizes = {
    'GPT2-XL': 1.5,
    'Gemma-2B': 2.0,
    'Mistral-7B-Instruct': 7.0,
    'Llama-3.1-8B-Instruct': 8.0
}

# Best calibration per model
best_cal_per_model = calibration_df.groupby('model')['overall_calibration'].max()

sizes = []
calibrations = []
for model, size in model_sizes.items():
    if model in best_cal_per_model.index:
        sizes.append(size)
        calibrations.append(best_cal_per_model[model])

if len(sizes) > 2:
    corr, p_val = stats.spearmanr(sizes, calibrations)
    print(f"\nModel Size vs Calibration:")
    print(f"  Spearman ρ = {corr:.3f}, p = {p_val:.3f}")

    # Linear regression
    slope, intercept, r_value, p_value, std_err = stats.linregress(sizes, calibrations)
    print(f"  Linear regression: R² = {r_value**2:.3f}, p = {p_value:.3f}")


# Base vs Instruction-tuned
base_models = ['GPT2-XL', 'Gemma-2B']
instruct_models = ['Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']

base_cal = calibration_df[calibration_df['model'].isin(base_models)]['overall_calibration'].values
instruct_cal = calibration_df[calibration_df['model'].isin(instruct_models)]['overall_calibration'].values


if len(base_cal) > 0 and len(instruct_cal) > 0:
    t_stat, p_val = stats.ttest_ind(base_cal, instruct_cal)
    print(f"\nBase vs Instruction-tuned Models:")
    print(f"  Base: μ = {np.mean(base_cal):.3f} ± {np.std(base_cal):.3f}")
    print(f"  Instruction-tuned: μ = {np.mean(instruct_cal):.3f} ± {np.std(instruct_cal):.3f}")
    print(f"  t({len(base_cal)+len(instruct_cal)-2}) = {t_stat:.2f}, p = {p_val:.3f}")
    print(f"  Effect size (Cohen's d): {(np.mean(base_cal) - np.mean(instruct_cal))/np.sqrt((np.var(base_cal) + np.var(instruct_cal))/2):.3f}")


# Epistemic vs Calibration
epistemic_by_model = uncertainty_decomp_df.set_index('model')['epistemic_ratio'].to_dict()
ep_vals = []
cal_vals = []
for model in epistemic_by_model.keys():
    if model in best_cal_per_model.index:
        ep_vals.append(epistemic_by_model[model])
        cal_vals.append(best_cal_per_model[model])


if len(ep_vals) > 2:
    corr_ep, p_val_ep = stats.spearmanr(ep_vals, cal_vals)
    print(f"\nEpistemic Ratio vs Calibration:")
    print(f"  Spearman ρ = {corr_ep:.3f}, p = {p_val_ep:.3f}")


# %% [markdown]
# ## 13. Generate Publication Tables

# %%
print("\nGenerating publication-ready tables...")

# Table 1: Human Diversity Baselines
print("\n=== Table 1: Human Diversity Baselines ===")
human_stats = []
for div_type in ['semantic', 'lexical', 'syntactic']:
    if 'human' in diversity_vectors and div_type in diversity_vectors['human']:
        vectors = np.array(diversity_vectors['human'][div_type])
        if len(vectors) > 0:
            # Assuming the first column is the mean metric
            human_stats.append({
                'Diversity Type': div_type.capitalize(),
                'Mean': f"{np.mean(vectors[:, 0]):.3f}",
                'Std Dev': f"{np.std(vectors[:, 0]):.3f}",
                'IQR': f"{np.percentile(vectors[:, 0], 75) - np.percentile(vectors[:, 0], 25):.3f}",
                'N': len(vectors)
            })

human_table = pd.DataFrame(human_stats)
print(human_table.to_string(index=False))
human_table.to_csv(os.path.join(CREDAL_DIR, 'table1_human_baselines.csv'), index=False)


# Table 2: Model Performance Summary
print("\n=== Table 2: Model Performance Summary ===")
model_summary = []
for model in ['GPT2-XL', 'Gemma-2B', 'Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']:
    model_df = calibration_df[calibration_df['model'] == model]
    decomp_row = uncertainty_decomp_df[uncertainty_decomp_df['model'] == model].iloc[0]

    model_summary.append({
        'Model': model,
        'Size (B)': model_sizes[model],
        'Best Calibration': f"{model_df['overall_calibration'].max():.3f}",
        'Mean Calibration': f"{model_df['overall_calibration'].mean():.3f}",
        'Epistemic': f"{decomp_row['epistemic_uncertainty']:.3f}",
        'Aleatoric': f"{decomp_row['aleatoric_uncertainty']:.3f}",
        'Epistemic Ratio': f"{decomp_row['epistemic_ratio']:.2%}"
    })

model_summary_df = pd.DataFrame(model_summary)
print(model_summary_df.to_string(index=False))
model_summary_df.to_csv(os.path.join(CREDAL_DIR, 'table2_model_summary.csv'), index=False)


# Table 3: Strategy Performance
print("\n=== Table 3: Strategy Performance Summary ===")
strategy_summary = calibration_df.groupby('strategy').agg({
    'overall_calibration': ['mean', 'std', 'max'],
    'overall_distance': 'mean'
}).round(3)
strategy_summary.columns = ['Mean Cal', 'Std Cal', 'Max Cal', 'Mean Dist']
strategy_summary = strategy_summary.sort_values('Mean Cal', ascending=False)
print(strategy_summary)
strategy_summary.to_csv(os.path.join(CREDAL_DIR, 'table3_strategy_summary.csv'))


# %% [markdown]
# ## 14. Summary Report

# %%
# Helper function for JSON serialization
def make_json_serializable(obj):
    """Convert numpy/special types to JSON-serializable formats."""
    if isinstance(obj, (np.bool_)):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: make_json_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [make_json_serializable(v) for v in obj]
    else:
        return obj


print("\n" + "="*70)
print("CREDAL SET ANALYSIS COMPLETE")
print("="*70)

total_time = time.time() - start_time
print(f"\nTotal analysis time: {total_time:.1f} seconds")


print("\nKEY FINDINGS:")

# Best overall configuration
best_overall = calibration_df.iloc[0]
print(f"\n1. Best Overall Configuration:")
print(f"   Model: {best_overall['source']}")
print(f"   Calibration: {best_overall['overall_calibration']:.3f}")
print(f"   Distance: {best_overall['overall_distance']:.3f}")


# Model comparison
base_models = ['GPT2-XL', 'Gemma-2B']
instruct_models = ['Mistral-7B-Instruct', 'Llama-3.1-8B-Instruct']

base_cal = calibration_df[calibration_df['model'].isin(base_models)]['overall_calibration'].values
instruct_cal = calibration_df[calibration_df['model'].isin(instruct_models)]['overall_calibration'].values

print(f"\n2. Model Type Comparison:")
print(f"   Base models: μ = {np.mean(base_cal):.3f}")
print(f"   Instruction-tuned: μ = {np.mean(instruct_cal):.3f}")
print(f"   Finding: Base models show {'better' if np.mean(base_cal) > np.mean(instruct_cal) else 'worse'} calibration")


# Uncertainty decomposition
print(f"\n3. Uncertainty Decomposition:")
for _, row in uncertainty_decomp_df.iterrows():
    print(f"   {row['model']}:")
    print(f"     - Epistemic: {row['epistemic_uncertainty']:.3f} ({row['epistemic_ratio']:.1%})")
    print(f"     - Aleatoric: {row['aleatoric_uncertainty']:.3f}")


# Best strategies
print(f"\n4. Optimal Strategies:")
print(f"   Overall best: {calibration_df.groupby('strategy')['overall_calibration'].mean().idxmax()}")
print(f"   Most consistent: {calibration_df.groupby('strategy')['overall_calibration'].std().idxmin()}")


# Save summary
summary = {
    'timestamp': datetime.now().isoformat(),
    'total_time_seconds': total_time,
    'methodology': 'Continuous Credal Sets with Uncertainty Decomposition',
    'key_findings': {
        'best_configuration': best_overall['source'],
        'best_calibration': float(best_overall['overall_calibration']),
        'model_size_correlation': float(corr) if 'corr' in locals() else None,
        'base_vs_instruct': {
            'base_mean': float(np.mean(base_cal)),
            'instruct_mean': float(np.mean(instruct_cal)),
            'p_value': float(p_val) if 'p_val' in locals() else None
        }
    },
    'uncertainty_decomposition': [
        {
            'model': row['model'],
            'epistemic': float(row['epistemic_uncertainty']),
            'aleatoric': float(row['aleatoric_uncertainty']),
            'epistemic_ratio': float(row['epistemic_ratio'])
        }
        for _, row in uncertainty_decomp_df.iterrows()
    ],
    'files_generated': {
        'tables': [
            'table1_human_baselines.csv',
            'table2_model_summary.csv',
            'table3_strategy_summary.csv',
            'best_strategy_per_model.csv',
            'best_model_per_strategy.csv'
        ],
        'figures': [
            'credal_sets_pca.png',
            'uncertainty_analysis.png',
            'epistemic_aleatoric_tradeoff.png'
        ],
        'data': [
            'calibration_metrics.csv',
            'uncertainty_decomposition.csv'
        ]
    }
}

# Clean for JSON serialization
summary_clean = make_json_serializable(summary)

summary_path = os.path.join(CREDAL_DIR, 'analysis_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary_clean, f, indent=2)


print(f"\nResults saved to: {CREDAL_DIR}")
print("\nANALYSIS COMPLETE")
print("="*70)