In [20]:


import torch
import json
import numpy as np
import umap
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from pathlib import Path
import webbrowser
import os
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import pandas as pd
import hdbscan
from sklearn.metrics import silhouette_score

# Configure Plotly for browser display
import plotly.io as pio
pio.renderers.default = "browser"  # This will open plots in browser by default

# Set device and paths
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
base_path = Path("/Users/ivanculo/Desktop/Projects/turn_point")
activations_dir = base_path / "activations"

# Load data (same as test_single_pattern.py)
negative_activations = torch.load(activations_dir / "activations_8ff00d963316212d.pt", map_location=device)
positive_activations = torch.load(activations_dir / "activations_e5ad16e9b3c33c9b.pt", map_location=device)
transition_activations = torch.load(activations_dir / "activations_332f24de2a3f82ff.pt", map_location=device)

with open(base_path / "data" / "final" / "enriched_metadata.json", 'r') as f:
    metadata = json.load(f)

# Create pattern indices
pattern_indices = {}
for i, entry in enumerate(metadata):
    pattern_name = entry['bad_good_narratives_match']['cognitive_pattern_name_from_bad_good']
    if pattern_name not in pattern_indices:
        pattern_indices[pattern_name] = []
    pattern_indices[pattern_name].append(i)

layer = 17
first_pattern = list(pattern_indices.keys())[0]

def prepare_data_for_umap(neg_data, pos_data, trans_data, max_samples=1000, subsample_tokens=True):
    """Prepare data for UMAP analysis with aggressive sampling"""
    if subsample_tokens:
        # Take every 8th token and limit samples
        neg_flat = neg_data[:, ::8, :].reshape(-1, neg_data.shape[-1])
        pos_flat = pos_data[:, ::8, :].reshape(-1, pos_data.shape[-1])
        trans_flat = trans_data[:, ::8, :].reshape(-1, trans_data.shape[-1])
    else:
        neg_flat = neg_data.reshape(-1, neg_data.shape[-1])
        pos_flat = pos_data.reshape(-1, pos_data.shape[-1])
        trans_flat = trans_data.reshape(-1, trans_data.shape[-1])
    
    # Subsample to max_samples per category
    if len(neg_flat) > max_samples:
        indices = torch.randperm(len(neg_flat))[:max_samples]
        neg_flat = neg_flat[indices]
    if len(pos_flat) > max_samples:
        indices = torch.randperm(len(pos_flat))[:max_samples]
        pos_flat = pos_flat[indices]
    if len(trans_flat) > max_samples:
        indices = torch.randperm(len(trans_flat))[:max_samples]
        trans_flat = trans_flat[indices]
    
    # Combine data and create labels
    combined_data = torch.cat([neg_flat, pos_flat, trans_flat], dim=0).cpu().numpy()
    labels = ['Negative'] * len(neg_flat) + ['Positive'] * len(pos_flat) + ['Transition'] * len(trans_flat)
    colors = ['red'] * len(neg_flat) + ['green'] * len(pos_flat) + ['blue'] * len(trans_flat)
    
    return combined_data, labels, colors

def plot_umap_2d_3d(data, labels, colors, title_prefix=""):
    """Create interactive 2D and 3D UMAP plots with Plotly"""
    print(f"Computing UMAP for {len(data)} samples...")
    
    # 2D UMAP with faster parameters
    umap_2d = umap.UMAP(n_components=2, random_state=42, n_neighbors=min(15, len(data)//3), min_dist=0.1, n_jobs=1)
    embedding_2d = umap_2d.fit_transform(data)
    
    # 3D UMAP with faster parameters  
    umap_3d = umap.UMAP(n_components=3, random_state=42, n_neighbors=min(15, len(data)//3), min_dist=0.1, n_jobs=1)
    embedding_3d = umap_3d.fit_transform(data)
    
    # Create color mapping
    color_map = {'Negative': 'red', 'Positive': 'green', 'Transition': 'blue'}
    
    # Create 2D plot
    fig_2d = go.Figure()
    
    for label in ['Negative', 'Positive', 'Transition']:
        mask = [l == label for l in labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            fig_2d.add_trace(go.Scatter(
                x=embedding_2d[indices, 0],
                y=embedding_2d[indices, 1],
                mode='markers',
                marker=dict(
                    color=color_map[label],
                    size=4,
                    opacity=0.7
                ),
                name=label,
                hovertemplate=f'<b>{label}</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<extra></extra>'
            ))
    
    fig_2d.update_layout(
        title=f'{title_prefix}2D UMAP - Interactive Visualization',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        width=800,
        height=600,
        showlegend=True,
        hovermode='closest'
    )
    
    # Create 3D plot
    fig_3d = go.Figure()
    
    for label in ['Negative', 'Positive', 'Transition']:
        mask = [l == label for l in labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            fig_3d.add_trace(go.Scatter3d(
                x=embedding_3d[indices, 0],
                y=embedding_3d[indices, 1],
                z=embedding_3d[indices, 2],
                mode='markers',
                marker=dict(
                    color=color_map[label],
                    size=3,
                    opacity=0.7
                ),
                name=label,
                hovertemplate=f'<b>{label}</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<br>UMAP 3: %{{z:.2f}}<extra></extra>'
            ))
    
    fig_3d.update_layout(
        title=f'{title_prefix}3D UMAP - Interactive Visualization',
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ),
        width=800,
        height=600,
        showlegend=True
    )
    
    # Save and open plots in new browser tabs
    safe_title = title_prefix.replace(" ", "_").replace("-", "").strip("_")
    
    # Save 2D plot
    filename_2d = f"umap_2d_{safe_title}_{hash(str(embedding_2d.tolist())) % 10000}.html"
    fig_2d.write_html(filename_2d, auto_open=False)
    
    # Save 3D plot  
    filename_3d = f"umap_3d_{safe_title}_{hash(str(embedding_3d.tolist())) % 10000}.html"
    fig_3d.write_html(filename_3d, auto_open=False)
    
    # Open in new browser tabs
    abs_path_2d = os.path.abspath(filename_2d)
    abs_path_3d = os.path.abspath(filename_3d)
    
    print(f"Opening 2D plot: {filename_2d}")
    webbrowser.open(f'file://{abs_path_2d}', new=2)  # new=2 opens in new tab
    
    print(f"Opening 3D plot: {filename_3d}")  
    webbrowser.open(f'file://{abs_path_3d}', new=2)  # new=2 opens in new tab
    
    # Also display inline for Jupyter
    fig_2d.show()
    fig_3d.show()
    
    return embedding_2d, embedding_3d

def plot_depressive_only_umap(layer=17, max_samples=2000):
    """Visualize only depressive (negative) tokens across all cognitive patterns"""
    print(f"\n🔴 Depressive Dataset UMAP (all patterns, negative tokens only)")
    
    all_indices = [i for indices_list in pattern_indices.values() for i in indices_list]
    neg_all = negative_activations[f'negative_layer_{layer}'][all_indices]
    
    print(f"Depressive data shape: {neg_all.shape}")
    
    # Prepare data - only negative, but we'll create dummy pos/trans for consistency
    if max_samples and neg_all.shape[0] * neg_all.shape[1] > max_samples:
        # Subsample tokens more aggressively
        neg_flat = neg_all[:, ::12, :].reshape(-1, neg_all.shape[-1])
    else:
        neg_flat = neg_all.reshape(-1, neg_all.shape[-1])
    
    if len(neg_flat) > max_samples:
        indices = torch.randperm(len(neg_flat))[:max_samples]
        neg_flat = neg_flat[indices]
    
    # Create data with only negatives
    combined_data = neg_flat.cpu().numpy()
    labels = ['Depressive'] * len(neg_flat)
    colors = ['darkred'] * len(neg_flat)
    
    print(f"Combined depressive data shape: {combined_data.shape}")
    
    # Compute UMAP
    print(f"Computing UMAP for {len(combined_data)} depressive samples...")
    umap_2d = umap.UMAP(n_components=2, random_state=42, n_neighbors=min(15, len(combined_data)//3), min_dist=0.1, n_jobs=1)
    embedding_2d = umap_2d.fit_transform(combined_data)
    
    umap_3d = umap.UMAP(n_components=3, random_state=42, n_neighbors=min(15, len(combined_data)//3), min_dist=0.1, n_jobs=1)
    embedding_3d = umap_3d.fit_transform(combined_data)
    
    # Create 2D plot
    fig_2d = go.Figure()
    fig_2d.add_trace(go.Scatter(
        x=embedding_2d[:, 0],
        y=embedding_2d[:, 1],
        mode='markers',
        marker=dict(
            color='darkred',
            size=4,
            opacity=0.7
        ),
        name='Depressive',
        hovertemplate='<b>Depressive</b><br>UMAP 1: %{x:.2f}<br>UMAP 2: %{y:.2f}<extra></extra>'
    ))
    
    fig_2d.update_layout(
        title='Depressive Dataset Only - 2D UMAP',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        width=800,
        height=600,
        showlegend=True
    )
    
    # Create 3D plot
    fig_3d = go.Figure()
    fig_3d.add_trace(go.Scatter3d(
        x=embedding_3d[:, 0],
        y=embedding_3d[:, 1],
        z=embedding_3d[:, 2],
        mode='markers',
        marker=dict(
            color='darkred',
            size=3,
            opacity=0.7
        ),
        name='Depressive',
        hovertemplate='<b>Depressive</b><br>UMAP 1: %{x:.2f}<br>UMAP 2: %{y:.2f}<br>UMAP 3: %{z:.2f}<extra></extra>'
    ))
    
    fig_3d.update_layout(
        title='Depressive Dataset Only - 3D UMAP',
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ),
        width=800,
        height=600,
        showlegend=True
    )
    
    # Save and open plots
    filename_2d = f"umap_2d_depressive_only_{hash(str(embedding_2d.tolist())) % 10000}.html"
    filename_3d = f"umap_3d_depressive_only_{hash(str(embedding_3d.tolist())) % 10000}.html"
    
    fig_2d.write_html(filename_2d, auto_open=False)
    fig_3d.write_html(filename_3d, auto_open=False)
    
    print(f"Opening depressive 2D plot: {filename_2d}")
    webbrowser.open(f'file://{os.path.abspath(filename_2d)}', new=2)
    
    print(f"Opening depressive 3D plot: {filename_3d}")
    webbrowser.open(f'file://{os.path.abspath(filename_3d)}', new=2)
    
    fig_2d.show()
    fig_3d.show()
    
    return embedding_2d, embedding_3d

def plot_single_pattern_all_examples(pattern_name, layer=17, max_samples=1000):
    """Visualize all examples (neg, pos, trans) from one specific cognitive pattern"""
    print(f"\n🎯 Single Pattern UMAP: {pattern_name} (all examples)")
    
    if pattern_name not in pattern_indices:
        print(f"Pattern '{pattern_name}' not found. Available patterns:")
        for p in list(pattern_indices.keys())[:5]:
            print(f"  - {p}")
        return None, None
    
    indices = pattern_indices[pattern_name]
    neg_single = negative_activations[f'negative_layer_{layer}'][indices]
    pos_single = positive_activations[f'positive_layer_{layer}'][indices]
    trans_single = transition_activations[f'transition_layer_{layer}'][indices]
    
    print(f"Pattern data shapes - Neg: {neg_single.shape}, Pos: {pos_single.shape}, Trans: {trans_single.shape}")
    
    data_single, labels_single, colors_single = prepare_data_for_umap(neg_single, pos_single, trans_single, max_samples=max_samples)
    embedding_2d, embedding_3d = plot_umap_2d_3d(data_single, labels_single, colors_single, f"{pattern_name}_AllExamples_")
    
    return embedding_2d, embedding_3d

def plot_single_example(pattern_name, example_idx=0, layer=17):
    """Visualize just one example from one cognitive pattern (all its tokens)"""
    print(f"\n🔬 Single Example UMAP: {pattern_name} (example {example_idx})")
    
    if pattern_name not in pattern_indices:
        print(f"Pattern '{pattern_name}' not found. Available patterns:")
        for p in list(pattern_indices.keys())[:5]:
            print(f"  - {p}")
        return None, None
    
    indices = pattern_indices[pattern_name]
    if example_idx >= len(indices):
        print(f"Example index {example_idx} out of range. Pattern has {len(indices)} examples.")
        return None, None
    
    # Get single example
    single_idx = indices[example_idx]
    neg_example = negative_activations[f'negative_layer_{layer}'][single_idx:single_idx+1]
    pos_example = positive_activations[f'positive_layer_{layer}'][single_idx:single_idx+1]
    trans_example = transition_activations[f'transition_layer_{layer}'][single_idx:single_idx+1]
    
    print(f"Single example shapes - Neg: {neg_example.shape}, Pos: {pos_example.shape}, Trans: {trans_example.shape}")
    
    # Flatten to get all tokens from this single example
    neg_flat = neg_example.reshape(-1, neg_example.shape[-1])
    pos_flat = pos_example.reshape(-1, pos_example.shape[-1])
    trans_flat = trans_example.reshape(-1, trans_example.shape[-1])
    
    # Combine data
    combined_data = torch.cat([neg_flat, pos_flat, trans_flat], dim=0).cpu().numpy()
    labels = (['Negative'] * len(neg_flat) + 
              ['Positive'] * len(pos_flat) + 
              ['Transition'] * len(trans_flat))
    
    print(f"Combined single example data shape: {combined_data.shape}")
    
    if len(combined_data) < 10:
        print("⚠️  Very few tokens in this example - UMAP may not be meaningful")
    
    # Compute UMAP with adjusted parameters for small datasets
    n_neighbors = min(5, len(combined_data)//2) if len(combined_data) < 50 else min(15, len(combined_data)//3)
    
    print(f"Computing UMAP for {len(combined_data)} tokens from single example...")
    umap_2d = umap.UMAP(n_components=2, random_state=42, n_neighbors=n_neighbors, min_dist=0.1, n_jobs=1)
    embedding_2d = umap_2d.fit_transform(combined_data)
    
    umap_3d = umap.UMAP(n_components=3, random_state=42, n_neighbors=n_neighbors, min_dist=0.1, n_jobs=1)
    embedding_3d = umap_3d.fit_transform(combined_data)
    
    # Create color mapping
    color_map = {'Negative': 'red', 'Positive': 'green', 'Transition': 'blue'}
    
    # Create 2D plot
    fig_2d = go.Figure()
    for label in ['Negative', 'Positive', 'Transition']:
        mask = [l == label for l in labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            fig_2d.add_trace(go.Scatter(
                x=embedding_2d[indices, 0],
                y=embedding_2d[indices, 1],
                mode='markers',
                marker=dict(
                    color=color_map[label],
                    size=6,
                    opacity=0.8
                ),
                name=f'{label} (tokens)',
                hovertemplate=f'<b>{label} Token</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<extra></extra>'
            ))
    
    fig_2d.update_layout(
        title=f'Single Example: {pattern_name} (Example {example_idx}) - 2D UMAP',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        width=800,
        height=600,
        showlegend=True
    )
    
    # Create 3D plot
    fig_3d = go.Figure()
    for label in ['Negative', 'Positive', 'Transition']:
        mask = [l == label for l in labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            fig_3d.add_trace(go.Scatter3d(
                x=embedding_3d[indices, 0],
                y=embedding_3d[indices, 1],
                z=embedding_3d[indices, 2],
                mode='markers',
                marker=dict(
                    color=color_map[label],
                    size=4,
                    opacity=0.8
                ),
                name=f'{label} (tokens)',
                hovertemplate=f'<b>{label} Token</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<br>UMAP 3: %{{z:.2f}}<extra></extra>'
            ))
    
    fig_3d.update_layout(
        title=f'Single Example: {pattern_name} (Example {example_idx}) - 3D UMAP',
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ),
        width=800,
        height=600,
        showlegend=True
    )
    
    # Save and open plots
    safe_pattern = pattern_name.replace(" ", "_").replace("-", "_")
    filename_2d = f"umap_2d_single_example_{safe_pattern}_{example_idx}_{hash(str(embedding_2d.tolist())) % 10000}.html"
    filename_3d = f"umap_3d_single_example_{safe_pattern}_{example_idx}_{hash(str(embedding_3d.tolist())) % 10000}.html"
    
    fig_2d.write_html(filename_2d, auto_open=False)
    fig_3d.write_html(filename_3d, auto_open=False)
    
    print(f"Opening single example 2D plot: {filename_2d}")
    webbrowser.open(f'file://{os.path.abspath(filename_2d)}', new=2)
    
    print(f"Opening single example 3D plot: {filename_3d}")
    webbrowser.open(f'file://{os.path.abspath(filename_3d)}', new=2)
    
    fig_2d.show()
    fig_3d.show()
    
    return embedding_2d, embedding_3d

def perform_clustering_analysis(neg_data, pos_data, trans_data, layer=17, n_clusters=3, max_samples=1500):
    """
    Perform K-means clustering on each cognitive state separately and create UMAP visualization
    """
    print(f"\n🔬 CLUSTERING ANALYSIS (K-means with {n_clusters} clusters per state)")
    print("="*70)
    
    # Prepare data for each state
    def prepare_state_data(data, state_name, max_samples):
        # Subsample tokens if needed
        if data.shape[0] * data.shape[1] > max_samples:
            data_flat = data[:, ::8, :].reshape(-1, data.shape[-1])
        else:
            data_flat = data.reshape(-1, data.shape[-1])
        
        if len(data_flat) > max_samples:
            indices = torch.randperm(len(data_flat))[:max_samples]
            data_flat = data_flat[indices]
            
        return data_flat.cpu().numpy()
    
    # Prepare data for each cognitive state
    neg_flat = prepare_state_data(neg_data, "Negative", max_samples)
    pos_flat = prepare_state_data(pos_data, "Positive", max_samples)
    trans_flat = prepare_state_data(trans_data, "Transition", max_samples)
    
    print(f"Data shapes - Neg: {neg_flat.shape}, Pos: {pos_flat.shape}, Trans: {trans_flat.shape}")
    
    # Standardize data for better clustering
    scaler = StandardScaler()
    
    # Perform clustering on each state separately
    states_data = {
        'Negative': neg_flat,
        'Positive': pos_flat, 
        'Transition': trans_flat
    }
    
    clustered_data = {}
    cluster_results = {}
    
    for state_name, data in states_data.items():
        print(f"\n🎯 Clustering {state_name} state ({data.shape[0]} samples)...")
        
        # Standardize the data
        data_scaled = scaler.fit_transform(data)
        
        # Perform K-means clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(data_scaled)
        
        # Store results
        clustered_data[state_name] = {
            'data': data,
            'data_scaled': data_scaled,
            'cluster_labels': cluster_labels,
            'kmeans_model': kmeans
        }
        
        # Print cluster statistics
        unique_labels, counts = np.unique(cluster_labels, return_counts=True)
        print(f"  Cluster distribution: {dict(zip(unique_labels, counts))}")
        
        cluster_results[state_name] = {
            'cluster_labels': cluster_labels,
            'cluster_centers': kmeans.cluster_centers_,
            'inertia': kmeans.inertia_
        }
    
    # Compute UMAP on combined data but keep cluster info
    combined_data = np.vstack([neg_flat, pos_flat, trans_flat])
    combined_labels = (['Negative'] * len(neg_flat) + 
                      ['Positive'] * len(pos_flat) + 
                      ['Transition'] * len(trans_flat))
    
    # Create detailed cluster labels combining state and cluster
    detailed_labels = []
    cluster_colors = []
    
    # Color palettes for each state
    neg_colors = ['#8B0000', '#DC143C', '#B22222']  # Dark red shades
    pos_colors = ['#006400', '#228B22', '#32CD32']  # Dark green shades  
    trans_colors = ['#00008B', '#4169E1', '#1E90FF']  # Blue shades
    
    color_maps = {
        'Negative': neg_colors,
        'Positive': pos_colors,
        'Transition': trans_colors
    }
    
    # Build combined labels and colors
    for state_name in ['Negative', 'Positive', 'Transition']:
        state_clusters = clustered_data[state_name]['cluster_labels']
        state_colors = color_maps[state_name]
        
        for cluster_id in state_clusters:
            detailed_labels.append(f"{state_name}_Cluster_{cluster_id}")
            cluster_colors.append(state_colors[cluster_id % len(state_colors)])
    
    print(f"\n🗺️  Computing UMAP for {len(combined_data)} total samples...")
    
    # Compute UMAP
    umap_2d = umap.UMAP(n_components=2, random_state=42, 
                       n_neighbors=min(15, len(combined_data)//3), 
                       min_dist=0.1, n_jobs=1)
    embedding_2d = umap_2d.fit_transform(combined_data)
    
    umap_3d = umap.UMAP(n_components=3, random_state=42,
                       n_neighbors=min(15, len(combined_data)//3),
                       min_dist=0.1, n_jobs=1)
    embedding_3d = umap_3d.fit_transform(combined_data)
    
    return {
        'embedding_2d': embedding_2d,
        'embedding_3d': embedding_3d,
        'combined_data': combined_data,
        'detailed_labels': detailed_labels,
        'cluster_colors': cluster_colors,
        'clustered_data': clustered_data,
        'cluster_results': cluster_results,
        'state_labels': combined_labels
    }

def plot_clustered_umap(clustering_results, title_prefix="Clustered"):
    """
    Create UMAP plots showing cluster assignments for each cognitive state
    """
    embedding_2d = clustering_results['embedding_2d']
    embedding_3d = clustering_results['embedding_3d']
    detailed_labels = clustering_results['detailed_labels']
    cluster_colors = clustering_results['cluster_colors']
    state_labels = clustering_results['state_labels']
    
    print(f"\n🎨 Creating clustered UMAP visualizations...")
    
    # Create 2D plot
    fig_2d = go.Figure()
    
    # Group by detailed cluster labels for legend
    unique_detailed_labels = list(set(detailed_labels))
    
    for label in unique_detailed_labels:
        mask = [l == label for l in detailed_labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            
            # Extract state and cluster info for hover
            state = label.split('_')[0]
            cluster_id = label.split('_')[-1]
            
            fig_2d.add_trace(go.Scatter(
                x=embedding_2d[indices, 0],
                y=embedding_2d[indices, 1],
                mode='markers',
                marker=dict(
                    color=cluster_colors[indices[0]],  # Use the assigned color
                    size=4,
                    opacity=0.7,
                    line=dict(width=0.5, color='white')
                ),
                name=f"{state} C{cluster_id}",
                hovertemplate=f'<b>{state} Cluster {cluster_id}</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<extra></extra>'
            ))
    
    fig_2d.update_layout(
        title=f'{title_prefix} - 2D UMAP with K-means Clusters',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        width=900,
        height=700,
        showlegend=True,
        hovermode='closest',
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.02
        )
    )
    
    # Create 3D plot
    fig_3d = go.Figure()
    
    for label in unique_detailed_labels:
        mask = [l == label for l in detailed_labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            
            state = label.split('_')[0]
            cluster_id = label.split('_')[-1]
            
            fig_3d.add_trace(go.Scatter3d(
                x=embedding_3d[indices, 0],
                y=embedding_3d[indices, 1],
                z=embedding_3d[indices, 2],
                mode='markers',
                marker=dict(
                    color=cluster_colors[indices[0]],
                    size=3,
                    opacity=0.7,
                    line=dict(width=0.5, color='white')
                ),
                name=f"{state} C{cluster_id}",
                hovertemplate=f'<b>{state} Cluster {cluster_id}</b><br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<br>UMAP 3: %{{z:.2f}}<extra></extra>'
            ))
    
    fig_3d.update_layout(
        title=f'{title_prefix} - 3D UMAP with K-means Clusters',
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ),
        width=900,
        height=700,
        showlegend=True
    )
    
    # Save and open plots
    safe_title = title_prefix.replace(" ", "_").replace("-", "").strip("_")
    filename_2d = f"umap_2d_clustered_{safe_title}_{hash(str(embedding_2d.tolist())) % 10000}.html"
    filename_3d = f"umap_3d_clustered_{safe_title}_{hash(str(embedding_3d.tolist())) % 10000}.html"
    
    fig_2d.write_html(filename_2d, auto_open=False)
    fig_3d.write_html(filename_3d, auto_open=False)
    
    print(f"Opening clustered 2D plot: {filename_2d}")
    webbrowser.open(f'file://{os.path.abspath(filename_2d)}', new=2)
    
    print(f"Opening clustered 3D plot: {filename_3d}")
    webbrowser.open(f'file://{os.path.abspath(filename_3d)}', new=2)
    
    fig_2d.show()
    fig_3d.show()
    
    return fig_2d, fig_3d

def analyze_cluster_characteristics(clustering_results):
    """
    Analyze and print characteristics of each cluster
    """
    print(f"\n📊 CLUSTER ANALYSIS SUMMARY")
    print("="*50)
    
    clustered_data = clustering_results['clustered_data']
    cluster_results = clustering_results['cluster_results']
    
    for state_name, data_info in clustered_data.items():
        print(f"\n🔍 {state_name.upper()} STATE CLUSTERS:")
        
        cluster_labels = data_info['cluster_labels']
        unique_labels, counts = np.unique(cluster_labels, return_counts=True)
        
        total_samples = len(cluster_labels)
        
        for cluster_id, count in zip(unique_labels, counts):
            percentage = (count / total_samples) * 100
            print(f"  Cluster {cluster_id}: {count:4d} samples ({percentage:5.1f}%)")
        
        print(f"  Total inertia (within-cluster sum of squares): {cluster_results[state_name]['inertia']:.2f}")
    
    print(f"\n💡 INTERPRETATION GUIDE:")
    print("• Each cognitive state (Negative/Positive/Transition) has 3 distinct clusters")
    print("• Clusters represent different 'subtypes' or 'patterns' within each state")
    print("• Lower inertia = more compact, well-separated clusters")
    print("• In UMAP plots, look for:")
    print("  - Tight clusters = consistent activation patterns")  
    print("  - Scattered points = diverse activation patterns")
    print("  - Cluster separation = how distinct the subtypes are")

def perform_hdbscan_analysis(neg_data, pos_data, trans_data, layer=17, min_cluster_size=50, max_samples=1500):
    """
    Perform HDBSCAN clustering on each cognitive state separately
    """
    print(f"\n🌳 HDBSCAN ANALYSIS (Hierarchical density-based clustering)")
    print("="*70)
    print(f"Parameters: min_cluster_size={min_cluster_size}, max_samples={max_samples}")
    
    # Prepare data for each state (same as K-means)
    def prepare_state_data(data, state_name, max_samples):
        if data.shape[0] * data.shape[1] > max_samples:
            data_flat = data[:, ::8, :].reshape(-1, data.shape[-1])
        else:
            data_flat = data.reshape(-1, data.shape[-1])
        
        if len(data_flat) > max_samples:
            indices = torch.randperm(len(data_flat))[:max_samples]
            data_flat = data_flat[indices]
            
        return data_flat.cpu().numpy()
    
    # Prepare data for each cognitive state
    neg_flat = prepare_state_data(neg_data, "Negative", max_samples)
    pos_flat = prepare_state_data(pos_data, "Positive", max_samples)
    trans_flat = prepare_state_data(trans_data, "Transition", max_samples)
    
    print(f"Data shapes - Neg: {neg_flat.shape}, Pos: {pos_flat.shape}, Trans: {trans_flat.shape}")
    
    # Standardize data
    scaler = StandardScaler()
    
    # Perform HDBSCAN on each state separately
    states_data = {
        'Negative': neg_flat,
        'Positive': pos_flat, 
        'Transition': trans_flat
    }
    
    clustered_data = {}
    cluster_results = {}
    
    for state_name, data in states_data.items():
        print(f"\n🎯 HDBSCAN clustering {state_name} state ({data.shape[0]} samples)...")
        
        # Standardize the data
        data_scaled = scaler.fit_transform(data)
        
        # Perform HDBSCAN clustering
        hdbscan_model = hdbscan.HDBSCAN(
            min_cluster_size=min_cluster_size,
            min_samples=10,
            cluster_selection_epsilon=0.0,
            metric='euclidean',
            cluster_selection_method='eom'  # Excess of Mass
        )
        
        cluster_labels = hdbscan_model.fit_predict(data_scaled)
        
        # Store results
        clustered_data[state_name] = {
            'data': data,
            'data_scaled': data_scaled,
            'cluster_labels': cluster_labels,
            'hdbscan_model': hdbscan_model
        }
        
        # Print cluster statistics
        unique_labels, counts = np.unique(cluster_labels, return_counts=True)
        n_clusters = len(unique_labels) - (1 if -1 in unique_labels else 0)
        n_noise = counts[unique_labels == -1][0] if -1 in unique_labels else 0
        
        print(f"  Found {n_clusters} clusters")
        print(f"  Noise points: {n_noise} ({n_noise/len(cluster_labels)*100:.1f}%)")
        
        for cluster_id, count in zip(unique_labels, counts):
            if cluster_id == -1:
                print(f"  Noise: {count:4d} samples ({count/len(cluster_labels)*100:5.1f}%)")
            else:
                print(f"  Cluster {cluster_id}: {count:4d} samples ({count/len(cluster_labels)*100:5.1f}%)")
        
        # Calculate silhouette score (excluding noise points)
        if n_clusters > 1:
            non_noise_mask = cluster_labels != -1
            if np.sum(non_noise_mask) > 1:
                silhouette_avg = silhouette_score(data_scaled[non_noise_mask], 
                                                cluster_labels[non_noise_mask])
                print(f"  Silhouette score: {silhouette_avg:.3f}")
        
        cluster_results[state_name] = {
            'cluster_labels': cluster_labels,
            'n_clusters': n_clusters,
            'n_noise': n_noise,
            'cluster_probabilities': hdbscan_model.probabilities_ if hasattr(hdbscan_model, 'probabilities_') else None
        }
    
    # Compute UMAP on combined data
    combined_data = np.vstack([neg_flat, pos_flat, trans_flat])
    combined_labels = (['Negative'] * len(neg_flat) + 
                      ['Positive'] * len(pos_flat) + 
                      ['Transition'] * len(trans_flat))
    
    # Create detailed cluster labels and colors for HDBSCAN
    detailed_labels = []
    cluster_colors = []
    
    # Generate colors dynamically based on number of clusters found
    def generate_colors(n_clusters, base_color):
        """Generate n_clusters different shades of base_color"""
        import colorsys
        colors = []
        if base_color == 'red':
            base_hue = 0.0
        elif base_color == 'green':
            base_hue = 0.33
        else:  # blue
            base_hue = 0.67
        
        for i in range(n_clusters):
            # Vary saturation and lightness
            sat = 0.7 + 0.3 * (i / max(1, n_clusters-1))
            light = 0.3 + 0.4 * (i / max(1, n_clusters-1))
            rgb = colorsys.hsv_to_rgb(base_hue, sat, light)
            colors.append(f'rgb({int(rgb[0]*255)},{int(rgb[1]*255)},{int(rgb[2]*255)})')
        
        # Add noise color (gray)
        colors.append('#808080')
        return colors
    
    # Build combined labels and colors
    for state_name, base_color in [('Negative', 'red'), ('Positive', 'green'), ('Transition', 'blue')]:
        state_clusters = clustered_data[state_name]['cluster_labels']
        n_clusters = cluster_results[state_name]['n_clusters']
        state_colors = generate_colors(n_clusters, base_color)
        
        for cluster_id in state_clusters:
            if cluster_id == -1:
                detailed_labels.append(f"{state_name}_Noise")
                cluster_colors.append(state_colors[-1])  # Gray for noise
            else:
                detailed_labels.append(f"{state_name}_Cluster_{cluster_id}")
                cluster_colors.append(state_colors[cluster_id % len(state_colors)])
    
    print(f"\n🗺️  Computing UMAP for {len(combined_data)} total samples...")
    
    # Compute UMAP
    umap_2d = umap.UMAP(n_components=2, random_state=42, 
                       n_neighbors=min(15, len(combined_data)//3), 
                       min_dist=0.1, n_jobs=1)
    embedding_2d = umap_2d.fit_transform(combined_data)
    
    umap_3d = umap.UMAP(n_components=3, random_state=42,
                       n_neighbors=min(15, len(combined_data)//3),
                       min_dist=0.1, n_jobs=1)
    embedding_3d = umap_3d.fit_transform(combined_data)
    
    return {
        'embedding_2d': embedding_2d,
        'embedding_3d': embedding_3d,
        'combined_data': combined_data,
        'detailed_labels': detailed_labels,
        'cluster_colors': cluster_colors,
        'clustered_data': clustered_data,
        'cluster_results': cluster_results,
        'state_labels': combined_labels,
        'algorithm': 'HDBSCAN'
    }

def plot_hdbscan_umap(clustering_results, title_prefix="HDBSCAN_Clustered"):
    """
    Create UMAP plots showing HDBSCAN cluster assignments
    """
    embedding_2d = clustering_results['embedding_2d']
    embedding_3d = clustering_results['embedding_3d']
    detailed_labels = clustering_results['detailed_labels']
    cluster_colors = clustering_results['cluster_colors']
    
    print(f"\n🎨 Creating HDBSCAN UMAP visualizations...")
    
    # Create 2D plot
    fig_2d = go.Figure()
    
    # Group by detailed cluster labels for legend
    unique_detailed_labels = list(set(detailed_labels))
    # Sort to put noise at the end
    unique_detailed_labels.sort(key=lambda x: (x.endswith('_Noise'), x))
    
    for label in unique_detailed_labels:
        mask = [l == label for l in detailed_labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            
            # Extract state and cluster info for hover
            if label.endswith('_Noise'):
                state = label.split('_')[0]
                display_name = f"{state} Noise"
                hover_text = f"<b>{state} Noise Point</b>"
                marker_size = 3
                opacity = 0.4
            else:
                state = label.split('_')[0]
                cluster_id = label.split('_')[-1]
                display_name = f"{state} C{cluster_id}"
                hover_text = f"<b>{state} Cluster {cluster_id}</b>"
                marker_size = 4
                opacity = 0.7
            
            fig_2d.add_trace(go.Scatter(
                x=embedding_2d[indices, 0],
                y=embedding_2d[indices, 1],
                mode='markers',
                marker=dict(
                    color=cluster_colors[indices[0]],
                    size=marker_size,
                    opacity=opacity,
                    line=dict(width=0.5, color='white')
                ),
                name=display_name,
                hovertemplate=f'{hover_text}<br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<extra></extra>'
            ))
    
    fig_2d.update_layout(
        title=f'{title_prefix} - 2D UMAP with HDBSCAN Clusters',
        xaxis_title='UMAP 1',
        yaxis_title='UMAP 2',
        width=900,
        height=700,
        showlegend=True,
        hovermode='closest',
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.02
        )
    )
    
    # Create 3D plot
    fig_3d = go.Figure()
    
    for label in unique_detailed_labels:
        mask = [l == label for l in detailed_labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            
            if label.endswith('_Noise'):
                state = label.split('_')[0]
                display_name = f"{state} Noise"
                hover_text = f"<b>{state} Noise Point</b>"
                marker_size = 2
                opacity = 0.4
            else:
                state = label.split('_')[0]
                cluster_id = label.split('_')[-1]
                display_name = f"{state} C{cluster_id}"
                hover_text = f"<b>{state} Cluster {cluster_id}</b>"
                marker_size = 3
                opacity = 0.7
            
            fig_3d.add_trace(go.Scatter3d(
                x=embedding_3d[indices, 0],
                y=embedding_3d[indices, 1],
                z=embedding_3d[indices, 2],
                mode='markers',
                marker=dict(
                    color=cluster_colors[indices[0]],
                    size=marker_size,
                    opacity=opacity,
                    line=dict(width=0.5, color='white')
                ),
                name=display_name,
                hovertemplate=f'{hover_text}<br>UMAP 1: %{{x:.2f}}<br>UMAP 2: %{{y:.2f}}<br>UMAP 3: %{{z:.2f}}<extra></extra>'
            ))
    
    fig_3d.update_layout(
        title=f'{title_prefix} - 3D UMAP with HDBSCAN Clusters',
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3'
        ),
        width=900,
        height=700,
        showlegend=True
    )
    
    # Save and open plots
    safe_title = title_prefix.replace(" ", "_").replace("-", "").strip("_")
    filename_2d = f"umap_2d_hdbscan_{safe_title}_{hash(str(embedding_2d.tolist())) % 10000}.html"
    filename_3d = f"umap_3d_hdbscan_{safe_title}_{hash(str(embedding_3d.tolist())) % 10000}.html"
    
    fig_2d.write_html(filename_2d, auto_open=False)
    fig_3d.write_html(filename_3d, auto_open=False)
    
    print(f"Opening HDBSCAN 2D plot: {filename_2d}")
    webbrowser.open(f'file://{os.path.abspath(filename_2d)}', new=2)
    
    print(f"Opening HDBSCAN 3D plot: {filename_3d}")
    webbrowser.open(f'file://{os.path.abspath(filename_3d)}', new=2)
    
    fig_2d.show()
    fig_3d.show()
    
    return fig_2d, fig_3d

def analyze_hdbscan_characteristics(clustering_results):
    """
    Analyze and print characteristics of HDBSCAN clusters
    """
    print(f"\n📊 HDBSCAN CLUSTER ANALYSIS SUMMARY")
    print("="*50)
    
    clustered_data = clustering_results['clustered_data']
    cluster_results = clustering_results['cluster_results']
    
    for state_name, data_info in clustered_data.items():
        print(f"\n🔍 {state_name.upper()} STATE HDBSCAN CLUSTERS:")
        
        cluster_labels = data_info['cluster_labels']
        n_clusters = cluster_results[state_name]['n_clusters']
        n_noise = cluster_results[state_name]['n_noise']
        
        print(f"  Found {n_clusters} natural clusters")
        print(f"  Noise points: {n_noise} ({n_noise/len(cluster_labels)*100:.1f}%)")
        
        unique_labels, counts = np.unique(cluster_labels, return_counts=True)
        total_samples = len(cluster_labels)
        
        for cluster_id, count in zip(unique_labels, counts):
            percentage = (count / total_samples) * 100
            if cluster_id == -1:
                print(f"    Noise: {count:4d} samples ({percentage:5.1f}%)")
            else:
                print(f"    Cluster {cluster_id}: {count:4d} samples ({percentage:5.1f}%)")
    
    print(f"\n💡 HDBSCAN INTERPRETATION GUIDE:")
    print("• HDBSCAN finds natural clusters without pre-specifying the number")
    print("• Noise points = outliers that don't belong to any cluster")
    print("• Varying cluster sizes = different cognitive patterns have different prevalence")
    print("• Hierarchical structure = clusters can have sub-clusters")
    print("• In UMAP plots, look for:")
    print("  - Dense clusters = strong, consistent patterns")
    print("  - Noise points (gray) = unique or transitional states")
    print("  - Cluster boundaries = natural separations in the data")

print("🔍 UMAP Analysis of Cognitive Transformations")
print("=" * 50)


🔍 UMAP Analysis of Cognitive Transformations


In [5]:
# 1. Full dataset UMAP (all patterns, all tokens subsampled)
print("\n1. Full Dataset UMAP (all patterns)")
all_indices = [i for indices_list in pattern_indices.values() for i in indices_list]
neg_all = negative_activations[f'negative_layer_{layer}'][all_indices]
pos_all = positive_activations[f'positive_layer_{layer}'][all_indices]
trans_all = transition_activations[f'transition_layer_{layer}'][all_indices]

print(f"Full data shapes - Neg: {neg_all.shape}, Pos: {pos_all.shape}, Trans: {trans_all.shape}")

data_all, labels_all, colors_all = prepare_data_for_umap(neg_all, pos_all, trans_all)
print(f"Combined data shape: {data_all.shape}")

embedding_2d_all, embedding_3d_all = plot_umap_2d_3d(data_all, labels_all, colors_all, "Full Dataset - ")



1. Full Dataset UMAP (all patterns)
Full data shapes - Neg: torch.Size([520, 208, 2304]), Pos: torch.Size([520, 261, 2304]), Trans: torch.Size([520, 311, 2304])
Combined data shape: (3000, 2304)
Computing UMAP for 3000 samples...
Opening 2D plot: umap_2d_Full_Dataset_9021.html
Opening 3D plot: umap_3d_Full_Dataset_7055.html


In [6]:
# 2. Single cognitive pattern UMAP
print(f"\n2. Single Pattern UMAP: {first_pattern}")
indices = pattern_indices[first_pattern]
neg_single = negative_activations[f'negative_layer_{layer}'][indices]
pos_single = positive_activations[f'positive_layer_{layer}'][indices]
trans_single = transition_activations[f'transition_layer_{layer}'][indices]

print(f"Single pattern data shapes - Neg: {neg_single.shape}, Pos: {pos_single.shape}, Trans: {trans_single.shape}")

data_single, labels_single, colors_single = prepare_data_for_umap(neg_single, pos_single, trans_single)
embedding_2d_single, embedding_3d_single = plot_umap_2d_3d(data_single, labels_single, colors_single, f"{first_pattern} - ")



2. Single Pattern UMAP: Executive Fatigue & Avolition
Single pattern data shapes - Neg: torch.Size([40, 208, 2304]), Pos: torch.Size([40, 261, 2304]), Trans: torch.Size([40, 311, 2304])
Computing UMAP for 3000 samples...
Opening 2D plot: umap_2d_Executive_Fatigue_&_Avolition_3466.html
Opening 3D plot: umap_3d_Executive_Fatigue_&_Avolition_2827.html


In [7]:
# 3. Last token only UMAP (single pattern)
print(f"\n3. Last Token Only UMAP: {first_pattern}")
neg_last = neg_single[:, -1, :].unsqueeze(1)  # Keep token dimension
pos_last = pos_single[:, -1, :].unsqueeze(1)
trans_last = trans_single[:, -1, :].unsqueeze(1)

print(f"Last token shapes - Neg: {neg_last.shape}, Pos: {pos_last.shape}, Trans: {trans_last.shape}")

data_last, labels_last, colors_last = prepare_data_for_umap(neg_last, pos_last, trans_last, subsample_tokens=False)
embedding_2d_last, embedding_3d_last = plot_umap_2d_3d(data_last, labels_last, colors_last, "Last Token Only - ")

print("\n✅ UMAP Analysis Complete!")
print(f"📊 Analyzed {len(all_indices)} total samples across {len(pattern_indices)} cognitive patterns")
print(f"🎯 Single pattern '{first_pattern}' had {len(indices)} samples")
print(f"🔬 Layer {layer} activations visualized in 2D and 3D space")


3. Last Token Only UMAP: Executive Fatigue & Avolition
Last token shapes - Neg: torch.Size([40, 1, 2304]), Pos: torch.Size([40, 1, 2304]), Trans: torch.Size([40, 1, 2304])
Computing UMAP for 120 samples...
Opening 2D plot: umap_2d_Last_Token_Only_7531.html
Opening 3D plot: umap_3d_Last_Token_Only_3571.html

✅ UMAP Analysis Complete!
📊 Analyzed 520 total samples across 13 cognitive patterns
🎯 Single pattern 'Executive Fatigue & Avolition' had 40 samples
🔬 Layer 17 activations visualized in 2D and 3D space


In [None]:
# 4. Depressive Dataset Only UMAP (all cognitive patterns, negative tokens only)
print("\n" + "="*60)
print("🔴 FILTERED VISUALIZATIONS")
print("="*60)

# Show available patterns first
print(f"\n📋 Available Cognitive Patterns ({len(pattern_indices)} total):")
for i, pattern in enumerate(list(pattern_indices.keys())[:8]):  # Show first 8
    count = len(pattern_indices[pattern])
    print(f"  {i+1}. {pattern} ({count} examples)")
if len(pattern_indices) > 8:
    print(f"  ... and {len(pattern_indices) - 8} more patterns")

# Visualize only depressive (negative) tokens across all patterns
embedding_2d_depressive, embedding_3d_depressive = plot_depressive_only_umap(layer=layer, max_samples=2000)


In [10]:
# 5. Single Cognitive Pattern UMAP (all examples: negative, positive, transition)
selected_pattern = list(pattern_indices.keys())[0]  # Use first pattern
print(f"\n🎯 Analyzing single pattern: '{selected_pattern}'")

# You can change this to any pattern name from the list above
# selected_pattern = "catastrophizing"  # Example: uncomment and change to analyze a specific pattern

embedding_2d_pattern, embedding_3d_pattern = plot_single_pattern_all_examples(
    pattern_name=selected_pattern, 
    layer=layer, 
    max_samples=1000
)



🎯 Analyzing single pattern: 'Executive Fatigue & Avolition'

🎯 Single Pattern UMAP: Executive Fatigue & Avolition (all examples)
Pattern data shapes - Neg: torch.Size([40, 208, 2304]), Pos: torch.Size([40, 261, 2304]), Trans: torch.Size([40, 311, 2304])
Computing UMAP for 3000 samples...
Opening 2D plot: umap_2d_Executive_Fatigue_&_Avolition_AllExamples_7854.html
Opening 3D plot: umap_3d_Executive_Fatigue_&_Avolition_AllExamples_4865.html


In [11]:
# 6. Single Example UMAP (just one example from one pattern - all its tokens)
example_idx = 0  # First example
print(f"\n🔬 Analyzing single example from '{selected_pattern}' (example {example_idx})")

# This shows how tokens flow through the transformation within a single narrative
# You can change example_idx to analyze different examples (0, 1, 2, etc.)

embedding_2d_example, embedding_3d_example = plot_single_example(
    pattern_name=selected_pattern,
    example_idx=example_idx,
    layer=layer
)



🔬 Analyzing single example from 'Executive Fatigue & Avolition' (example 0)

🔬 Single Example UMAP: Executive Fatigue & Avolition (example 0)
Single example shapes - Neg: torch.Size([1, 208, 2304]), Pos: torch.Size([1, 261, 2304]), Trans: torch.Size([1, 311, 2304])
Combined single example data shape: (780, 2304)
Computing UMAP for 780 tokens from single example...
Opening single example 2D plot: umap_2d_single_example_Executive_Fatigue_&_Avolition_0_5260.html
Opening single example 3D plot: umap_3d_single_example_Executive_Fatigue_&_Avolition_0_9102.html


In [8]:
# 7. Summary and Analysis Guide
print("\n" + "="*60)
print("📊 ANALYSIS COMPLETE - INTERPRETATION GUIDE")
print("="*60)

print(f"""
🔍 **What You've Visualized:**

1. **Full Dataset UMAP** - Overview of all cognitive patterns
   - Shows general clustering of negative vs positive vs transition states
   - Reveals global structure across all {len(pattern_indices)} cognitive patterns

2. **Depressive Dataset Only** - Pure negative emotional states  
   - {embedding_2d_depressive.shape[0] if 'embedding_2d_depressive' in locals() else 'N/A'} depressive tokens visualized
   - Shows internal structure of negative thought patterns
   - Look for clusters that might represent different types of depressive thinking

3. **Single Pattern Analysis** - Focus on '{selected_pattern}'
   - All examples of this specific cognitive pattern
   - Shows how this pattern manifests across negative → positive → transition

4. **Single Example Analysis** - Micro-level view
   - Individual narrative token progression  
   - Shows the transformation journey within one story

🎯 **How to Use These Visualizations:**

• **Clusters** = Similar activation patterns (similar "thoughts")
• **Distance** = How different the neural representations are  
• **Transitions** = Look for paths between negative and positive regions
• **Outliers** = Unusual or unique activation patterns

🔧 **Customization Options:**
""")

print("# To analyze a different pattern:")
print("# selected_pattern = 'your_pattern_name_here'")
print("# embedding_2d, embedding_3d = plot_single_pattern_all_examples(selected_pattern)")
print()
print("# To analyze a different example:")  
print("# embedding_2d, embedding_3d = plot_single_example('pattern_name', example_idx=5)")
print()
print("# To focus on a different layer:")
print("# plot_depressive_only_umap(layer=15)  # Try layers 10-20")

print(f"\n✅ All visualizations saved as HTML files and opened in browser tabs!")
print(f"📁 Check your current directory for the generated .html files")



📊 ANALYSIS COMPLETE - INTERPRETATION GUIDE


NameError: name 'selected_pattern' is not defined

In [13]:
# 8. K-MEANS CLUSTERING ANALYSIS
print("\n" + "="*70)
print("🔬 CLUSTERING ANALYSIS - Finding Subtypes Within Each Cognitive State")
print("="*70)

# Perform clustering analysis on full dataset
all_indices = [i for indices_list in pattern_indices.values() for i in indices_list]
neg_all_cluster = negative_activations[f'negative_layer_{layer}'][all_indices]
pos_all_cluster = positive_activations[f'positive_layer_{layer}'][all_indices]
trans_all_cluster = transition_activations[f'transition_layer_{layer}'][all_indices]

print(f"🔍 Performing K-means clustering (3 clusters per cognitive state)")
print(f"Data shapes - Neg: {neg_all_cluster.shape}, Pos: {pos_all_cluster.shape}, Trans: {trans_all_cluster.shape}")

# Perform clustering analysis
clustering_results = perform_clustering_analysis(
    neg_all_cluster, 
    pos_all_cluster, 
    trans_all_cluster,
    layer=layer,
    n_clusters=3,
    max_samples=1500
)



🔬 CLUSTERING ANALYSIS - Finding Subtypes Within Each Cognitive State
🔍 Performing K-means clustering (3 clusters per cognitive state)
Data shapes - Neg: torch.Size([520, 208, 2304]), Pos: torch.Size([520, 261, 2304]), Trans: torch.Size([520, 311, 2304])

🔬 CLUSTERING ANALYSIS (K-means with 3 clusters per state)
Data shapes - Neg: (1500, 2304), Pos: (1500, 2304), Trans: (1500, 2304)

🎯 Clustering Negative state (1500 samples)...
  Cluster distribution: {0: 1079, 1: 58, 2: 363}

🎯 Clustering Positive state (1500 samples)...
  Cluster distribution: {0: 946, 1: 534, 2: 20}

🎯 Clustering Transition state (1500 samples)...
  Cluster distribution: {0: 394, 1: 1104, 2: 2}

🗺️  Computing UMAP for 4500 total samples...


In [14]:
# 9. Visualize Clustered Results
print("\n🎨 Creating interactive UMAP plots with cluster assignments...")

# Create clustered UMAP plots
fig_2d_clustered, fig_3d_clustered = plot_clustered_umap(
    clustering_results, 
    title_prefix="Full_Dataset_Clustered"
)

# Analyze cluster characteristics
analyze_cluster_characteristics(clustering_results)



🎨 Creating interactive UMAP plots with cluster assignments...

🎨 Creating clustered UMAP visualizations...
Opening clustered 2D plot: umap_2d_clustered_Full_Dataset_Clustered_5680.html
Opening clustered 3D plot: umap_3d_clustered_Full_Dataset_Clustered_4791.html

📊 CLUSTER ANALYSIS SUMMARY

🔍 NEGATIVE STATE CLUSTERS:
  Cluster 0: 1079 samples ( 71.9%)
  Cluster 1:   58 samples (  3.9%)
  Cluster 2:  363 samples ( 24.2%)
  Total inertia (within-cluster sum of squares): 3086096.75

🔍 POSITIVE STATE CLUSTERS:
  Cluster 0:  946 samples ( 63.1%)
  Cluster 1:  534 samples ( 35.6%)
  Cluster 2:   20 samples (  1.3%)
  Total inertia (within-cluster sum of squares): 3213893.00

🔍 TRANSITION STATE CLUSTERS:
  Cluster 0:  394 samples ( 26.3%)
  Cluster 1: 1104 samples ( 73.6%)
  Cluster 2:    2 samples (  0.1%)
  Total inertia (within-cluster sum of squares): 3274455.00

💡 INTERPRETATION GUIDE:
• Each cognitive state (Negative/Positive/Transition) has 3 distinct clusters
• Clusters represent dif

In [15]:
# 10. Single Pattern Clustering Analysis
print("\n" + "="*60)
print("🎯 SINGLE PATTERN CLUSTERING ANALYSIS")
print("="*60)

# Analyze clustering within a specific cognitive pattern
selected_pattern_cluster = list(pattern_indices.keys())[0]
print(f"Analyzing clustering within pattern: '{selected_pattern_cluster}'")

indices_cluster = pattern_indices[selected_pattern_cluster]
neg_single_cluster = negative_activations[f'negative_layer_{layer}'][indices_cluster]
pos_single_cluster = positive_activations[f'positive_layer_{layer}'][indices_cluster]
trans_single_cluster = transition_activations[f'transition_layer_{layer}'][indices_cluster]

print(f"Single pattern data shapes - Neg: {neg_single_cluster.shape}, Pos: {pos_single_cluster.shape}, Trans: {trans_single_cluster.shape}")

# Perform clustering on single pattern
single_pattern_clustering = perform_clustering_analysis(
    neg_single_cluster,
    pos_single_cluster, 
    trans_single_cluster,
    layer=layer,
    n_clusters=3,
    max_samples=800
)

# Visualize single pattern clustering
fig_2d_single_clustered, fig_3d_single_clustered = plot_clustered_umap(
    single_pattern_clustering,
    title_prefix=f"{selected_pattern_cluster}_Clustered"
)

# Analyze single pattern clusters
analyze_cluster_characteristics(single_pattern_clustering)



🎯 SINGLE PATTERN CLUSTERING ANALYSIS
Analyzing clustering within pattern: 'Executive Fatigue & Avolition'
Single pattern data shapes - Neg: torch.Size([40, 208, 2304]), Pos: torch.Size([40, 261, 2304]), Trans: torch.Size([40, 311, 2304])

🔬 CLUSTERING ANALYSIS (K-means with 3 clusters per state)
Data shapes - Neg: (800, 2304), Pos: (800, 2304), Trans: (800, 2304)

🎯 Clustering Negative state (800 samples)...
  Cluster distribution: {0: 16, 1: 752, 2: 32}

🎯 Clustering Positive state (800 samples)...
  Cluster distribution: {0: 159, 1: 538, 2: 103}

🎯 Clustering Transition state (800 samples)...
  Cluster distribution: {0: 776, 1: 23, 2: 1}

🗺️  Computing UMAP for 2400 total samples...

🎨 Creating clustered UMAP visualizations...
Opening clustered 2D plot: umap_2d_clustered_Executive_Fatigue_&_Avolition_Clustered_3477.html
Opening clustered 3D plot: umap_3d_clustered_Executive_Fatigue_&_Avolition_Clustered_3949.html

📊 CLUSTER ANALYSIS SUMMARY

🔍 NEGATIVE STATE CLUSTERS:
  Cluster 0:  

In [None]:
# 11. Final Summary - Clustering Analysis
print("\n" + "="*70)
print("📊 CLUSTERING ANALYSIS COMPLETE - RESEARCH INSIGHTS")
print("="*70)

print(f"""
🔬 **What the Clustering Analysis Reveals:**

1. **Subtype Discovery**: Each cognitive state (Negative, Positive, Transition) contains 3 distinct subtypes
   - These represent different neural "flavors" or patterns within each emotional state
   - E.g., different types of negative thinking, different positive emotions, different transition mechanisms

2. **Cluster Characteristics**:
   - **Tight clusters** = Consistent, well-defined activation patterns
   - **Scattered clusters** = More variable activation patterns  
   - **Separated clusters** = Distinct neural mechanisms

3. **Research Applications**:
   - **Therapeutic targeting**: Different subtypes might need different interventions
   - **Individual differences**: People might predominantly use certain cluster patterns
   - **Transition pathways**: Understand how specific negative subtypes transform to specific positive subtypes

🎨 **Visualization Features**:
• **Color coding**: Each state has 3 shades (dark→light) for its 3 clusters
• **Interactive plots**: Hover to see which cluster each point belongs to
• **Legend**: "Negative C0", "Positive C1", etc. for easy identification
• **Browser tabs**: All plots automatically open for detailed exploration

🔧 **Customization Options:**
""")

print("# To change number of clusters:")
print("# clustering_results = perform_clustering_analysis(neg_data, pos_data, trans_data, n_clusters=5)")
print()
print("# To analyze different layer:")
print("# clustering_results = perform_clustering_analysis(neg_data, pos_data, trans_data, layer=15)")
print()
print("# To cluster only one cognitive state:")
print("# from sklearn.cluster import KMeans")
print("# neg_flat = neg_data.reshape(-1, neg_data.shape[-1]).cpu().numpy()")
print("# kmeans = KMeans(n_clusters=3).fit(neg_flat)")
print("# labels = kmeans.labels_")

print(f"\n🎯 **Next Steps for Research:**")
print("• Examine what makes each cluster unique (feature analysis)")
print("• Study transition patterns between specific clusters")
print("• Correlate clusters with therapeutic outcomes")
print("• Investigate individual differences in cluster usage")

print(f"\n✅ Clustering analysis complete! Check browser tabs for interactive exploration.")
print(f"📁 HTML files saved in current directory for future reference.")


In [21]:
# 12. HDBSCAN CLUSTERING ANALYSIS
print("\n" + "="*70)
print("🌳 HDBSCAN ANALYSIS - Hierarchical Density-Based Clustering")
print("="*70)

print("""
🔍 **HDBSCAN vs K-means Comparison:**
• **K-means**: Forces exactly 3 clusters per state (pre-specified)
• **HDBSCAN**: Discovers natural clusters automatically (data-driven)
• **HDBSCAN**: Identifies outliers as "noise points"
• **HDBSCAN**: Can find clusters of varying shapes and densities
""")

# Perform HDBSCAN analysis on full dataset
print(f"🔍 Performing HDBSCAN clustering (automatic cluster discovery)")
print(f"Data shapes - Neg: {neg_all_cluster.shape}, Pos: {pos_all_cluster.shape}, Trans: {trans_all_cluster.shape}")

# Perform HDBSCAN analysis
hdbscan_results = perform_hdbscan_analysis(
    neg_all_cluster, 
    pos_all_cluster, 
    trans_all_cluster,
    layer=layer,
    min_cluster_size=50,  # Minimum points needed to form a cluster
    max_samples=1500
)



🌳 HDBSCAN ANALYSIS - Hierarchical Density-Based Clustering

🔍 **HDBSCAN vs K-means Comparison:**
• **K-means**: Forces exactly 3 clusters per state (pre-specified)
• **HDBSCAN**: Discovers natural clusters automatically (data-driven)
• **HDBSCAN**: Identifies outliers as "noise points"
• **HDBSCAN**: Can find clusters of varying shapes and densities

🔍 Performing HDBSCAN clustering (automatic cluster discovery)
Data shapes - Neg: torch.Size([520, 208, 2304]), Pos: torch.Size([520, 261, 2304]), Trans: torch.Size([520, 311, 2304])

🌳 HDBSCAN ANALYSIS (Hierarchical density-based clustering)
Parameters: min_cluster_size=50, max_samples=1500
Data shapes - Neg: (1500, 2304), Pos: (1500, 2304), Trans: (1500, 2304)

🎯 HDBSCAN clustering Negative state (1500 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 2 clusters
  Noise points: 384 (25.6%)
  Noise:  384 samples ( 25.6%)
  Cluster 0: 1049 samples ( 69.9%)
  Cluster 1:   67 samples (  4.5%)
  Silhouette score: 1.000

🎯 HDBSCAN clustering Positive state (1500 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 0 clusters
  Noise points: 1500 (100.0%)
  Noise: 1500 samples (100.0%)

🎯 HDBSCAN clustering Transition state (1500 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 0 clusters
  Noise points: 1500 (100.0%)
  Noise: 1500 samples (100.0%)

🗺️  Computing UMAP for 4500 total samples...


In [18]:
# 13. Visualize HDBSCAN Results
print("\n🎨 Creating HDBSCAN UMAP plots with noise point detection...")

# Create HDBSCAN UMAP plots
fig_2d_hdbscan, fig_3d_hdbscan = plot_hdbscan_umap(
    hdbscan_results, 
    title_prefix="Full_Dataset_HDBSCAN"
)

# Analyze HDBSCAN cluster characteristics
analyze_hdbscan_characteristics(hdbscan_results)



🎨 Creating HDBSCAN UMAP plots with noise point detection...

🎨 Creating HDBSCAN UMAP visualizations...
Opening HDBSCAN 2D plot: umap_2d_hdbscan_Full_Dataset_HDBSCAN_9059.html
Opening HDBSCAN 3D plot: umap_3d_hdbscan_Full_Dataset_HDBSCAN_4004.html

📊 HDBSCAN CLUSTER ANALYSIS SUMMARY

🔍 NEGATIVE STATE HDBSCAN CLUSTERS:
  Found 2 natural clusters
  Noise points: 399 (26.6%)
    Noise:  399 samples ( 26.6%)
    Cluster 0: 1048 samples ( 69.9%)
    Cluster 1:   53 samples (  3.5%)

🔍 POSITIVE STATE HDBSCAN CLUSTERS:
  Found 0 natural clusters
  Noise points: 1500 (100.0%)
    Noise: 1500 samples (100.0%)

🔍 TRANSITION STATE HDBSCAN CLUSTERS:
  Found 0 natural clusters
  Noise points: 1500 (100.0%)
    Noise: 1500 samples (100.0%)

💡 HDBSCAN INTERPRETATION GUIDE:
• HDBSCAN finds natural clusters without pre-specifying the number
• Noise points = outliers that don't belong to any cluster
• Varying cluster sizes = different cognitive patterns have different prevalence
• Hierarchical structure

In [23]:
# 14. Single Pattern HDBSCAN Analysis
print("\n" + "="*60)
print("🎯 SINGLE PATTERN HDBSCAN ANALYSIS")
print("="*60)

# Analyze HDBSCAN clustering within a specific cognitive pattern
print(f"Analyzing HDBSCAN clustering within pattern: '{selected_pattern_cluster}'")

# Perform HDBSCAN on single pattern
single_pattern_hdbscan = perform_hdbscan_analysis(
    neg_single_cluster,
    pos_single_cluster, 
    trans_single_cluster,
    layer=layer,
    min_cluster_size=60,  # Smaller min_cluster_size for single pattern
    max_samples=2000
)

# Visualize single pattern HDBSCAN
fig_2d_single_hdbscan, fig_3d_single_hdbscan = plot_hdbscan_umap(
    single_pattern_hdbscan,
    title_prefix=f"{selected_pattern_cluster}_HDBSCAN"
)

# Analyze single pattern HDBSCAN clusters
analyze_hdbscan_characteristics(single_pattern_hdbscan)



🎯 SINGLE PATTERN HDBSCAN ANALYSIS
Analyzing HDBSCAN clustering within pattern: 'Executive Fatigue & Avolition'

🌳 HDBSCAN ANALYSIS (Hierarchical density-based clustering)
Parameters: min_cluster_size=60, max_samples=2000
Data shapes - Neg: (1040, 2304), Pos: (1320, 2304), Trans: (1560, 2304)

🎯 HDBSCAN clustering Negative state (1040 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 0 clusters
  Noise points: 1040 (100.0%)
  Noise: 1040 samples (100.0%)

🎯 HDBSCAN clustering Positive state (1320 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 0 clusters
  Noise points: 1320 (100.0%)
  Noise: 1320 samples (100.0%)

🎯 HDBSCAN clustering Transition state (1560 samples)...



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.


'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



  Found 0 clusters
  Noise points: 1560 (100.0%)
  Noise: 1560 samples (100.0%)

🗺️  Computing UMAP for 3920 total samples...

🎨 Creating HDBSCAN UMAP visualizations...
Opening HDBSCAN 2D plot: umap_2d_hdbscan_Executive_Fatigue_&_Avolition_HDBSCAN_7835.html
Opening HDBSCAN 3D plot: umap_3d_hdbscan_Executive_Fatigue_&_Avolition_HDBSCAN_7776.html

📊 HDBSCAN CLUSTER ANALYSIS SUMMARY

🔍 NEGATIVE STATE HDBSCAN CLUSTERS:
  Found 0 natural clusters
  Noise points: 1040 (100.0%)
    Noise: 1040 samples (100.0%)

🔍 POSITIVE STATE HDBSCAN CLUSTERS:
  Found 0 natural clusters
  Noise points: 1320 (100.0%)
    Noise: 1320 samples (100.0%)

🔍 TRANSITION STATE HDBSCAN CLUSTERS:
  Found 0 natural clusters
  Noise points: 1560 (100.0%)
    Noise: 1560 samples (100.0%)

💡 HDBSCAN INTERPRETATION GUIDE:
• HDBSCAN finds natural clusters without pre-specifying the number
• Noise points = outliers that don't belong to any cluster
• Varying cluster sizes = different cognitive patterns have different prevalen

In [None]:
# 15. K-means vs HDBSCAN Comparison
print("\n" + "="*70)
print("⚖️  CLUSTERING ALGORITHM COMPARISON")
print("="*70)

def compare_clustering_results(kmeans_results, hdbscan_results):
    """Compare K-means and HDBSCAN clustering results"""
    
    print(f"""
🔬 **ALGORITHM COMPARISON SUMMARY:**

📊 **K-MEANS RESULTS:**""")
    
    for state_name in ['Negative', 'Positive', 'Transition']:
        kmeans_labels = kmeans_results['clustered_data'][state_name]['cluster_labels']
        unique_labels, counts = np.unique(kmeans_labels, return_counts=True)
        print(f"  {state_name}: {len(unique_labels)} clusters, sizes: {counts}")
    
    print(f"""
🌳 **HDBSCAN RESULTS:**""")
    
    for state_name in ['Negative', 'Positive', 'Transition']:
        hdbscan_labels = hdbscan_results['clustered_data'][state_name]['cluster_labels']
        n_clusters = hdbscan_results['cluster_results'][state_name]['n_clusters']
        n_noise = hdbscan_results['cluster_results'][state_name]['n_noise']
        unique_labels, counts = np.unique(hdbscan_labels, return_counts=True)
        cluster_counts = counts[unique_labels != -1] if -1 in unique_labels else counts
        print(f"  {state_name}: {n_clusters} clusters, {n_noise} noise points")
        print(f"    Cluster sizes: {cluster_counts}")
    
    print(f"""
💡 **KEY DIFFERENCES:**

🎯 **Cluster Discovery:**
• K-means: Fixed 3 clusters per state (forced partitioning)
• HDBSCAN: Variable clusters per state (natural discovery)

🔍 **Outlier Handling:**
• K-means: All points assigned to clusters (even outliers)
• HDBSCAN: Outliers identified as noise points (more realistic)

📏 **Cluster Shapes:**
• K-means: Assumes spherical clusters
• HDBSCAN: Finds clusters of any shape and varying density

🎨 **Visualization Differences:**
• K-means: Uniform 3-color scheme per state
• HDBSCAN: Dynamic colors + gray noise points

🔬 **Research Implications:**
• K-means: Good for comparing fixed number of subtypes
• HDBSCAN: Better for discovering natural cognitive structure
• Noise points may represent transitional or unique mental states
• Different algorithms may reveal different aspects of cognition
""")

# Perform the comparison
compare_clustering_results(clustering_results, hdbscan_results)


In [None]:
# 16. Final Summary - Complete Clustering Analysis
print("\n" + "="*80)
print("🎯 COMPLETE CLUSTERING ANALYSIS SUMMARY")
print("="*80)

print(f"""
🔬 **What You've Discovered:**

1. **Multiple Clustering Perspectives**:
   - **K-means**: Structured 3-cluster analysis per cognitive state
   - **HDBSCAN**: Natural cluster discovery with outlier detection

2. **Cognitive State Structure**:
   - **Negative states**: Multiple subtypes of depressive thinking
   - **Positive states**: Different flavors of positive emotions  
   - **Transition states**: Various transformation mechanisms
   - **Noise points**: Unique or transitional mental states

3. **Research Insights**:
   - Not all cognitive patterns fit into neat categories
   - Some mental states are truly unique (noise points)
   - Natural clustering reveals the true structure of cognition
   - Different algorithms highlight different aspects

🎨 **Visualization Arsenal Created:**
• Original UMAP plots (cognitive state overview)
• Filtered visualizations (depressive-only, single patterns, single examples)  
• K-means clustered plots (structured subtype analysis)
• HDBSCAN clustered plots (natural cluster discovery)
• All plots interactive and saved as HTML files

🔧 **Customization Examples:**
""")

print("# Try different HDBSCAN parameters:")
print("# hdbscan_results = perform_hdbscan_analysis(neg_data, pos_data, trans_data, min_cluster_size=30)")
print()
print("# Compare different layers:")
print("# layer_15_results = perform_hdbscan_analysis(neg_data, pos_data, trans_data, layer=15)")
print()
print("# Focus on specific cognitive states:")
print("# neg_only_hdbscan = hdbscan.HDBSCAN(min_cluster_size=50).fit_predict(neg_flat)")

print(f"""
🎯 **Next Research Directions:**
• Investigate what makes noise points unique
• Study cluster transitions across layers
• Correlate clusters with clinical outcomes
• Examine individual differences in cluster membership
• Use clusters to guide therapeutic interventions

🏆 **Achievement Unlocked:**
✅ Multi-algorithm clustering analysis complete!
✅ Natural cognitive structure discovered!
✅ Outlier detection implemented!
✅ Interactive visualizations created!
✅ Research insights generated!

📁 Check your directory for all the generated HTML visualization files!
🌐 All plots are now open in your browser for detailed exploration!
""")

print(f"🎉 Total analysis complete! You now have comprehensive insights into the")
print(f"   neural structure of cognitive transformations using multiple clustering approaches.")
