# Subset Visualization

This notebook visualizes the selected subsets and compares them with baselines.

Features:
- Display sample images from selected subsets
- Show class distribution histograms
- Compare GA-selected vs random baseline subsets


In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Add parent directory to path
sys.path.insert(0, str(Path().absolute().parent))
import config
from data.load_data import load_selection_pool, get_class_distribution

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)


In [None]:
# Load data
data, labels = load_selection_pool()
print(f"Loaded selection pool: {len(data)} samples")


## Display Sample Images from Selected Subsets


In [None]:
def display_subset_samples(subset_indices, title, num_samples=20):
    """Display sample images from a subset."""
    # Select random samples to display
    display_indices = np.random.choice(len(subset_indices), 
                                       size=min(num_samples, len(subset_indices)), 
                                       replace=False)
    
    fig, axes = plt.subplots(4, 5, figsize=(12, 10))
    axes = axes.flatten()
    
    for i, idx in enumerate(display_indices):
        sample_idx = subset_indices[idx]
        img = data[sample_idx].squeeze()  # Remove channel dimension
        label = labels[sample_idx]
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Class {label}', fontsize=9)
        axes[i].axis('off')
    
    # Hide unused subplots
    for i in range(len(display_indices), len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    return fig

# Load and display GA-selected subset for k=100
k = 100
ga_subset_path = config.get_selected_subset_path(k)
if ga_subset_path.exists():
    ga_indices = np.load(ga_subset_path)
    fig = display_subset_samples(ga_indices, f'GA-Selected Subset (k={k})')
    plt.savefig(f'results/ga_subset_samples_k{k}.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print(f"GA-selected subset for k={k} not found. Run subset selection first.")


## Class Distribution Comparison


In [None]:
# Compare class distributions for different k values
k_values_to_plot = [50, 100, 200, 500]

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

for idx, k in enumerate(k_values_to_plot):
    ax = axes[idx]
    
    # GA-selected subset
    ga_subset_path = config.get_selected_subset_path(k)
    if ga_subset_path.exists():
        ga_indices = np.load(ga_subset_path)
        ga_labels = labels[ga_indices]
        ga_dist = get_class_distribution(ga_labels, num_classes=config.NUM_CLASSES)
        
        # Random baseline (average of 5 runs)
        random_dists = []
        for run_num in range(1, config.NUM_RANDOM_BASELINES + 1):
            # Generate random subset with same seed as training
            np.random.seed(config.BASELINE_SEED + run_num)
            random_indices = np.random.choice(len(labels), size=k, replace=False)
            random_labels = labels[random_indices]
            random_dists.append(get_class_distribution(random_labels, num_classes=config.NUM_CLASSES))
        
        random_mean_dist = np.mean(random_dists, axis=0)
        random_std_dist = np.std(random_dists, axis=0)
        
        # Plot
        x = np.arange(config.NUM_CLASSES)
        width = 0.35
        
        bars1 = ax.bar(x - width/2, ga_dist, width, label='GA-Selected', 
                       color='steelblue', alpha=0.8, edgecolor='black')
        bars2 = ax.bar(x + width/2, random_mean_dist, width, label='Random (mean)', 
                       color='orange', alpha=0.8, edgecolor='black', yerr=random_std_dist, capsize=3)
        
        ax.set_xlabel('Class', fontsize=11)
        ax.set_ylabel('Count', fontsize=11)
        ax.set_title(f'Class Distribution: k={k}', fontsize=12, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([str(i) for i in range(config.NUM_CLASSES)])
        ax.legend(fontsize=10)
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels
        for bar in bars1:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{int(height)}', ha='center', va='bottom', fontsize=8)
    else:
        ax.text(0.5, 0.5, f'Subset for k={k} not found', 
               ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title(f'k={k} (not available)', fontsize=12)

plt.tight_layout()
plt.savefig('results/class_distribution_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
