# Sampling Methods Visualizer (Interactive)

This notebook provides an interactive visualization of different statistical sampling methods, ported from the React application.

## Supported Methods:
1. **Simple Random Sampling**: Random selection.
2. **Systematic Sampling**: Select every k-th item (simulates scanlines in Phantom mode).
3. **Stratified Sampling**: Proportional sampling by category/tissue type.
4. **Cluster Sampling**: Random selection of spatial clusters/patches.

## Modes:
- **Scatter**: General population statistics.
- **Phantom**: Medical imaging simulation (signal acquisition).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd

# Set random seed for reproducibility in demo
# np.random.seed(42) 

# Configuration constants
CANVAS_WIDTH = 500
CANVAS_HEIGHT = 300

# Colors for Scatter Plot
COLORS = {
    'A': '#EF4444', # Red
    'B': '#3B82F6', # Blue
    'C': '#10B981'  # Emerald
}

def generate_scatter_population(n=200):
    """Generates a scatter population with 3 clusters."""
    data = []
    clusters = [
        {'x': 100, 'y': 100, 'id': 1},
        {'x': 350, 'y': 100, 'id': 2},
        {'x': 100, 'y': 220, 'id': 3},
        {'x': 350, 'y': 220, 'id': 4},
    ]
    
    for i in range(n):
        # Assign category randomly
        rand = np.random.random()
        category = 'A' if rand < 0.5 else ('B' if rand < 0.8 else 'C')
        
        # Assign cluster
        cluster = np.random.choice(clusters)
        
        # Random position around cluster center
        x = np.clip(cluster['x'] + (np.random.random() - 0.5) * 150, 20, CANVAS_WIDTH - 20)
        y = np.clip(cluster['y'] + (np.random.random() - 0.5) * 120, 20, CANVAS_HEIGHT - 20)
        
        data.append({
            'id': i,
            'x': x,
            'y': y,
            'value': np.random.randint(0, 100),
            'category': category,
            'clusterId': cluster['id'],
            'type': 'point'
        })
    return pd.DataFrame(data)

def generate_phantom_population():
    """Generates a grid-based phantom representing tissue structure."""
    data = []
    rows, cols = 20, 40
    cell_width = CANVAS_WIDTH / cols
    cell_height = CANVAS_HEIGHT / rows
    
    # Define circular structures (Cells)
    cells = [
        {'cx': 10, 'cy': 5, 'r': 3},
        {'cx': 30, 'cy': 10, 'r': 5},
        {'cx': 15, 'cy': 15, 'r': 4},
    ]
    
    id_counter = 0
    for r in range(rows):
        for c in range(cols):
            tissue = 'Background'
            
            # Determine tissue type
            for cell in cells:
                dist = np.sqrt((c - cell['cx'])**2 + (r - cell['cy'])**2)
                if dist < cell['r']:
                    tissue = 'Cell'
                elif abs(dist - cell['r']) < 0.8:
                    tissue = 'Membrane'
            
            # Simulate signal value
            base_value = 80 if tissue == 'Background' else (40 if tissue == 'Cell' else 10)
            noise = (np.random.random() - 0.5) * 20
            value = np.clip(base_value + noise, 0, 100)
            
            # Define spatial clusters (blocks of 5x8)
            block_row = r // 5
            block_col = c // 8
            cluster_id = block_row * 5 + block_col
            
            data.append({
                'id': id_counter,
                'r': r, 'c': c,
                'x': c * cell_width + cell_width/2,
                'y': r * cell_height + cell_height/2,
                'width': cell_width,
                'height': cell_height,
                'value': value,
                'category': tissue,
                'clusterId': cluster_id,
                'type': 'pixel'
            })
            id_counter += 1
    return pd.DataFrame(data)

In [None]:
def perform_sampling(df, method, sample_size, mode='scatter'):
    """Performs the selected sampling method on the DataFrame."""
    
    # Ensure sample_size doesn't exceed population
    sample_size = min(sample_size, len(df))
    
    if method == 'Simple Random':
        return df.sample(n=sample_size)
        
    elif method == 'Systematic':
        # Sort: by X for scatter, by row/col for phantom
        if mode == 'phantom':
            sorted_df = df.sort_values(by=['r', 'c'])
        else:
            sorted_df = df.sort_values(by='x')
            
        k = max(1, len(df) // sample_size)
        start = np.random.randint(0, k)
        indices = [i for i in range(start, len(df), k)][:sample_size]
        return sorted_df.iloc[indices]
        
    elif method == 'Stratified':
        # Group by category and sample proportionally
        result_indices = []
        groups = df.groupby('category')
        for name, group in groups:
            # Calculate proportion
            ratio = len(group) / len(df)
            target = max(1, int(round(sample_size * ratio)))
            sampled_group = group.sample(n=min(target, len(group)))
            result_indices.extend(sampled_group.index.tolist())
        
        # If rounding caused slightly different size, adjust (simple fix)
        return df.loc[result_indices].iloc[:sample_size] # Simple truncation if over
        
    elif method == 'Cluster':
        unique_clusters = df['clusterId'].unique()
        avg_cluster_size = len(df) / len(unique_clusters)
        clusters_needed = max(1, int(round(sample_size / avg_cluster_size)))
        
        selected_clusters = np.random.choice(unique_clusters, size=clusters_needed, replace=False)
        return df[df['clusterId'].isin(selected_clusters)]
    
    return df.sample(n=sample_size)

In [None]:
def plot_simulation(mode, method, sample_size):
    # 1. Generate Data
    if mode == 'Scatter':
        df = generate_scatter_population()
        pop_size = 200
    else:
        df = generate_phantom_population()
        pop_size = 800
        
    # 2. Sample Data
    sample_df = perform_sampling(df, method, sample_size, mode.lower())
    
    # 3. Visualization Setup
    fig = plt.figure(figsize=(14, 6))
    gs = fig.add_gridspec(1, 2, width_ratios=[2, 1])
    ax_viz = fig.add_subplot(gs[0])
    ax_stats = fig.add_subplot(gs[1])
    
    # --- Visualization Plot ---
    ax_viz.set_title(f"{mode} View - {method} Sampling (n={len(sample_df)})", fontsize=14)
    ax_viz.set_xlim(0, CANVAS_WIDTH)
    ax_viz.set_ylim(CANVAS_HEIGHT, 0) # Flip Y to match canvas coords
    
    if mode == 'Scatter':
        # Plot all points (faint)
        ax_viz.scatter(df['x'], df['y'], c=[COLORS[c] for c in df['category']], alpha=0.1, s=30)
        # Plot sampled points (bold)
        ax_viz.scatter(sample_df['x'], sample_df['y'], c=[COLORS[c] for c in sample_df['category']], 
                       edgecolors='black', linewidth=1.5, alpha=1.0, s=50, label='Sampled')
        ax_viz.legend()
        
    else: # Phantom Mode
        # Create image grid
        grid_img = np.zeros((20, 40, 3))
        
        # Fill grid with colors
        # Logic: Unsampled pixels are gray/faint. Sampled pixels are true color based on intensity.
        sample_ids = set(sample_df['id'])
        
        for idx, row in df.iterrows():
            r, c = int(row['r']), int(row['c'])
            val = row['value']
            
            if row['id'] in sample_ids:
                # True color (Intensity -> Grayscale/Blueish)
                norm_val = val / 100.0
                # Make it look like MRI (Blue/Grey mapping)
                color = [norm_val * 0.2, norm_val * 0.5, norm_val * 0.8]
            else:
                # Missing data (Faint background)
                color = [0.9, 0.9, 0.95]
                
            grid_img[r, c] = color
            
        ax_viz.imshow(grid_img, extent=[0, CANVAS_WIDTH, CANVAS_HEIGHT, 0], aspect='auto')
        
    # --- Statistics Plot ---
    ax_stats.set_title("Statistics Comparison", fontsize=14)
    
    # 1. Mean Comparison
    pop_mean = df['value'].mean()
    sample_mean = sample_df['value'].mean()
    
    ax_stats.text(0.1, 0.9, f"Population Mean: {pop_mean:.2f}", fontsize=12)
    ax_stats.text(0.1, 0.8, f"Sample Mean: {sample_mean:.2f}", fontsize=12, fontweight='bold', color='blue')
    ax_stats.text(0.1, 0.72, f"Error: {abs(pop_mean - sample_mean):.2f}", fontsize=10, color='gray')
    
    # 2. Bar Chart for Categories
    cats = sorted(df['category'].unique())
    pop_counts = df['category'].value_counts(normalize=True)
    sample_counts = sample_df['category'].value_counts(normalize=True)
    
    x = np.arange(len(cats))
    width = 0.35
    
    # Prepare data ensuring all cats align
    p_vals = [pop_counts.get(c, 0) for c in cats]
    s_vals = [sample_counts.get(c, 0) for c in cats]
    
    ax_stats.bar(x - width/2, p_vals, width, label='Pop', color='gray', alpha=0.5)
    ax_stats.bar(x + width/2, s_vals, width, label='Sample', color='blue', alpha=0.8)
    
    ax_stats.set_xticks(x)
    ax_stats.set_xticklabels(cats)
    ax_stats.set_ylabel("Proportion")
    ax_stats.legend()
    
    # Layout adjustment stats axis
    ax_stats.set_ylim(0, 1.0)
    # Make the text area clean
    ax_stats.axis('on')
    
    plt.tight_layout()
    plt.show()

# --- Interactive Widgets ---
style = {'description_width': 'initial'}

ui = widgets.Interactive(
    plot_simulation,
    mode=widgets.Dropdown(options=['Scatter', 'Phantom'], value='Phantom', description='Data Source:'),
    method=widgets.Dropdown(options=['Simple Random', 'Systematic', 'Stratified', 'Cluster'], 
                            value='Simple Random', description='Sampling Method:', style=style),
    sample_size=widgets.IntSlider(min=10, max=200, step=10, value=50, description='Sample Size (n):', style=style)
)

display(ui)