# Análisis de dSprites y MPI3D

Este notebook sirve para inspeccionar los datasets **dSprites** y **MPI3D** en los splits de IDG. Incluye estadísticas básicas y una comparación visual de ejemplos de **train** vs **test**.

**Requisitos previos**: descarga los `.npz` del benchmark y colócalos en `DATA_ROOT` con la estructura esperada (por ejemplo `data/dsprites/dsprites_random_train_images.npz`, etc.).

In [None]:
from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt

from dataset import IDGBenchmarkDataset


In [None]:
# Configuración
DATA_ROOT = Path("../data")  # Ajusta este path a tu ubicación de datasets
SPLIT = "random"  # random | composition | interpolation | extrapolation
SAMPLE_SEED = 0
EXAMPLES_PER_SPLIT = 8
STATS_SAMPLE = 10000  # nº de muestras para estadísticas rápidas


In [None]:
rng = np.random.default_rng(SAMPLE_SEED)


def _sample_indices(n_total: int, sample_size: int) -> np.ndarray:
    sample_size = min(sample_size, n_total)
    return rng.choice(n_total, size=sample_size, replace=False)


def compute_image_stats(images: np.ndarray, sample_size: int = 10000) -> dict:
    idx = _sample_indices(len(images), sample_size)
    sample = images[idx]
    return {
        'sample_size': int(len(sample)),
        'dtype': str(sample.dtype),
        'shape': tuple(sample.shape[1:]),
        'min': float(sample.min()),
        'max': float(sample.max()),
        'mean': float(sample.mean()),
        'std': float(sample.std()),
    }


def compute_factor_stats(labels: np.ndarray) -> list[dict]:
    if labels.ndim == 1:
        labels = labels[:, None]
    stats = []
    for i in range(labels.shape[1]):
        values, counts = np.unique(labels[:, i], return_counts=True)
        stats.append({'values': values, 'counts': counts})
    return stats


def plot_factor_histograms(factor_stats: list[dict], factor_names: list[str], title: str) -> None:
    n = len(factor_stats)
    fig, axes = plt.subplots(1, n, figsize=(3 * n, 3), squeeze=False)
    for idx, ax in enumerate(axes[0]):
        values = factor_stats[idx]['values']
        counts = factor_stats[idx]['counts']
        ax.bar(values, counts, width=0.8, color='#4c72b0')
        ax.set_title(factor_names[idx])
        ax.set_xticks(values if len(values) <= 10 else values[:: max(1, len(values) // 10)])
    fig.suptitle(title)
    fig.tight_layout()
    plt.show()


def plot_examples(ds: IDGBenchmarkDataset, title: str, n: int = 8) -> None:
    indices = _sample_indices(len(ds), n)
    fig, axes = plt.subplots(1, n, figsize=(2.2 * n, 2.2))
    for ax, idx in zip(axes, indices):
        img, latents, names = ds[idx]
        img_np = img.permute(1, 2, 0).numpy()
        ax.imshow(img_np)
        ax.set_title(', '.join(f'{name}:{int(value)}' for name, value in zip(names, latents.tolist())), fontsize=7)
        ax.axis('off')
    fig.suptitle(title)
    fig.tight_layout()
    plt.show()


def describe_split(ds: IDGBenchmarkDataset, split_name: str) -> dict:
    images = ds._images
    labels = ds._labels
    stats = compute_image_stats(images, sample_size=STATS_SAMPLE)
    return {
        'split': split_name,
        'num_samples': len(ds),
        'factor_names': ds.factor_names,
        'label_shape': tuple(labels.shape),
        'image_stats': stats,
    }


In [None]:
def analyze_dataset(dataset_name: str) -> None:
    print(f'\n==== {dataset_name.upper()} ====')

    train_ds = IDGBenchmarkDataset(DATA_ROOT, dataset_name, SPLIT, 'train')
    test_ds = IDGBenchmarkDataset(DATA_ROOT, dataset_name, SPLIT, 'test')

    train_info = describe_split(train_ds, 'train')
    test_info = describe_split(test_ds, 'test')

    print('Train info:')
    for k, v in train_info.items():
        print(f'  {k}: {v}')

    print('\nTest info:')
    for k, v in test_info.items():
        print(f'  {k}: {v}')

    train_factor_stats = compute_factor_stats(train_ds._labels)
    test_factor_stats = compute_factor_stats(test_ds._labels)

    plot_factor_histograms(train_factor_stats, train_ds.factor_names, f'{dataset_name} - train distribuciones')
    plot_factor_histograms(test_factor_stats, test_ds.factor_names, f'{dataset_name} - test distribuciones')

    plot_examples(train_ds, f'{dataset_name} - train ejemplos', n=EXAMPLES_PER_SPLIT)
    plot_examples(test_ds, f'{dataset_name} - test ejemplos', n=EXAMPLES_PER_SPLIT)


for dataset_name in ['dsprites', 'mpi3d']:
    analyze_dataset(dataset_name)
