# Interactive Exploratory Data Analysis for Lattice Gaussian MCMC

This notebook provides interactive tools for exploring experimental results from lattice Gaussian sampling experiments.

## Contents
1. **Data Loading and Preprocessing**
2. **Visual Inspection Tools**
3. **Diagnostic Analysis**
4. **Rapid Prototyping**
5. **Documentation and Insights**
6. **Export and Reproducibility**

In [ ]:
# Standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import pickle
from datetime import datetime

# Interactive widgets
import ipywidgets as widgets
from IPython.display import display, HTML, Markdown

# Plotly for interactive plots
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Project imports
import sys
sys.path.append('..')
from src.visualization.plots import PlottingTools
from src.diagnostics.convergence import ConvergenceDiagnostics
from src.diagnostics.spectral import SpectralAnalysis
from src.diagnostics.mcmc import MCMCDiagnostics

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print(f"Notebook initialized at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## 1. Data Loading and Preprocessing

The `DataLoader` class provides flexible data loading with automatic detection of file formats and experiment types.

In [ ]:
class DataLoader:
    """Interactive data loader for experimental results."""
    
    def __init__(self, base_path='../results'):
        self.base_path = Path(base_path)
        self.data_cache = {}
        
    def list_experiments(self):
        """List all available experiment results."""
        experiments = []
        for exp_type in ['samples', 'diagnostics', 'logs']:
            exp_dir = self.base_path / exp_type
            if exp_dir.exists():
                for file in exp_dir.glob('**/*'):
                    if file.is_file():
                        experiments.append({
                            'type': exp_type,
                            'path': str(file.relative_to(self.base_path)),
                            'size': f"{file.stat().st_size / 1024:.1f} KB",
                            'modified': datetime.fromtimestamp(file.stat().st_mtime).strftime('%Y-%m-%d %H:%M')
                        })
        return pd.DataFrame(experiments)
    
    def load_samples(self, path):
        """Load sample data from various formats."""
        full_path = self.base_path / path
        
        if path in self.data_cache:
            return self.data_cache[path]
        
        if full_path.suffix == '.npy':
            data = np.load(full_path)
        elif full_path.suffix == '.npz':
            data = dict(np.load(full_path))
        elif full_path.suffix == '.pkl':
            with open(full_path, 'rb') as f:
                data = pickle.load(f)
        elif full_path.suffix == '.json':
            with open(full_path, 'r') as f:
                data = json.load(f)
        else:
            raise ValueError(f"Unsupported file format: {full_path.suffix}")
        
        self.data_cache[path] = data
        return data
    
    def create_interactive_loader(self):
        """Create interactive file browser widget."""
        experiments_df = self.list_experiments()
        
        # Filter widgets
        type_filter = widgets.Dropdown(
            options=['All'] + list(experiments_df['type'].unique()),
            value='All',
            description='Type:'
        )
        
        search_box = widgets.Text(
            placeholder='Search experiments...',
            description='Search:'
        )
        
        # Output area
        output = widgets.Output()
        file_info = widgets.HTML()
        
        def update_list(*args):
            with output:
                output.clear_output()
                
                # Apply filters
                df = experiments_df.copy()
                if type_filter.value != 'All':
                    df = df[df['type'] == type_filter.value]
                if search_box.value:
                    mask = df['path'].str.contains(search_box.value, case=False)
                    df = df[mask]
                
                display(df)
        
        def load_selected():
            """Load selected file."""
            selection = widgets.Text(
                placeholder='Enter path from above',
                description='Path:'
            )
            load_btn = widgets.Button(description='Load')
            
            def on_load(b):
                try:
                    data = self.load_samples(selection.value)
                    file_info.value = f"<b>Loaded:</b> {selection.value}<br>"
                    file_info.value += f"<b>Type:</b> {type(data).__name__}<br>"
                    if isinstance(data, np.ndarray):
                        file_info.value += f"<b>Shape:</b> {data.shape}<br>"
                        file_info.value += f"<b>Dtype:</b> {data.dtype}"
                    elif isinstance(data, dict):
                        file_info.value += f"<b>Keys:</b> {list(data.keys())}"
                except Exception as e:
                    file_info.value = f"<b style='color:red'>Error:</b> {str(e)}"
            
            load_btn.on_click(on_load)
            display(widgets.HBox([selection, load_btn]))
        
        # Set up event handlers
        type_filter.observe(update_list, 'value')
        search_box.observe(update_list, 'value')
        
        # Initial display
        update_list()
        
        # Layout
        display(widgets.VBox([
            widgets.HBox([type_filter, search_box]),
            output,
            widgets.Button(description='Load Selected', on_click=lambda b: load_selected()),
            file_info
        ]))
        
        return self

# Initialize data loader
loader = DataLoader()
loader.create_interactive_loader()

## 2. Visual Inspection Tools

The `InteractiveVisualizer` provides dynamic exploration of MCMC chains, distributions, and convergence behavior.

In [ ]:
class InteractiveVisualizer:
    """Interactive visualization tools for MCMC analysis."""
    
    def __init__(self, samples=None):
        self.samples = samples
        self.current_plot = None
        
    def set_samples(self, samples):
        """Update the samples for visualization."""
        self.samples = samples
        if isinstance(samples, dict):
            print(f"Loaded samples with keys: {list(samples.keys())}")
        elif isinstance(samples, np.ndarray):
            print(f"Loaded array with shape: {samples.shape}")
    
    def trace_plot_interactive(self):
        """Interactive trace plot with dimension selection."""
        if self.samples is None:
            print("No samples loaded. Use set_samples() first.")
            return
        
        # Handle different sample formats
        if isinstance(self.samples, dict):
            chains = self.samples.get('samples', self.samples.get('chain', None))
        else:
            chains = self.samples
        
        if chains is None:
            print("Could not find sample chains in the data.")
            return
        
        # Ensure 3D shape: (n_chains, n_samples, n_dims)
        if chains.ndim == 2:
            chains = chains[np.newaxis, :]
        elif chains.ndim == 1:
            chains = chains[np.newaxis, :, np.newaxis]
        
        n_chains, n_samples, n_dims = chains.shape
        
        # Interactive controls
        dim_slider = widgets.IntSlider(
            value=0, min=0, max=n_dims-1,
            description='Dimension:', continuous_update=False
        )
        
        chain_select = widgets.SelectMultiple(
            options=[(f'Chain {i}', i) for i in range(n_chains)],
            value=list(range(min(4, n_chains))),
            description='Chains:'
        )
        
        burnin_slider = widgets.IntSlider(
            value=0, min=0, max=n_samples//2,
            description='Burn-in:', continuous_update=False
        )
        
        plot_output = widgets.Output()
        
        def update_plot(*args):
            with plot_output:
                plot_output.clear_output(wait=True)
                
                fig = make_subplots(
                    rows=2, cols=2,
                    subplot_titles=('Trace Plot', 'Running Mean', 
                                  'Autocorrelation', 'Density'),
                    specs=[[{'secondary_y': False}, {'secondary_y': False}],
                           [{'secondary_y': False}, {'secondary_y': False}]]
                )
                
                dim = dim_slider.value
                burnin = burnin_slider.value
                
                colors = px.colors.qualitative.Plotly
                
                for i, chain_idx in enumerate(chain_select.value):
                    color = colors[i % len(colors)]
                    chain_data = chains[chain_idx, burnin:, dim]
                    
                    # Trace plot
                    fig.add_trace(
                        go.Scatter(y=chain_data, mode='lines', 
                                 name=f'Chain {chain_idx}',
                                 line=dict(color=color, width=1),
                                 opacity=0.7),
                        row=1, col=1
                    )
                    
                    # Running mean
                    running_mean = pd.Series(chain_data).expanding().mean()
                    fig.add_trace(
                        go.Scatter(y=running_mean, mode='lines',
                                 name=f'Chain {chain_idx}',
                                 line=dict(color=color, width=2),
                                 showlegend=False),
                        row=1, col=2
                    )
                    
                    # Autocorrelation
                    max_lag = min(100, len(chain_data) // 4)
                    acf = [pd.Series(chain_data).autocorr(lag=k) 
                           for k in range(max_lag)]
                    fig.add_trace(
                        go.Scatter(y=acf, mode='lines',
                                 name=f'Chain {chain_idx}',
                                 line=dict(color=color, width=2),
                                 showlegend=False),
                        row=2, col=1
                    )
                    
                    # Density
                    fig.add_trace(
                        go.Histogram(x=chain_data, name=f'Chain {chain_idx}',
                                   opacity=0.6, histnorm='probability density',
                                   marker_color=color, showlegend=False),
                        row=2, col=2
                    )
                
                # Update layout
                fig.update_xaxes(title_text="Iteration", row=1, col=1)
                fig.update_xaxes(title_text="Iteration", row=1, col=2)
                fig.update_xaxes(title_text="Lag", row=2, col=1)
                fig.update_xaxes(title_text="Value", row=2, col=2)
                
                fig.update_yaxes(title_text="Value", row=1, col=1)
                fig.update_yaxes(title_text="Running Mean", row=1, col=2)
                fig.update_yaxes(title_text="ACF", row=2, col=1)
                fig.update_yaxes(title_text="Density", row=2, col=2)
                
                fig.update_layout(
                    height=800,
                    title_text=f"MCMC Diagnostics - Dimension {dim}",
                    showlegend=True
                )
                
                fig.show()
        
        # Connect controls
        dim_slider.observe(update_plot, 'value')
        chain_select.observe(update_plot, 'value')
        burnin_slider.observe(update_plot, 'value')
        
        # Initial plot
        update_plot()
        
        # Display
        display(widgets.VBox([
            widgets.HBox([dim_slider, burnin_slider]),
            chain_select,
            plot_output
        ]))
    
    def pairwise_scatter(self):
        """Interactive pairwise scatter plots."""
        if self.samples is None:
            print("No samples loaded.")
            return
        
        # Extract samples
        if isinstance(self.samples, dict):
            data = self.samples.get('samples', self.samples.get('chain', None))
        else:
            data = self.samples
        
        if data.ndim == 3:
            # Flatten chains
            data = data.reshape(-1, data.shape[-1])
        
        n_samples, n_dims = data.shape
        
        # Dimension selection
        dim_select = widgets.SelectMultiple(
            options=[(f'Dim {i}', i) for i in range(n_dims)],
            value=list(range(min(5, n_dims))),
            description='Dimensions:',
            rows=min(10, n_dims)
        )
        
        subsample_slider = widgets.IntSlider(
            value=min(1000, n_samples),
            min=100, max=min(5000, n_samples),
            description='Subsample:'
        )
        
        plot_output = widgets.Output()
        
        def update_plot(*args):
            with plot_output:
                plot_output.clear_output(wait=True)
                
                selected_dims = list(dim_select.value)
                n_selected = len(selected_dims)
                
                if n_selected < 2:
                    print("Select at least 2 dimensions.")
                    return
                
                # Subsample data
                idx = np.random.choice(n_samples, subsample_slider.value, replace=False)
                plot_data = data[idx][:, selected_dims]
                
                # Create pairwise scatter matrix
                fig = make_subplots(
                    rows=n_selected, cols=n_selected,
                    shared_xaxes=True, shared_yaxes=True,
                    horizontal_spacing=0.02, vertical_spacing=0.02
                )
                
                for i in range(n_selected):
                    for j in range(n_selected):
                        if i == j:
                            # Diagonal: histogram
                            fig.add_trace(
                                go.Histogram(x=plot_data[:, i], 
                                           histnorm='probability density',
                                           showlegend=False),
                                row=i+1, col=j+1
                            )
                        else:
                            # Off-diagonal: scatter
                            fig.add_trace(
                                go.Scatter(x=plot_data[:, j], y=plot_data[:, i],
                                         mode='markers', marker=dict(size=3, opacity=0.5),
                                         showlegend=False),
                                row=i+1, col=j+1
                            )
                
                # Labels
                for i in range(n_selected):
                    fig.update_xaxes(title_text=f"Dim {selected_dims[i]}", 
                                   row=n_selected, col=i+1)
                    fig.update_yaxes(title_text=f"Dim {selected_dims[i]}", 
                                   row=i+1, col=1)
                
                fig.update_layout(
                    height=150*n_selected,
                    width=150*n_selected,
                    title_text="Pairwise Scatter Matrix"
                )
                
                fig.show()
        
        # Connect and display
        dim_select.observe(update_plot, 'value')
        subsample_slider.observe(update_plot, 'value')
        
        update_plot()
        
        display(widgets.VBox([
            widgets.HBox([dim_select, subsample_slider]),
            plot_output
        ]))
    
    def spectral_density_plot(self):
        """Interactive spectral density visualization."""
        if self.samples is None:
            print("No samples loaded.")
            return
            
        # Extract samples
        if isinstance(self.samples, dict):
            data = self.samples.get('samples', self.samples.get('chain', None))
        else:
            data = self.samples
            
        if data.ndim == 3:
            n_chains, n_samples, n_dims = data.shape
        else:
            data = data[np.newaxis, :]
            n_chains, n_samples, n_dims = data.shape
        
        # Controls
        dim_slider = widgets.IntSlider(
            value=0, min=0, max=n_dims-1,
            description='Dimension:'
        )
        
        chain_slider = widgets.IntSlider(
            value=0, min=0, max=n_chains-1,
            description='Chain:'
        )
        
        window_dropdown = widgets.Dropdown(
            options=['hann', 'hamming', 'blackman', 'bartlett'],
            value='hann',
            description='Window:'
        )
        
        plot_output = widgets.Output()
        
        def update_plot(*args):
            with plot_output:
                plot_output.clear_output(wait=True)
                
                chain_data = data[chain_slider.value, :, dim_slider.value]
                
                # Compute spectral density
                from scipy import signal
                freqs, psd = signal.periodogram(chain_data, 
                                               window=window_dropdown.value)
                
                # Create plot
                fig = go.Figure()
                
                fig.add_trace(
                    go.Scatter(x=freqs, y=psd, mode='lines',
                             name='Spectral Density')
                )
                
                fig.update_xaxes(title_text="Frequency", type="log")
                fig.update_yaxes(title_text="Power Spectral Density", type="log")
                
                fig.update_layout(
                    title_text=f"Spectral Density - Chain {chain_slider.value}, Dim {dim_slider.value}",
                    height=500
                )
                
                fig.show()
                
                # Additional info
                print(f"Effective sample size estimate: {len(chain_data) / (1 + 2*sum(pd.Series(chain_data).autocorr(lag=k) for k in range(1, 50))):.1f}")
        
        # Connect and display
        dim_slider.observe(update_plot, 'value')
        chain_slider.observe(update_plot, 'value')
        window_dropdown.observe(update_plot, 'value')
        
        update_plot()
        
        display(widgets.VBox([
            widgets.HBox([dim_slider, chain_slider, window_dropdown]),
            plot_output
        ]))

# Create visualizer instance
viz = InteractiveVisualizer()

# Example: Load and visualize synthetic data
synthetic_data = np.random.randn(2, 1000, 5)  # 2 chains, 1000 samples, 5 dimensions
viz.set_samples(synthetic_data)
viz.trace_plot_interactive()

## 3. Diagnostic Analysis

The `DiagnosticAnalyzer` provides comprehensive MCMC diagnostics with interactive controls.

In [ ]:
class DiagnosticAnalyzer:
    """Comprehensive diagnostic analysis for MCMC chains."""
    
    def __init__(self):
        self.results = {}
        
    def analyze_convergence(self, samples, burn_in=0.1):
        """Run comprehensive convergence diagnostics."""
        if isinstance(samples, dict):
            chains = samples.get('samples', samples.get('chain', None))
        else:
            chains = samples
            
        if chains.ndim == 2:
            chains = chains[np.newaxis, :]
        
        n_chains, n_samples, n_dims = chains.shape
        burn_in_idx = int(n_samples * burn_in)
        
        # Create diagnostic dashboard
        tab_contents = []
        tab_titles = []
        
        # 1. Summary statistics
        summary_output = widgets.Output()
        with summary_output:
            print("=== MCMC Chain Summary ===")
            print(f"Number of chains: {n_chains}")
            print(f"Number of samples per chain: {n_samples}")
            print(f"Number of dimensions: {n_dims}")
            print(f"Burn-in samples: {burn_in_idx}")
            print()
            
            # Basic statistics
            post_burnin = chains[:, burn_in_idx:, :]
            print("Post burn-in statistics:")
            print(f"Mean: {np.mean(post_burnin, axis=(0,1))[:5]}...")
            print(f"Std: {np.std(post_burnin, axis=(0,1))[:5]}...")
            
        tab_contents.append(summary_output)
        tab_titles.append("Summary")
        
        # 2. Gelman-Rubin diagnostic
        gr_output = widgets.Output()
        dim_select_gr = widgets.IntSlider(
            value=0, min=0, max=n_dims-1,
            description='Dimension:'
        )
        
        def update_gr(*args):
            with gr_output:
                gr_output.clear_output(wait=True)
                
                if n_chains < 2:
                    print("Need at least 2 chains for Gelman-Rubin diagnostic.")
                    return
                
                dim = dim_select_gr.value
                chain_data = chains[:, burn_in_idx:, dim]
                
                # Calculate R-hat
                n = chain_data.shape[1]
                chain_means = np.mean(chain_data, axis=1)
                chain_vars = np.var(chain_data, axis=1, ddof=1)
                
                B = n * np.var(chain_means, ddof=1)
                W = np.mean(chain_vars)
                var_plus = ((n-1)*W + B) / n
                R_hat = np.sqrt(var_plus / W)
                
                print(f"Gelman-Rubin R-hat for dimension {dim}: {R_hat:.4f}")
                print("(Values close to 1.0 indicate convergence)")
                
                # Plot split R-hat over iterations
                split_points = np.linspace(burn_in_idx, n_samples-1, 50, dtype=int)
                r_hats = []
                
                for sp in split_points:
                    if sp - burn_in_idx > 10:
                        data_slice = chains[:, burn_in_idx:sp, dim]
                        n_slice = data_slice.shape[1]
                        means = np.mean(data_slice, axis=1)
                        vars = np.var(data_slice, axis=1, ddof=1)
                        B = n_slice * np.var(means, ddof=1)
                        W = np.mean(vars)
                        var_plus = ((n_slice-1)*W + B) / n_slice
                        r_hats.append(np.sqrt(var_plus / W))
                    else:
                        r_hats.append(np.nan)
                
                fig = go.Figure()
                fig.add_trace(go.Scatter(x=split_points, y=r_hats, mode='lines'))
                fig.add_hline(y=1.1, line_dash="dash", line_color="red",
                            annotation_text="R-hat = 1.1")
                fig.update_layout(
                    title="R-hat vs Chain Length",
                    xaxis_title="Iteration",
                    yaxis_title="R-hat",
                    height=400
                )
                fig.show()
        
        dim_select_gr.observe(update_gr, 'value')
        update_gr()
        
        gr_tab = widgets.VBox([dim_select_gr, gr_output])
        tab_contents.append(gr_tab)
        tab_titles.append("Gelman-Rubin")
        
        # 3. Effective Sample Size
        ess_output = widgets.Output()
        
        with ess_output:
            print("=== Effective Sample Size ===")
            ess_values = []
            
            for dim in range(min(10, n_dims)):
                # Simple ESS estimate
                chain_data = chains[0, burn_in_idx:, dim]
                acf_sum = 0
                for lag in range(1, min(100, len(chain_data)//4)):
                    acf = pd.Series(chain_data).autocorr(lag=lag)
                    if acf < 0.05:
                        break
                    acf_sum += acf
                
                ess = len(chain_data) / (1 + 2*acf_sum)
                ess_values.append(ess)
                print(f"Dimension {dim}: ESS = {ess:.1f} ({ess/len(chain_data)*100:.1f}%)")
            
            # Plot ESS
            fig = go.Figure()
            fig.add_trace(go.Bar(x=list(range(len(ess_values))), y=ess_values))
            fig.update_layout(
                title="Effective Sample Size by Dimension",
                xaxis_title="Dimension",
                yaxis_title="ESS",
                height=400
            )
            fig.show()
            
        tab_contents.append(ess_output)
        tab_titles.append("ESS")
        
        # 4. Geweke diagnostic
        geweke_output = widgets.Output()
        chain_select = widgets.IntSlider(
            value=0, min=0, max=n_chains-1,
            description='Chain:'
        )
        
        def update_geweke(*args):
            with geweke_output:
                geweke_output.clear_output(wait=True)
                
                chain_idx = chain_select.value
                chain_data = chains[chain_idx, burn_in_idx:, :]
                
                # Geweke z-scores
                n = chain_data.shape[0]
                first_prop = 0.1
                last_prop = 0.5
                
                first_n = int(n * first_prop)
                last_n = int(n * last_prop)
                
                z_scores = []
                for dim in range(min(20, n_dims)):
                    first_mean = np.mean(chain_data[:first_n, dim])
                    last_mean = np.mean(chain_data[-last_n:, dim])
                    
                    # Estimate spectral densities at zero
                    first_var = np.var(chain_data[:first_n, dim])
                    last_var = np.var(chain_data[-last_n:, dim])
                    
                    se = np.sqrt(first_var/first_n + last_var/last_n)
                    z = (first_mean - last_mean) / se
                    z_scores.append(z)
                
                # Plot
                fig = go.Figure()
                fig.add_trace(go.Scatter(
                    x=list(range(len(z_scores))),
                    y=z_scores,
                    mode='markers+lines'
                ))
                fig.add_hline(y=1.96, line_dash="dash", line_color="red")
                fig.add_hline(y=-1.96, line_dash="dash", line_color="red")
                fig.update_layout(
                    title=f"Geweke Z-scores - Chain {chain_idx}",
                    xaxis_title="Dimension",
                    yaxis_title="Z-score",
                    height=400
                )
                fig.show()
                
                print(f"Dimensions outside 95% CI: {sum(abs(z) > 1.96 for z in z_scores)}/{len(z_scores)}")
        
        chain_select.observe(update_geweke, 'value')
        update_geweke()
        
        geweke_tab = widgets.VBox([chain_select, geweke_output])
        tab_contents.append(geweke_tab)
        tab_titles.append("Geweke")
        
        # Create tabs
        tabs = widgets.Tab(children=tab_contents)
        for i, title in enumerate(tab_titles):
            tabs.set_title(i, title)
            
        display(tabs)
        
        return self.results
    
    def compare_samplers(self, results_dict):
        """Compare diagnostics across different samplers."""
        output = widgets.Output()
        
        with output:
            # Create comparison plots
            fig = make_subplots(
                rows=2, cols=2,
                subplot_titles=('Acceptance Rate', 'ESS per Second', 
                              'R-hat Distribution', 'Autocorrelation')
            )
            
            colors = px.colors.qualitative.Plotly
            
            for i, (name, result) in enumerate(results_dict.items()):
                color = colors[i % len(colors)]
                
                if 'diagnostics' in result:
                    diag = result['diagnostics']
                    
                    # Acceptance rate
                    if 'acceptance_rate' in diag:
                        fig.add_trace(
                            go.Bar(x=[name], y=[diag['acceptance_rate']],
                                 name=name, marker_color=color),
                            row=1, col=1
                        )
                    
                    # ESS per second
                    if 'ess_per_second' in diag:
                        fig.add_trace(
                            go.Bar(x=[name], y=[diag['ess_per_second']],
                                 name=name, marker_color=color,
                                 showlegend=False),
                            row=1, col=2
                        )
                    
                    # R-hat
                    if 'r_hat' in diag:
                        fig.add_trace(
                            go.Box(y=diag['r_hat'], name=name,
                                 marker_color=color, showlegend=False),
                            row=2, col=1
                        )
                    
                    # Autocorrelation
                    if 'samples' in result:
                        samples = result['samples']
                        if samples.ndim == 3:
                            chain_data = samples[0, :, 0]
                        else:
                            chain_data = samples[:, 0]
                        
                        acf = [pd.Series(chain_data).autocorr(lag=k) 
                               for k in range(min(50, len(chain_data)//4))]
                        fig.add_trace(
                            go.Scatter(y=acf, mode='lines', name=name,
                                     line=dict(color=color), showlegend=False),
                            row=2, col=2
                        )
            
            fig.update_layout(height=800, showlegend=True)
            fig.show()
            
            # Summary table
            summary_data = []
            for name, result in results_dict.items():
                if 'diagnostics' in result:
                    diag = result['diagnostics']
                    summary_data.append({
                        'Sampler': name,
                        'Acceptance Rate': diag.get('acceptance_rate', 'N/A'),
                        'Mean ESS': diag.get('mean_ess', 'N/A'),
                        'Min R-hat': diag.get('min_r_hat', 'N/A'),
                        'Max R-hat': diag.get('max_r_hat', 'N/A'),
                        'Runtime (s)': diag.get('runtime', 'N/A')
                    })
            
            display(pd.DataFrame(summary_data))
        
        display(output)

# Create diagnostic analyzer
diag = DiagnosticAnalyzer()

# Example analysis
print("Running diagnostic analysis on synthetic data...")
diag.analyze_convergence(synthetic_data)

## 4. Rapid Prototyping

The `AnalysisSandbox` provides a flexible environment for quick experiments and custom analyses.

In [ ]:
class AnalysisSandbox:
    """Sandbox environment for rapid prototyping and experimentation."""
    
    def __init__(self):
        self.data_store = {}
        self.plot_store = []
        self.code_history = []
        
    def create_sandbox(self):
        """Create interactive sandbox interface."""
        # Code editor
        code_area = widgets.Textarea(
            value='# Enter your analysis code here\n# Available variables: data_store, samples\n\n',
            layout=widgets.Layout(width='100%', height='200px'),
            description='Code:'
        )
        
        # Variable inspector
        var_output = widgets.Output()
        
        # Plot output
        plot_output = widgets.Output()
        
        # Console output
        console_output = widgets.Output()
        
        # Control buttons
        run_btn = widgets.Button(description='Run Code', button_style='primary')
        clear_btn = widgets.Button(description='Clear Output', button_style='warning')
        save_btn = widgets.Button(description='Save Code', button_style='success')
        
        def update_variables():
            """Update variable inspector."""
            with var_output:
                var_output.clear_output()
                print("=== Data Store ===")
                for key, value in self.data_store.items():
                    if isinstance(value, np.ndarray):
                        print(f"{key}: ndarray {value.shape}")
                    elif isinstance(value, pd.DataFrame):
                        print(f"{key}: DataFrame {value.shape}")
                    else:
                        print(f"{key}: {type(value).__name__}")
        
        def run_code(b):
            """Execute code in sandbox."""
            with console_output:
                console_output.clear_output()
                plot_output.clear_output()
                
                # Prepare namespace
                namespace = {
                    'np': np,
                    'pd': pd,
                    'plt': plt,
                    'sns': sns,
                    'go': go,
                    'px': px,
                    'data_store': self.data_store,
                    'samples': viz.samples,
                    'loader': loader,
                    'viz': viz,
                    'diag': diag
                }
                
                try:
                    # Capture plots
                    with plot_output:
                        exec(code_area.value, namespace)
                    
                    # Update data store if modified
                    self.data_store = namespace['data_store']
                    update_variables()
                    
                    # Save to history
                    self.code_history.append({
                        'timestamp': datetime.now(),
                        'code': code_area.value,
                        'success': True
                    })
                    
                    print("✅ Code executed successfully")
                    
                except Exception as e:
                    print(f"❌ Error: {str(e)}")
                    import traceback
                    traceback.print_exc()
                    
                    self.code_history.append({
                        'timestamp': datetime.now(),
                        'code': code_area.value,
                        'success': False,
                        'error': str(e)
                    })
        
        def clear_output(b):
            """Clear all outputs."""
            console_output.clear_output()
            plot_output.clear_output()
        
        def save_code(b):
            """Save code snippet."""
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            filename = f'sandbox_code_{timestamp}.py'
            
            with console_output:
                print(f"Code saved to: {filename}")
                # In practice, save to file
                print("(File saving not implemented in demo)")
        
        # Connect buttons
        run_btn.on_click(run_code)
        clear_btn.on_click(clear_output)
        save_btn.on_click(save_code)
        
        # Code snippets dropdown
        snippets = {
            'Basic Statistics': '''# Calculate basic statistics
if samples is not None:
    print("Mean:", np.mean(samples))
    print("Std:", np.std(samples))
    print("Shape:", samples.shape)
''',
            'Quick Histogram': '''# Create histogram
if samples is not None:
    data = samples.flatten() if samples.ndim > 1 else samples
    plt.figure(figsize=(8, 6))
    plt.hist(data[:1000], bins=50, alpha=0.7)
    plt.title('Sample Distribution')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.show()
''',
            'Correlation Matrix': '''# Compute and plot correlation matrix
if samples is not None:
    if samples.ndim == 3:
        data = samples[0]  # First chain
    else:
        data = samples
    
    if data.shape[1] > 1:
        corr = np.corrcoef(data.T)
        plt.figure(figsize=(8, 6))
        sns.heatmap(corr, cmap='coolwarm', center=0)
        plt.title('Correlation Matrix')
        plt.show()
''',
            'Custom Analysis': '''# Custom analysis template
# Store results in data_store for later use

result = {}

# Your analysis here
if samples is not None:
    result['mean'] = np.mean(samples, axis=0)
    result['std'] = np.std(samples, axis=0)
    
    # Store in data_store
    data_store['my_analysis'] = result
    
    print("Analysis complete!")
    print("Results stored in data_store['my_analysis']")
''',
            'Interactive Plot': '''# Create interactive plot with Plotly
if samples is not None:
    if samples.ndim >= 2:
        data = samples[:1000, 0] if samples.ndim == 2 else samples[0, :1000, 0]
    else:
        data = samples[:1000]
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=data, mode='lines', name='Samples'))
    fig.update_layout(
        title='Sample Trace',
        xaxis_title='Iteration',
        yaxis_title='Value',
        height=400
    )
    fig.show()
'''
        }
        
        snippet_dropdown = widgets.Dropdown(
            options=['Select snippet...'] + list(snippets.keys()),
            description='Snippets:'
        )
        
        def load_snippet(change):
            if change['new'] != 'Select snippet...' and change['new'] in snippets:
                code_area.value = snippets[change['new']]
        
        snippet_dropdown.observe(load_snippet, 'value')
        
        # Initial variable display
        update_variables()
        
        # Layout
        display(widgets.VBox([
            widgets.HBox([snippet_dropdown]),
            code_area,
            widgets.HBox([run_btn, clear_btn, save_btn]),
            widgets.HBox([
                widgets.VBox([widgets.HTML('<b>Variables</b>'), var_output], 
                           layout=widgets.Layout(width='30%')),
                widgets.VBox([widgets.HTML('<b>Console</b>'), console_output], 
                           layout=widgets.Layout(width='70%'))
            ]),
            widgets.VBox([widgets.HTML('<b>Plots</b>'), plot_output])
        ]))
    
    def load_experiment_data(self, experiment_path):
        """Quick loader for experiment data."""
        try:
            data = loader.load_samples(experiment_path)
            self.data_store['loaded_data'] = data
            print(f"✅ Loaded data from {experiment_path}")
            return data
        except Exception as e:
            print(f"❌ Error loading data: {str(e)}")
            return None

# Create sandbox
sandbox = AnalysisSandbox()
sandbox.create_sandbox()

## 5. Documentation and Insights

The `InsightRecorder` helps track findings and generate documentation during analysis.

In [ ]:
class InsightRecorder:
    """Tool for recording insights and generating documentation."""
    
    def __init__(self):
        self.insights = []
        self.figures = []
        self.current_session = {
            'start_time': datetime.now(),
            'insights': [],
            'figures': []
        }
    
    def create_recorder(self):
        """Create interactive insight recording interface."""
        # Insight entry
        insight_text = widgets.Textarea(
            placeholder='Enter your insight or observation...',
            layout=widgets.Layout(width='100%', height='100px')
        )
        
        # Category selection
        category_dropdown = widgets.Dropdown(
            options=['General', 'Convergence', 'Performance', 'Anomaly', 'Hypothesis', 'TODO'],
            value='General',
            description='Category:'
        )
        
        # Priority
        priority_slider = widgets.IntSlider(
            value=3, min=1, max=5,
            description='Priority:'
        )
        
        # Insight list
        insights_output = widgets.Output()
        
        # Buttons
        add_btn = widgets.Button(description='Add Insight', button_style='primary')
        export_btn = widgets.Button(description='Export Report', button_style='success')
        clear_btn = widgets.Button(description='Clear Current', button_style='warning')
        
        def update_insights_display():
            """Update the insights display."""
            with insights_output:
                insights_output.clear_output()
                
                if not self.current_session['insights']:
                    print("No insights recorded yet.")
                    return
                
                # Group by category
                from collections import defaultdict
                by_category = defaultdict(list)
                
                for insight in self.current_session['insights']:
                    by_category[insight['category']].append(insight)
                
                # Display
                for category, items in by_category.items():
                    print(f"\n### {category}")
                    print("-" * 40)
                    
                    for item in sorted(items, key=lambda x: x['priority'], reverse=True):
                        timestamp = item['timestamp'].strftime('%H:%M:%S')
                        priority_stars = "⭐" * item['priority']
                        print(f"[{timestamp}] {priority_stars}")
                        print(f"  {item['text']}")
                        if 'figure' in item:
                            print(f"  📊 Attached figure: {item['figure']}")
                        print()
        
        def add_insight(b):
            """Add new insight."""
            if not insight_text.value.strip():
                return
            
            insight = {
                'timestamp': datetime.now(),
                'text': insight_text.value,
                'category': category_dropdown.value,
                'priority': priority_slider.value
            }
            
            # Check if there's a current figure
            if plt.get_fignums():
                fig_name = f"fig_{len(self.current_session['figures'])}.png"
                insight['figure'] = fig_name
                self.current_session['figures'].append({
                    'name': fig_name,
                    'figure': plt.gcf()
                })
            
            self.current_session['insights'].append(insight)
            self.insights.append(insight)
            
            # Clear input
            insight_text.value = ''
            
            # Update display
            update_insights_display()
        
        def export_report(b):
            """Export insights as markdown report."""
            report = []
            report.append("# Analysis Report")
            report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
            report.append(f"Session Duration: {datetime.now() - self.current_session['start_time']}")
            report.append("")
            
            # Summary
            report.append("## Summary")
            report.append(f"- Total insights: {len(self.current_session['insights'])}")
            report.append(f"- Figures generated: {len(self.current_session['figures'])}")
            report.append("")
            
            # Insights by category
            from collections import defaultdict
            by_category = defaultdict(list)
            
            for insight in self.current_session['insights']:
                by_category[insight['category']].append(insight)
            
            for category, items in by_category.items():
                report.append(f"## {category}")
                
                for item in sorted(items, key=lambda x: x['priority'], reverse=True):
                    timestamp = item['timestamp'].strftime('%Y-%m-%d %H:%M:%S')
                    priority = "Priority: " + "⭐" * item['priority']
                    
                    report.append(f"### {timestamp} - {priority}")
                    report.append(item['text'])
                    
                    if 'figure' in item:
                        report.append(f"\n![{item['figure']}](./{item['figure']})")
                    
                    report.append("")
            
            # Display report
            from IPython.display import Markdown
            display(Markdown('\n'.join(report)))
            
            # Save to file
            filename = f"analysis_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
            print(f"\nReport would be saved to: {filename}")
        
        def clear_current(b):
            """Clear current session."""
            self.current_session = {
                'start_time': datetime.now(),
                'insights': [],
                'figures': []
            }
            update_insights_display()
        
        # Connect buttons
        add_btn.on_click(add_insight)
        export_btn.on_click(export_report)
        clear_btn.on_click(clear_current)
        
        # Initial display
        update_insights_display()
        
        # Layout
        display(widgets.VBox([
            widgets.HTML('<h3>Record Analysis Insights</h3>'),
            widgets.HBox([category_dropdown, priority_slider]),
            insight_text,
            widgets.HBox([add_btn, export_btn, clear_btn]),
            widgets.HTML('<h4>Current Session Insights</h4>'),
            insights_output
        ]))
    
    def quick_note(self, text, category='General', priority=3):
        """Quick method to add an insight programmatically."""
        insight = {
            'timestamp': datetime.now(),
            'text': text,
            'category': category,
            'priority': priority
        }
        
        self.current_session['insights'].append(insight)
        self.insights.append(insight)
        
        print(f"✅ Insight recorded: {text[:50]}...")
    
    def attach_current_figure(self, name=None):
        """Attach the current matplotlib figure to the last insight."""
        if plt.get_fignums() and self.current_session['insights']:
            if name is None:
                name = f"fig_{len(self.current_session['figures'])}.png"
            
            self.current_session['figures'].append({
                'name': name,
                'figure': plt.gcf()
            })
            
            self.current_session['insights'][-1]['figure'] = name
            print(f"📊 Figure attached: {name}")

# Create insight recorder
recorder = InsightRecorder()
recorder.create_recorder()

# Example usage
recorder.quick_note("Synthetic data shows expected normal distribution", "General", 3)

## 6. Export and Reproducibility

The `ExportManager` handles exporting results in various formats for publications and reproducibility.

In [ ]:
class ExportManager:
    """Manage export of figures and results for publications."""
    
    def __init__(self, export_dir='../results/figures'):
        self.export_dir = Path(export_dir)
        self.export_dir.mkdir(parents=True, exist_ok=True)
        self.exported_items = []
        
    def create_export_interface(self):
        """Create interactive export interface."""
        # Figure selection
        fig_name = widgets.Text(
            placeholder='figure_name',
            description='Name:'
        )
        
        # Format selection
        format_select = widgets.SelectMultiple(
            options=['png', 'pdf', 'svg', 'eps'],
            value=['png', 'pdf'],
            description='Formats:',
            rows=4
        )
        
        # Resolution
        dpi_slider = widgets.IntSlider(
            value=300, min=72, max=600, step=50,
            description='DPI:'
        )
        
        # Publication presets
        preset_dropdown = widgets.Dropdown(
            options=[
                'Custom',
                'Journal (300dpi, PDF+PNG)',
                'Presentation (150dpi, PNG)',
                'Web (72dpi, PNG+SVG)',
                'Print (600dpi, PDF+EPS)'
            ],
            value='Custom',
            description='Preset:'
        )
        
        # Export log
        export_log = widgets.Output()
        
        # Buttons
        export_current_btn = widgets.Button(
            description='Export Current Figure',
            button_style='primary'
        )
        
        export_all_btn = widgets.Button(
            description='Export All Figures',
            button_style='warning'
        )
        
        generate_latex_btn = widgets.Button(
            description='Generate LaTeX',
            button_style='success'
        )
        
        def apply_preset(change):
            """Apply export preset."""
            preset = change['new']
            if preset == 'Journal (300dpi, PDF+PNG)':
                format_select.value = ['png', 'pdf']
                dpi_slider.value = 300
            elif preset == 'Presentation (150dpi, PNG)':
                format_select.value = ['png']
                dpi_slider.value = 150
            elif preset == 'Web (72dpi, PNG+SVG)':
                format_select.value = ['png', 'svg']
                dpi_slider.value = 72
            elif preset == 'Print (600dpi, PDF+EPS)':
                format_select.value = ['pdf', 'eps']
                dpi_slider.value = 600
        
        preset_dropdown.observe(apply_preset, 'value')
        
        def export_figure(fig, name, formats, dpi):
            """Export a single figure."""
            exported_files = []
            
            for fmt in formats:
                filename = f"{name}.{fmt}"
                filepath = self.export_dir / filename
                
                # Configure matplotlib for better export
                if fmt in ['pdf', 'eps']:
                    fig.savefig(filepath, format=fmt, dpi=dpi, 
                              bbox_inches='tight', pad_inches=0.1)
                else:
                    fig.savefig(filepath, format=fmt, dpi=dpi,
                              bbox_inches='tight', pad_inches=0.1,
                              facecolor='white', edgecolor='none')
                
                exported_files.append(filepath)
                
                self.exported_items.append({
                    'name': name,
                    'format': fmt,
                    'dpi': dpi,
                    'path': filepath,
                    'timestamp': datetime.now()
                })
            
            return exported_files
        
        def export_current(b):
            """Export current matplotlib figure."""
            with export_log:
                if not plt.get_fignums():
                    print("❌ No active figure found.")
                    return
                
                if not fig_name.value:
                    print("❌ Please enter a figure name.")
                    return
                
                fig = plt.gcf()
                exported = export_figure(
                    fig, fig_name.value,
                    list(format_select.value),
                    dpi_slider.value
                )
                
                print(f"✅ Exported {fig_name.value}:")
                for path in exported:
                    print(f"   - {path}")
        
        def export_all(b):
            """Export all open figures."""
            with export_log:
                figs = [plt.figure(num) for num in plt.get_fignums()]
                
                if not figs:
                    print("❌ No figures to export.")
                    return
                
                print(f"Exporting {len(figs)} figures...")
                
                for i, fig in enumerate(figs):
                    name = f"{fig_name.value or 'figure'}_{i+1}"
                    exported = export_figure(
                        fig, name,
                        list(format_select.value),
                        dpi_slider.value
                    )
                    print(f"✅ Exported {name}")
                
                print(f"\nTotal files exported: {len(figs) * len(format_select.value)}")
        
        def generate_latex(b):
            """Generate LaTeX code for figures."""
            with export_log:
                print("% LaTeX code for figures")
                print("% Add to preamble: \\usepackage{graphicx}")
                print("% Add to preamble: \\usepackage{subcaption} % for subfigures\n")
                
                # Group by base name
                from collections import defaultdict
                by_name = defaultdict(list)
                
                for item in self.exported_items:
                    if item['format'] == 'pdf':  # Prefer PDF for LaTeX
                        by_name[item['name']].append(item)
                
                for name, items in by_name.items():
                    print(f"% Figure: {name}")
                    print("\\begin{figure}[htbp]")
                    print("    \\centering")
                    print(f"    \\includegraphics[width=0.8\\textwidth]{{{name}.pdf}}")
                    print(f"    \\caption{{Caption for {name}}}")
                    print(f"    \\label{{fig:{name}}}")
                    print("\\end{figure}\n")
                
                # Also generate a comparison figure
                if len(by_name) > 1:
                    print("% Comparison figure with subfigures")
                    print("\\begin{figure}[htbp]")
                    print("    \\centering")
                    
                    names = list(by_name.keys())[:4]  # Max 4 subfigures
                    width = 0.45 if len(names) <= 2 else 0.45
                    
                    for i, name in enumerate(names):
                        if i % 2 == 0 and i > 0:
                            print("    \\\\")
                        print(f"    \\begin{{subfigure}}[b]{{{width}\\textwidth}}")
                        print(f"        \\includegraphics[width=\\textwidth]{{{name}.pdf}}")
                        print(f"        \\caption{{{name}}}")
                        print("    \\end{subfigure}")
                        if i % 2 == 0 and i < len(names) - 1:
                            print("    \\hfill")
                    
                    print("    \\caption{Comparison of results}")
                    print("    \\label{fig:comparison}")
                    print("\\end{figure}")
        
        # Connect buttons
        export_current_btn.on_click(export_current)
        export_all_btn.on_click(export_all)
        generate_latex_btn.on_click(generate_latex)
        
        # Display interface
        display(widgets.VBox([
            widgets.HTML('<h3>Export Manager</h3>'),
            widgets.HBox([preset_dropdown]),
            widgets.HBox([fig_name, format_select]),
            widgets.HBox([dpi_slider]),
            widgets.HBox([export_current_btn, export_all_btn, generate_latex_btn]),
            widgets.HTML('<h4>Export Log</h4>'),
            export_log
        ]))
    
    def create_publication_figure(self, width=6, height=4, style='paper'):
        """Create a figure with publication-ready settings."""
        # Set font sizes
        if style == 'paper':
            plt.rcParams.update({
                'font.size': 10,
                'axes.titlesize': 11,
                'axes.labelsize': 10,
                'xtick.labelsize': 9,
                'ytick.labelsize': 9,
                'legend.fontsize': 9,
                'figure.titlesize': 12
            })
        elif style == 'presentation':
            plt.rcParams.update({
                'font.size': 14,
                'axes.titlesize': 16,
                'axes.labelsize': 14,
                'xtick.labelsize': 12,
                'ytick.labelsize': 12,
                'legend.fontsize': 12,
                'figure.titlesize': 18
            })
        
        # Create figure
        fig, ax = plt.subplots(figsize=(width, height))
        
        # Set properties
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        return fig, ax
    
    def save_session_metadata(self):
        """Save metadata about the export session."""
        metadata = {
            'timestamp': datetime.now().isoformat(),
            'exported_items': self.exported_items,
            'export_dir': str(self.export_dir),
            'total_exports': len(self.exported_items)
        }
        
        metadata_path = self.export_dir / 'export_metadata.json'
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2, default=str)
        
        print(f"✅ Metadata saved to {metadata_path}")
        return metadata

# Create export manager
exporter = ExportManager()
exporter.create_export_interface()

# Example: Create publication-ready figure
fig, ax = exporter.create_publication_figure(width=6, height=4, style='paper')
ax.plot(np.random.randn(100).cumsum())
ax.set_xlabel('Iteration')
ax.set_ylabel('Value')
ax.set_title('Example Publication Figure')
plt.show()

## 7. Summary

This notebook provides a comprehensive toolkit for interactive exploratory analysis of lattice Gaussian MCMC experiments:

### Key Components:

1. **DataLoader**: Flexible data loading with caching and format detection
2. **InteractiveVisualizer**: Dynamic visualization tools for chains, distributions, and diagnostics
3. **DiagnosticAnalyzer**: Comprehensive MCMC diagnostics (R-hat, ESS, Geweke, etc.)
4. **AnalysisSandbox**: Rapid prototyping environment with code execution
5. **InsightRecorder**: Documentation and insight tracking system
6. **ExportManager**: Publication-ready figure export with LaTeX generation

### Usage Tips:

- Start by loading your experimental data with the DataLoader
- Use the InteractiveVisualizer for initial exploration
- Run diagnostic analysis to verify convergence
- Experiment with custom analyses in the Sandbox
- Record important findings with the InsightRecorder
- Export final figures with the ExportManager

### Next Steps:

1. Load actual experimental results
2. Compare different sampling algorithms
3. Generate publication figures
4. Export insights and create reports

In [ ]:
# Quick reference for all tools
print("=== Quick Reference ===")
print("\n1. Load data:")
print("   loader.create_interactive_loader()")
print("   data = loader.load_samples('path/to/data.npz')")
print("\n2. Visualize:")
print("   viz.set_samples(data)")
print("   viz.trace_plot_interactive()")
print("   viz.pairwise_scatter()")
print("   viz.spectral_density_plot()")
print("\n3. Diagnose:")
print("   diag.analyze_convergence(data)")
print("   diag.compare_samplers(results_dict)")
print("\n4. Sandbox:")
print("   sandbox.create_sandbox()")
print("   sandbox.load_experiment_data('path/to/data')")
print("\n5. Record insights:")
print("   recorder.quick_note('My insight', 'Category', priority=5)")
print("   recorder.attach_current_figure()")
print("\n6. Export:")
print("   exporter.create_export_interface()")
print("   exporter.save_session_metadata()")