
## Topic Modeling for TMCL: LDA Implementation

This notebook implements Latent Dirichlet Allocation (LDA) topic modeling for the Topic-Modeled Curriculum Learning (TMCL) framework. We'll process both vision and NLP datasets to extract topic distributions that will serve as unsupervised difficulty metrics for curriculum learning.

### Mathematical Foundation of LDA

Latent Dirichlet Allocation is a generative probabilistic model that represents documents as mixtures of latent topics, where each topic is characterized by a distribution over words (or features). The generative process for a document $d$ is:

1. Choose topic proportions $\theta_d \sim \text{Dirichlet}(\alpha)$
2. For each word position $n$ in document $d$:
    - Choose a topic $z_{dn} \sim \text{Multinomial}(\theta_d)$
    - Choose a word $w_{dn} \sim \text{Multinomial}(\phi_{z_{dn}})$
 
 Where:
 - $\alpha$ is the Dirichlet prior parameter for document-topic distributions
 - $\phi_k$ is the word distribution for topic $k$
 - $\theta_d$ is the topic distribution for document $d$
 
 The inference goal is to estimate the posterior distribution $P(\theta, z | w, \alpha, \beta)$, which is typically done using variational inference or Gibbs sampling.
 
 For TMCL, we're interested in the document-topic distributions $\theta_d$ which represent how much each topic contributes to a sample. The entropy of this distribution will serve as our difficulty metric.



In [1]:

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle
import json
from pathlib import Path

# Plotting 
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx

# Scikit-learn imports
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.decomposition import NMF, LatentDirichletAllocation
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler 
from sklearn.metrics.pairwise import cosine_similarity

# Deep learning imports
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

# Hugging Face imports
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel


# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Base directory where datasets are stored
BASE_DIR = "data/experiment/101"
OUTPUT_DIR = "results/tmcl/topic_models"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Base directory: {BASE_DIR}")
print(f"Output directory: {OUTPUT_DIR}")



Base directory: data/experiment/101
Output directory: results/tmcl/topic_models



### 1. Data Loading and Preprocessing
 
We need to handle two different types of data:
 - **Text data (AG News, IMDb)**: Requires tokenization and vectorization
 - **Image data (CIFAR-10/100, MNIST, Fashion-MNIST)**: Requires feature extraction using pre-trained models
 
#### Let's create utility functions to load and preprocess each dataset type.


In [2]:

class DatasetLoader:
    """Utility class to load and preprocess datasets for TMCL."""
    
    def __init__(self, base_dir):
        self.base_dir = base_dir
        self.datasets = {}
        
    def load_text_dataset(self, dataset_name):
        """Load and preprocess text datasets (AG News, IMDb)."""
        print(f"\nLoading {dataset_name} text dataset...")
        
        # Load dataset from cache
        dataset = load_dataset(dataset_name, cache_dir=os.path.join(self.base_dir, dataset_name))
        
        # Convert to pandas DataFrame for easier processing
        train_df = pd.DataFrame(dataset['train'])
        test_df = pd.DataFrame(dataset['test'])
        
        print(f"Training samples: {len(train_df)}, Test samples: {len(test_df)}")
        
        return {
            'train': train_df,
            'test': test_df,
            'name': dataset_name
        }
    
    def load_image_dataset(self, dataset_name):
        """Load and preprocess image datasets (CIFAR, MNIST, Fashion-MNIST)."""
        print(f"\nLoading {dataset_name} image dataset...")
        
        # Load dataset from cache
        dataset = load_dataset(dataset_name, cache_dir=os.path.join(self.base_dir, dataset_name))
        
        # Convert to format suitable for feature extraction
        train_data = dataset['train']
        test_data = dataset['test']
        
        print(f"Training samples: {len(train_data)}, Test samples: {len(test_data)}")
        
        return {
            'train': train_data,
            'test': test_data,
            'name': dataset_name
        }
    
    def get_all_datasets(self):
        """Load all datasets for TMCL experiments."""
        text_datasets = ['ag_news', 'imdb']
        image_datasets = ['cifar10', 'cifar100', 'mnist', 'fashion_mnist']
        
        for name in text_datasets:
            self.datasets[name] = self.load_text_dataset(name)
        
        for name in image_datasets:
            self.datasets[name] = self.load_image_dataset(name)
        
        return self.datasets

# Load all datasets
loader = DatasetLoader(BASE_DIR)
datasets = loader.get_all_datasets()



Loading ag_news text dataset...
Training samples: 120000, Test samples: 7600

Loading imdb text dataset...
Training samples: 25000, Test samples: 25000

Loading cifar10 image dataset...
Training samples: 50000, Test samples: 10000

Loading cifar100 image dataset...
Training samples: 50000, Test samples: 10000

Loading mnist image dataset...
Training samples: 60000, Test samples: 10000

Loading fashion_mnist image dataset...
Training samples: 60000, Test samples: 10000



### 2. Feature Extraction Pipeline
#### 2.1 Text Feature Extraction

For text datasets, we'll use two approaches:
1. **Bag-of-Words (BoW)**: Simple count-based representation
2. **TF-IDF**: Term Frequency-Inverse Document Frequency weighted representation

The mathematical transformation for TF-IDF is:

$$
 \text{tfidf}(t, d) = \text{tf}(t, d) \times \log\left(\frac{N}{\text{df}(t)}\right)
$$

Where:
- $\text{tf}(t, d)$ is the term frequency of term $t$ in document $d$
- $\text{df}(t)$ is the document frequency of term $t$ (number of documents containing $t$)
- $N$ is the total number of documents

This transformation gives higher weight to terms that are frequent in a document but rare across the corpus, helping to identify distinctive topics.


In [3]:

def preprocess_text_data(text_data, max_features=10000, use_tfidf=True):
    """
    Preprocess text data for topic modeling.
    
    Args:
        text_data: List of text samples
        max_features: Maximum number of features to extract
        use_tfidf: Whether to use TF-IDF or simple count vectorization
    
    Returns:
        feature_matrix: Document-term matrix
        vectorizer: Fitted vectorizer object
    """
    print("Preprocessing text data...")
    
    # Clean and normalize text
    def clean_text(text):
        if isinstance(text, str):
            return text.lower().strip()
        return ""
    
    cleaned_texts = [clean_text(text) for text in text_data]
    
    # Choose vectorizer based on configuration
    if use_tfidf:
        vectorizer = TfidfVectorizer(
            max_features=max_features,
            stop_words='english',
            min_df=5,  # Ignore terms that appear in fewer than 5 documents
            max_df=0.95,  # Ignore terms that appear in more than 95% of documents
            ngram_range=(1, 2)  # Use both unigrams and bigrams
        )
    else:
        vectorizer = CountVectorizer(
            max_features=max_features,
            stop_words='english',
            min_df=5,
            max_df=0.95,
            ngram_range=(1, 2)
        )
    
    # Transform text to feature matrix
    feature_matrix = vectorizer.fit_transform(cleaned_texts)
    
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Vocabulary size: {len(vectorizer.vocabulary_)}")
    
    return feature_matrix, vectorizer



### 2.2 Image Feature Extraction

For image datasets, we need to extract meaningful features using pre-trained deep learning models. The mathematical foundation involves:

1. **Feature extraction**: Using a pre-trained CNN (ResNet-18) to extract high-level features
2. **Dimensionality reduction**: Applying PCA or direct use of penultimate layer features
 
The transformation can be represented as:
 
$$
 \mathbf{f}_i = \text{ResNet-18}_{\text{penultimate}}(\mathbf{x}_i)
$$
 
Where:
 - $\mathbf{x}_i$ is the input image
 - $\mathbf{f}_i \in \mathbb{R}^{512}$ is the 512-dimensional feature vector
 - $\text{ResNet-18}_{\text{penultimate}}$ represents the ResNet-18 model up to the penultimate layer
 
These features capture semantic information about the image content, which can then be clustered into topics.



In [4]:

class ImageFeatureExtractor(nn.Module):
    """Feature extractor using pre-trained ResNet-18."""
    
    def __init__(self, pretrained=True):
        super().__init__()
        # Use modern weights parameter instead of deprecated pretrained
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.resnet = models.resnet18(weights=weights)
        
        # Remove the final classification layer
        self.features = nn.Sequential(*list(self.resnet.children())[:-1])
        
        # Freeze the model parameters
        for param in self.features.parameters():
            param.requires_grad = False
        
        print("Initialized ResNet-18 feature extractor (frozen)")
    
    def forward(self, x):
        """Extract features from input images."""
        # Handle grayscale images (1 channel) by converting to 3 channels
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        
        features = self.features(x)
        features = torch.flatten(features, 1)
        return features

def extract_image_features(dataset, batch_size=128, max_samples=None):
    """
    Extract features from image dataset using pre-trained ResNet-18.
    
    Args:
        dataset: Hugging Face dataset object
        batch_size: Batch size for feature extraction
        max_samples: Maximum number of samples to process
    
    Returns:
        feature_matrix: numpy array of extracted features
        labels: corresponding labels
    """
    print("Extracting image features...")
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize feature extractor
    feature_extractor = ImageFeatureExtractor(pretrained=True).to(device)
    feature_extractor.eval()
    
    # Setup data loader
    class HFImageDataset(Dataset):
        def __init__(self, hf_dataset, transform=None):
            self.dataset = hf_dataset
            self.transform = transform
            self.image_key = 'image' if 'image' in hf_dataset.features else 'img'
            self.label_key = 'label' if 'label' in hf_dataset.features else 'fine_label'
        
        def __len__(self):
            return len(self.dataset)
        
        def __getitem__(self, idx):
            item = self.dataset[idx]
            image = item[self.image_key]
            label = item[self.label_key]
            
            # Convert PIL image to tensor
            if self.transform:
                image = self.transform(image)
            else:
                image = transforms.ToTensor()(image)
            
            return image, label
    
    # Define transformations
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create dataset and dataloader
    full_dataset = HFImageDataset(dataset, transform=transform)
    
    if max_samples and max_samples < len(full_dataset):
        indices = torch.randperm(len(full_dataset))[:max_samples]
        subset_dataset = torch.utils.data.Subset(full_dataset, indices)
        dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
        print(f"Using {max_samples} samples out of {len(full_dataset)}")
    else:
        dataloader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)
    
    # Extract features
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            images, labels = batch
            images = images.to(device)
            
            features = feature_extractor(images)
            all_features.append(features.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Concatenate all features
    feature_matrix = np.vstack(all_features)
    labels = np.array(all_labels)
    
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Feature dimension: {feature_matrix.shape[1]}")
    
    return feature_matrix, labels




### 3. LDA Topic Modeling Implementation
 
Now we implement the core LDA topic modeling. The mathematical formulation involves:
 
**Dirichlet Prior**: 
$$
 \theta_d \sim \text{Dirichlet}(\alpha)
$$
 
**Topic-Word Distribution**:
$$
 \phi_k \sim \text{Dirichlet}(\beta)
$$
 
**Log-likelihood maximization**:
$$
\mathcal{L}(\theta, \phi) = \sum_{d=1}^D \sum_{n=1}^{N_d} \log \sum_{k=1}^K \theta_{dk} \phi_{kw_{dn}}
$$
 
Where:
 - $D$ is the number of documents
 - $N_d$ is the number of words in document $d$
 - $K$ is the number of topics
 - $\theta_{dk}$ is the probability of topic $k$ in document $d$
 - $\phi_{kw}$ is the probability of word $w$ in topic $k$
 
We use scikit-learn's implementation which uses variational inference to approximate the posterior distribution.


In [5]:
def train_lda_model(feature_matrix, n_topics=50, random_state=42):
    """
    Train LDA model on feature matrix.
    
    Args:
        feature_matrix: Document-term matrix or feature matrix
        n_topics: Number of topics to extract
        random_state: Random seed for reproducibility
    
    Returns:
        lda_model: Trained LDA model
        topic_distributions: Topic distributions for each document
    """
    print(f"\nTraining LDA model with {n_topics} topics...")
    
    # Initialize LDA model
    lda_model = LatentDirichletAllocation(
        n_components=n_topics,
        max_iter=100,
        learning_method='online',
        batch_size=128,
        evaluate_every=10,
        random_state=random_state,
        n_jobs=-1,  # Use all available cores
        verbose=1
    )
    
    # Fit the model
    topic_distributions = lda_model.fit_transform(feature_matrix)
    
    # Normalize topic distributions to sum to 1
    topic_distributions = topic_distributions / topic_distributions.sum(axis=1, keepdims=True)
    
    print(f"LDA model trained successfully")
    print(f"Topic distributions shape: {topic_distributions.shape}")
    
    return lda_model, topic_distributions

def analyze_lda_topics(lda_model, vectorizer=None, top_n=10):
    """
    Analyze and display the top words for each topic.
    
    Args:
        lda_model: Trained LDA model
        vectorizer: Vectorizer used for text data (optional)
        top_n: Number of top words to display per topic
    
    Returns:
        topic_words: Dictionary mapping topic indices to top words
    """
    print("\nAnalyzing LDA topics...")
    
    topic_words = {}
    
    if vectorizer is not None:
        # Get feature names (words)
        feature_names = vectorizer.get_feature_names_out()
        
        for topic_idx, topic in enumerate(lda_model.components_):
            top_indices = topic.argsort()[::-1][:top_n]
            top_words = [feature_names[i] for i in top_indices]
            topic_words[topic_idx] = top_words
            
            print(f"Topic {topic_idx}: {', '.join(top_words)}")
    else:
        print("No vectorizer provided - skipping word analysis for image features")
    
    return topic_words




### 4. Difficulty Score Calculation
 
The core innovation of TMCL is using topic distributions to compute sample difficulty. The primary metric is **Shannon entropy**:
 
$$
 D_{\text{entropy}}(x_i) = H(P) = -\sum_{t=1}^{T} P(t \mid x_i) \log P(t \mid x_i)
$$
 
Where:
 - $P(t \mid x_i)$ is the probability of topic $t$ for sample $x_i$
 - $T$ is the total number of topics
 
**Interpretation**:
 - **Low entropy** ($H(P) \approx 0$): Sample belongs predominantly to one topic → **Easy sample**
 - **High entropy** ($H(P) \approx \log T$): Sample has uniform topic distribution → **Hard sample**
 
 We also implement alternative difficulty metrics :
 
**Max Probability (Purity)**:
$$
 D_{\text{max}}(x_i) = 1 - \max_t P(t \mid x_i)
$$

**Composite Score**:
$$
 D_{\text{comp}}(x_i) = \lambda H(P) + (1 - \lambda) D_{\text{max}}(x_i)
$$


In [6]:

def compute_difficulty_scores(topic_distributions, metric='entropy', lambda_comp=0.5):
    """
    Compute difficulty scores from topic distributions.
    
    Args:
        topic_distributions: Array of topic distributions [n_samples, n_topics]
        metric: Difficulty metric to use ('entropy', 'max_prob', 'composite')
        lambda_comp: Weight for composite metric (only used if metric='composite')
    
    Returns:
        difficulty_scores: Array of difficulty scores [n_samples]
    """
    print(f"\nComputing difficulty scores using {metric} metric...")
    
    if metric == 'entropy':
        # Shannon entropy
        epsilon = 1e-10  # Small constant to avoid log(0)
        entropy = -np.sum(topic_distributions * np.log(topic_distributions + epsilon), axis=1)
        difficulty_scores = entropy
        
    elif metric == 'max_prob':
        # 1 - max probability (higher = more ambiguous)
        max_prob = np.max(topic_distributions, axis=1)
        difficulty_scores = 1 - max_prob
        
    elif metric == 'composite':
        # Composite of entropy and max probability
        epsilon = 1e-10
        entropy = -np.sum(topic_distributions * np.log(topic_distributions + epsilon), axis=1)
        max_prob = np.max(topic_distributions, axis=1)
        max_prob_diff = 1 - max_prob
        
        difficulty_scores = lambda_comp * entropy + (1 - lambda_comp) * max_prob_diff
        
    else:
        raise ValueError(f"Unknown metric: {metric}")
    
    # Normalize scores to [0, 1] range
    difficulty_scores = (difficulty_scores - difficulty_scores.min()) / (difficulty_scores.max() - difficulty_scores.min() + 1e-10)
    
    print(f"Difficulty scores computed. Range: [{difficulty_scores.min():.4f}, {difficulty_scores.max():.4f}]")
    print(f"Mean difficulty: {difficulty_scores.mean():.4f}, Std: {difficulty_scores.std():.4f}")
    
    return difficulty_scores

def visualize_difficulty_distribution(difficulty_scores, dataset_name, metric_name):
    """
    Visualize the distribution of difficulty scores.
    
    Args:
        difficulty_scores: Array of difficulty scores
        dataset_name: Name of the dataset
        metric_name: Name of the difficulty metric
    """
    plt.figure(figsize=(10, 6))
    
    # Plot histogram
    sns.histplot(difficulty_scores, bins=50, kde=True)
    
    plt.title(f'Difficulty Score Distribution - {dataset_name} ({metric_name})')
    plt.xlabel('Difficulty Score')
    plt.ylabel('Frequency')
    plt.grid(True, alpha=0.3)
    
    # Save plot
    output_path = os.path.join(OUTPUT_DIR, f'{dataset_name}_difficulty_{metric_name}.png')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    print(f"Difficulty distribution plot saved to: {output_path}")




### 5. Complete TMCL Pipeline Implementation

Now we integrate all components into a complete pipeline for each dataset type.


In [7]:

def process_text_dataset(dataset_info, n_topics=50, max_features=10000):
    """
    Complete pipeline for processing text datasets for TMCL.
    
    Args:
        dataset_info: Dictionary containing dataset information
        n_topics: Number of topics for LDA
        max_features: Maximum number of text features
    
    Returns:
        results: Dictionary containing all processing results
    """
    dataset_name = dataset_info['name']
    print(f"\n{'='*50}")
    print(f"Processing text dataset: {dataset_name}")
    print(f"{'='*50}")
    
    # Combine train and test for topic modeling (unsupervised)
    all_texts = pd.concat([dataset_info['train'], dataset_info['test']])
    
    # Determine text column based on dataset
    text_column = 'text' if 'text' in all_texts.columns else 'sentence'
    label_column = 'label' if 'label' in all_texts.columns else 'target'
    
    print(f"Text column: {text_column}, Label column: {label_column}")
    
    # 1. Preprocess text data
    feature_matrix, vectorizer = preprocess_text_data(
        all_texts[text_column].tolist(),
        max_features=max_features,
        use_tfidf=True
    )
    
    # 2. Train LDA model
    lda_model, topic_distributions = train_lda_model(
        feature_matrix, 
        n_topics=n_topics,
        random_state=42
    )
    
    # 3. Analyze topics
    topic_words = analyze_lda_topics(lda_model, vectorizer, top_n=10)
    
    # 4. Compute difficulty scores
    difficulty_scores_entropy = compute_difficulty_scores(topic_distributions, metric='entropy')
    difficulty_scores_max = compute_difficulty_scores(topic_distributions, metric='max_prob')
    difficulty_scores_comp = compute_difficulty_scores(topic_distributions, metric='composite', lambda_comp=0.7)
    
    # 5. Visualize distributions
    visualize_difficulty_distribution(difficulty_scores_entropy, dataset_name, 'entropy')
    visualize_difficulty_distribution(difficulty_scores_max, dataset_name, 'max_prob')
    
    # 6. Save results
    results = {
        'dataset_name': dataset_name,
        'feature_matrix': feature_matrix,
        'vectorizer': vectorizer,
        'lda_model': lda_model,
        'topic_distributions': topic_distributions,
        'topic_words': topic_words,
        'difficulty_scores': {
            'entropy': difficulty_scores_entropy,
            'max_prob': difficulty_scores_max,
            'composite': difficulty_scores_comp
        },
        'labels': all_texts[label_column].values if label_column in all_texts.columns else None,
        'sample_ids': range(len(all_texts))
    }
    
    # Save to disk
    output_file = os.path.join(OUTPUT_DIR, f'{dataset_name}_tmcl_results.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(results, f)
    
    print(f"Results saved to: {output_file}")
    
    return results

def process_image_dataset(dataset_info, n_topics=15, max_samples=None):
    """
    Complete pipeline for processing image datasets for TMCL.
    
    Args:
        dataset_info: Dictionary containing dataset information
        n_topics: Number of topics for LDA
        max_samples: Maximum number of samples to process
    
    Returns:
        results: Dictionary containing all processing results
    """
    dataset_name = dataset_info['name']
    print(f"\n{'='*50}")
    print(f"Processing image dataset: {dataset_name}")
    print(f"{'='*50}")
    
    # 1. Extract features from training data
    train_features, train_labels = extract_image_features(
        dataset_info['train'], 
        batch_size=128,
        max_samples=max_samples
    )
    
    # 2. Extract features from test data
    test_features, test_labels = extract_image_features(
        dataset_info['test'], 
        batch_size=128,
        max_samples=max_samples
    )
    
    # 3. Combine features for topic modeling
    all_features = np.vstack([train_features, test_features])
    all_labels = np.concatenate([train_labels, test_labels])
    
    print(f"Combined feature matrix shape: {all_features.shape}")
    
    # 4. Normalize features using MinMaxScaler to ensure non-negative values
    # LDA requires non-negative input data
    scaler = MinMaxScaler()
    all_features_normalized = scaler.fit_transform(all_features)
    print(f"Feature range after scaling: [{all_features_normalized.min():.4f}, {all_features_normalized.max():.4f}]")
    
    # 5. Train LDA model
    lda_model, topic_distributions = train_lda_model(
        all_features_normalized, 
        n_topics=n_topics,
        random_state=42
    )
    
    # 6. Compute difficulty scores
    difficulty_scores_entropy = compute_difficulty_scores(topic_distributions, metric='entropy')
    difficulty_scores_max = compute_difficulty_scores(topic_distributions, metric='max_prob')
    difficulty_scores_comp = compute_difficulty_scores(topic_distributions, metric='composite', lambda_comp=0.7)
    
    # 7. Visualize distributions
    visualize_difficulty_distribution(difficulty_scores_entropy, dataset_name, 'entropy')
    visualize_difficulty_distribution(difficulty_scores_max, dataset_name, 'max_prob')
    
    # 8. Save results
    results = {
        'dataset_name': dataset_name,
        'feature_matrix': all_features_normalized,
        'scaler': scaler,
        'lda_model': lda_model,
        'topic_distributions': topic_distributions,
        'difficulty_scores': {
            'entropy': difficulty_scores_entropy,
            'max_prob': difficulty_scores_max,
            'composite': difficulty_scores_comp
        },
        'labels': all_labels,
        'sample_ids': range(len(all_features))
    }
    
    # Save to disk
    output_file = os.path.join(OUTPUT_DIR, f'{dataset_name}_tmcl_results.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(results, f)
    
    print(f"Results saved to: {output_file}")
    
    return results




### 6. Execute TMCL Pipeline for All Datasets
 
Now we run the complete pipeline for all datasets. This will generate the topic distributions and difficulty scores needed for the curriculum learning experiments.

## Process text datasets

In [8]:
class HFImageDataset(Dataset):
    """Dataset wrapper for Hugging Face image datasets."""
    
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform
        
        # Determine correct keys for different datasets
        if 'img' in hf_dataset.features:
            self.image_key = 'img'
        elif 'image' in hf_dataset.features:
            self.image_key = 'image'
        else:
            raise ValueError("Could not find image key in dataset features")
        
        if 'label' in hf_dataset.features:
            self.label_key = 'label'
        elif 'fine_label' in hf_dataset.features:
            self.label_key = 'fine_label'
        elif 'coarse_label' in hf_dataset.features:
            self.label_key = 'coarse_label'
        else:
            raise ValueError("Could not find label key in dataset features")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[int(idx)]  # Ensure idx is integer
        image = item[self.image_key]
        label = item[self.label_key]
        
        # Convert grayscale images to RGB by repeating the channel
        if image.mode == 'L':  # Grayscale image
            image = image.convert('RGB')
        
        # Convert PIL image to tensor
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.ToTensor()(image)
        
        return image, label

def extract_image_features(dataset, batch_size=128, max_samples=None):
    """
    Extract features from image dataset using pre-trained ResNet-18.
    
    Args:
        dataset: Hugging Face dataset object
        batch_size: Batch size for feature extraction
        max_samples: Maximum number of samples to process
    
    Returns:
        feature_matrix: numpy array of extracted features
        labels: corresponding labels
    """
    print("Extracting image features...")
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize feature extractor
    feature_extractor = ImageFeatureExtractor(pretrained=True).to(device)
    feature_extractor.eval()
    
    # Define transformations - handle grayscale conversion here as backup
    transform = transforms.Compose([
        transforms.Lambda(lambda x: x.convert('RGB') if x.mode == 'L' else x),  # Ensure RGB
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create dataset and dataloader
    full_dataset = HFImageDataset(dataset, transform=transform)
    
    if max_samples and max_samples < len(full_dataset):
        # Use random subset
        indices = torch.randperm(len(full_dataset))[:max_samples]
        subset_dataset = torch.utils.data.Subset(full_dataset, indices.numpy())
        dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False)
        print(f"Using {max_samples} samples out of {len(full_dataset)}")
    else:
        dataloader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)
    
    # Extract features
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting features"):
            images, labels = batch
            images = images.to(device)
            
            features = feature_extractor(images)
            all_features.append(features.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Concatenate all features
    if all_features:
        feature_matrix = np.vstack(all_features)
        labels = np.array(all_labels)
    else:
        raise ValueError("No features extracted - dataset might be empty")
    
    print(f"Feature matrix shape: {feature_matrix.shape}")
    print(f"Feature dimension: {feature_matrix.shape[1]}")
    
    return feature_matrix, labels

In [9]:
# Process text datasets with optimizations (same as before)
text_datasets = ['ag_news', 'imdb']
text_results = {}

# Use smaller sample sizes and reduced features for faster processing
SAMPLE_SIZE_LIMIT = 20000
MAX_FEATURES_TEXT = 2000
N_TOPICS_TEXT = 20

for dataset_name in text_datasets:
    if dataset_name in datasets:
        print(f"\nProcessing {dataset_name} with optimizations...")
        dataset_info = datasets[dataset_name]
        train_df = dataset_info['train']
        test_df = dataset_info['test']
        all_texts = pd.concat([train_df, test_df])
        if len(all_texts) > SAMPLE_SIZE_LIMIT:
            all_texts = all_texts.sample(n=SAMPLE_SIZE_LIMIT, random_state=42)
            print(f"✓ Limited to {SAMPLE_SIZE_LIMIT} samples")
        dataset_info['train'] = all_texts[:len(train_df)]
        dataset_info['test'] = all_texts[len(train_df):]
        results = process_text_dataset(
            dataset_info,
            n_topics=N_TOPICS_TEXT,
            max_features=MAX_FEATURES_TEXT
        )
        text_results[dataset_name] = results

# Process image datasets with optimizations
image_datasets = ['cifar10', 'cifar100', 'mnist', 'fashion_mnist']
image_results = {}

MAX_SAMPLES_IMAGE = 5000
N_TOPICS_IMAGE = 15

for dataset_name in image_datasets:
    if dataset_name in datasets:
        print(f"\nProcessing {dataset_name} with optimizations...")
        max_samples = MAX_SAMPLES_IMAGE if dataset_name in ['cifar10', 'cifar100'] else None
        results = process_image_dataset(
            datasets[dataset_name],
            n_topics=N_TOPICS_IMAGE,
            max_samples=max_samples
        )
        image_results[dataset_name] = results


Processing ag_news with optimizations...
✓ Limited to 20000 samples

Processing text dataset: ag_news
Text column: text, Label column: label
Preprocessing text data...
Feature matrix shape: (20000, 2000)
Vocabulary size: 2000

Training LDA model with 20 topics...
iteration: 1 of max_iter: 100
iteration: 2 of max_iter: 100
iteration: 3 of max_iter: 100
iteration: 4 of max_iter: 100
iteration: 5 of max_iter: 100
iteration: 6 of max_iter: 100
iteration: 7 of max_iter: 100
iteration: 8 of max_iter: 100
iteration: 9 of max_iter: 100
iteration: 10 of max_iter: 100, perplexity: 8630.4607
iteration: 11 of max_iter: 100
iteration: 12 of max_iter: 100
iteration: 13 of max_iter: 100
iteration: 14 of max_iter: 100
iteration: 15 of max_iter: 100
iteration: 16 of max_iter: 100
iteration: 17 of max_iter: 100
iteration: 18 of max_iter: 100
iteration: 19 of max_iter: 100
iteration: 20 of max_iter: 100, perplexity: 8620.9762
iteration: 21 of max_iter: 100
iteration: 22 of max_iter: 100
iteration: 23 of

Extracting features: 100%|██████████| 40/40 [01:12<00:00,  1.82s/it]


Feature matrix shape: (5000, 512)
Feature dimension: 512
Extracting image features...
Using device: cpu
Initialized ResNet-18 feature extractor (frozen)
Using 5000 samples out of 10000


Extracting features: 100%|██████████| 40/40 [01:13<00:00,  1.85s/it]


Feature matrix shape: (5000, 512)
Feature dimension: 512
Combined feature matrix shape: (10000, 512)
Feature range after scaling: [0.0000, 1.0000]

Training LDA model with 15 topics...
iteration: 1 of max_iter: 100
iteration: 2 of max_iter: 100
iteration: 3 of max_iter: 100
iteration: 4 of max_iter: 100
iteration: 5 of max_iter: 100
iteration: 6 of max_iter: 100
iteration: 7 of max_iter: 100
iteration: 8 of max_iter: 100
iteration: 9 of max_iter: 100
iteration: 10 of max_iter: 100, perplexity: 522.3162
iteration: 11 of max_iter: 100
iteration: 12 of max_iter: 100
iteration: 13 of max_iter: 100
iteration: 14 of max_iter: 100
iteration: 15 of max_iter: 100
iteration: 16 of max_iter: 100
iteration: 17 of max_iter: 100
iteration: 18 of max_iter: 100
iteration: 19 of max_iter: 100
iteration: 20 of max_iter: 100, perplexity: 522.3400
LDA model trained successfully
Topic distributions shape: (10000, 15)

Computing difficulty scores using entropy metric...
Difficulty scores computed. Range: [0

Extracting features: 100%|██████████| 40/40 [01:15<00:00,  1.88s/it]


Feature matrix shape: (5000, 512)
Feature dimension: 512
Extracting image features...
Using device: cpu
Initialized ResNet-18 feature extractor (frozen)
Using 5000 samples out of 10000


Extracting features: 100%|██████████| 40/40 [01:10<00:00,  1.77s/it]


Feature matrix shape: (5000, 512)
Feature dimension: 512
Combined feature matrix shape: (10000, 512)
Feature range after scaling: [0.0000, 1.0000]

Training LDA model with 15 topics...
iteration: 1 of max_iter: 100
iteration: 2 of max_iter: 100
iteration: 3 of max_iter: 100
iteration: 4 of max_iter: 100
iteration: 5 of max_iter: 100
iteration: 6 of max_iter: 100
iteration: 7 of max_iter: 100
iteration: 8 of max_iter: 100
iteration: 9 of max_iter: 100
iteration: 10 of max_iter: 100, perplexity: 536.0614
iteration: 11 of max_iter: 100
iteration: 12 of max_iter: 100
iteration: 13 of max_iter: 100
iteration: 14 of max_iter: 100
iteration: 15 of max_iter: 100
iteration: 16 of max_iter: 100
iteration: 17 of max_iter: 100
iteration: 18 of max_iter: 100
iteration: 19 of max_iter: 100
iteration: 20 of max_iter: 100, perplexity: 535.7940
iteration: 21 of max_iter: 100
iteration: 22 of max_iter: 100
iteration: 23 of max_iter: 100
iteration: 24 of max_iter: 100
iteration: 25 of max_iter: 100
itera

Extracting features: 100%|██████████| 469/469 [14:04<00:00,  1.80s/it]


Feature matrix shape: (60000, 512)
Feature dimension: 512
Extracting image features...
Using device: cpu
Initialized ResNet-18 feature extractor (frozen)


Extracting features: 100%|██████████| 79/79 [02:20<00:00,  1.78s/it]

Feature matrix shape: (10000, 512)
Feature dimension: 512
Combined feature matrix shape: (70000, 512)
Feature range after scaling: [0.0000, 1.0000]

Training LDA model with 15 topics...





iteration: 1 of max_iter: 100
iteration: 2 of max_iter: 100
iteration: 3 of max_iter: 100
iteration: 4 of max_iter: 100
iteration: 5 of max_iter: 100
iteration: 6 of max_iter: 100
iteration: 7 of max_iter: 100
iteration: 8 of max_iter: 100
iteration: 9 of max_iter: 100
iteration: 10 of max_iter: 100, perplexity: 443.1813
iteration: 11 of max_iter: 100
iteration: 12 of max_iter: 100
iteration: 13 of max_iter: 100
iteration: 14 of max_iter: 100
iteration: 15 of max_iter: 100
iteration: 16 of max_iter: 100
iteration: 17 of max_iter: 100
iteration: 18 of max_iter: 100
iteration: 19 of max_iter: 100
iteration: 20 of max_iter: 100, perplexity: 443.2361
LDA model trained successfully
Topic distributions shape: (70000, 15)

Computing difficulty scores using entropy metric...
Difficulty scores computed. Range: [0.0000, 1.0000]
Mean difficulty: 0.4061, Std: 0.2187

Computing difficulty scores using max_prob metric...
Difficulty scores computed. Range: [0.0000, 1.0000]
Mean difficulty: 0.4283, St

Extracting features: 100%|██████████| 469/469 [14:03<00:00,  1.80s/it]


Feature matrix shape: (60000, 512)
Feature dimension: 512
Extracting image features...
Using device: cpu
Initialized ResNet-18 feature extractor (frozen)


Extracting features: 100%|██████████| 79/79 [02:20<00:00,  1.78s/it]

Feature matrix shape: (10000, 512)
Feature dimension: 512
Combined feature matrix shape: (70000, 512)
Feature range after scaling: [0.0000, 1.0000]

Training LDA model with 15 topics...





iteration: 1 of max_iter: 100
iteration: 2 of max_iter: 100
iteration: 3 of max_iter: 100
iteration: 4 of max_iter: 100
iteration: 5 of max_iter: 100
iteration: 6 of max_iter: 100
iteration: 7 of max_iter: 100
iteration: 8 of max_iter: 100
iteration: 9 of max_iter: 100
iteration: 10 of max_iter: 100, perplexity: 452.1410
iteration: 11 of max_iter: 100
iteration: 12 of max_iter: 100
iteration: 13 of max_iter: 100
iteration: 14 of max_iter: 100
iteration: 15 of max_iter: 100
iteration: 16 of max_iter: 100
iteration: 17 of max_iter: 100
iteration: 18 of max_iter: 100
iteration: 19 of max_iter: 100
iteration: 20 of max_iter: 100, perplexity: 452.0477
LDA model trained successfully
Topic distributions shape: (70000, 15)

Computing difficulty scores using entropy metric...
Difficulty scores computed. Range: [0.0000, 1.0000]
Mean difficulty: 0.2874, Std: 0.2172

Computing difficulty scores using max_prob metric...
Difficulty scores computed. Range: [0.0000, 1.0000]
Mean difficulty: 0.3122, St

### 7. Analysis and Validation
 
Let's perform some basic validation of the difficulty scores by checking their correlation with initial training loss. This validates our hypothesis that topic entropy correlates with sample difficulty.


In [10]:

def validate_difficulty_scores(dataset_name, results, sample_size=1000):
    """
    Validate difficulty scores by correlating with initial training loss.
    
    Args:
        dataset_name: Name of the dataset
        results: Processing results containing difficulty scores
        sample_size: Number of samples to use for validation
    
    Returns:
        correlation_results: Dictionary of correlation results
    """
    print(f"\nValidating difficulty scores for {dataset_name}...")
    
    # Get difficulty scores and labels
    difficulty_entropy = results['difficulty_scores']['entropy']
    labels = results['labels']
    
    if labels is None:
        print("No labels available for validation")
        return None
    
    # Sample a subset for validation
    if len(difficulty_entropy) > sample_size:
        indices = np.random.choice(len(difficulty_entropy), sample_size, replace=False)
        difficulty_entropy = difficulty_entropy[indices]
        labels = labels[indices]
    
    # Create simple validation: check if high difficulty samples have more diverse labels
    # This is a proxy validation since we don't have actual training loss yet
    
    # Bin samples by difficulty
    n_bins = 5
    bins = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(difficulty_entropy, bins) - 1
    
    # Calculate label diversity per bin
    bin_diversity = []
    for i in range(n_bins):
        bin_mask = (bin_indices == i)
        if np.sum(bin_mask) > 0:
            bin_labels = labels[bin_mask]
            # Calculate number of unique labels / total samples
            diversity = len(np.unique(bin_labels)) / len(bin_labels)
            bin_diversity.append(diversity)
        else:
            bin_diversity.append(0)
    
    # Plot label diversity vs difficulty
    plt.figure(figsize=(10, 6))
    plt.plot(range(n_bins), bin_diversity, 'o-', linewidth=2, markersize=8)
    plt.title(f'Label Diversity vs Difficulty - {dataset_name}')
    plt.xlabel('Difficulty Bin (0=easy, 4=hard)')
    plt.ylabel('Label Diversity (Unique labels / Total samples)')
    plt.grid(True, alpha=0.3)
    
    output_path = os.path.join(OUTPUT_DIR, f'{dataset_name}_diversity_validation.png')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.close()
    
    print(f"Validation plot saved to: {output_path}")
    print(f"Label diversity by difficulty bin: {bin_diversity}")
    
    return {
        'dataset_name': dataset_name,
        'bin_diversity': bin_diversity,
        'correlation_plot': output_path
    }

# Run validation for a few datasets
validation_results = {}
for dataset_name in ['ag_news', 'cifar10' , 'imdb' , 'cifar100', 'mnist', 'fashion_mnist' ]:
    if dataset_name in text_results:
        validation_results[dataset_name] = validate_difficulty_scores(dataset_name, text_results[dataset_name])
    elif dataset_name in image_results:
        validation_results[dataset_name] = validate_difficulty_scores(dataset_name, image_results[dataset_name])




Validating difficulty scores for ag_news...
Validation plot saved to: results/tmcl/topic_models/ag_news_diversity_validation.png
Label diversity by difficulty bin: [0.3, 0.05194805194805195, 0.015209125475285171, 0.0073937153419593345, 0.03669724770642202]

Validating difficulty scores for cifar10...
Validation plot saved to: results/tmcl/topic_models/cifar10_diversity_validation.png
Label diversity by difficulty bin: [0.04784688995215311, 0.03076923076923077, 0.03184713375796178, 0.06622516556291391, 1.0]

Validating difficulty scores for imdb...
Validation plot saved to: results/tmcl/topic_models/imdb_diversity_validation.png
Label diversity by difficulty bin: [0.005277044854881266, 0.0038684719535783366, 0.024096385542168676, 0.1, 1.0]

Validating difficulty scores for cifar100...
Validation plot saved to: results/tmcl/topic_models/cifar100_diversity_validation.png
Label diversity by difficulty bin: [0.3106796116504854, 0.33584905660377357, 0.304635761589404, 0.3619047619047619, 1.

In [None]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
import os
import pickle
from sklearn.feature_extraction.text import CountVectorizer

def create_interactive_topic_network(dataset_name, results, feature_names=None, top_n_words=10):
    """
    Create interactive topic network visualization with actual topic labels
    
    Args:
        dataset_name: Name of the dataset
        results: Processing results containing LDA model and topic distributions
        feature_names: Vocabulary for text datasets
        top_n_words: Number of top words to show for each topic
    """
    
    print(f"\nCreating interactive topic network for {dataset_name}...")
    
    # Extract components from results
    lda_model = results['lda_model']
    topic_distributions = results['topic_distributions']
    difficulty_scores = results['difficulty_scores']
    
    # Get topic components
    topic_vectors = lda_model.components_  # Shape: n_topics x n_features
    topic_similarity = cosine_similarity(topic_vectors)
    
    # Calculate topic weights
    topic_weights = np.sum(topic_distributions, axis=0)
    topic_weights_norm = topic_weights / np.sum(topic_weights)
    
    # Create topic relationship graph
    G = nx.Graph()
    
    # Get actual topic labels (top words for each topic)
    topic_labels = []
    topic_full_labels = []  # Full label with all top words
    topic_words_list = []   # List of top words for each topic
    
    for i in range(topic_vectors.shape[0]):
        # Get top words for this topic
        if feature_names is not None and len(feature_names) > 0:
            topic_vector = topic_vectors[i]
            # Get indices of top words for this topic
            top_indices = np.argsort(topic_vector)[-top_n_words:][::-1]
            top_words = []
            for idx in top_indices:
                if idx < len(feature_names):
                    if isinstance(feature_names[idx], (int, np.integer)):
                        # If it's an integer, it might be an index
                        top_words.append(f"word_{idx}")
                    else:
                        top_words.append(str(feature_names[idx]))
                else:
                    top_words.append(f"word_{idx}")
            
            # Create short label (first 3 words)
            short_label = f"T{i}: {', '.join(top_words[:3])}"
            # Create full label
            full_label = f"Topic {i}<br>Top words: {', '.join(top_words)}"
            topic_labels.append(short_label)
            topic_full_labels.append(full_label)
            topic_words_list.append(top_words)
        else:
            # For image datasets or when no feature names
            topic_labels.append(f"Topic {i}")
            topic_full_labels.append(f"Topic {i}")
            topic_words_list.append([])
    
    # Add nodes with topic information
    for i in range(topic_vectors.shape[0]):
        # Calculate average difficulty for documents where this topic is dominant
        topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
        if np.sum(topic_doc_mask) > 0:
            avg_difficulty = np.mean(difficulty_scores['entropy'][topic_doc_mask])
            avg_composite = np.mean(difficulty_scores['composite'][topic_doc_mask])
        else:
            avg_difficulty = 0.5
            avg_composite = 0.5
        
        # Node size based on topic weight
        node_size = topic_weights_norm[i] * 100 + 10
        
        # Number of documents for this topic
        n_docs = np.sum(topic_doc_mask)
        
        # Add to graph
        G.add_node(i,
                   label=topic_labels[i],
                   full_label=topic_full_labels[i],
                   top_words=topic_words_list[i],
                   size=node_size,
                   avg_difficulty=avg_difficulty,
                   avg_composite=avg_composite,
                   weight=topic_weights_norm[i],
                   n_docs=n_docs,
                   topic_id=i)
    
    # Add edges based on similarity
    if len(topic_similarity) > 1:
        # Flatten similarity matrix excluding diagonal
        flat_similarities = topic_similarity[np.triu_indices_from(topic_similarity, k=1)]
        if len(flat_similarities) > 0:
            threshold = np.percentile(flat_similarities, 75)
        else:
            threshold = 0.5
    else:
        threshold = 0.5
    
    edges = []
    edge_weights = []
    
    for i in range(topic_vectors.shape[0]):
        for j in range(i+1, topic_vectors.shape[0]):
            if topic_similarity[i, j] > threshold:
                G.add_edge(i, j, weight=topic_similarity[i, j])
                edges.append((i, j))
                edge_weights.append(topic_similarity[i, j])
    
    # Use spring layout for node positions
    if len(G.nodes()) > 0:
        pos = nx.spring_layout(G, k=1.5, iterations=50, seed=42)
    else:
        print(f"  Warning: No nodes in graph for {dataset_name}")
        return None, None
    
    # Create edge traces
    edge_x = []
    edge_y = []
    edge_hovertext = []
    
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        
        # Get edge weight
        weight = G[edge[0]][edge[1]]['weight']
        
        # Edge hover text
        edge_hovertext.append(
            f"Topic {edge[0]} ↔ Topic {edge[1]}<br>"
            f"Similarity: {weight:.3f}"
        )
    
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=1.5, color='#888'),
        hoverinfo='none',
        mode='lines',
        showlegend=False
    )
    
    # Create node traces
    node_x = []
    node_y = []
    node_text = []
    node_hovertext = []
    node_sizes = []
    node_colors = []
    
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)
        
        # Get node attributes
        node_data = G.nodes[node]
        node_text.append(node_data['label'])
        
        # Create hover text with topic words
        if node_data['top_words']:
            words_str = "<br>".join([f"• {word}" for word in node_data['top_words']])
        else:
            words_str = "No words available"
        
        hover_text = (
            f"<b>Topic {node_data['topic_id']}</b><br>"
            f"{words_str}<br><br>"
            f"Average Difficulty (Entropy): {node_data['avg_difficulty']:.3f}<br>"
            f"Average Difficulty (Composite): {node_data['avg_composite']:.3f}<br>"
            f"Topic Weight: {node_data['weight']:.3f}<br>"
            f"Number of Documents: {node_data['n_docs']}"
        )
        node_hovertext.append(hover_text)
        
        # Node size based on topic weight
        node_sizes.append(node_data['size'])
        
        # Node color based on average difficulty
        node_colors.append(node_data['avg_composite'])
    
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=node_text,
        textposition="top center",
        hoverinfo='text',
        hovertext=node_hovertext,
        marker=dict(
            showscale=True,
            colorscale='RdYlBu_r',  
            color=node_colors,
            size=node_sizes,
            colorbar=dict(
                thickness=15,
                title='Average Difficulty<br>(Composite)',
                xanchor='left',
                title_side='right'
            ),
            line=dict(width=2, color='white')
        ),
        showlegend=False
    )
    
    fig = go.Figure(data=[edge_trace, node_trace],
                   layout=go.Layout(
                       title=dict(
                           text=f'Interactive Topic Network - {dataset_name}<br>'
                                f'Node size ∝ Topic weight, Color ∝ Average difficulty',
                           font=dict(size=16)
                       ),
                       showlegend=False,
                       hovermode='closest',
                       margin=dict(b=20, l=5, r=5, t=40),
                       annotations=[dict(
                           text=f"Network Statistics:<br>"
                               f"Nodes: {G.number_of_nodes()}<br>"
                               f"Edges: {G.number_of_edges()}<br>"
                               f"Edge threshold: {threshold:.3f}",
                           showarrow=False,
                           xref="paper", yref="paper",
                           x=0.02, y=0.02,
                           bgcolor="white",
                           bordercolor="black",
                           borderwidth=1,
                           borderpad=4
                       )],
                       xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                       yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
                   ))
    
    output_path = f"results/tmcl/topic_models/{dataset_name}_interactive_network.html"
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    fig.write_html(output_path)
    print(f"  Interactive network saved to: {output_path}")
    
    # create a static version if kaleido is available
    try:
        fig.update_layout(
            width=1200,
            height=800,
            font=dict(size=12)
        )
        static_output_path = f"results/tmcl/topic_models/{dataset_name}_interactive_network_static.png"
        fig.write_image(static_output_path)
        print(f"  Static version saved to: {static_output_path}")
    except Exception as e:
        print(f"  Note: Could not save static image. Install kaleido: pip install kaleido")
    
    return fig, G

def create_interactive_topic_dashboard(dataset_name, results, feature_names=None):
    """
    Create comprehensive interactive dashboard for topic analysis
    
    Args:
        dataset_name: Name of the dataset
        results: Processing results
        feature_names: Vocabulary for text datasets
    """
    
    print(f"\nCreating interactive dashboard for {dataset_name}...")
    
    # Extract components
    lda_model = results['lda_model']
    topic_distributions = results['topic_distributions']
    difficulty_scores = results['difficulty_scores']
    labels = results.get('labels', None)
    
    # Calculate topic statistics
    topic_vectors = lda_model.components_
    topic_weights = np.sum(topic_distributions, axis=0)
    topic_weights_norm = topic_weights / np.sum(topic_weights)
    
    # Create DataFrame for analysis
    analysis_data = []
    for i in range(topic_vectors.shape[0]):
        topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
        n_docs = np.sum(topic_doc_mask)
        
        if n_docs > 0:
            avg_entropy = np.mean(difficulty_scores['entropy'][topic_doc_mask])
            avg_max_prob = np.mean(difficulty_scores['max_prob'][topic_doc_mask])
            avg_composite = np.mean(difficulty_scores['composite'][topic_doc_mask])
            topic_strength = np.mean(topic_distributions[topic_doc_mask, i])
            
            # Get top words
            top_words = []
            if feature_names is not None and len(feature_names) > 0:
                topic_vector = topic_vectors[i]
                top_indices = np.argsort(topic_vector)[-10:][::-1]
                for idx in top_indices:
                    if idx < len(feature_names):
                        if isinstance(feature_names[idx], (int, np.integer)):
                            top_words.append(f"word_{idx}")
                        else:
                            top_words.append(str(feature_names[idx]))
                    else:
                        top_words.append(f"word_{idx}")
            
            analysis_data.append({
                'Topic': f'T{i}',
                'Topic_ID': i,
                'N_Docs': n_docs,
                'Doc_Proportion': n_docs / len(topic_distributions),
                'Avg_Entropy_Difficulty': avg_entropy,
                'Avg_MaxProb_Difficulty': avg_max_prob,
                'Avg_Composite_Difficulty': avg_composite,
                'Topic_Strength': topic_strength,
                'Topic_Weight': topic_weights_norm[i],
                'Top_Words': ', '.join(top_words) if top_words else 'N/A'
            })
    
    analysis_df = pd.DataFrame(analysis_data)
    
    # Create interactive dashboard with multiple plots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            'Topic Difficulty Distribution',
            'Topic Weight vs Difficulty',
            'Difficulty Metrics by Topic',
            'Topic Document Distribution'
        ),
        specs=[
            [{'type': 'scatter'}, {'type': 'scatter'}],
            [{'type': 'bar'}, {'type': 'pie'}]
        ],
        vertical_spacing=0.12,
        horizontal_spacing=0.1
    )
    
    # Plot 1: Topic Difficulty Distribution (scatter plot)
    fig.add_trace(
        go.Scatter(
            x=analysis_df['Topic'],
            y=analysis_df['Avg_Composite_Difficulty'],
            mode='markers+text',
            text=analysis_df['Topic'],
            textposition='top center',
            marker=dict(
                size=analysis_df['Topic_Weight'] * 100,
                color=analysis_df['Avg_Composite_Difficulty'],
                colorscale='RdYlBu_r',
                showscale=True,
                colorbar=dict(
                    title='Difficulty',
                    x=0.45,
                    y=0.95,
                    len=0.3
                ),
                line=dict(width=2, color='DarkSlateGrey')
            ),
            customdata=analysis_df[['Top_Words', 'N_Docs', 'Topic_Weight']].values,
            hovertemplate=(
                '<b>Topic: %{x}</b><br>'
                'Difficulty: %{y:.3f}<br>'
                'Top Words: %{customdata[0]}<br>'
                'Documents: %{customdata[1]}<br>'
                'Weight: %{customdata[2]:.3f}<br>'
                '<extra></extra>'
            ),
            name='Topics'
        ),
        row=1, col=1
    )
    
    # Add difficulty thresholds
    if len(analysis_df) > 0:
        mean_diff = analysis_df['Avg_Composite_Difficulty'].mean()
        std_diff = analysis_df['Avg_Composite_Difficulty'].std()
        
        fig.add_hline(y=mean_diff, line_dash="dash", line_color="gray", 
                      annotation_text=f"Mean: {mean_diff:.3f}", 
                      annotation_position="top right", row=1, col=1)
        fig.add_hline(y=mean_diff + std_diff, line_dash="dot", line_color="red", 
                      annotation_text=f"+1 std", annotation_position="top right", row=1, col=1)
        fig.add_hline(y=mean_diff - std_diff, line_dash="dot", line_color="green", 
                      annotation_text=f"-1 std", annotation_position="top right", row=1, col=1)
    
    # Plot 2: Topic Weight vs Difficulty (bubble chart)
    if len(analysis_df) > 0:
        fig.add_trace(
            go.Scatter(
                x=analysis_df['Topic_Weight'],
                y=analysis_df['Avg_Composite_Difficulty'],
                mode='markers+text',
                text=analysis_df['Topic'],
                textposition='top center',
                marker=dict(
                    size=analysis_df['N_Docs'] / max(analysis_df['N_Docs']) * 50 + 10 if max(analysis_df['N_Docs']) > 0 else 20,
                    color=analysis_df['Avg_Composite_Difficulty'],
                    colorscale='Viridis',
                    showscale=False
                ),
                customdata=analysis_df[['Top_Words', 'N_Docs', 'Topic_ID']].values,
                hovertemplate=(
                    '<b>Topic: %{text}</b><br>'
                    'Weight: %{x:.3f}<br>'
                    'Difficulty: %{y:.3f}<br>'
                    'Top Words: %{customdata[0]}<br>'
                    'Documents: %{customdata[1]}<br>'
                    'ID: %{customdata[2]}<br>'
                    '<extra></extra>'
                ),
                name='Weight vs Difficulty'
            ),
            row=1, col=2
        )
        
        # Add trend line
        if len(analysis_df) > 1:
            z = np.polyfit(analysis_df['Topic_Weight'], analysis_df['Avg_Composite_Difficulty'], 1)
            p = np.poly1d(z)
            x_range = np.linspace(analysis_df['Topic_Weight'].min(), analysis_df['Topic_Weight'].max(), 100)
            fig.add_trace(
                go.Scatter(
                    x=x_range,
                    y=p(x_range),
                    mode='lines',
                    line=dict(color='red', dash='dash'),
                    name='Trend',
                    showlegend=False
                ),
                row=1, col=2
            )
    
    # Plot 3: Difficulty Metrics by Topic (grouped bar chart)
    if len(analysis_df) > 0:
        for i, metric in enumerate(['Avg_Entropy_Difficulty', 'Avg_MaxProb_Difficulty', 'Avg_Composite_Difficulty']):
            fig.add_trace(
                go.Bar(
                    x=analysis_df['Topic'],
                    y=analysis_df[metric],
                    name=metric.replace('Avg_', '').replace('_', ' '),
                    marker_color=px.colors.qualitative.Set2[i],
                    hovertemplate=(
                        '<b>Topic: %{x}</b><br>'
                        'Metric: %{fullData.name}<br>'
                        'Value: %{y:.3f}<br>'
                        '<extra></extra>'
                    ),
                    showlegend=True if i == 0 else False
                ),
                row=2, col=1
            )
    
    # Plot 4: Topic Document Distribution (pie chart)
    if len(analysis_df) > 0:
        fig.add_trace(
            go.Pie(
                labels=analysis_df['Topic'],
                values=analysis_df['N_Docs'],
                textinfo='label+percent',
                textposition='inside',
                hole=0.4,
                marker=dict(
                    colors=px.colors.qualitative.Plotly,
                    line=dict(color='white', width=2)
                ),
                hovertemplate=(
                    '<b>Topic: %{label}</b><br>'
                    'Documents: %{value}<br>'
                    'Percentage: %{percent}<br>'
                    '<extra></extra>'
                ),
                name='Document Distribution'
            ),
            row=2, col=2
        )
    
    # Update layout with corrected title property
    fig.update_layout(
        title=dict(
            text=f'Interactive Topic Analysis Dashboard - {dataset_name}',
            font=dict(size=20)
        ),
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.02,
            bgcolor='rgba(255, 255, 255, 0.8)',
            bordercolor='black',
            borderwidth=1
        ),
        hovermode='closest',
        height=1000,
        width=1400,
        template='plotly_white'
    )
    
    # Update axes labels
    fig.update_xaxes(title_text="Topic", row=1, col=1)
    fig.update_yaxes(title_text="Composite Difficulty", row=1, col=1)
    fig.update_xaxes(title_text="Topic Weight", row=1, col=2)
    fig.update_yaxes(title_text="Composite Difficulty", row=1, col=2)
    fig.update_xaxes(title_text="Topic", row=2, col=1)
    fig.update_yaxes(title_text="Difficulty Score", row=2, col=1)
    
    # Save dashboard
    dashboard_path = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.html"
    os.makedirs(os.path.dirname(dashboard_path), exist_ok=True)
    fig.write_html(dashboard_path)
    print(f"  Interactive dashboard saved to: {dashboard_path}")
    
    # Also save static version if kaleido is available
    try:
        static_dashboard_path = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.png"
        fig.write_image(static_dashboard_path, width=1400, height=1000)
        print(f"  Static dashboard saved to: {static_dashboard_path}")
    except Exception as e:
        print(f"  Note: Could not save static image. Install kaleido: pip install kaleido")
    
    return fig

def create_topic_evolution_timeline(dataset_name, results, feature_names=None, n_samples_per_topic=100):
    """
    Create interactive timeline showing topic evolution and difficulty
    
    Args:
        dataset_name: Name of the dataset
        results: Processing results
        feature_names: Vocabulary for text datasets
        n_samples_per_topic: Number of samples to show per topic
    """
    
    print(f"\nCreating topic evolution timeline for {dataset_name}...")
    
    # Extract components
    lda_model = results['lda_model']
    topic_distributions = results['topic_distributions']
    difficulty_scores = results['difficulty_scores']
    
    # Get topic components
    topic_vectors = lda_model.components_
    n_topics = topic_vectors.shape[0]
    
    # Create figure
    fig = go.Figure()
    
    # Get top words for each topic for hover text
    topic_words = []
    for i in range(n_topics):
        if feature_names is not None and len(feature_names) > 0:
            topic_vector = topic_vectors[i]
            top_indices = np.argsort(topic_vector)[-5:][::-1]
            words = []
            for idx in top_indices:
                if idx < len(feature_names):
                    if isinstance(feature_names[idx], (int, np.integer)):
                        words.append(f"word_{idx}")
                    else:
                        words.append(str(feature_names[idx]))
                else:
                    words.append(f"word_{idx}")
            topic_words.append(words)
        else:
            topic_words.append([])
    
    # For each topic, create a timeline trace
    colors = px.colors.qualitative.Plotly
    
    for i in range(n_topics):
        # Get indices of documents where this topic is dominant
        topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
        
        if np.sum(topic_doc_mask) > 0:
            # Get subset of documents for this topic
            topic_doc_indices = np.where(topic_doc_mask)[0]
            
            if len(topic_doc_indices) > n_samples_per_topic:
                # Sample documents
                sampled_indices = np.random.choice(topic_doc_indices, n_samples_per_topic, replace=False)
            else:
                sampled_indices = topic_doc_indices
            
            # Get topic strengths and difficulties for these documents
            topic_strengths = topic_distributions[sampled_indices, i]
            doc_difficulties = difficulty_scores['composite'][sampled_indices]
            
            # Sort by topic strength
            sort_idx = np.argsort(topic_strengths)
            topic_strengths = topic_strengths[sort_idx]
            doc_difficulties = doc_difficulties[sort_idx]
            
            # Create hover text
            hover_texts = []
            for j, idx in enumerate(sampled_indices[sort_idx]):
                hover_text = (
                    f"<b>Topic {i}</b><br>"
                    f"Document {idx}<br>"
                    f"Topic Strength: {topic_strengths[j]:.3f}<br>"
                    f"Difficulty: {doc_difficulties[j]:.3f}"
                )
                if topic_words[i]:
                    hover_text += f"<br>Top words: {', '.join(topic_words[i])}"
                hover_texts.append(hover_text)
            
            # Add trace for this topic
            fig.add_trace(go.Scatter(
                x=topic_strengths,
                y=doc_difficulties,
                mode='markers',
                name=f'Topic {i}',
                text=hover_texts,
                hoverinfo='text',
                marker=dict(
                    size=8,
                    opacity=0.6,
                    color=colors[i % len(colors)],
                    line=dict(width=1, color='white')
                ),
                showlegend=True
            ))
    
    # Update layout with corrected title property
    fig.update_layout(
        title=dict(
            text=f'Topic Evolution Timeline - {dataset_name}',
            font=dict(size=16)
        ),
        xaxis_title='Topic Strength (How dominant the topic is)',
        yaxis_title='Document Difficulty (Composite)',
        legend_title='Topics',
        hovermode='closest',
        height=800,
        width=1200,
        template='plotly_white'
    )
    
    # Add trend line for each topic
    for i, trace in enumerate(fig.data):
        topic_strengths = trace.x
        difficulties = trace.y
        
        if len(topic_strengths) > 1:
            # Calculate trend line
            z = np.polyfit(topic_strengths, difficulties, 1)
            p = np.poly1d(z)
            x_range = np.linspace(min(topic_strengths), max(topic_strengths), 100)
            
            fig.add_trace(go.Scatter(
                x=x_range,
                y=p(x_range),
                mode='lines',
                line=dict(color=trace.marker.color, dash='dash', width=1),
                showlegend=False,
                hoverinfo='skip',
                name=f'Trend Topic {i}'
            ))
    
    # Save timeline
    timeline_path = f"results/tmcl/topic_models/{dataset_name}_topic_timeline.html"
    os.makedirs(os.path.dirname(timeline_path), exist_ok=True)
    fig.write_html(timeline_path)
    print(f"  Topic timeline saved to: {timeline_path}")
    
    return fig

def create_interactive_topic_comparison(all_analysis_results):
    """
    Create interactive comparison of all datasets
    
    Args:
        all_analysis_results: Dictionary containing analysis results for all datasets
    """
    
    print(f"\nCreating interactive comparison across all datasets...")
    
    # Collect comparison data
    comparison_data = []
    
    for dataset_name, analysis in all_analysis_results.items():
        if analysis is not None:
            topic_analysis = analysis.get('topic_analysis', pd.DataFrame())
            difficulty_stats = analysis.get('difficulty_stats', {})
            network_stats = analysis.get('network_stats', {})
            
            if not topic_analysis.empty:
                # Get difficulty range
                if 'Avg_Composite_Difficulty' in topic_analysis.columns:
                    difficulty_range = topic_analysis['Avg_Composite_Difficulty'].max() - topic_analysis['Avg_Composite_Difficulty'].min()
                else:
                    difficulty_range = 0
                
                comparison_data.append({
                    'Dataset': dataset_name,
                    'Type': 'Text' if dataset_name in ['ag_news', 'imdb'] else 'Image',
                    'N_Topics': analysis.get('n_topics', 0),
                    'N_Documents': analysis.get('n_documents', 0),
                    'Perplexity': analysis.get('perplexity', 0),
                    'Avg_Difficulty': difficulty_stats.get('overall_mean_composite', 0),
                    'Std_Difficulty': difficulty_stats.get('overall_std_composite', 0),
                    'Network_Density': network_stats.get('density', 0),
                    'Difficulty_Range': difficulty_range
                })
    
    if not comparison_data:
        print("  No comparison data available")
        return None
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # Create interactive comparison dashboard
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=(
            'Average Difficulty by Dataset',
            'Number of Topics vs Documents',
            'Network Density vs Difficulty',
            'Perplexity Comparison',
            'Difficulty Range Comparison',
            'Dataset Clustering'
        ),
        specs=[
            [{'type': 'bar'}, {'type': 'scatter'}, {'type': 'scatter'}],
            [{'type': 'bar'}, {'type': 'bar'}, {'type': 'scatter'}]
        ],
        vertical_spacing=0.15,
        horizontal_spacing=0.1
    )
    
    # Plot 1: Average difficulty by dataset
    colors = px.colors.qualitative.Set2
    for i, (_, row) in enumerate(comparison_df.iterrows()):
        fig.add_trace(
            go.Bar(
                x=[row['Dataset']],
                y=[row['Avg_Difficulty']],
                name=row['Dataset'],
                marker_color=colors[i % len(colors)],
                error_y=dict(
                    type='data',
                    array=[row['Std_Difficulty']],
                    visible=True
                ),
                customdata=[[
                    row['Type'],
                    row['N_Topics'],
                    row['N_Documents'],
                    row['Perplexity']
                ]],
                hovertemplate=(
                    '<b>%{x}</b><br>'
                    'Type: %{customdata[0]}<br>'
                    'Difficulty: %{y:.3f} ± %{error_y.array[0]:.3f}<br>'
                    'Topics: %{customdata[1]}<br>'
                    'Documents: %{customdata[2]:,}<br>'
                    'Perplexity: %{customdata[3]:.2f}<br>'
                    '<extra></extra>'
                ),
                showlegend=False
            ),
            row=1, col=1
        )
    
    # Plot 2: Number of topics vs documents
    if len(comparison_df) > 0:
        fig.add_trace(
            go.Scatter(
                x=comparison_df['N_Topics'],
                y=comparison_df['N_Documents'] / 1000,  # Convert to thousands
                mode='markers+text',
                text=comparison_df['Dataset'],
                textposition='top center',
                marker=dict(
                    size=comparison_df['Avg_Difficulty'] * 50,
                    color=comparison_df['Avg_Difficulty'],
                    colorscale='RdYlBu_r',
                    showscale=True,
                    colorbar=dict(title='Avg Difficulty', x=0.47, y=0.95, len=0.25),
                    line=dict(width=2, color='black')
                ),
                customdata=comparison_df[['Type', 'Perplexity', 'Network_Density']].values,
                hovertemplate=(
                    '<b>%{text}</b><br>'
                    'Topics: %{x}<br>'
                    'Documents (thousands): %{y:.1f}<br>'
                    'Type: %{customdata[0]}<br>'
                    'Perplexity: %{customdata[1]:.2f}<br>'
                    'Network Density: %{customdata[2]:.3f}<br>'
                    '<extra></extra>'
                ),
                name='Topics vs Documents'
            ),
            row=1, col=2
        )
    
    # Plot 3: Network density vs difficulty
    if len(comparison_df) > 0:
        fig.add_trace(
            go.Scatter(
                x=comparison_df['Network_Density'],
                y=comparison_df['Avg_Difficulty'],
                mode='markers+text',
                text=comparison_df['Dataset'],
                textposition='top center',
                marker=dict(
                    size=comparison_df['N_Topics'] * 10,
                    color=comparison_df['Type'].map({'Text': 'blue', 'Image': 'red'}),
                    symbol=comparison_df['Type'].map({'Text': 'circle', 'Image': 'square'}),
                    line=dict(width=2, color='black')
                ),
                customdata=comparison_df[['N_Documents', 'Perplexity', 'Type']].values,
                hovertemplate=(
                    '<b>%{text}</b><br>'
                    'Network Density: %{x:.3f}<br>'
                    'Avg Difficulty: %{y:.3f}<br>'
                    'Type: %{customdata[2]}<br>'
                    'Documents: %{customdata[0]:,}<br>'
                    'Perplexity: %{customdata[1]:.2f}<br>'
                    '<extra></extra>'
                ),
                name='Network vs Difficulty'
            ),
            row=1, col=3
        )
    
    # Plot 4: Perplexity comparison
    if len(comparison_df) > 0:
        fig.add_trace(
            go.Bar(
                x=comparison_df['Dataset'],
                y=comparison_df['Perplexity'],
                marker_color=comparison_df['Type'].map({'Text': 'lightblue', 'Image': 'lightcoral'}),
                customdata=comparison_df[['Type', 'N_Topics', 'Avg_Difficulty']].values,
                hovertemplate=(
                    '<b>%{x}</b><br>'
                    'Perplexity: %{y:.2f}<br>'
                    'Type: %{customdata[0]}<br>'
                    'Topics: %{customdata[1]}<br>'
                    'Avg Difficulty: %{customdata[2]:.3f}<br>'
                    '<extra></extra>'
                ),
                name='Perplexity',
                showlegend=False
            ),
            row=2, col=1
        )
    
    # Plot 5: Difficulty range comparison
    if len(comparison_df) > 0:
        # Normalize difficulty range for coloring
        if comparison_df['Difficulty_Range'].max() > comparison_df['Difficulty_Range'].min():
            norm_range = (comparison_df['Difficulty_Range'] - comparison_df['Difficulty_Range'].min()) / \
                        (comparison_df['Difficulty_Range'].max() - comparison_df['Difficulty_Range'].min())
        else:
            norm_range = pd.Series([0.5] * len(comparison_df))
        
        # Use Plotly colorscale
        colorscale = px.colors.sequential.Viridis
        colors = [colorscale[int(val * (len(colorscale)-1))] for val in norm_range]
        
        fig.add_trace(
            go.Bar(
                x=comparison_df['Dataset'],
                y=comparison_df['Difficulty_Range'],
                marker_color=colors,
                customdata=comparison_df[['Type', 'Avg_Difficulty', 'Std_Difficulty']].values,
                hovertemplate=(
                    '<b>%{x}</b><br>'
                    'Difficulty Range: %{y:.3f}<br>'
                    'Type: %{customdata[0]}<br>'
                    'Avg Difficulty: %{customdata[1]:.3f}<br>'
                    'Std Difficulty: %{customdata[2]:.3f}<br>'
                    '<extra></extra>'
                ),
                name='Difficulty Range',
                showlegend=False
            ),
            row=2, col=2
        )
    
    # Plot 6: Dataset clustering (simplified)
    if len(comparison_df) > 1:
        try:
            from sklearn.preprocessing import StandardScaler
            from sklearn.decomposition import PCA
            
            # Prepare data for PCA
            cluster_data = comparison_df[['Avg_Difficulty', 'Std_Difficulty', 
                                         'Network_Density', 'Difficulty_Range', 
                                         'Perplexity']].values
            cluster_data = StandardScaler().fit_transform(cluster_data)
            
            # Apply PCA
            pca = PCA(n_components=2)
            pca_result = pca.fit_transform(cluster_data)
            
            fig.add_trace(
                go.Scatter(
                    x=pca_result[:, 0],
                    y=pca_result[:, 1],
                    mode='markers+text',
                    text=comparison_df['Dataset'],
                    textposition='top center',
                    marker=dict(
                        size=20,
                        color=comparison_df['Type'].map({'Text': 'blue', 'Image': 'red'}),
                        symbol=comparison_df['Type'].map({'Text': 'circle', 'Image': 'square'}),
                        line=dict(width=2, color='black')
                    ),
                    customdata=comparison_df[['Type', 'N_Topics', 'N_Documents']].values,
                    hovertemplate=(
                        '<b>%{text}</b><br>'
                        'PCA 1: %{x:.2f}<br>'
                        'PCA 2: %{y:.2f}<br>'
                        'Type: %{customdata[0]}<br>'
                        'Topics: %{customdata[1]}<br>'
                        'Documents: %{customdata[2]:,}<br>'
                        '<extra></extra>'
                    ),
                    name='Dataset Clusters'
                ),
                row=2, col=3
            )
            
            # Add variance explained to axis labels
            fig.update_xaxes(
                title_text=f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)',
                row=2, col=3
            )
            fig.update_yaxes(
                title_text=f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)',
                row=2, col=3
            )
        except Exception as e:
            print(f"  Could not create clustering plot: {e}")
            fig.add_annotation(
                text="Clustering not available",
                xref="x domain", yref="y domain",
                x=0.5, y=0.5,
                showarrow=False,
                font=dict(size=14),
                row=2, col=3
            )
    else:
        fig.add_annotation(
            text="Need at least 2 datasets for clustering",
            xref="x domain", yref="y domain",
            x=0.5, y=0.5,
            showarrow=False,
            font=dict(size=14),
            row=2, col=3
        )
    
    # Update layout with corrected title property
    fig.update_layout(
        title=dict(
            text='Interactive Dataset Comparison Dashboard',
            font=dict(size=20)
        ),
        showlegend=False,
        height=1000,
        width=1600,
        template='plotly_white'
    )
    
    # Update axes labels
    fig.update_xaxes(title_text="Dataset", row=1, col=1)
    fig.update_yaxes(title_text="Average Difficulty", row=1, col=1)
    fig.update_xaxes(title_text="Number of Topics", row=1, col=2)
    fig.update_yaxes(title_text="Documents (thousands)", row=1, col=2)
    fig.update_xaxes(title_text="Network Density", row=1, col=3)
    fig.update_yaxes(title_text="Average Difficulty", row=1, col=3)
    fig.update_xaxes(title_text="Dataset", row=2, col=1)
    fig.update_yaxes(title_text="Perplexity", row=2, col=1)
    fig.update_xaxes(title_text="Dataset", row=2, col=2)
    fig.update_yaxes(title_text="Difficulty Range", row=2, col=2)
    
    # Save comparison results for the dashboard
    comparison_path = "results/tmcl/topic_models/interactive_dataset_comparison.html"
    os.makedirs(os.path.dirname(comparison_path), exist_ok=True)
    fig.write_html(comparison_path)
    print(f"  Interactive comparison dashboard saved to: {comparison_path}")
    
    return fig

def create_all_interactive_visualizations():
    """
    Create all interactive visualizations for the analyzed datasets
    """
    
    print("\n" + "="*80)
    print("CREATING INTERACTIVE TOPIC VISUALIZATIONS")
    print("="*80)
    
    # Try to import plotly for image export
    try:
        import plotly.io as pio
        pio.kaleido.scope.default_format = "png"
    except:
        pass
    
    # Load all results
    datasets = ['ag_news', 'imdb', 'cifar10', 'cifar100', 'mnist', 'fashion_mnist']
    all_analysis_results = {}
    
    # Create visualizations for each dataset
    for dataset_name in datasets:
        print(f"\nProcessing {dataset_name}...")
        
        # Load results
        results_path = f"results/tmcl/topic_models/{dataset_name}_tmcl_results.pkl"
        if os.path.exists(results_path):
            try:
                with open(results_path, 'rb') as f:
                    results = pickle.load(f)
                
                print(f"  Successfully loaded results from {results_path}")
                
                # Get feature names for the text datasets
                feature_names = None
                if dataset_name in ['ag_news', 'imdb']:
                    if 'vocabulary' in results:
                        feature_names = results['vocabulary']
                        print(f"  Using saved vocabulary of size {len(feature_names)}")
                    elif 'vectorizer' in results:
                        # Try to get vocabulary from the vectorizer
                        vectorizer = results['vectorizer']
                        if hasattr(vectorizer, 'get_feature_names_out'):
                            feature_names = vectorizer.get_feature_names_out()
                            print(f"  Using vectorizer vocabulary of size {len(feature_names)}")
                        elif hasattr(vectorizer, 'get_feature_names'):
                            feature_names = vectorizer.get_feature_names()
                            print(f"  Using vectorizer vocabulary of size {len(feature_names)}")
                    else:
                        # Try to load from a vocabulary file
                        vocab_path = f"results/tmcl/topic_models/{dataset_name}_vocabulary.pkl"
                        if os.path.exists(vocab_path):
                            with open(vocab_path, 'rb') as f:
                                feature_names = pickle.load(f)
                            print(f"  Loaded vocabulary from file of size {len(feature_names)}")
                        else:
                            # Create a sample vocabulary
                            if 'feature_matrix' in results:
                                n_features = results['feature_matrix'].shape[1]
                            else:
                                n_features = 1000
                            feature_names = [f'feature_{i}' for i in range(n_features)]
                            print(f"  WARNING: Using placeholder feature names (n={n_features})")
                
                # 1. Create interactive topic network
                try:
                    network_fig, network_graph = create_interactive_topic_network(
                        dataset_name, results, feature_names
                    )
                    print(f"  ✓ Created interactive topic network")
                except Exception as e:
                    print(f"  ✗ Error creating network: {e}")
                
                # 2. Create interactive dashboard
                try:
                    dashboard_fig = create_interactive_topic_dashboard(
                        dataset_name, results, feature_names
                    )
                    print(f"  ✓ Created interactive dashboard")
                except Exception as e:
                    print(f"  ✗ Error creating dashboard: {e}")
                
                # 3. Create topic evolution timeline
                try:
                    timeline_fig = create_topic_evolution_timeline(
                        dataset_name, results, feature_names
                    )
                    print(f"  ✓ Created topic evolution timeline")
                except Exception as e:
                    print(f"  ✗ Error creating timeline: {e}")
                
                enhanced_path = f"results/tmcl/topic_models/{dataset_name}_enhanced_analysis.pkl"
                if os.path.exists(enhanced_path):
                    try:
                        with open(enhanced_path, 'rb') as f:
                            all_analysis_results[dataset_name] = pickle.load(f)
                        print(f"  ✓ Loaded enhanced analysis results")
                    except:
                        print(f"  ✗ Could not load enhanced analysis results")
                else:
                    print(f"  Note: Enhanced analysis results not found at {enhanced_path}")
                    
            except Exception as e:
                print(f"  ✗ Error loading results for {dataset_name}: {e}")
        else:
            print(f"  ✗ Results file not found for {dataset_name}")
    
    # 4. Create interactive comparison across all the datasets 
    if all_analysis_results:
        try:
            comparison_fig = create_interactive_topic_comparison(all_analysis_results)
            if comparison_fig:
                print(f"  ✓ Created interactive comparison across datasets")
        except Exception as e:
            print(f"  ✗ Error creating comparison: {e}")
    
    print("\n" + "="*80)
    print("INTERACTIVE VISUALIZATIONS COMPLETE!")
    print("="*80)
    
    # Print summary of the created files
    print("\nCreated interactive visualizations:")
    print("-" * 40)
    for dataset_name in datasets:
        print(f"\n{dataset_name}:")
        network_file = f"results/tmcl/topic_models/{dataset_name}_interactive_network.html"
        if os.path.exists(network_file):
            print(f"  ✓ Network: {network_file}")
        
        dashboard_file = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.html"
        if os.path.exists(dashboard_file):
            print(f"  ✓ Dashboard: {dashboard_file}")
        
        timeline_file = f"results/tmcl/topic_models/{dataset_name}_topic_timeline.html"
        if os.path.exists(timeline_file):
            print(f"  ✓ Timeline: {timeline_file}")
    
    comparison_file = "results/tmcl/topic_models/interactive_dataset_comparison.html"
    if os.path.exists(comparison_file):
        print(f"\nCross-dataset comparison: {comparison_file}")
    
    print("\n" + "="*80)
    print("To view interactive visualizations:")
    print("1. Open the HTML files in a web browser")
    print("2. Hover over nodes/points to see details")
    print("3. Use zoom and pan to explore the visualizations")
    print("="*80)
    
    # Additional instructions for text datasets
    print("\nFor text datasets (ag_news, imdb):")
    print("• Topic labels show T0, T1, etc. with top words")
    print("• Hover over nodes to see all top words for each topic")
    print("• Word indices are shown if actual words are not available")
    print("="*80)

# Execute the interactive visualizations
if __name__ == "__main__":
    create_all_interactive_visualizations()


CREATING INTERACTIVE TOPIC VISUALIZATIONS

Processing ag_news...
  Successfully loaded results from results/tmcl/topic_models/ag_news_tmcl_results.pkl
  Using vectorizer vocabulary of size 2000

Creating interactive topic network for ag_news...
  Interactive network saved to: results/tmcl/topic_models/ag_news_interactive_network.html




Use of plotly.io.kaleido.scope.default_format is deprecated and support will be removed after September 2025.
Please use plotly.io.defaults.default_format instead.




  Static version saved to: results/tmcl/topic_models/ag_news_interactive_network_static.png
  ✓ Created interactive topic network

Creating interactive dashboard for ag_news...
  Interactive dashboard saved to: results/tmcl/topic_models/ag_news_interactive_dashboard.html
  Static dashboard saved to: results/tmcl/topic_models/ag_news_interactive_dashboard.png
  ✓ Created interactive dashboard

Creating topic evolution timeline for ag_news...
  Topic timeline saved to: results/tmcl/topic_models/ag_news_topic_timeline.html
  ✓ Created topic evolution timeline
  Note: Enhanced analysis results not found at results/tmcl/topic_models/ag_news_enhanced_analysis.pkl

Processing imdb...
  Successfully loaded results from results/tmcl/topic_models/imdb_tmcl_results.pkl
  Using vectorizer vocabulary of size 2000

Creating interactive topic network for imdb...
  Interactive network saved to: results/tmcl/topic_models/imdb_interactive_network.html
  Static version saved to: results/tmcl/topic_models


### 8. Summary and Next Steps
 
#### What We've Accomplished
 
1. **Data Loading**: Successfully loaded all 6 datasets (4 vision, 2 NLP)
2. **Feature Extraction**:
    - Text: TF-IDF vectorization with n-grams and stopword removal
    - Images: Deep feature extraction using pre-trained ResNet-18
3. **Topic Modeling**: Applied LDA to learn latent topic structures
4. **Difficulty Scoring**: Computed multiple difficulty metrics based on topic distributions
5. **Validation**: Performed preliminary validation of difficulty scores
 
#### Mathematical Transformations Applied
 
1. **Text Processing**:

$$
  \mathbf{x}_i \rightarrow \text{TF-IDF}(\mathbf{x}_i) \rightarrow \text{LDA}(\text{TF-IDF}(\mathbf{x}_i)) \rightarrow P(t|\mathbf{x}_i)
$$

2. **Image Processing**:

$$
  \mathbf{x}_i \rightarrow \text{ResNet-18}(\mathbf{x}_i) \rightarrow \text{Normalize}(\mathbf{f}_i) \rightarrow \text{LDA}(\mathbf{f}_i) \rightarrow P(t|\mathbf{x}_i)
$$

3. **Difficulty Calculation**:

$$
  D_{\text{entropy}}(\mathbf{x}_i) = -\sum_{t=1}^T P(t|\mathbf{x}_i) \log P(t|\mathbf{x}_i)
$$

#### Convergence Toward TMCL Experiments

The results generated in this notebook provide the foundation for the TMCL experiments described in the research plan:
 
  1. **Curriculum Construction**: The difficulty scores can now be used to sort samples and create easy-to-hard curricula
  2. **Training Comparison**: These topic-modeled difficulties can be compared against heuristic baselines
  3. **Ablation Studies**: Different difficulty metrics (entropy, max_prob, composite) can be evaluated
  4. **Cross-domain Analysis**: Results across vision and NLP domains can be compared
 
#### Next Steps for Complete TMCL Implementation
 
  1. **Neural Network Training**: Implement training loops with curriculum scheduling
  2. **Schedule Functions**: Implement linear, root, and exponential curriculum schedules
  3. **Baseline Comparisons**: Implement heuristic CL and self-paced learning baselines
  4. **Comprehensive Evaluation**: Run full experiments measuring convergence speed, generalization gap, and training stability
  5. **Hyperparameter Analysis**: Study the effect of number of topics, schedule parameters, etc.
 
The topic modeling results saved in the `results/tmcl/topic_models/` directory contain all the necessary information to proceed with these next steps. Each pickle file contains:
  - Topic distributions for every sample
  - Multiple difficulty scores
  - Feature matrices and models for reproducibility
  - Label information for analysis
 
This completes the first major phase of the TMCL research pipeline - creating the unsupervised, data-driven difficulty metric that will guide the curriculum learning process.

#### Unused Code

In [12]:
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# import networkx as nx
# from sklearn.metrics.pairwise import cosine_similarity
# from sklearn.manifold import TSNE
# import warnings
# import os
# import pickle
# from scipy import stats
# warnings.filterwarnings('ignore')

# # For better visualization
# plt.style.use('seaborn-v0_8-darkgrid')
# sns.set_palette("husl")

# def enhanced_topic_analysis(dataset_name, results, feature_names=None, original_texts=None):
#     """
#     Enhanced topic analysis with comprehensive visualizations
    
#     Args:
#         dataset_name: Name of the dataset
#         results: Processing results containing LDA model and topic distributions
#         feature_names: Vocabulary for text datasets
#         original_texts: Original text samples (for text datasets)
#     """
    
#     print(f"\n{'='*60}")
#     print(f"ENHANCED TOPIC ANALYSIS: {dataset_name.upper()}")
#     print(f"{'='*60}")
    
#     # Extract components from results
#     lda_model = results['lda_model']
#     topic_distributions = results['topic_distributions']
#     difficulty_scores = results['difficulty_scores']
#     labels = results.get('labels', None)
    
#     # 1. TOPIC QUALITY ANALYSIS
#     print("\n1. TOPIC QUALITY ANALYSIS")
#     print("-" * 40)
    
#     # Try to get perplexity from the model or results
#     perplexity = None
#     if 'perplexity' in results:
#         perplexity = results['perplexity']
#         print(f"Perplexity: {perplexity:.4f}")
#     elif hasattr(lda_model, 'perplexity'):
#         # Try to calculate perplexity if possible
#         try:
#             if hasattr(lda_model, 'perplexity') and callable(lda_model.perplexity):
#                 if 'feature_matrix' in results:
#                     dtm = results['feature_matrix']
#                     perplexity = lda_model.perplexity(dtm)
#                     print(f"Perplexity: {perplexity:.4f}")
#         except:
#             pass
    
#     # Topic significance (based on topic distribution)
#     topic_weights = np.sum(topic_distributions, axis=0)
#     topic_weights_norm = topic_weights / np.sum(topic_weights)
#     print(f"\nTopic Weight Distribution:")
#     print(f"  Mean weight: {np.mean(topic_weights_norm):.4f}")
#     print(f"  Std weight: {np.std(topic_weights_norm):.4f}")
#     print(f"  Min weight: {np.min(topic_weights_norm):.4f}")
#     print(f"  Max weight: {np.max(topic_weights_norm):.4f}")
    
#     # 2. TOPIC VISUALIZATION WITH pyLDAvis (for text datasets)
#     if feature_names is not None:
#         print("\n2. GENERATING INTERACTIVE TOPIC VISUALIZATION...")
#         try:
#             # Try to import pyLDAvis
#             import pyLDAvis
#             import pyLDAvis.sklearn as ldavis
            
#             # Prepare data for pyLDAvis
#             dtm = results.get('feature_matrix')
#             if dtm is not None and hasattr(lda_model, 'components_'):
#                 # Create visualization data using sklearn's prepare function
#                 vis_data = ldavis.prepare(
#                     lda_model, 
#                     dtm, 
#                     feature_names,
#                     mds='tsne'
#                 )
                
#                 # Save as HTML for interactive visualization
#                 output_path = f"results/tmcl/topic_models/{dataset_name}_lda_vis.html"
#                 pyLDAvis.save_html(vis_data, output_path)
#                 print(f"  Interactive visualization saved to: {output_path}")
                
#                 # Extract topic coordinates for static visualization
#                 topic_coordinates = vis_data.topic_coordinates
                
#                 # Create static visualization of topics
#                 plt.figure(figsize=(12, 10))
#                 plt.scatter(topic_coordinates['x'], topic_coordinates['y'], 
#                           s=topic_weights_norm*5000, alpha=0.7, 
#                           c=range(len(topic_coordinates)), cmap='tab20')
                
#                 # Add topic labels
#                 for i, (x, y) in enumerate(zip(topic_coordinates['x'], topic_coordinates['y'])):
#                     plt.text(x, y, f'T{i}', fontsize=12, ha='center', va='center',
#                            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
                
#                 plt.title(f'Topic Visualization - {dataset_name}', fontsize=16)
#                 plt.xlabel('Dimension 1', fontsize=12)
#                 plt.ylabel('Dimension 2', fontsize=12)
#                 plt.colorbar(label='Topic Index')
#                 plt.grid(True, alpha=0.3)
                
#                 output_path = f"results/tmcl/topic_models/{dataset_name}_topic_scatter.png"
#                 plt.savefig(output_path, bbox_inches='tight', dpi=300)
#                 plt.close()
#                 print(f"  Topic scatter plot saved to: {output_path}")
                
#         except ImportError:
#             print("  pyLDAvis not installed. Skipping interactive visualization.")
#             print("  Install with: pip install pyLDAvis")
#         except Exception as e:
#             print(f"  Could not generate pyLDAvis visualization: {type(e).__name__}: {str(e)}")
    
#     # 3. TOPIC RELATIONSHIP NETWORK
#     print("\n3. ANALYZING TOPIC RELATIONSHIPS...")
    
#     # Calculate topic-topic similarity matrix
#     topic_vectors = lda_model.components_  # Shape: n_topics x n_features
#     topic_similarity = cosine_similarity(topic_vectors)
    
#     # Create topic relationship graph
#     G = nx.Graph()
    
#     # Add nodes with topic information
#     for i in range(topic_vectors.shape[0]):
#         # Calculate average difficulty for documents where this topic is dominant
#         topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
#         if np.sum(topic_doc_mask) > 0:
#             avg_difficulty = np.mean(difficulty_scores['entropy'][topic_doc_mask])
#         else:
#             avg_difficulty = 0.5
        
#         # Node size based on topic weight
#         node_size = topic_weights_norm[i] * 1000 + 300
        
#         G.add_node(f'T{i}', 
#                    size=node_size,
#                    avg_difficulty=avg_difficulty,
#                    weight=topic_weights_norm[i])
    
#     # Add edges based on similarity (only if we have more than 1 topic)
#     if topic_vectors.shape[0] > 1:
#         # Calculate threshold for edges (only consider similarities less than 1 to exclude self-similarity)
#         similarity_values = topic_similarity[topic_similarity < 0.999]  # Exclude self-similarity
#         if len(similarity_values) > 0:
#             threshold = np.percentile(similarity_values, 75)
            
#             for i in range(topic_vectors.shape[0]):
#                 for j in range(i+1, topic_vectors.shape[0]):
#                     if topic_similarity[i, j] > threshold:
#                         G.add_edge(f'T{i}', f'T{j}', 
#                                   weight=topic_similarity[i, j],
#                                   width=topic_similarity[i, j] * 3)
#         else:
#             threshold = 0
#     else:
#         threshold = 0
    
#     # Visualize the network
#     if G.number_of_nodes() > 0:
#         fig, ax = plt.subplots(figsize=(14, 12))
        
#         # Get node positions using spring layout
#         pos = nx.spring_layout(G, k=2, iterations=50)
        
#         # Extract node attributes
#         node_sizes = [G.nodes[n]['size'] for n in G.nodes()]
#         avg_difficulties = [G.nodes[n]['avg_difficulty'] for n in G.nodes()]
        
#         # Normalize difficulties for coloring
#         if len(set(avg_difficulties)) > 1:
#             normalized_difficulties = (avg_difficulties - np.min(avg_difficulties)) / (np.max(avg_difficulties) - np.min(avg_difficulties))
#         else:
#             normalized_difficulties = [0.5] * len(avg_difficulties)
        
#         # Create colormap for difficulties
#         cmap = plt.cm.RdYlBu
#         node_colors = [cmap(d) for d in normalized_difficulties]
        
#         # Extract edge attributes
#         if G.number_of_edges() > 0:
#             edge_weights = [G[u][v]['width'] for u, v in G.edges()]
#         else:
#             edge_weights = []
        
#         # Draw network
#         nx.draw_networkx_nodes(G, pos, 
#                               node_size=node_sizes,
#                               node_color=node_colors,
#                               alpha=0.8,
#                               edgecolors='black',
#                               linewidths=1,
#                               ax=ax)
        
#         if G.number_of_edges() > 0:
#             nx.draw_networkx_edges(G, pos, 
#                                   width=edge_weights,
#                                   alpha=0.3,
#                                   edge_color='gray',
#                                   ax=ax)
        
#         # Add labels
#         nx.draw_networkx_labels(G, pos, 
#                                font_size=10,
#                                font_weight='bold',
#                                ax=ax)
        
#         # Add colorbar for difficulty
#         sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=min(avg_difficulties), vmax=max(avg_difficulties)))
#         sm.set_array([])
#         cbar = plt.colorbar(sm, ax=ax, shrink=0.8)
#         cbar.set_label('Average Difficulty (Entropy)', fontsize=12)
        
#         ax.set_title(f'Topic Relationship Network - {dataset_name}\n'
#                      f'Node size ∝ Topic weight, Color ∝ Average difficulty',
#                      fontsize=16, pad=20)
        
#         # Add statistics to plot
#         ax.text(0.02, 0.02, 
#                 f"Nodes: {G.number_of_nodes()}\n"
#                 f"Edges: {G.number_of_edges()}\n"
#                 f"Edge threshold: {threshold:.3f}",
#                 fontsize=10,
#                 transform=ax.transAxes,
#                 bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
        
#         output_path = f"results/tmcl/topic_models/{dataset_name}_topic_network.png"
#         plt.savefig(output_path, bbox_inches='tight', dpi=300)
#         plt.close()
        
#         print(f"  Topic network saved to: {output_path}")
#         print(f"  Network stats: {G.number_of_nodes()} topics, {G.number_of_edges()} connections")
#     else:
#         print("  Not enough topics to create network visualization")
    
#     # 4. TOPIC-DIFFICULTY ANALYSIS
#     print("\n4. TOPIC-DIFFICULTY ANALYSIS")
    
#     # Create a DataFrame for analysis
#     analysis_data = []
#     for i in range(topic_vectors.shape[0]):
#         topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
#         n_docs = np.sum(topic_doc_mask)
        
#         if n_docs > 0:
#             avg_entropy = np.mean(difficulty_scores['entropy'][topic_doc_mask])
#             avg_max_prob = np.mean(difficulty_scores['max_prob'][topic_doc_mask])
#             avg_composite = np.mean(difficulty_scores['composite'][topic_doc_mask])
            
#             # Calculate topic purity (how dominant is the main topic)
#             topic_strength = np.mean(topic_distributions[topic_doc_mask, i])
            
#             # Get top words for this topic (if available)
#             top_words = []
#             if feature_names is not None and hasattr(lda_model, 'components_'):
#                 # Get indices of top words for this topic
#                 topic_vector = lda_model.components_[i]
#                 top_indices = np.argsort(topic_vector)[-5:][::-1]  # Top 5 words
#                 top_words = [feature_names[idx] for idx in top_indices]
            
#             analysis_data.append({
#                 'Topic': f'T{i}',
#                 'N_Docs': n_docs,
#                 'Doc_Proportion': n_docs / len(topic_distributions),
#                 'Avg_Entropy_Difficulty': avg_entropy,
#                 'Avg_MaxProb_Difficulty': avg_max_prob,
#                 'Avg_Composite_Difficulty': avg_composite,
#                 'Topic_Strength': topic_strength,
#                 'Topic_Weight': topic_weights_norm[i],
#                 'Top_Words': ', '.join(top_words) if top_words else ''
#             })
    
#     if analysis_data:  # Check if we have data
#         analysis_df = pd.DataFrame(analysis_data)
        
#         # Create visualization of topic difficulty distribution
#         fig, axes = plt.subplots(2, 3, figsize=(18, 12))
#         fig.suptitle(f'Topic Analysis - {dataset_name}', fontsize=20, y=1.02)
        
#         # Plot 1: Topic weights
#         axes[0, 0].bar(analysis_df['Topic'], analysis_df['Topic_Weight'])
#         axes[0, 0].set_title('Topic Weights', fontsize=14)
#         axes[0, 0].set_xlabel('Topic')
#         axes[0, 0].set_ylabel('Weight')
#         axes[0, 0].tick_params(axis='x', rotation=45)
        
#         # Plot 2: Number of documents per topic
#         axes[0, 1].bar(analysis_df['Topic'], analysis_df['N_Docs'])
#         axes[0, 1].set_title('Documents per Topic', fontsize=14)
#         axes[0, 1].set_xlabel('Topic')
#         axes[0, 1].set_ylabel('Number of Documents')
#         axes[0, 1].tick_params(axis='x', rotation=45)
        
#         # Plot 3: Average difficulty per topic (entropy)
#         if len(analysis_df) > 1:
#             axes[0, 2].scatter(analysis_df['Topic_Weight'], analysis_df['Avg_Entropy_Difficulty'],
#                               s=analysis_df['N_Docs']/10, alpha=0.6)
#             for idx, row in analysis_df.iterrows():
#                 axes[0, 2].annotate(row['Topic'], 
#                                   (row['Topic_Weight'], row['Avg_Entropy_Difficulty']),
#                                   fontsize=9, alpha=0.8)
#         axes[0, 2].set_title('Topic Weight vs Difficulty', fontsize=14)
#         axes[0, 2].set_xlabel('Topic Weight')
#         axes[0, 2].set_ylabel('Average Entropy Difficulty')
        
#         # Plot 4: Topic strength distribution
#         axes[1, 0].hist(analysis_df['Topic_Strength'], bins=20, alpha=0.7)
#         axes[1, 0].axvline(analysis_df['Topic_Strength'].mean(), color='red', 
#                           linestyle='--', label=f'Mean: {analysis_df["Topic_Strength"].mean():.3f}')
#         axes[1, 0].set_title('Topic Strength Distribution', fontsize=14)
#         axes[1, 0].set_xlabel('Topic Strength')
#         axes[1, 0].set_ylabel('Frequency')
#         axes[1, 0].legend()
        
#         # Plot 5: Difficulty metrics comparison
#         x = np.arange(len(analysis_df))
#         width = 0.25
#         axes[1, 1].bar(x - width, analysis_df['Avg_Entropy_Difficulty'], width, label='Entropy')
#         axes[1, 1].bar(x, analysis_df['Avg_MaxProb_Difficulty'], width, label='Max Prob')
#         axes[1, 1].bar(x + width, analysis_df['Avg_Composite_Difficulty'], width, label='Composite')
#         axes[1, 1].set_title('Difficulty Metrics by Topic', fontsize=14)
#         axes[1, 1].set_xlabel('Topic')
#         axes[1, 1].set_ylabel('Difficulty Score')
#         axes[1, 1].set_xticks(x)
#         axes[1, 1].set_xticklabels(analysis_df['Topic'], rotation=45)
#         axes[1, 1].legend()
        
#         # Plot 6: Topic-difficulty heatmap
#         difficulty_matrix = np.array([
#             analysis_df['Avg_Entropy_Difficulty'].values,
#             analysis_df['Avg_MaxProb_Difficulty'].values,
#             analysis_df['Avg_Composite_Difficulty'].values
#         ])
#         im = axes[1, 2].imshow(difficulty_matrix, aspect='auto', cmap='YlOrRd')
#         axes[1, 2].set_title('Difficulty Metrics Heatmap', fontsize=14)
#         axes[1, 2].set_xlabel('Topic')
#         axes[1, 2].set_ylabel('Difficulty Metric')
#         axes[1, 2].set_xticks(range(len(analysis_df)))
#         axes[1, 2].set_xticklabels(analysis_df['Topic'], rotation=45)
#         axes[1, 2].set_yticks(range(3))
#         axes[1, 2].set_yticklabels(['Entropy', 'MaxProb', 'Composite'])
#         plt.colorbar(im, ax=axes[1, 2])
        
#         plt.tight_layout()
#         output_path = f"results/tmcl/topic_models/{dataset_name}_topic_analysis.png"
#         plt.savefig(output_path, bbox_inches='tight', dpi=300)
#         plt.close()
        
#         print(f"  Topic analysis plot saved to: {output_path}")
        
#         # 5. TOPIC EVOLUTION IN DIFFICULTY SPACE
#         print("\n5. TOPIC EVOLUTION IN DIFFICULTY SPACE")
        
#         # Use t-SNE to visualize topics in 2D based on difficulty patterns
#         if len(analysis_df) > 2:
#             # Prepare data for t-SNE
#             tsne_data = analysis_df[['Avg_Entropy_Difficulty', 'Avg_MaxProb_Difficulty', 
#                                     'Avg_Composite_Difficulty', 'Topic_Strength', 'Topic_Weight']].values
            
#             # Normalize
#             tsne_data = (tsne_data - tsne_data.mean(axis=0)) / (tsne_data.std(axis=0) + 1e-8)
            
#             # Apply t-SNE
#             tsne = TSNE(n_components=2, random_state=42, perplexity=min(5, len(analysis_df)-1))
#             tsne_results = tsne.fit_transform(tsne_data)
            
#             # Plot
#             plt.figure(figsize=(12, 10))
#             scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], 
#                                 c=analysis_df['Avg_Composite_Difficulty'], 
#                                 s=analysis_df['Topic_Weight'] * 2000,
#                                 alpha=0.7, cmap='coolwarm', edgecolors='black')
            
#             # Add labels
#             for i, (x, y) in enumerate(tsne_results):
#                 plt.text(x, y, f"T{i}", fontsize=11, ha='center', va='center',
#                        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
            
#             plt.colorbar(scatter, label='Composite Difficulty')
#             plt.title(f'Topic Clustering by Difficulty Patterns - {dataset_name}', fontsize=16)
#             plt.xlabel('t-SNE Component 1')
#             plt.ylabel('t-SNE Component 2')
#             plt.grid(True, alpha=0.3)
            
#             # Add statistics
#             plt.figtext(0.02, 0.98, 
#                        f"Topics: {len(analysis_df)}\n"
#                        f"Difficulty range: [{analysis_df['Avg_Composite_Difficulty'].min():.3f}, "
#                        f"{analysis_df['Avg_Composite_Difficulty'].max():.3f}]",
#                        fontsize=10,
#                        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
            
#             output_path = f"results/tmcl/topic_models/{dataset_name}_topic_clustering.png"
#             plt.savefig(output_path, bbox_inches='tight', dpi=300)
#             plt.close()
            
#             print(f"  Topic clustering plot saved to: {output_path}")
#         else:
#             print("  Not enough topics for t-SNE clustering (need at least 3)")
        
#         # 6. COMPREHENSIVE STATISTICAL ANALYSIS
#         print("\n6. COMPREHENSIVE STATISTICAL ANALYSIS")
#         print("-" * 40)
        
#         # Calculate correlations
#         if len(analysis_df) > 2:
#             correlation_matrix = analysis_df[['Avg_Entropy_Difficulty', 'Avg_MaxProb_Difficulty',
#                                              'Avg_Composite_Difficulty', 'Topic_Strength',
#                                              'Topic_Weight', 'N_Docs']].corr()
            
#             print("\nCorrelation Matrix:")
#             print(correlation_matrix.round(3))
            
#             # Plot correlation heatmap
#             plt.figure(figsize=(10, 8))
#             sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
#                        square=True, linewidths=1, cbar_kws={"shrink": 0.8})
#             plt.title(f'Feature Correlations - {dataset_name}', fontsize=16)
#             plt.tight_layout()
            
#             output_path = f"results/tmcl/topic_models/{dataset_name}_correlation_heatmap.png"
#             plt.savefig(output_path, bbox_inches='tight', dpi=300)
#             plt.close()
            
#             print(f"  Correlation heatmap saved to: {output_path}")
        
#         # Calculate difficulty statistics by topic
#         print("\nDifficulty Statistics by Topic (Top 5 most difficult):")
#         difficult_topics = analysis_df.sort_values('Avg_Composite_Difficulty', ascending=False).head(5)
#         for _, row in difficult_topics.iterrows():
#             print(f"  {row['Topic']}: {row['Avg_Composite_Difficulty']:.3f} "
#                   f"(Entropy: {row['Avg_Entropy_Difficulty']:.3f}, "
#                   f"MaxProb: {row['Avg_MaxProb_Difficulty']:.3f})")
#             if row['Top_Words']:
#                 print(f"     Top words: {row['Top_Words']}")
        
#         print("\nDifficulty Statistics by Topic (Top 5 least difficult):")
#         easy_topics = analysis_df.sort_values('Avg_Composite_Difficulty').head(5)
#         for _, row in easy_topics.iterrows():
#             print(f"  {row['Topic']}: {row['Avg_Composite_Difficulty']:.3f} "
#                   f"(Entropy: {row['Avg_Entropy_Difficulty']:.3f}, "
#                   f"MaxProb: {row['Avg_MaxProb_Difficulty']:.3f})")
#             if row['Top_Words']:
#                 print(f"     Top words: {row['Top_Words']}")
#     else:
#         print("  No valid topic data for analysis")
#         analysis_df = pd.DataFrame()
    
#     # 7. SAVE COMPREHENSIVE ANALYSIS REPORT
#     print("\n7. GENERATING ANALYSIS REPORT...")
    
#     if not analysis_df.empty:
#         # Calculate network density
#         if G.number_of_nodes() > 0:
#             density = nx.density(G)
#         else:
#             density = 0
        
#         report = {
#             'dataset_name': dataset_name,
#             'n_topics': topic_vectors.shape[0],
#             'n_documents': len(topic_distributions),
#             'perplexity': perplexity,
#             'topic_analysis': analysis_df,
#             'network_stats': {
#                 'n_nodes': G.number_of_nodes(),
#                 'n_edges': G.number_of_edges(),
#                 'density': density
#             },
#             'difficulty_stats': {
#                 'overall_mean_entropy': np.mean(difficulty_scores['entropy']),
#                 'overall_std_entropy': np.std(difficulty_scores['entropy']),
#                 'topic_mean_entropy': analysis_df['Avg_Entropy_Difficulty'].mean() if not analysis_df.empty else 0,
#                 'topic_std_entropy': analysis_df['Avg_Entropy_Difficulty'].std() if not analysis_df.empty else 0,
#                 'overall_mean_composite': np.mean(difficulty_scores['composite']),
#                 'overall_std_composite': np.std(difficulty_scores['composite'])
#             }
#         }
#     else:
#         report = {
#             'dataset_name': dataset_name,
#             'n_topics': topic_vectors.shape[0],
#             'n_documents': len(topic_distributions),
#             'perplexity': perplexity,
#             'topic_analysis': pd.DataFrame(),
#             'network_stats': {
#                 'n_nodes': 0,
#                 'n_edges': 0,
#                 'density': 0
#             },
#             'difficulty_stats': {
#                 'overall_mean_entropy': np.mean(difficulty_scores['entropy']),
#                 'overall_std_entropy': np.std(difficulty_scores['entropy']),
#                 'topic_mean_entropy': 0,
#                 'topic_std_entropy': 0,
#                 'overall_mean_composite': np.mean(difficulty_scores['composite']),
#                 'overall_std_composite': np.std(difficulty_scores['composite'])
#             }
#         }
    
#     # Save report
#     output_path = f"results/tmcl/topic_models/{dataset_name}_enhanced_analysis.pkl"
#     with open(output_path, 'wb') as f:
#         pickle.dump(report, f)
    
#     print(f"  Analysis report saved to: {output_path}")
    
#     # Generate summary text file
#     summary_path = f"results/tmcl/topic_models/{dataset_name}_analysis_summary.txt"
#     with open(summary_path, 'w') as f:
#         f.write(f"ENHANCED TOPIC ANALYSIS REPORT - {dataset_name.upper()}\n")
#         f.write("=" * 60 + "\n\n")
        
#         f.write("1. DATASET STATISTICS\n")
#         f.write("-" * 40 + "\n")
#         f.write(f"   Number of topics: {topic_vectors.shape[0]}\n")
#         f.write(f"   Number of documents: {len(topic_distributions)}\n")
#         if perplexity:
#             f.write(f"   Perplexity: {perplexity:.4f}\n")
        
#         if not analysis_df.empty:
#             f.write(f"\n   Topic weight distribution:\n")
#             f.write(f"     Mean: {analysis_df['Topic_Weight'].mean():.4f}\n")
#             f.write(f"     Std: {analysis_df['Topic_Weight'].std():.4f}\n")
#             f.write(f"     Min: {analysis_df['Topic_Weight'].min():.4f}\n")
#             f.write(f"     Max: {analysis_df['Topic_Weight'].max():.4f}\n")
        
#         f.write("\n2. DIFFICULTY ANALYSIS\n")
#         f.write("-" * 40 + "\n")
#         f.write(f"   Overall difficulty (entropy): {np.mean(difficulty_scores['entropy']):.4f} "
#                 f"± {np.std(difficulty_scores['entropy']):.4f}\n")
#         f.write(f"   Overall difficulty (composite): {np.mean(difficulty_scores['composite']):.4f} "
#                 f"± {np.std(difficulty_scores['composite']):.4f}\n")
        
#         if not analysis_df.empty:
#             f.write(f"   Topic-level difficulty (composite): {analysis_df['Avg_Composite_Difficulty'].mean():.4f} "
#                     f"± {analysis_df['Avg_Composite_Difficulty'].std():.4f}\n")
        
#         if not analysis_df.empty:
#             f.write("\n3. MOST DIFFICULT TOPICS\n")
#             f.write("-" * 40 + "\n")
#             for _, row in difficult_topics.iterrows():
#                 f.write(f"   {row['Topic']}: {row['Avg_Composite_Difficulty']:.3f} "
#                        f"(docs: {row['N_Docs']}, weight: {row['Topic_Weight']:.3f})\n")
#                 if row['Top_Words']:
#                     f.write(f"     Top words: {row['Top_Words']}\n")
            
#             f.write("\n4. LEAST DIFFICULT TOPICS\n")
#             f.write("-" * 40 + "\n")
#             for _, row in easy_topics.iterrows():
#                 f.write(f"   {row['Topic']}: {row['Avg_Composite_Difficulty']:.3f} "
#                        f"(docs: {row['N_Docs']}, weight: {row['Topic_Weight']:.3f})\n")
#                 if row['Top_Words']:
#                     f.write(f"     Top words: {row['Top_Words']}\n")
        
#         f.write("\n5. TOPIC NETWORK STATISTICS\n")
#         f.write("-" * 40 + "\n")
#         f.write(f"   Number of topic nodes: {G.number_of_nodes()}\n")
#         f.write(f"   Number of connections: {G.number_of_edges()}\n")
#         f.write(f"   Network density: {nx.density(G) if G.number_of_nodes() > 0 else 0:.4f}\n")
        
#         # Calculate degree centrality
#         if G.number_of_nodes() > 0:
#             degree_centrality = nx.degree_centrality(G)
#             f.write(f"   Average degree centrality: {np.mean(list(degree_centrality.values())):.4f}\n")
        
#         f.write("\n6. KEY INSIGHTS\n")
#         f.write("-" * 40 + "\n")
#         if not analysis_df.empty and 'Topic_Weight' in analysis_df.columns and 'Avg_Composite_Difficulty' in analysis_df.columns:
#             try:
#                 corr = analysis_df['Topic_Weight'].corr(analysis_df['Avg_Composite_Difficulty'])
#                 trend = "higher" if corr > 0 else "lower" if corr < 0 else "no clear relationship with"
#                 f.write(f"   - Topics with higher weights tend to have {trend} difficulty (correlation: {corr:.3f})\n")
#             except:
#                 f.write(f"   - Could not calculate correlation between topic weight and difficulty\n")
        
#         if not analysis_df.empty:
#             f.write(f"   - Document distribution across topics: {analysis_df['N_Docs'].min()} to {analysis_df['N_Docs'].max()} documents\n")
#             f.write(f"   - Topic strength varies from {analysis_df['Topic_Strength'].min():.3f} to {analysis_df['Topic_Strength'].max():.3f}\n")
        
#         # Add dataset-specific insights
#         if dataset_name == 'ag_news':
#             f.write(f"   - AG News has relatively high average difficulty ({np.mean(difficulty_scores['composite']):.3f}), suggesting complex topic structure\n")
#         elif dataset_name == 'imdb':
#             f.write(f"   - IMDB has relatively low average difficulty ({np.mean(difficulty_scores['composite']):.3f}), suggesting clearer topic separation\n")
    
#     print(f"  Analysis summary saved to: {summary_path}")
    
#     print(f"\n{'='*60}")
#     print(f"ANALYSIS COMPLETE FOR {dataset_name.upper()}")
#     print(f"{'='*60}")
    
#     return report

# def enhanced_validation_analysis(dataset_name, results, sample_size=2000):
#     """
#     Enhanced validation analysis with comprehensive metrics
    
#     Args:
#         dataset_name: Name of the dataset
#         results: Processing results containing difficulty scores
#         sample_size: Number of samples to use for validation
#     """
    
#     print(f"\n{'='*60}")
#     print(f"ENHANCED VALIDATION: {dataset_name.upper()}")
#     print(f"{'='*60}")
    
#     # Get difficulty scores and labels
#     difficulty_entropy = results['difficulty_scores']['entropy']
#     difficulty_max_prob = results['difficulty_scores']['max_prob']
#     difficulty_composite = results['difficulty_scores']['composite']
#     labels = results.get('labels', None)
#     topic_distributions = results.get('topic_distributions', None)
    
#     if labels is None:
#         print("No labels available for validation")
#         return None
    
#     # Sample a subset for validation
#     if len(difficulty_entropy) > sample_size:
#         indices = np.random.choice(len(difficulty_entropy), sample_size, replace=False)
#         difficulty_entropy = difficulty_entropy[indices]
#         difficulty_max_prob = difficulty_max_prob[indices]
#         difficulty_composite = difficulty_composite[indices]
#         labels = labels[indices]
#         if topic_distributions is not None:
#             topic_distributions = topic_distributions[indices]
    
#     # Create comprehensive validation analysis
#     fig, axes = plt.subplots(2, 3, figsize=(18, 12))
#     fig.suptitle(f'Enhanced Validation Analysis - {dataset_name}', fontsize=20, y=1.02)
    
#     # 1. Label diversity vs difficulty (original)
#     n_bins = 10
#     bins = np.linspace(0, 1, n_bins + 1)
#     bin_indices = np.digitize(difficulty_composite, bins) - 1
    
#     bin_diversity = []
#     bin_samples = []
#     for i in range(n_bins):
#         bin_mask = (bin_indices == i)
#         if np.sum(bin_mask) > 10:  # Minimum samples threshold
#             bin_labels = labels[bin_mask]
#             diversity = len(np.unique(bin_labels)) / len(bin_labels)
#             bin_diversity.append(diversity)
#             bin_samples.append(np.sum(bin_mask))
#         else:
#             bin_diversity.append(0)
#             bin_samples.append(0)
    
#     # Plot 1: Label diversity vs difficulty
#     axes[0, 0].bar(range(n_bins), bin_diversity, alpha=0.7, color='steelblue')
#     axes[0, 0].set_title('Label Diversity vs Difficulty', fontsize=14)
#     axes[0, 0].set_xlabel('Difficulty Bin (0=easy, 9=hard)')
#     axes[0, 0].set_ylabel('Label Diversity')
#     axes[0, 0].set_xticks(range(n_bins))
#     axes[0, 0].grid(True, alpha=0.3)
    
#     # Add sample counts
#     for i, (div, count) in enumerate(zip(bin_diversity, bin_samples)):
#         if count > 0:
#             axes[0, 0].text(i, div + 0.02, str(count), ha='center', fontsize=8)
    
#     # 2. Difficulty distribution by label
#     unique_labels = np.unique(labels)
#     if len(unique_labels) <= 20:  # Only plot if reasonable number of labels
#         label_difficulties = []
#         for label in unique_labels:
#             mask = labels == label
#             if np.sum(mask) > 0:
#                 label_difficulties.append(np.mean(difficulty_composite[mask]))
#             else:
#                 label_difficulties.append(0)
        
#         axes[0, 1].bar(range(len(unique_labels)), label_difficulties, alpha=0.7)
#         axes[0, 1].set_title('Average Difficulty by Label', fontsize=14)
#         axes[0, 1].set_xlabel('Label')
#         axes[0, 1].set_ylabel('Average Difficulty')
#         axes[0, 1].set_xticks(range(len(unique_labels)))
#         axes[0, 1].set_xticklabels([f'L{i}' for i in unique_labels], rotation=45)
#         axes[0, 1].grid(True, alpha=0.3)
#     else:
#         axes[0, 1].text(0.5, 0.5, f'Too many labels ({len(unique_labels)})', 
#                        ha='center', va='center', fontsize=12)
#         axes[0, 1].set_title('Average Difficulty by Label', fontsize=14)
    
#     # 3. Topic concentration vs difficulty
#     if topic_distributions is not None:
#         # Calculate topic concentration (max probability in topic distribution)
#         topic_concentration = np.max(topic_distributions, axis=1)
        
#         # Scatter plot
#         sc = axes[0, 2].scatter(topic_concentration, difficulty_composite, 
#                                c=labels, alpha=0.5, cmap='tab20')
#         axes[0, 2].set_title('Topic Concentration vs Difficulty', fontsize=14)
#         axes[0, 2].set_xlabel('Topic Concentration (Max Probability)')
#         axes[0, 2].set_ylabel('Difficulty Score')
#         axes[0, 2].grid(True, alpha=0.3)
        
#         # Add correlation line
#         if len(topic_concentration) > 1:
#             z = np.polyfit(topic_concentration, difficulty_composite, 1)
#             p = np.poly1d(z)
#             axes[0, 2].plot(np.sort(topic_concentration), p(np.sort(topic_concentration)), 
#                            "r--", alpha=0.8, label=f'Corr: {np.corrcoef(topic_concentration, difficulty_composite)[0,1]:.3f}')
#             axes[0, 2].legend()
#     else:
#         axes[0, 2].text(0.5, 0.5, 'No topic distributions available', 
#                        ha='center', va='center', fontsize=12)
#         axes[0, 2].set_title('Topic Concentration vs Difficulty', fontsize=14)
    
#     # 4. Difficulty metrics comparison
#     scatter = axes[1, 0].scatter(difficulty_entropy, difficulty_max_prob, 
#                                 c=difficulty_composite, alpha=0.5, cmap='viridis')
#     axes[1, 0].set_title('Entropy vs MaxProb Difficulty', fontsize=14)
#     axes[1, 0].set_xlabel('Entropy Difficulty')
#     axes[1, 0].set_ylabel('MaxProb Difficulty')
#     axes[1, 0].grid(True, alpha=0.3)
    
#     # Add identity line
#     axes[1, 0].plot([0, 1], [0, 1], 'k--', alpha=0.5)
    
#     # Add colorbar
#     plt.colorbar(scatter, ax=axes[1, 0], label='Composite Difficulty')
    
#     # 5. Cumulative difficulty distribution
#     sorted_difficulty = np.sort(difficulty_composite)
#     cumulative = np.arange(1, len(sorted_difficulty) + 1) / len(sorted_difficulty)
    
#     axes[1, 1].plot(sorted_difficulty, cumulative, linewidth=2, color='darkgreen')
#     axes[1, 1].set_title('Cumulative Difficulty Distribution', fontsize=14)
#     axes[1, 1].set_xlabel('Difficulty Score')
#     axes[1, 1].set_ylabel('Cumulative Proportion')
#     axes[1, 1].grid(True, alpha=0.3)
    
#     # Add quartile lines
#     for q in [0.25, 0.5, 0.75]:
#         q_value = np.percentile(sorted_difficulty, q * 100)
#         axes[1, 1].axvline(q_value, color='red', linestyle='--', alpha=0.7)
#         axes[1, 1].text(q_value, 0.5, f'Q{q*100}%', rotation=90, va='center')
    
#     # 6. Difficulty histogram with KDE
#     axes[1, 2].hist(difficulty_composite, bins=30, alpha=0.7, density=True, 
#                    color='skyblue', edgecolor='black')
    
#     # Add KDE
#     kde = stats.gaussian_kde(difficulty_composite)
#     x_range = np.linspace(0, 1, 100)
#     axes[1, 2].plot(x_range, kde(x_range), color='darkblue', linewidth=2)
    
#     axes[1, 2].set_title('Difficulty Distribution with KDE', fontsize=14)
#     axes[1, 2].set_xlabel('Difficulty Score')
#     axes[1, 2].set_ylabel('Density')
#     axes[1, 2].grid(True, alpha=0.3)
    
#     # Add statistics
#     stats_text = (f'Mean: {np.mean(difficulty_composite):.3f}\n'
#                   f'Std: {np.std(difficulty_composite):.3f}\n'
#                   f'Skew: {stats.skew(difficulty_composite):.3f}\n'
#                   f'Kurtosis: {stats.kurtosis(difficulty_composite):.3f}')
#     axes[1, 2].text(0.02, 0.98, stats_text, transform=axes[1, 2].transAxes,
#                    fontsize=10, verticalalignment='top',
#                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
#     plt.tight_layout()
    
#     # Save the comprehensive validation plot
#     output_path = f"results/tmcl/topic_models/{dataset_name}_enhanced_validation.png"
#     plt.savefig(output_path, bbox_inches='tight', dpi=300)
#     plt.close()
    
#     print(f"Enhanced validation plot saved to: {output_path}")
    
#     # Calculate and print validation statistics
#     print("\nVALIDATION STATISTICS:")
#     print("-" * 40)
    
#     # Calculate correlation between difficulty metrics
#     corr_entropy_maxprob = np.corrcoef(difficulty_entropy, difficulty_max_prob)[0, 1]
#     corr_entropy_composite = np.corrcoef(difficulty_entropy, difficulty_composite)[0, 1]
#     corr_maxprob_composite = np.corrcoef(difficulty_max_prob, difficulty_composite)[0, 1]
    
#     print(f"Correlation between difficulty metrics:")
#     print(f"  Entropy vs MaxProb: {corr_entropy_maxprob:.4f}")
#     print(f"  Entropy vs Composite: {corr_entropy_composite:.4f}")
#     print(f"  MaxProb vs Composite: {corr_maxprob_composite:.4f}")
    
#     # Calculate difficulty by label (if not too many labels)
#     if len(unique_labels) <= 20:
#         print(f"\nDifficulty by label (top 5 most difficult):")
#         label_stats = []
#         for label in unique_labels:
#             mask = labels == label
#             if np.sum(mask) > 0:
#                 avg_diff = np.mean(difficulty_composite[mask])
#                 label_stats.append((label, avg_diff, np.sum(mask)))
        
#         # Sort by difficulty
#         label_stats.sort(key=lambda x: x[1], reverse=True)
        
#         for label, avg_diff, count in label_stats[:5]:
#             print(f"  Label {label}: {avg_diff:.3f} (n={count})")
    
#     # Calculate topic concentration statistics
#     if topic_distributions is not None:
#         print(f"\nTopic concentration statistics:")
#         print(f"  Mean concentration: {np.mean(topic_concentration):.4f}")
#         print(f"  Std concentration: {np.std(topic_concentration):.4f}")
#         corr_concentration = np.corrcoef(topic_concentration, difficulty_composite)[0,1] if len(topic_concentration) > 1 else 0
#         print(f"  Correlation with difficulty: {corr_concentration:.4f}")
    
#     # Calculate label diversity statistics
#     valid_bins = [d for d, s in zip(bin_diversity, bin_samples) if s > 0]
#     if valid_bins:
#         print(f"\nLabel diversity statistics:")
#         print(f"  Mean diversity: {np.mean(valid_bins):.4f}")
#         print(f"  Std diversity: {np.std(valid_bins):.4f}")
        
#         # Calculate correlation between bin index and diversity
#         valid_indices = [i for i, s in enumerate(bin_samples) if s > 0]
#         if len(valid_indices) > 1:
#             corr_diversity_difficulty = np.corrcoef(valid_indices, 
#                                                    [bin_diversity[i] for i in valid_indices])[0, 1]
#             print(f"  Correlation with difficulty bin: {corr_diversity_difficulty:.4f}")
    
#     # Generate validation report
#     validation_report = {
#         'dataset_name': dataset_name,
#         'n_samples': len(difficulty_composite),
#         'difficulty_stats': {
#             'mean_entropy': float(np.mean(difficulty_entropy)),
#             'std_entropy': float(np.std(difficulty_entropy)),
#             'mean_maxprob': float(np.mean(difficulty_max_prob)),
#             'std_maxprob': float(np.std(difficulty_max_prob)),
#             'mean_composite': float(np.mean(difficulty_composite)),
#             'std_composite': float(np.std(difficulty_composite)),
#         },
#         'correlations': {
#             'entropy_maxprob': float(corr_entropy_maxprob),
#             'entropy_composite': float(corr_entropy_composite),
#             'maxprob_composite': float(corr_maxprob_composite),
#         },
#         'label_diversity_by_bin': [float(d) for d in bin_diversity],
#         'validation_plot': output_path
#     }
    
#     if topic_distributions is not None:
#         validation_report['topic_concentration'] = {
#             'mean': float(np.mean(topic_concentration)),
#             'std': float(np.std(topic_concentration)),
#             'corr_with_difficulty': float(corr_concentration)
#         }
    
#     # Save validation report
#     report_path = f"results/tmcl/topic_models/{dataset_name}_validation_report.pkl"
#     with open(report_path, 'wb') as f:
#         pickle.dump(validation_report, f)
    
#     print(f"\nValidation report saved to: {report_path}")
#     print(f"{'='*60}")
    
#     return validation_report

# # Main execution
# print("\n" + "="*80)
# print("COMPREHENSIVE TOPIC MODELING ANALYSIS AND VALIDATION")
# print("="*80)

# # Define datasets to analyze
# datasets = ['ag_news', 'imdb', 'cifar10', 'cifar100', 'mnist', 'fashion_mnist']

# # Store all analysis results
# all_analysis_results = {}
# all_validation_results = {}

# # First, analyze text datasets with enhanced topic analysis
# print("\nANALYZING TEXT DATASETS WITH TOPIC VISUALIZATION...")

# for dataset_name in ['ag_news', 'imdb']:
#     print(f"\n{'='*60}")
#     print(f"Processing: {dataset_name}")
#     print(f"{'='*60}")
    
#     # Load results
#     results_path = f"results/tmcl/topic_models/{dataset_name}_tmcl_results.pkl"
#     if os.path.exists(results_path):
#         with open(results_path, 'rb') as f:
#             results = pickle.load(f)
        
#         # Get feature names for text datasets
#         if 'vocabulary' in results:
#             feature_names = results['vocabulary']
#             print(f"  Found vocabulary with {len(feature_names)} words")
#         else:
#             # Try to get from feature matrix shape
#             n_features = results.get('feature_matrix', np.zeros((1, 2000))).shape[1]
#             feature_names = [f'word_{i}' for i in range(min(n_features, 2000))]
#             print(f"  Created dummy feature names (n={len(feature_names)})")
        
#         # Run enhanced topic analysis
#         analysis_result = enhanced_topic_analysis(
#             dataset_name, 
#             results, 
#             feature_names=feature_names
#         )
        
#         all_analysis_results[dataset_name] = analysis_result
        
#         # Run enhanced validation
#         validation_result = enhanced_validation_analysis(dataset_name, results)
#         all_validation_results[dataset_name] = validation_result
        
#     else:
#         print(f"Results file not found for {dataset_name}")

# # Now analyze image datasets
# print("\nANALYZING IMAGE DATASETS...")

# for dataset_name in ['cifar10', 'cifar100', 'mnist', 'fashion_mnist']:
#     print(f"\n{'='*60}")
#     print(f"Processing: {dataset_name}")
#     print(f"{'='*60}")
    
#     # Load results
#     results_path = f"results/tmcl/topic_models/{dataset_name}_tmcl_results.pkl"
#     if os.path.exists(results_path):
#         with open(results_path, 'rb') as f:
#             results = pickle.load(f)
        
#         # For image datasets, we don't have feature names
#         # Run analysis without feature names
#         analysis_result = enhanced_topic_analysis(dataset_name, results)
#         all_analysis_results[dataset_name] = analysis_result
        
#         # Run enhanced validation
#         validation_result = enhanced_validation_analysis(dataset_name, results)
#         all_validation_results[dataset_name] = validation_result
        
#     else:
#         print(f"Results file not found for {dataset_name}")

# # 8. CROSS-DATASET COMPARATIVE ANALYSIS
# print("\n" + "="*80)
# print("CROSS-DATASET COMPARATIVE ANALYSIS")
# print("="*80)

# # Create comparative analysis
# comparison_data = []

# for dataset_name, analysis in all_analysis_results.items():
#     if analysis is not None:
#         # Safely extract values with defaults
#         topic_analysis = analysis.get('topic_analysis', pd.DataFrame())
#         difficulty_stats = analysis.get('difficulty_stats', {})
#         network_stats = analysis.get('network_stats', {})
        
#         # Get Avg_Composite_Difficulty with fallback
#         if not topic_analysis.empty and 'Avg_Composite_Difficulty' in topic_analysis.columns:
#             difficulty_range = topic_analysis['Avg_Composite_Difficulty'].max() - topic_analysis['Avg_Composite_Difficulty'].min()
#         else:
#             difficulty_range = 0
        
#         comparison_data.append({
#             'Dataset': dataset_name,
#             'N_Topics': analysis.get('n_topics', 0),
#             'N_Documents': analysis.get('n_documents', 0),
#             'Perplexity': analysis.get('perplexity', 0),
#             'Avg_Topic_Weight': topic_analysis['Topic_Weight'].mean() if not topic_analysis.empty and 'Topic_Weight' in topic_analysis.columns else 0,
#             'Std_Topic_Weight': topic_analysis['Topic_Weight'].std() if not topic_analysis.empty and 'Topic_Weight' in topic_analysis.columns else 0,
#             'Avg_Difficulty': difficulty_stats.get('overall_mean_composite', 0),
#             'Std_Difficulty': difficulty_stats.get('overall_std_composite', 0),
#             'Network_Density': network_stats.get('density', 0),
#             'Difficulty_Range': difficulty_range
#         })

# if comparison_data:
#     comparison_df = pd.DataFrame(comparison_data)
    
#     # Create comparative visualizations
#     fig, axes = plt.subplots(2, 3, figsize=(18, 12))
#     fig.suptitle('Cross-Dataset Comparative Analysis', fontsize=20, y=1.02)
    
#     # Plot 1: Average difficulty by dataset
#     if len(comparison_df) > 0:
#         axes[0, 0].bar(comparison_df['Dataset'], comparison_df['Avg_Difficulty'], 
#                        alpha=0.7, color='steelblue')
#         axes[0, 0].set_title('Average Difficulty by Dataset', fontsize=14)
#         axes[0, 0].set_xlabel('Dataset')
#         axes[0, 0].set_ylabel('Average Difficulty')
#         axes[0, 0].tick_params(axis='x', rotation=45)
#         axes[0, 0].grid(True, alpha=0.3, axis='y')
        
#         # Add error bars
#         axes[0, 0].errorbar(comparison_df['Dataset'], comparison_df['Avg_Difficulty'], 
#                            yerr=comparison_df['Std_Difficulty'], 
#                            fmt='none', color='black', capsize=5)
    
#     # Plot 2: Topic statistics
#     if len(comparison_df) > 0:
#         x = np.arange(len(comparison_df))
#         width = 0.35
#         axes[0, 1].bar(x - width/2, comparison_df['N_Topics'], width, label='Number of Topics')
#         axes[0, 1].bar(x + width/2, comparison_df['N_Documents']/1000, width, label='Documents (thousands)')
#         axes[0, 1].set_title('Dataset Scale Comparison', fontsize=14)
#         axes[0, 1].set_xlabel('Dataset')
#         axes[0, 1].set_ylabel('Count')
#         axes[0, 1].set_xticks(x)
#         axes[0, 1].set_xticklabels(comparison_df['Dataset'], rotation=45)
#         axes[0, 1].legend()
#         axes[0, 1].grid(True, alpha=0.3, axis='y')
    
#     # Plot 3: Network density vs difficulty
#     if len(comparison_df) > 0:
#         scatter = axes[0, 2].scatter(comparison_df['Network_Density'], comparison_df['Avg_Difficulty'],
#                                     s=comparison_df['N_Topics']*50, alpha=0.7,
#                                     c=range(len(comparison_df)), cmap='tab20')
#         axes[0, 2].set_title('Network Density vs Average Difficulty', fontsize=14)
#         axes[0, 2].set_xlabel('Network Density')
#         axes[0, 2].set_ylabel('Average Difficulty')
#         axes[0, 2].grid(True, alpha=0.3)
        
#         # Add dataset labels
#         for i, row in comparison_df.iterrows():
#             axes[0, 2].annotate(row['Dataset'], 
#                                (row['Network_Density'], row['Avg_Difficulty']),
#                                fontsize=9, alpha=0.8)
    
#     # Plot 4: Perplexity comparison
#     if len(comparison_df) > 0:
#         axes[1, 0].bar(comparison_df['Dataset'], comparison_df['Perplexity'], 
#                        alpha=0.7, color='purple')
#         axes[1, 0].set_title('Model Perplexity by Dataset', fontsize=14)
#         axes[1, 0].set_xlabel('Dataset')
#         axes[1, 0].set_ylabel('Perplexity')
#         axes[1, 0].tick_params(axis='x', rotation=45)
#         axes[1, 0].grid(True, alpha=0.3, axis='y')
    
#     # Plot 5: Difficulty range comparison
#     if len(comparison_df) > 0:
#         axes[1, 1].bar(comparison_df['Dataset'], comparison_df['Difficulty_Range'], 
#                        alpha=0.7, color='coral')
#         axes[1, 1].set_title('Difficulty Range by Dataset', fontsize=14)
#         axes[1, 1].set_xlabel('Dataset')
#         axes[1, 1].set_ylabel('Difficulty Range (Max - Min)')
#         axes[1, 1].tick_params(axis='x', rotation=45)
#         axes[1, 1].grid(True, alpha=0.3, axis='y')
    
#     # Plot 6: Dataset clustering based on metrics
#     if len(comparison_df) > 1:
#         try:
#             from sklearn.preprocessing import StandardScaler
#             from sklearn.cluster import KMeans
#             from sklearn.decomposition import PCA
            
#             # Prepare data for clustering
#             cluster_data = comparison_df[['Avg_Difficulty', 'Std_Difficulty', 
#                                          'Network_Density', 'Difficulty_Range', 'Perplexity']].values
            
#             # Handle NaN values
#             cluster_data = np.nan_to_num(cluster_data)
#             cluster_data = StandardScaler().fit_transform(cluster_data)
            
#             # Apply K-means clustering
#             kmeans = KMeans(n_clusters=min(3, len(cluster_data)), random_state=42)
#             clusters = kmeans.fit_predict(cluster_data)
            
#             # Use PCA for 2D visualization
#             pca = PCA(n_components=2)
#             pca_result = pca.fit_transform(cluster_data)
            
#             scatter = axes[1, 2].scatter(pca_result[:, 0], pca_result[:, 1], 
#                                         c=clusters, s=100, alpha=0.7, cmap='Set2')
            
#             # Add dataset labels
#             for i, (x, y) in enumerate(pca_result):
#                 axes[1, 2].annotate(comparison_df.iloc[i]['Dataset'], 
#                                    (x, y), fontsize=10, alpha=0.8)
            
#             axes[1, 2].set_title('Dataset Clustering Based on Metrics', fontsize=14)
#             axes[1, 2].set_xlabel(f'PCA Component 1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
#             axes[1, 2].set_ylabel(f'PCA Component 2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
#             axes[1, 2].grid(True, alpha=0.3)
            
#             # Add legend for clusters
#             from matplotlib.lines import Line2D
#             legend_elements = [Line2D([0], [0], marker='o', color='w', 
#                                      markerfacecolor=plt.cm.Set2(i), markersize=10,
#                                      label=f'Cluster {i}') 
#                               for i in range(len(np.unique(clusters)))]
#             axes[1, 2].legend(handles=legend_elements, loc='best')
#         except Exception as e:
#             axes[1, 2].text(0.5, 0.5, f'Clustering error: {str(e)}', 
#                            ha='center', va='center', fontsize=10)
#             axes[1, 2].set_title('Dataset Clustering Based on Metrics', fontsize=14)
#     else:
#         axes[1, 2].text(0.5, 0.5, 'Insufficient data for clustering', 
#                        ha='center', va='center', fontsize=12)
#         axes[1, 2].set_title('Dataset Clustering Based on Metrics', fontsize=14)
    
#     plt.tight_layout()
    
#     # Save comparative analysis
#     output_path = "results/tmcl/topic_models/cross_dataset_comparison.png"
#     plt.savefig(output_path, bbox_inches='tight', dpi=300)
#     plt.close()
    
#     print(f"Cross-dataset comparison plot saved to: {output_path}")
    
#     # Create comprehensive summary report
#     print("\n" + "="*80)
#     print("COMPREHENSIVE SUMMARY REPORT")
#     print("="*80)
    
#     print("\n1. DATASET OVERVIEW:")
#     print("-" * 40)
#     for i, row in comparison_df.iterrows():
#         print(f"{row['Dataset']}:")
#         print(f"  Topics: {row['N_Topics']}, Documents: {row['N_Documents']:,}")
#         print(f"  Perplexity: {row['Perplexity']:.2f}")
#         print(f"  Avg Difficulty: {row['Avg_Difficulty']:.3f} ± {row['Std_Difficulty']:.3f}")
#         print(f"  Network Density: {row['Network_Density']:.3f}")
#         print()
    
#     print("\n2. KEY FINDINGS:")
#     print("-" * 40)
    
#     # Find hardest and easiest datasets
#     if not comparison_df.empty:
#         hardest_dataset = comparison_df.loc[comparison_df['Avg_Difficulty'].idxmax()]
#         easiest_dataset = comparison_df.loc[comparison_df['Avg_Difficulty'].idxmin()]
        
#         print(f"Hardest Dataset: {hardest_dataset['Dataset']} "
#               f"(Difficulty: {hardest_dataset['Avg_Difficulty']:.3f})")
#         print(f"Easiest Dataset: {easiest_dataset['Dataset']} "
#               f"(Difficulty: {easiest_dataset['Avg_Difficulty']:.3f})")
#         print(f"Difficulty Range Across Datasets: {comparison_df['Avg_Difficulty'].max() - comparison_df['Avg_Difficulty'].min():.3f}")
        
#         # Calculate average network density
#         avg_density = comparison_df['Network_Density'].mean()
#         print(f"\nAverage Network Density: {avg_density:.3f}")
#         print("  Higher density indicates more interconnected topics")
#         print("  Lower density indicates more distinct, separate topics")
        
#         # Text vs Image dataset comparison
#         text_datasets = [d for d in comparison_df['Dataset'] if d in ['ag_news', 'imdb']]
#         image_datasets = [d for d in comparison_df['Dataset'] if d not in ['ag_news', 'imdb']]
        
#         if text_datasets and image_datasets:
#             text_avg_diff = comparison_df[comparison_df['Dataset'].isin(text_datasets)]['Avg_Difficulty'].mean()
#             image_avg_diff = comparison_df[comparison_df['Dataset'].isin(image_datasets)]['Avg_Difficulty'].mean()
            
#             print(f"\nText vs Image Dataset Comparison:")
#             print(f"  Text datasets avg difficulty: {text_avg_diff:.3f}")
#             print(f"  Image datasets avg difficulty: {image_avg_diff:.3f}")
#             print(f"  Difference: {abs(text_avg_diff - image_avg_diff):.3f}")
    
#     print("\n3. RECOMMENDATIONS FOR TMCL:")
#     print("-" * 40)
#     print("Based on the analysis, consider the following:")
#     print("1. For datasets with high average difficulty:")
#     print("   - Implement stronger regularization")
#     print("   - Consider curriculum learning with gradual difficulty increase")
#     print("   - Use more sophisticated difficulty metrics")
    
#     print("\n2. For datasets with low network density:")
#     print("   - Topics are more distinct, which is good for interpretability")
#     print("   - Consider topic-guided data augmentation")
    
#     print("\n3. For datasets with high difficulty variance:")
#     print("   - Implement adaptive sampling strategies")
#     print("   - Consider multi-stage training approaches")
    
#     print("\n4. GENERAL RECOMMENDATIONS:")
#     print("   - Use composite difficulty metric for balanced assessment")
#     print("   - Monitor topic evolution during training")
#     print("   - Validate difficulty scores with downstream task performance")
    
#     # Save comprehensive summary
#     summary_path = "results/tmcl/topic_models/comprehensive_analysis_summary.txt"
#     with open(summary_path, 'w') as f:
#         f.write("="*80 + "\n")
#         f.write("COMPREHENSIVE TOPIC MODELING ANALYSIS SUMMARY\n")
#         f.write("="*80 + "\n\n")
        
#         f.write("1. DATASET OVERVIEW\n")
#         f.write("-" * 40 + "\n")
#         for i, row in comparison_df.iterrows():
#             f.write(f"{row['Dataset'].upper()}:\n")
#             f.write(f"  Topics: {row['N_Topics']}\n")
#             f.write(f"  Documents: {row['N_Documents']:,}\n")
#             f.write(f"  Perplexity: {row['Perplexity']:.2f}\n")
#             f.write(f"  Average Difficulty: {row['Avg_Difficulty']:.3f} ± {row['Std_Difficulty']:.3f}\n")
#             f.write(f"  Difficulty Range: {row['Difficulty_Range']:.3f}\n")
#             f.write(f"  Network Density: {row['Network_Density']:.3f}\n")
#             f.write(f"  Average Topic Weight: {row['Avg_Topic_Weight']:.3f}\n")
#             f.write(f"  Topic Weight Std: {row['Std_Topic_Weight']:.3f}\n\n")
        
#         f.write("2. KEY FINDINGS\n")
#         f.write("-" * 40 + "\n")
#         if not comparison_df.empty:
#             f.write(f"Hardest Dataset: {hardest_dataset['Dataset']} "
#                    f"(Difficulty: {hardest_dataset['Avg_Difficulty']:.3f})\n")
#             f.write(f"Easiest Dataset: {easiest_dataset['Dataset']} "
#                    f"(Difficulty: {easiest_dataset['Avg_Difficulty']:.3f})\n")
#             f.write(f"Overall Difficulty Range: {comparison_df['Avg_Difficulty'].max() - comparison_df['Avg_Difficulty'].min():.3f}\n")
#             f.write(f"Average Network Density: {avg_density:.3f}\n")
            
#             if text_datasets and image_datasets:
#                 f.write(f"\nText vs Image Comparison:\n")
#                 f.write(f"  Text datasets: {text_avg_diff:.3f}\n")
#                 f.write(f"  Image datasets: {image_avg_diff:.3f}\n")
#                 f.write(f"  Difference: {abs(text_avg_diff - image_avg_diff):.3f}\n")
        
#         f.write("\n3. RECOMMENDATIONS FOR TMCL\n")
#         f.write("-" * 40 + "\n")
#         f.write("For Topic Modeling Curriculum Learning implementation:\n")
#         f.write("1. Use adaptive difficulty sampling based on topic distributions\n")
#         f.write("2. Monitor topic evolution during training\n")
#         f.write("3. Validate with multiple difficulty metrics (entropy, max_prob, composite)\n")
#         f.write("4. Consider dataset-specific difficulty thresholds\n")
#         f.write("5. Implement curriculum learning for high-difficulty datasets\n")
#         f.write("6. Use topic networks to understand relationships between concepts\n")
#         f.write("7. For text data, leverage topic-word distributions for interpretability\n")
#         f.write("8. For image data, consider feature-space clustering for topic assignment\n")
    
#     print(f"\nComprehensive summary saved to: {summary_path}")
    
#     # Save comparison dataframe for future analysis
#     comparison_df_path = "results/tmcl/topic_models/dataset_comparison.csv"
#     comparison_df.to_csv(comparison_df_path, index=False)
#     print(f"Comparison data saved to: {comparison_df_path}")
# else:
#     print("No valid analysis data available for cross-dataset comparison")

# print("\n" + "="*80)
# print("ANALYSIS COMPLETE!")
# print(f"All visualizations and reports saved to: results/tmcl/topic_models/")
# print("="*80)

# # Print final summary
# print("\nFINAL SUMMARY:")
# print("-" * 40)
# for dataset_name in datasets:
#     if dataset_name in all_analysis_results:
#         analysis = all_analysis_results[dataset_name]
#         if analysis and 'difficulty_stats' in analysis:
#             stats = analysis['difficulty_stats']
#             print(f"{dataset_name}:")
#             print(f"  Composite Difficulty: {stats.get('overall_mean_composite', 0):.3f} ± {stats.get('overall_std_composite', 0):.3f}")
#             print(f"  Topics: {analysis.get('n_topics', 0)}, Documents: {analysis.get('n_documents', 0):,}")
#             print()

In [None]:
# def create_interactive_topic_network(dataset_name, results, feature_names=None, top_n_words=10):
#     """
#     Create interactive topic network visualization with actual topic labels
    
#     Args:
#         dataset_name: Name of the dataset
#         results: Processing results containing LDA model and topic distributions
#         feature_names: Vocabulary for text datasets
#         top_n_words: Number of top words to show for each topic
#     """
    
#     print(f"\nCreating interactive topic network for {dataset_name}...")
    
#     # Extract components from results
#     lda_model = results['lda_model']
#     topic_distributions = results['topic_distributions']
#     difficulty_scores = results['difficulty_scores']
    
#     # Get topic components
#     topic_vectors = lda_model.components_  # Shape: n_topics x n_features
#     topic_similarity = cosine_similarity(topic_vectors)
    
#     # Calculate topic weights
#     topic_weights = np.sum(topic_distributions, axis=0)
#     topic_weights_norm = topic_weights / np.sum(topic_weights)
    
#     # Create topic relationship graph
#     G = nx.Graph()
    
#     # Get actual topic labels (top words for each topic)
#     topic_labels = []
#     topic_full_labels = []  # Full label with all top words
#     topic_words_list = []   # List of top words for each topic
    
#     for i in range(topic_vectors.shape[0]):
#         # Get top words for this topic
#         if feature_names is not None and len(feature_names) > 0:
#             topic_vector = topic_vectors[i]
#             # Get indices of top words for this topic
#             top_indices = np.argsort(topic_vector)[-top_n_words:][::-1]
#             top_words = []
#             for idx in top_indices:
#                 if idx < len(feature_names):
#                     if isinstance(feature_names[idx], (int, np.integer)):
#                         # If it's an integer, it might be an index
#                         top_words.append(f"word_{idx}")
#                     else:
#                         top_words.append(str(feature_names[idx]))
#                 else:
#                     top_words.append(f"word_{idx}")
            
#             # Create short label (first 3 words)
#             short_label = f"T{i}: {', '.join(top_words[:3])}"
#             # Create full label
#             full_label = f"Topic {i}<br>Top words: {', '.join(top_words)}"
#             topic_labels.append(short_label)
#             topic_full_labels.append(full_label)
#             topic_words_list.append(top_words)
#         else:
#             # For image datasets or when no feature names
#             topic_labels.append(f"Topic {i}")
#             topic_full_labels.append(f"Topic {i}")
#             topic_words_list.append([])
    
#     # Add nodes with topic information
#     for i in range(topic_vectors.shape[0]):
#         # Calculate average difficulty for documents where this topic is dominant
#         topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
#         if np.sum(topic_doc_mask) > 0:
#             avg_difficulty = np.mean(difficulty_scores['entropy'][topic_doc_mask])
#             avg_composite = np.mean(difficulty_scores['composite'][topic_doc_mask])
#         else:
#             avg_difficulty = 0.5
#             avg_composite = 0.5
        
#         # Node size based on topic weight
#         node_size = topic_weights_norm[i] * 100 + 10
        
#         # Number of documents for this topic
#         n_docs = np.sum(topic_doc_mask)
        
#         # Add to graph
#         G.add_node(i,
#                    label=topic_labels[i],
#                    full_label=topic_full_labels[i],
#                    top_words=topic_words_list[i],
#                    size=node_size,
#                    avg_difficulty=avg_difficulty,
#                    avg_composite=avg_composite,
#                    weight=topic_weights_norm[i],
#                    n_docs=n_docs,
#                    topic_id=i)
    
#     # Add edges based on similarity
#     if len(topic_similarity) > 1:
#         # Flatten similarity matrix excluding diagonal
#         flat_similarities = topic_similarity[np.triu_indices_from(topic_similarity, k=1)]
#         if len(flat_similarities) > 0:
#             threshold = np.percentile(flat_similarities, 75)
#         else:
#             threshold = 0.5
#     else:
#         threshold = 0.5
    
#     edges = []
#     edge_weights = []
    
#     for i in range(topic_vectors.shape[0]):
#         for j in range(i+1, topic_vectors.shape[0]):
#             if topic_similarity[i, j] > threshold:
#                 G.add_edge(i, j, weight=topic_similarity[i, j])
#                 edges.append((i, j))
#                 edge_weights.append(topic_similarity[i, j])
    
#     # Use spring layout for node positions
#     if len(G.nodes()) > 0:
#         pos = nx.spring_layout(G, k=1.5, iterations=50, seed=42)
#     else:
#         print(f"  Warning: No nodes in graph for {dataset_name}")
#         return None, None
    
#     # Create edge traces
#     edge_x = []
#     edge_y = []
#     edge_hovertext = []
    
#     for edge in G.edges():
#         x0, y0 = pos[edge[0]]
#         x1, y1 = pos[edge[1]]
#         edge_x.extend([x0, x1, None])
#         edge_y.extend([y0, y1, None])
        
#         # Get edge weight
#         weight = G[edge[0]][edge[1]]['weight']
        
#         # Edge hover text
#         edge_hovertext.append(
#             f"Topic {edge[0]} ↔ Topic {edge[1]}<br>"
#             f"Similarity: {weight:.3f}"
#         )
    
#     edge_trace = go.Scatter(
#         x=edge_x, y=edge_y,
#         line=dict(width=1.5, color='#888'),
#         hoverinfo='none',
#         mode='lines',
#         showlegend=False
#     )
    
#     # Create node traces
#     node_x = []
#     node_y = []
#     node_text = []
#     node_hovertext = []
#     node_sizes = []
#     node_colors = []
    
#     for node in G.nodes():
#         x, y = pos[node]
#         node_x.append(x)
#         node_y.append(y)
        
#         # Get node attributes
#         node_data = G.nodes[node]
#         node_text.append(node_data['label'])
        
#         # Create hover text with topic words
#         if node_data['top_words']:
#             words_str = "<br>".join([f"• {word}" for word in node_data['top_words']])
#         else:
#             words_str = "No words available"
        
#         hover_text = (
#             f"<b>Topic {node_data['topic_id']}</b><br>"
#             f"{words_str}<br><br>"
#             f"Average Difficulty (Entropy): {node_data['avg_difficulty']:.3f}<br>"
#             f"Average Difficulty (Composite): {node_data['avg_composite']:.3f}<br>"
#             f"Topic Weight: {node_data['weight']:.3f}<br>"
#             f"Number of Documents: {node_data['n_docs']}"
#         )
#         node_hovertext.append(hover_text)
        
#         # Node size based on topic weight
#         node_sizes.append(node_data['size'])
        
#         # Node color based on average difficulty
#         node_colors.append(node_data['avg_composite'])
    
#     node_trace = go.Scatter(
#         x=node_x, y=node_y,
#         mode='markers+text',
#         text=node_text,
#         textposition="top center",
#         hoverinfo='text',
#         hovertext=node_hovertext,
#         marker=dict(
#             showscale=True,
#             colorscale='RdYlBu_r',  # Reversed: red=hard, blue=easy
#             color=node_colors,
#             size=node_sizes,
#             colorbar=dict(
#                 thickness=15,
#                 title='Average Difficulty<br>(Composite)',
#                 xanchor='left',
#                 title_side='right'  # Fixed: changed from titleside to title_side
#             ),
#             line=dict(width=2, color='white')
#         ),
#         showlegend=False
#     )
    
#     # Create the figure
#     fig = go.Figure(data=[edge_trace, node_trace],
#                    layout=go.Layout(
#                        title=f'Interactive Topic Network - {dataset_name}<br>'
#                             f'Node size ∝ Topic weight, Color ∝ Average difficulty',
#                        titlefont_size=16,
#                        showlegend=False,
#                        hovermode='closest',
#                        margin=dict(b=20, l=5, r=5, t=40),
#                        annotations=[dict(
#                            text=f"Network Statistics:<br>"
#                                f"Nodes: {G.number_of_nodes()}<br>"
#                                f"Edges: {G.number_of_edges()}<br>"
#                                f"Edge threshold: {threshold:.3f}",
#                            showarrow=False,
#                            xref="paper", yref="paper",
#                            x=0.02, y=0.02,
#                            bgcolor="white",
#                            bordercolor="black",
#                            borderwidth=1,
#                            borderpad=4
#                        )],
#                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
#                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
#                    ))
    
#     # Save as HTML
#     output_path = f"results/tmcl/topic_models/{dataset_name}_interactive_network.html"
#     os.makedirs(os.path.dirname(output_path), exist_ok=True)
#     fig.write_html(output_path)
#     print(f"  Interactive network saved to: {output_path}")
    
#     # Also create a static version if kaleido is available
#     try:
#         fig.update_layout(
#             width=1200,
#             height=800,
#             font=dict(size=12)
#         )
#         static_output_path = f"results/tmcl/topic_models/{dataset_name}_interactive_network_static.png"
#         fig.write_image(static_output_path)
#         print(f"  Static version saved to: {static_output_path}")
#     except Exception as e:
#         print(f"  Note: Could not save static image. Install kaleido: pip install kaleido")
    
#     return fig, G

# def create_interactive_topic_dashboard(dataset_name, results, feature_names=None):
#     """
#     Create comprehensive interactive dashboard for topic analysis
    
#     Args:
#         dataset_name: Name of the dataset
#         results: Processing results
#         feature_names: Vocabulary for text datasets
#     """
    
#     print(f"\nCreating interactive dashboard for {dataset_name}...")
    
#     # Extract components
#     lda_model = results['lda_model']
#     topic_distributions = results['topic_distributions']
#     difficulty_scores = results['difficulty_scores']
#     labels = results.get('labels', None)
    
#     # Calculate topic statistics
#     topic_vectors = lda_model.components_
#     topic_weights = np.sum(topic_distributions, axis=0)
#     topic_weights_norm = topic_weights / np.sum(topic_weights)
    
#     # Create DataFrame for analysis
#     analysis_data = []
#     for i in range(topic_vectors.shape[0]):
#         topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
#         n_docs = np.sum(topic_doc_mask)
        
#         if n_docs > 0:
#             avg_entropy = np.mean(difficulty_scores['entropy'][topic_doc_mask])
#             avg_max_prob = np.mean(difficulty_scores['max_prob'][topic_doc_mask])
#             avg_composite = np.mean(difficulty_scores['composite'][topic_doc_mask])
#             topic_strength = np.mean(topic_distributions[topic_doc_mask, i])
            
#             # Get top words
#             top_words = []
#             if feature_names is not None and len(feature_names) > 0:
#                 topic_vector = topic_vectors[i]
#                 top_indices = np.argsort(topic_vector)[-10:][::-1]
#                 for idx in top_indices:
#                     if idx < len(feature_names):
#                         if isinstance(feature_names[idx], (int, np.integer)):
#                             top_words.append(f"word_{idx}")
#                         else:
#                             top_words.append(str(feature_names[idx]))
#                     else:
#                         top_words.append(f"word_{idx}")
            
#             analysis_data.append({
#                 'Topic': f'T{i}',
#                 'Topic_ID': i,
#                 'N_Docs': n_docs,
#                 'Doc_Proportion': n_docs / len(topic_distributions),
#                 'Avg_Entropy_Difficulty': avg_entropy,
#                 'Avg_MaxProb_Difficulty': avg_max_prob,
#                 'Avg_Composite_Difficulty': avg_composite,
#                 'Topic_Strength': topic_strength,
#                 'Topic_Weight': topic_weights_norm[i],
#                 'Top_Words': ', '.join(top_words) if top_words else 'N/A'
#             })
    
#     analysis_df = pd.DataFrame(analysis_data)
    
#     # Create interactive dashboard with multiple plots
#     fig = make_subplots(
#         rows=2, cols=2,
#         subplot_titles=(
#             'Topic Difficulty Distribution',
#             'Topic Weight vs Difficulty',
#             'Difficulty Metrics by Topic',
#             'Topic Document Distribution'
#         ),
#         specs=[
#             [{'type': 'scatter'}, {'type': 'scatter'}],
#             [{'type': 'bar'}, {'type': 'pie'}]
#         ],
#         vertical_spacing=0.12,
#         horizontal_spacing=0.1
#     )
    
#     # Plot 1: Topic Difficulty Distribution (scatter plot)
#     fig.add_trace(
#         go.Scatter(
#             x=analysis_df['Topic'],
#             y=analysis_df['Avg_Composite_Difficulty'],
#             mode='markers+text',
#             text=analysis_df['Topic'],
#             textposition='top center',
#             marker=dict(
#                 size=analysis_df['Topic_Weight'] * 100,
#                 color=analysis_df['Avg_Composite_Difficulty'],
#                 colorscale='RdYlBu_r',
#                 showscale=True,
#                 colorbar=dict(
#                     title='Difficulty',
#                     x=0.45,
#                     y=0.95,
#                     len=0.3
#                 ),
#                 line=dict(width=2, color='DarkSlateGrey')
#             ),
#             customdata=analysis_df[['Top_Words', 'N_Docs', 'Topic_Weight']].values,
#             hovertemplate=(
#                 '<b>Topic: %{x}</b><br>'
#                 'Difficulty: %{y:.3f}<br>'
#                 'Top Words: %{customdata[0]}<br>'
#                 'Documents: %{customdata[1]}<br>'
#                 'Weight: %{customdata[2]:.3f}<br>'
#                 '<extra></extra>'
#             ),
#             name='Topics'
#         ),
#         row=1, col=1
#     )
    
#     # Add difficulty thresholds
#     if len(analysis_df) > 0:
#         mean_diff = analysis_df['Avg_Composite_Difficulty'].mean()
#         std_diff = analysis_df['Avg_Composite_Difficulty'].std()
        
#         fig.add_hline(y=mean_diff, line_dash="dash", line_color="gray", 
#                       annotation_text=f"Mean: {mean_diff:.3f}", 
#                       annotation_position="top right", row=1, col=1)
#         fig.add_hline(y=mean_diff + std_diff, line_dash="dot", line_color="red", 
#                       annotation_text=f"+1 std", annotation_position="top right", row=1, col=1)
#         fig.add_hline(y=mean_diff - std_diff, line_dash="dot", line_color="green", 
#                       annotation_text=f"-1 std", annotation_position="top right", row=1, col=1)
    
#     # Plot 2: Topic Weight vs Difficulty (bubble chart)
#     if len(analysis_df) > 0:
#         fig.add_trace(
#             go.Scatter(
#                 x=analysis_df['Topic_Weight'],
#                 y=analysis_df['Avg_Composite_Difficulty'],
#                 mode='markers+text',
#                 text=analysis_df['Topic'],
#                 textposition='top center',
#                 marker=dict(
#                     size=analysis_df['N_Docs'] / max(analysis_df['N_Docs']) * 50 + 10 if max(analysis_df['N_Docs']) > 0 else 20,
#                     color=analysis_df['Avg_Composite_Difficulty'],
#                     colorscale='Viridis',
#                     showscale=False
#                 ),
#                 customdata=analysis_df[['Top_Words', 'N_Docs', 'Topic_ID']].values,
#                 hovertemplate=(
#                     '<b>Topic: %{text}</b><br>'
#                     'Weight: %{x:.3f}<br>'
#                     'Difficulty: %{y:.3f}<br>'
#                     'Top Words: %{customdata[0]}<br>'
#                     'Documents: %{customdata[1]}<br>'
#                     'ID: %{customdata[2]}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Weight vs Difficulty'
#             ),
#             row=1, col=2
#         )
        
#         # Add trend line
#         if len(analysis_df) > 1:
#             z = np.polyfit(analysis_df['Topic_Weight'], analysis_df['Avg_Composite_Difficulty'], 1)
#             p = np.poly1d(z)
#             x_range = np.linspace(analysis_df['Topic_Weight'].min(), analysis_df['Topic_Weight'].max(), 100)
#             fig.add_trace(
#                 go.Scatter(
#                     x=x_range,
#                     y=p(x_range),
#                     mode='lines',
#                     line=dict(color='red', dash='dash'),
#                     name='Trend',
#                     showlegend=False
#                 ),
#                 row=1, col=2
#             )
    
#     # Plot 3: Difficulty Metrics by Topic (grouped bar chart)
#     if len(analysis_df) > 0:
#         for i, metric in enumerate(['Avg_Entropy_Difficulty', 'Avg_MaxProb_Difficulty', 'Avg_Composite_Difficulty']):
#             fig.add_trace(
#                 go.Bar(
#                     x=analysis_df['Topic'],
#                     y=analysis_df[metric],
#                     name=metric.replace('Avg_', '').replace('_', ' '),
#                     marker_color=px.colors.qualitative.Set2[i],
#                     hovertemplate=(
#                         '<b>Topic: %{x}</b><br>'
#                         'Metric: %{fullData.name}<br>'
#                         'Value: %{y:.3f}<br>'
#                         '<extra></extra>'
#                     ),
#                     showlegend=True if i == 0 else False
#                 ),
#                 row=2, col=1
#             )
    
#     # Plot 4: Topic Document Distribution (pie chart)
#     if len(analysis_df) > 0:
#         fig.add_trace(
#             go.Pie(
#                 labels=analysis_df['Topic'],
#                 values=analysis_df['N_Docs'],
#                 textinfo='label+percent',
#                 textposition='inside',
#                 hole=0.4,
#                 marker=dict(
#                     colors=px.colors.qualitative.Plotly,
#                     line=dict(color='white', width=2)
#                 ),
#                 hovertemplate=(
#                     '<b>Topic: %{label}</b><br>'
#                     'Documents: %{value}<br>'
#                     'Percentage: %{percent}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Document Distribution'
#             ),
#             row=2, col=2
#         )
    
#     # Update layout
#     fig.update_layout(
#         title=f'Interactive Topic Analysis Dashboard - {dataset_name}',
#         title_font_size=20,
#         showlegend=True,
#         legend=dict(
#             yanchor="top",
#             y=0.99,
#             xanchor="left",
#             x=1.02,
#             bgcolor='rgba(255, 255, 255, 0.8)',
#             bordercolor='black',
#             borderwidth=1
#         ),
#         hovermode='closest',
#         height=1000,
#         width=1400,
#         template='plotly_white'
#     )
    
#     # Update axes labels
#     fig.update_xaxes(title_text="Topic", row=1, col=1)
#     fig.update_yaxes(title_text="Composite Difficulty", row=1, col=1)
#     fig.update_xaxes(title_text="Topic Weight", row=1, col=2)
#     fig.update_yaxes(title_text="Composite Difficulty", row=1, col=2)
#     fig.update_xaxes(title_text="Topic", row=2, col=1)
#     fig.update_yaxes(title_text="Difficulty Score", row=2, col=1)
    
#     # Save dashboard
#     dashboard_path = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.html"
#     os.makedirs(os.path.dirname(dashboard_path), exist_ok=True)
#     fig.write_html(dashboard_path)
#     print(f"  Interactive dashboard saved to: {dashboard_path}")
    
#     # Also save static version if kaleido is available
#     try:
#         static_dashboard_path = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.png"
#         fig.write_image(static_dashboard_path, width=1400, height=1000)
#         print(f"  Static dashboard saved to: {static_dashboard_path}")
#     except Exception as e:
#         print(f"  Note: Could not save static image. Install kaleido: pip install kaleido")
    
#     return fig

# def create_topic_evolution_timeline(dataset_name, results, feature_names=None, n_samples_per_topic=100):
#     """
#     Create interactive timeline showing topic evolution and difficulty
    
#     Args:
#         dataset_name: Name of the dataset
#         results: Processing results
#         feature_names: Vocabulary for text datasets
#         n_samples_per_topic: Number of samples to show per topic
#     """
    
#     print(f"\nCreating topic evolution timeline for {dataset_name}...")
    
#     # Extract components
#     lda_model = results['lda_model']
#     topic_distributions = results['topic_distributions']
#     difficulty_scores = results['difficulty_scores']
    
#     # Get topic components
#     topic_vectors = lda_model.components_
#     n_topics = topic_vectors.shape[0]
    
#     # Create figure
#     fig = go.Figure()
    
#     # Get top words for each topic for hover text
#     topic_words = []
#     for i in range(n_topics):
#         if feature_names is not None and len(feature_names) > 0:
#             topic_vector = topic_vectors[i]
#             top_indices = np.argsort(topic_vector)[-5:][::-1]
#             words = []
#             for idx in top_indices:
#                 if idx < len(feature_names):
#                     if isinstance(feature_names[idx], (int, np.integer)):
#                         words.append(f"word_{idx}")
#                     else:
#                         words.append(str(feature_names[idx]))
#                 else:
#                     words.append(f"word_{idx}")
#             topic_words.append(words)
#         else:
#             topic_words.append([])
    
#     # For each topic, create a timeline trace
#     colors = px.colors.qualitative.Plotly
    
#     for i in range(n_topics):
#         # Get indices of documents where this topic is dominant
#         topic_doc_mask = np.argmax(topic_distributions, axis=1) == i
        
#         if np.sum(topic_doc_mask) > 0:
#             # Get subset of documents for this topic
#             topic_doc_indices = np.where(topic_doc_mask)[0]
            
#             if len(topic_doc_indices) > n_samples_per_topic:
#                 # Sample documents
#                 sampled_indices = np.random.choice(topic_doc_indices, n_samples_per_topic, replace=False)
#             else:
#                 sampled_indices = topic_doc_indices
            
#             # Get topic strengths and difficulties for these documents
#             topic_strengths = topic_distributions[sampled_indices, i]
#             doc_difficulties = difficulty_scores['composite'][sampled_indices]
            
#             # Sort by topic strength
#             sort_idx = np.argsort(topic_strengths)
#             topic_strengths = topic_strengths[sort_idx]
#             doc_difficulties = doc_difficulties[sort_idx]
            
#             # Create hover text
#             hover_texts = []
#             for j, idx in enumerate(sampled_indices[sort_idx]):
#                 hover_text = (
#                     f"<b>Topic {i}</b><br>"
#                     f"Document {idx}<br>"
#                     f"Topic Strength: {topic_strengths[j]:.3f}<br>"
#                     f"Difficulty: {doc_difficulties[j]:.3f}"
#                 )
#                 if topic_words[i]:
#                     hover_text += f"<br>Top words: {', '.join(topic_words[i])}"
#                 hover_texts.append(hover_text)
            
#             # Add trace for this topic
#             fig.add_trace(go.Scatter(
#                 x=topic_strengths,
#                 y=doc_difficulties,
#                 mode='markers',
#                 name=f'Topic {i}',
#                 text=hover_texts,
#                 hoverinfo='text',
#                 marker=dict(
#                     size=8,
#                     opacity=0.6,
#                     color=colors[i % len(colors)],
#                     line=dict(width=1, color='white')
#                 ),
#                 showlegend=True
#             ))
    
#     # Update layout
#     fig.update_layout(
#         title=f'Topic Evolution Timeline - {dataset_name}',
#         xaxis_title='Topic Strength (How dominant the topic is)',
#         yaxis_title='Document Difficulty (Composite)',
#         legend_title='Topics',
#         hovermode='closest',
#         height=800,
#         width=1200,
#         template='plotly_white'
#     )
    
#     # Add trend line for each topic
#     for i, trace in enumerate(fig.data):
#         topic_strengths = trace.x
#         difficulties = trace.y
        
#         if len(topic_strengths) > 1:
#             # Calculate trend line
#             z = np.polyfit(topic_strengths, difficulties, 1)
#             p = np.poly1d(z)
#             x_range = np.linspace(min(topic_strengths), max(topic_strengths), 100)
            
#             fig.add_trace(go.Scatter(
#                 x=x_range,
#                 y=p(x_range),
#                 mode='lines',
#                 line=dict(color=trace.marker.color, dash='dash', width=1),
#                 showlegend=False,
#                 hoverinfo='skip',
#                 name=f'Trend Topic {i}'
#             ))
    
#     # Save timeline
#     timeline_path = f"results/tmcl/topic_models/{dataset_name}_topic_timeline.html"
#     os.makedirs(os.path.dirname(timeline_path), exist_ok=True)
#     fig.write_html(timeline_path)
#     print(f"  Topic timeline saved to: {timeline_path}")
    
#     return fig

# def create_interactive_topic_comparison(all_analysis_results):
#     """
#     Create interactive comparison of all datasets
    
#     Args:
#         all_analysis_results: Dictionary containing analysis results for all datasets
#     """
    
#     print(f"\nCreating interactive comparison across all datasets...")
    
#     # Collect comparison data
#     comparison_data = []
    
#     for dataset_name, analysis in all_analysis_results.items():
#         if analysis is not None:
#             topic_analysis = analysis.get('topic_analysis', pd.DataFrame())
#             difficulty_stats = analysis.get('difficulty_stats', {})
#             network_stats = analysis.get('network_stats', {})
            
#             if not topic_analysis.empty:
#                 # Get difficulty range
#                 if 'Avg_Composite_Difficulty' in topic_analysis.columns:
#                     difficulty_range = topic_analysis['Avg_Composite_Difficulty'].max() - topic_analysis['Avg_Composite_Difficulty'].min()
#                 else:
#                     difficulty_range = 0
                
#                 comparison_data.append({
#                     'Dataset': dataset_name,
#                     'Type': 'Text' if dataset_name in ['ag_news', 'imdb'] else 'Image',
#                     'N_Topics': analysis.get('n_topics', 0),
#                     'N_Documents': analysis.get('n_documents', 0),
#                     'Perplexity': analysis.get('perplexity', 0),
#                     'Avg_Difficulty': difficulty_stats.get('overall_mean_composite', 0),
#                     'Std_Difficulty': difficulty_stats.get('overall_std_composite', 0),
#                     'Network_Density': network_stats.get('density', 0),
#                     'Difficulty_Range': difficulty_range
#                 })
    
#     if not comparison_data:
#         print("  No comparison data available")
#         return None
    
#     comparison_df = pd.DataFrame(comparison_data)
    
#     # Create interactive comparison dashboard
#     fig = make_subplots(
#         rows=2, cols=3,
#         subplot_titles=(
#             'Average Difficulty by Dataset',
#             'Number of Topics vs Documents',
#             'Network Density vs Difficulty',
#             'Perplexity Comparison',
#             'Difficulty Range Comparison',
#             'Dataset Clustering'
#         ),
#         specs=[
#             [{'type': 'bar'}, {'type': 'scatter'}, {'type': 'scatter'}],
#             [{'type': 'bar'}, {'type': 'bar'}, {'type': 'scatter'}]
#         ],
#         vertical_spacing=0.15,
#         horizontal_spacing=0.1
#     )
    
#     # Plot 1: Average difficulty by dataset
#     colors = px.colors.qualitative.Set2
#     for i, (_, row) in enumerate(comparison_df.iterrows()):
#         fig.add_trace(
#             go.Bar(
#                 x=[row['Dataset']],
#                 y=[row['Avg_Difficulty']],
#                 name=row['Dataset'],
#                 marker_color=colors[i % len(colors)],
#                 error_y=dict(
#                     type='data',
#                     array=[row['Std_Difficulty']],
#                     visible=True
#                 ),
#                 customdata=[[
#                     row['Type'],
#                     row['N_Topics'],
#                     row['N_Documents'],
#                     row['Perplexity']
#                 ]],
#                 hovertemplate=(
#                     '<b>%{x}</b><br>'
#                     'Type: %{customdata[0]}<br>'
#                     'Difficulty: %{y:.3f} ± %{error_y.array[0]:.3f}<br>'
#                     'Topics: %{customdata[1]}<br>'
#                     'Documents: %{customdata[2]:,}<br>'
#                     'Perplexity: %{customdata[3]:.2f}<br>'
#                     '<extra></extra>'
#                 ),
#                 showlegend=False
#             ),
#             row=1, col=1
#         )
    
#     # Plot 2: Number of topics vs documents
#     if len(comparison_df) > 0:
#         fig.add_trace(
#             go.Scatter(
#                 x=comparison_df['N_Topics'],
#                 y=comparison_df['N_Documents'] / 1000,  # Convert to thousands
#                 mode='markers+text',
#                 text=comparison_df['Dataset'],
#                 textposition='top center',
#                 marker=dict(
#                     size=comparison_df['Avg_Difficulty'] * 50,
#                     color=comparison_df['Avg_Difficulty'],
#                     colorscale='RdYlBu_r',
#                     showscale=True,
#                     colorbar=dict(title='Avg Difficulty', x=0.47, y=0.95, len=0.25),
#                     line=dict(width=2, color='black')
#                 ),
#                 customdata=comparison_df[['Type', 'Perplexity', 'Network_Density']].values,
#                 hovertemplate=(
#                     '<b>%{text}</b><br>'
#                     'Topics: %{x}<br>'
#                     'Documents (thousands): %{y:.1f}<br>'
#                     'Type: %{customdata[0]}<br>'
#                     'Perplexity: %{customdata[1]:.2f}<br>'
#                     'Network Density: %{customdata[2]:.3f}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Topics vs Documents'
#             ),
#             row=1, col=2
#         )
    
#     # Plot 3: Network density vs difficulty
#     if len(comparison_df) > 0:
#         fig.add_trace(
#             go.Scatter(
#                 x=comparison_df['Network_Density'],
#                 y=comparison_df['Avg_Difficulty'],
#                 mode='markers+text',
#                 text=comparison_df['Dataset'],
#                 textposition='top center',
#                 marker=dict(
#                     size=comparison_df['N_Topics'] * 10,
#                     color=comparison_df['Type'].map({'Text': 'blue', 'Image': 'red'}),
#                     symbol=comparison_df['Type'].map({'Text': 'circle', 'Image': 'square'}),
#                     line=dict(width=2, color='black')
#                 ),
#                 customdata=comparison_df[['N_Documents', 'Perplexity', 'Type']].values,
#                 hovertemplate=(
#                     '<b>%{text}</b><br>'
#                     'Network Density: %{x:.3f}<br>'
#                     'Avg Difficulty: %{y:.3f}<br>'
#                     'Type: %{customdata[2]}<br>'
#                     'Documents: %{customdata[0]:,}<br>'
#                     'Perplexity: %{customdata[1]:.2f}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Network vs Difficulty'
#             ),
#             row=1, col=3
#         )
    
#     # Plot 4: Perplexity comparison
#     if len(comparison_df) > 0:
#         fig.add_trace(
#             go.Bar(
#                 x=comparison_df['Dataset'],
#                 y=comparison_df['Perplexity'],
#                 marker_color=comparison_df['Type'].map({'Text': 'lightblue', 'Image': 'lightcoral'}),
#                 customdata=comparison_df[['Type', 'N_Topics', 'Avg_Difficulty']].values,
#                 hovertemplate=(
#                     '<b>%{x}</b><br>'
#                     'Perplexity: %{y:.2f}<br>'
#                     'Type: %{customdata[0]}<br>'
#                     'Topics: %{customdata[1]}<br>'
#                     'Avg Difficulty: %{customdata[2]:.3f}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Perplexity',
#                 showlegend=False
#             ),
#             row=2, col=1
#         )
    
#     # Plot 5: Difficulty range comparison - FIXED: removed colorscale from Bar trace
#     if len(comparison_df) > 0:
#         # Normalize difficulty range for coloring
#         if comparison_df['Difficulty_Range'].max() > comparison_df['Difficulty_Range'].min():
#             norm_range = (comparison_df['Difficulty_Range'] - comparison_df['Difficulty_Range'].min()) / \
#                         (comparison_df['Difficulty_Range'].max() - comparison_df['Difficulty_Range'].min())
#         else:
#             norm_range = pd.Series([0.5] * len(comparison_df))
        
#         # Use Plotly colorscale
#         colorscale = px.colors.sequential.Viridis
#         colors = [colorscale[int(val * (len(colorscale)-1))] for val in norm_range]
        
#         fig.add_trace(
#             go.Bar(
#                 x=comparison_df['Dataset'],
#                 y=comparison_df['Difficulty_Range'],
#                 marker_color=colors,
#                 customdata=comparison_df[['Type', 'Avg_Difficulty', 'Std_Difficulty']].values,
#                 hovertemplate=(
#                     '<b>%{x}</b><br>'
#                     'Difficulty Range: %{y:.3f}<br>'
#                     'Type: %{customdata[0]}<br>'
#                     'Avg Difficulty: %{customdata[1]:.3f}<br>'
#                     'Std Difficulty: %{customdata[2]:.3f}<br>'
#                     '<extra></extra>'
#                 ),
#                 name='Difficulty Range',
#                 showlegend=False
#             ),
#             row=2, col=2
#         )
    
#     # Plot 6: Dataset clustering (simplified)
#     if len(comparison_df) > 1:
#         try:
#             from sklearn.preprocessing import StandardScaler
#             from sklearn.decomposition import PCA
            
#             # Prepare data for PCA
#             cluster_data = comparison_df[['Avg_Difficulty', 'Std_Difficulty', 
#                                          'Network_Density', 'Difficulty_Range', 
#                                          'Perplexity']].values
#             cluster_data = StandardScaler().fit_transform(cluster_data)
            
#             # Apply PCA
#             pca = PCA(n_components=2)
#             pca_result = pca.fit_transform(cluster_data)
            
#             fig.add_trace(
#                 go.Scatter(
#                     x=pca_result[:, 0],
#                     y=pca_result[:, 1],
#                     mode='markers+text',
#                     text=comparison_df['Dataset'],
#                     textposition='top center',
#                     marker=dict(
#                         size=20,
#                         color=comparison_df['Type'].map({'Text': 'blue', 'Image': 'red'}),
#                         symbol=comparison_df['Type'].map({'Text': 'circle', 'Image': 'square'}),
#                         line=dict(width=2, color='black')
#                     ),
#                     customdata=comparison_df[['Type', 'N_Topics', 'N_Documents']].values,
#                     hovertemplate=(
#                         '<b>%{text}</b><br>'
#                         'PCA 1: %{x:.2f}<br>'
#                         'PCA 2: %{y:.2f}<br>'
#                         'Type: %{customdata[0]}<br>'
#                         'Topics: %{customdata[1]}<br>'
#                         'Documents: %{customdata[2]:,}<br>'
#                         '<extra></extra>'
#                     ),
#                     name='Dataset Clusters'
#                 ),
#                 row=2, col=3
#             )
            
#             # Add variance explained to axis labels
#             fig.update_xaxes(
#                 title_text=f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% variance)',
#                 row=2, col=3
#             )
#             fig.update_yaxes(
#                 title_text=f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% variance)',
#                 row=2, col=3
#             )
#         except Exception as e:
#             print(f"  Could not create clustering plot: {e}")
#             fig.add_annotation(
#                 text="Clustering not available",
#                 xref="x domain", yref="y domain",
#                 x=0.5, y=0.5,
#                 showarrow=False,
#                 font=dict(size=14),
#                 row=2, col=3
#             )
#     else:
#         fig.add_annotation(
#             text="Need at least 2 datasets for clustering",
#             xref="x domain", yref="y domain",
#             x=0.5, y=0.5,
#             showarrow=False,
#             font=dict(size=14),
#             row=2, col=3
#         )
    
#     # Update layout
#     fig.update_layout(
#         title='Interactive Dataset Comparison Dashboard',
#         title_font_size=20,
#         showlegend=False,
#         height=1000,
#         width=1600,
#         template='plotly_white'
#     )
    
#     # Update axes labels
#     fig.update_xaxes(title_text="Dataset", row=1, col=1)
#     fig.update_yaxes(title_text="Average Difficulty", row=1, col=1)
#     fig.update_xaxes(title_text="Number of Topics", row=1, col=2)
#     fig.update_yaxes(title_text="Documents (thousands)", row=1, col=2)
#     fig.update_xaxes(title_text="Network Density", row=1, col=3)
#     fig.update_yaxes(title_text="Average Difficulty", row=1, col=3)
#     fig.update_xaxes(title_text="Dataset", row=2, col=1)
#     fig.update_yaxes(title_text="Perplexity", row=2, col=1)
#     fig.update_xaxes(title_text="Dataset", row=2, col=2)
#     fig.update_yaxes(title_text="Difficulty Range", row=2, col=2)
    
#     # Save comparison dashboard
#     comparison_path = "results/tmcl/topic_models/interactive_dataset_comparison.html"
#     os.makedirs(os.path.dirname(comparison_path), exist_ok=True)
#     fig.write_html(comparison_path)
#     print(f"  Interactive comparison dashboard saved to: {comparison_path}")
    
#     return fig

# def create_all_interactive_visualizations():
#     """
#     Create all interactive visualizations for the analyzed datasets
#     """
    
#     print("\n" + "="*80)
#     print("CREATING INTERACTIVE TOPIC VISUALIZATIONS")
#     print("="*80)
    
#     # Try to import plotly for image export
#     try:
#         import plotly.io as pio
#         pio.kaleido.scope.default_format = "png"
#     except:
#         pass
    
#     # Load all results
#     datasets = ['ag_news', 'imdb', 'cifar10', 'cifar100', 'mnist', 'fashion_mnist']
#     all_analysis_results = {}
    
#     # Create visualizations for each dataset
#     for dataset_name in datasets:
#         print(f"\nProcessing {dataset_name}...")
        
#         # Load results
#         results_path = f"results/tmcl/topic_models/{dataset_name}_tmcl_results.pkl"
#         if os.path.exists(results_path):
#             try:
#                 with open(results_path, 'rb') as f:
#                     results = pickle.load(f)
                
#                 print(f"  Successfully loaded results from {results_path}")
                
#                 # Get feature names for text datasets
#                 feature_names = None
#                 if dataset_name in ['ag_news', 'imdb']:
#                     if 'vocabulary' in results:
#                         feature_names = results['vocabulary']
#                         print(f"  Using saved vocabulary of size {len(feature_names)}")
#                     elif 'vectorizer' in results:
#                         # Try to get vocabulary from vectorizer
#                         vectorizer = results['vectorizer']
#                         if hasattr(vectorizer, 'get_feature_names_out'):
#                             feature_names = vectorizer.get_feature_names_out()
#                             print(f"  Using vectorizer vocabulary of size {len(feature_names)}")
#                         elif hasattr(vectorizer, 'get_feature_names'):
#                             feature_names = vectorizer.get_feature_names()
#                             print(f"  Using vectorizer vocabulary of size {len(feature_names)}")
#                     else:
#                         # Try to load from a vocabulary file
#                         vocab_path = f"results/tmcl/topic_models/{dataset_name}_vocabulary.pkl"
#                         if os.path.exists(vocab_path):
#                             with open(vocab_path, 'rb') as f:
#                                 feature_names = pickle.load(f)
#                             print(f"  Loaded vocabulary from file of size {len(feature_names)}")
#                         else:
#                             # Create placeholder vocabulary
#                             if 'feature_matrix' in results:
#                                 n_features = results['feature_matrix'].shape[1]
#                             else:
#                                 n_features = 1000
#                             feature_names = [f'feature_{i}' for i in range(n_features)]
#                             print(f"  WARNING: Using placeholder feature names (n={n_features})")
                
#                 # 1. Create interactive topic network
#                 try:
#                     network_fig, network_graph = create_interactive_topic_network(
#                         dataset_name, results, feature_names
#                     )
#                     print(f"  ✓ Created interactive topic network")
#                 except Exception as e:
#                     print(f"  ✗ Error creating network: {e}")
                
#                 # 2. Create interactive dashboard
#                 try:
#                     dashboard_fig = create_interactive_topic_dashboard(
#                         dataset_name, results, feature_names
#                     )
#                     print(f"  ✓ Created interactive dashboard")
#                 except Exception as e:
#                     print(f"  ✗ Error creating dashboard: {e}")
                
#                 # 3. Create topic evolution timeline
#                 try:
#                     timeline_fig = create_topic_evolution_timeline(
#                         dataset_name, results, feature_names
#                     )
#                     print(f"  ✓ Created topic evolution timeline")
#                 except Exception as e:
#                     print(f"  ✗ Error creating timeline: {e}")
                
#                 # Load enhanced analysis results for comparison
#                 enhanced_path = f"results/tmcl/topic_models/{dataset_name}_enhanced_analysis.pkl"
#                 if os.path.exists(enhanced_path):
#                     try:
#                         with open(enhanced_path, 'rb') as f:
#                             all_analysis_results[dataset_name] = pickle.load(f)
#                         print(f"  ✓ Loaded enhanced analysis results")
#                     except:
#                         print(f"  ✗ Could not load enhanced analysis results")
#                 else:
#                     print(f"  Note: Enhanced analysis results not found at {enhanced_path}")
                    
#             except Exception as e:
#                 print(f"  ✗ Error loading results for {dataset_name}: {e}")
#         else:
#             print(f"  ✗ Results file not found for {dataset_name}")
    
#     # 4. Create interactive comparison across all datasets
#     if all_analysis_results:
#         try:
#             comparison_fig = create_interactive_topic_comparison(all_analysis_results)
#             if comparison_fig:
#                 print(f"  ✓ Created interactive comparison across datasets")
#         except Exception as e:
#             print(f"  ✗ Error creating comparison: {e}")
    
#     print("\n" + "="*80)
#     print("INTERACTIVE VISUALIZATIONS COMPLETE!")
#     print("="*80)
    
#     # Print summary of created files
#     print("\nCreated interactive visualizations:")
#     print("-" * 40)
#     for dataset_name in datasets:
#         print(f"\n{dataset_name}:")
#         network_file = f"results/tmcl/topic_models/{dataset_name}_interactive_network.html"
#         if os.path.exists(network_file):
#             print(f"  ✓ Network: {network_file}")
        
#         dashboard_file = f"results/tmcl/topic_models/{dataset_name}_interactive_dashboard.html"
#         if os.path.exists(dashboard_file):
#             print(f"  ✓ Dashboard: {dashboard_file}")
        
#         timeline_file = f"results/tmcl/topic_models/{dataset_name}_topic_timeline.html"
#         if os.path.exists(timeline_file):
#             print(f"  ✓ Timeline: {timeline_file}")
    
#     comparison_file = "results/tmcl/topic_models/interactive_dataset_comparison.html"
#     if os.path.exists(comparison_file):
#         print(f"\nCross-dataset comparison: {comparison_file}")
    
#     print("\n" + "="*80)
#     print("To view interactive visualizations:")
#     print("1. Open the HTML files in a web browser")
#     print("2. Hover over nodes/points to see details")
#     print("3. Use zoom and pan to explore the visualizations")
#     print("="*80)
    
#     # Additional instructions for text datasets
#     print("\nFor text datasets (ag_news, imdb):")
#     print("• Topic labels show T0, T1, etc. with top words")
#     print("• Hover over nodes to see all top words for each topic")
#     print("• Word indices are shown if actual words are not available")
#     print("="*80)

# # Execute the interactive visualizations
# if __name__ == "__main__":
#     create_all_interactive_visualizations()