# Non-IID Wrapper Demo

Este notebook prueba el `NonIIDWrapper` en los 5 datasets principales del repo.
Antes de correr, asegúrate de tener los datasets descargados en `data/` como indica el README.


In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

from visgen.datasets import Cars3D, CLEVR, DSprites, IRAVEN, MPI3D, Shapes3D
from visgen.datasets.non_iid import NonIIDWrapper


In [None]:
from omegaconf import OmegaConf
from visgen.utils.general import register_resolvers

register_resolvers()

DATASET_CONFIG_DIR = Path("configs/datasets")

DATASET_CLASSES = {
    'dsprites': DSprites,
    'mpi3d': MPI3D,
    'shapes3d': Shapes3D,
    'cars3d': Cars3D,
    'iraven': IRAVEN,
}

def load_dataset(name, split='training'):
    cfg = OmegaConf.load(DATASET_CONFIG_DIR / f"{name}.yml")
    data_cfg = cfg.data[split]
    kwargs = {
        'path': str(Path(data_cfg.path)),
        'dataset_subset': data_cfg.get('dataset_subset'),
        'train': bool(data_cfg.train),
        'targets': data_cfg.get('targets'),
        'split_attributes': data_cfg.get('split_attributes'),
        'split': data_cfg.get('split'),
        'split_difficulty': data_cfg.get('split_difficulty'),
        'shuffle': bool(data_cfg.get('shuffle', True)),
        'downsample': int(data_cfg.get('downsample', 0)),
    }
    if name == 'iraven':
        kwargs['max_obj'] = data_cfg.get('max_objects', data_cfg.get('max_obj', 1))
    return DATASET_CLASSES[name](**kwargs)


In [None]:
def plot_quad(images, targets, title):
    fig, axes = plt.subplots(1, 4, figsize=(12, 3))
    fig.suptitle(title)
    for idx, ax in enumerate(axes):
        img = images[idx]
        if torch.is_tensor(img):
            img = img.detach().cpu().numpy()
        if img.ndim == 3 and img.shape[0] in (1, 3):
            img = np.moveaxis(img, 0, -1)
        ax.imshow(img.squeeze(), cmap='gray')
        if targets is not None:
            tgt = targets[idx]
            if torch.is_tensor(tgt):
                tgt = tgt.detach().cpu().numpy()
            tgt_str = ', '.join(str(int(val)) for val in tgt)
            ax.set_title(tgt_str, fontsize=8)
        ax.axis('off')
    plt.show()


In [None]:
def describe_dataset(dataset):
    targets = dataset.targets
    if targets.ndim == 3 and targets.shape[1] == 1:
        targets = targets[:, 0, :]
    if targets.ndim != 2:
        print(f'  Targets shape (unsupported): {targets.shape}')
        return
    print(f'  Num samples: {len(dataset)}')
    print(f'  Targets shape: {targets.shape}')
    unique_counts = []
    for idx in range(targets.shape[1]):
        unique_counts.append(len(np.unique(targets[:, idx])))
    print(f'  Attribute cardinalities: {unique_counts}')


In [None]:
SEED = 0
np.random.seed(SEED)
torch.manual_seed(SEED)


In [None]:
def validate_quadrant(targets, shared_other_attributes):
    if torch.is_tensor(targets):
        targets = targets.detach().cpu().numpy()
    targets = np.asarray(targets)
    if targets.ndim != 2 or targets.shape[0] != 4:
        return {'valid': False, 'reason': 'targets shape unexpected'}
    num_attrs = targets.shape[1]
    candidate = None
    for i in range(num_attrs):
        for j in range(i + 1, num_attrs):
            if len(np.unique(targets[:, i])) != 2:
                continue
            if len(np.unique(targets[:, j])) != 2:
                continue
            pairs = set(zip(targets[:, i], targets[:, j]))
            if len(pairs) == 4:
                candidate = (i, j)
                break
        if candidate:
            break
    if candidate is None:
        return {'valid': False, 'reason': 'no attribute pair forms a full quadrant'}
    if shared_other_attributes:
        other_indices = [k for k in range(num_attrs) if k not in candidate]
        for k in other_indices:
            if len(np.unique(targets[:, k])) != 1:
                return {'valid': False, 'reason': 'other attributes not shared', 'pair': candidate}
    return {'valid': True, 'pair': candidate}


In [None]:
for name in DATASET_CLASSES:
    print(f'\nDataset: {name}')
    dataset = load_dataset(name)
    describe_dataset(dataset)
    wrapper = NonIIDWrapper(dataset, shared_other_attributes=True, seed=SEED)
    images, targets = wrapper[0]
    print('Targets shape:', targets.shape)
    print('Quadrant check (shared):', validate_quadrant(targets, True))
    plot_quad(images, targets, f'{name} (shared other attributes)')

    wrapper_indep = NonIIDWrapper(dataset, shared_other_attributes=False, seed=SEED)
    images_indep, targets_indep = wrapper_indep[0]
    print('Targets shape (indep):', targets_indep.shape)
    print('Quadrant check (indep):', validate_quadrant(targets_indep, False))
    plot_quad(images_indep, targets_indep, f'{name} (independent other attributes)')
