# Estadísticas de datasets (c y vector D)

Este notebook carga un dataset por nombre y calcula estadísticas básicas usando el split **general\_composition**,
siguiendo la lógica del parámetro **c** y el vector **D** ("attr\_difficulty") usado en los scripts
`orthotopic_runner.sh`/`compositional_orth.sh`.

Antes de correrlo, asegúrate de descargar los datasets en `data/` (ver README).


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from omegaconf import OmegaConf
from visgen.utils.general import register_resolvers
from visgen.datasets import Cars3D, CLEVR, DSprites, IRAVEN, MPI3D, Shapes3D

register_resolvers()


In [None]:
DATASET_CONFIG_DIR = Path("configs/datasets")

DATASET_CLASSES = {
    'dsprites': DSprites,
    'mpi3d': MPI3D,
    'shapes3d': Shapes3D,
    'cars3d': Cars3D,
    'iraven': IRAVEN,
    'clevr': CLEVR,
}

def parse_attr_difficulty(d_vector):
    if d_vector is None:
        return None
    if isinstance(d_vector, str):
        cleaned = d_vector.strip().replace('[', '').replace(']', '')
        return [int(x) for x in cleaned.split(',') if x.strip()]
    return list(d_vector)

def get_attribute_names(dataset, num_attributes):
    if hasattr(dataset, '_attribute_indices'):
        return [k for k, _ in sorted(dataset._attribute_indices.items(), key=lambda kv: kv[1])]
    if hasattr(dataset, '_ATTRIBUTE_INDICES'):
        return [k for k, _ in sorted(dataset._ATTRIBUTE_INDICES.items(), key=lambda kv: kv[1])]
    return [f'attr_{i}' for i in range(num_attributes)]

def load_dataset(name, split='training', c=1, attr_difficulty=None, split_override='general_composition'):
    if name not in DATASET_CLASSES:
        raise ValueError(f'Unknown dataset: {name}')
    cfg = OmegaConf.load(DATASET_CONFIG_DIR / f"{name}.yml")
    data_cfg = cfg.data[split]
    kwargs = {
        'path': str(Path(data_cfg.path)),
        'dataset_subset': data_cfg.get('dataset_subset'),
        'train': bool(data_cfg.train),
        'targets': data_cfg.get('targets'),
        'split_attributes': data_cfg.get('split_attributes'),
        'split': split_override,
        'split_difficulty': data_cfg.get('split_difficulty'),
        'shuffle': bool(data_cfg.shuffle),
        'downsample': int(data_cfg.downsample),
        'test_complement': True,
        'c': int(c),
        'attr_difficulty': parse_attr_difficulty(attr_difficulty),
    }
    return DATASET_CLASSES[name](**kwargs)


In [None]:
def normalize_targets(targets):
    if targets.ndim == 3 and targets.shape[1] == 1:
        return targets[:, 0, :]
    if targets.ndim != 2:
        raise ValueError(f'Unsupported targets shape: {targets.shape}')
    return targets

def compute_included_combinations(attribute_values, split_indices, c, threshold_values):
    import itertools
    split_values = [attribute_values[i] for i in split_indices]
    cartesian = np.array(list(itertools.product(*split_values)))
    threshold_values = np.array(threshold_values)
    included = cartesian[np.sum(cartesian >= threshold_values, axis=1) <= c]
    return cartesian, included

def analyze_dataset(name, c, d_vector, split='training'):
    dataset = load_dataset(name, split=split, c=c, attr_difficulty=d_vector)
    targets = normalize_targets(dataset._dataset_targets)
    num_samples, num_attributes = targets.shape
    attribute_values = getattr(dataset, '_attribute_values', None)
    if attribute_values is None:
        attribute_values = dataset._get_attribute_values()
    attribute_names = get_attribute_names(dataset, num_attributes)

    split_attributes = getattr(dataset, '_split_attributes', None)
    if not split_attributes:
        split_attributes = attribute_names
    split_indices = [attribute_names.index(attr) for attr in split_attributes]
    threshold_values = parse_attr_difficulty(d_vector)
    if threshold_values is None:
        raise ValueError('Se requiere un vector D (attr_difficulty) para el split general_composition.')
    if len(threshold_values) != len(split_attributes):
        raise ValueError(
            f'Longitud de D ({len(threshold_values)}) no coincide con split_attributes ({len(split_attributes)}).'
        )

    cartesian, included = compute_included_combinations(
        attribute_values, split_indices, c, threshold_values
    )
    volume = included.shape[0] / cartesian.shape[0]

    summary = {
        'dataset': name,
        'num_samples': num_samples,
        'num_attributes': num_attributes,
        'split_attributes': split_attributes,
        'c': c,
        'D (attr_difficulty)': threshold_values,
        'cartesian_size': int(cartesian.shape[0]),
        'included_combinations': int(included.shape[0]),
        'volume (included/cartesian)': volume,
    }

    summary_df = pd.DataFrame([summary])

    attr_stats = []
    for idx, name_attr in enumerate(attribute_names):
        values, counts = np.unique(targets[:, idx], return_counts=True)
        attr_stats.append({
            'attribute': name_attr,
            'num_values': len(values),
            'min_value': int(values.min()),
            'max_value': int(values.max()),
            'most_common_value': int(values[np.argmax(counts)]),
            'most_common_count': int(counts.max()),
        })
    attr_df = pd.DataFrame(attr_stats)

    split_targets = targets[:, split_indices]
    split_combo_values, split_combo_counts = np.unique(split_targets, axis=0, return_counts=True)
    combo_df = pd.DataFrame({
        'combination': [tuple(row) for row in split_combo_values],
        'count': split_combo_counts,
    }).sort_values('count', ascending=False)

    return summary_df, attr_df, combo_df


## Parámetros de entrada

Define el nombre del dataset, el valor de **c** y el vector **D**.
Los valores de **D** deben estar en el mismo orden que `split_attributes` del dataset.


In [None]:
dataset_name = 'dsprites'
c_value = 1
d_vector = [2, 3, 14, 14]  # ejemplo para dsprites (c=1)

summary_df, attr_df, combo_df = analyze_dataset(dataset_name, c_value, d_vector)
summary_df


In [None]:
attr_df


In [None]:
combo_df.head(10)


In [None]:
# Histograma por atributo (primeros 6 atributos para no saturar la visualización)
plot_attrs = attr_df['attribute'].tolist()[:6]
fig, axes = plt.subplots(len(plot_attrs), 1, figsize=(8, 3 * len(plot_attrs)))
if len(plot_attrs) == 1:
    axes = [axes]

dataset = load_dataset(dataset_name, split='training', c=c_value, attr_difficulty=d_vector)
targets = normalize_targets(dataset._dataset_targets)
attribute_names = get_attribute_names(dataset, targets.shape[1])

for ax, attr in zip(axes, plot_attrs):
    idx = attribute_names.index(attr)
    values, counts = np.unique(targets[:, idx], return_counts=True)
    ax.bar(values, counts)
    ax.set_title(f'{attr} (n={len(values)})')
    ax.set_xlabel('value')
    ax.set_ylabel('count')

plt.tight_layout()
