# Evaluación via Regresión Modelos de Extrapolación

In [1]:
from models import LightningRegression,LightningVersatile,create_model
from model_info import encoders, modulators, model_output_dims
from utils import set_seed, get_args, get_name_from_args
import torch
from datasets import IdSpritesEval
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split


def get_dataloader(args, indices = []):
    
    data = torch.load(f"{args.dataset}/{args.dataset}.pth", map_location="cpu")
    reps_path = None
    if args.pretrained_reps:
        reps_path = args.pretrained_reps
    elif args.pretrained_encoder:
        encoder_args = get_args(args.pretrained_encoder)
        reps_path = encoder_args.pretrained_reps
    have_reps = reps_path is not None
    if have_reps:
        print("using pretrained reps...")
        data['reps'] = torch.load(f"{args.dataset}/{args.dataset}_images_feats_{reps_path}.pth", map_location="cpu") if reps_path else None
        data['reps'] = data['reps'] - data['reps'].mean(dim=0) # center
        data['reps'] = torch.nn.functional.normalize(data['reps'], p=2.0, dim=1, eps=1e-12)
    else:
        print("using input images")
    ds = TensorDataset(
                data['images'][indices],
                data['reps'][indices] if have_reps else data['latents'][indices],
                data['latent_ids'][indices]
                )
    dl = torch.utils.data.DataLoader(ds, batch_size=1024, shuffle=False)
    return dl
    
def evaluate(model, dataloader):
    def regression_metrics(logs, targets, mode="dim"):
        if mode == "dim":
            ss_res = torch.sum((targets - logs) ** 2, dim=0)
            mu = torch.mean(targets, dim=0)
            ss_tot = torch.sum((targets - mu)**2, dim=0)
            r2 = 1- ss_res/ss_tot
            mse = ss_res/targets.shape[0]
        else:
            ss_res = torch.sum((targets - logs) ** 2)
            mu = torch.mean(targets, dim=0)
            ss_tot = torch.sum((targets - mu)**2)
            r2 = 1- ss_res/ss_tot
            mse = ss_res/targets.shape[0]
        return r2, mse
    def accuracy(logs, targets):
        preds = torch.round(logs) 
        correct = (preds == targets).float().sum(dim=0)
        acc = correct/preds.shape[0]
        return acc

    model.eval()
    results = dict()
    logs = []
    targets = []
    same_logs = []
    same_tgts = []
    predict_method = model.split_step if args.train_method == "regression" else model.predict_regression
    with torch.no_grad():
        for n_batch, batch in enumerate(tqdm(dataloader)):
            # Unpack index + batch
            imgs, gt_reps, latents = batch
            data = predict_method((imgs.cuda(), gt_reps.cuda(), latents.cuda()))
            logs.append(data['logits'])
            targets.append(data['targets'])
            if "same_logits" in data:
                same_logs.append(data['same_logits'])
                same_tgts.append(data['same_tgts'])
        logs = torch.cat(logs, dim=0)    
        targets = torch.cat(targets, dim=0).float()
       
        results['r2_reg'], results['mse_reg'] = regression_metrics(logs,targets)
        results['acc_reg'] = accuracy(logs, targets)
        if len(same_logs) > 0:
            results['r2_same'], results['mse_same'] = regression_metrics(same_logs,same_targets)
        else:
            results['r2_same'] = results['mse_same'] = -10
        return results

There was a problem when trying to write in your cache folder (/storage/cache). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


## 1. Definir Experimentos a Procesar

In [26]:
# Define experiment ids for checkpoint
exps = [
        "ze2poky9", # rep_train(SAME)+ RDM 
        "b6l0zvfq", # rep_train(SAME)+ RDM 
        "xf2a1c3k", # rep_train(SAME)+ RDM 
        "0ziv0je0", # rep_train(SAME)+ RDM 
        "6hwr35jz", # rep_train(SAME)+ RDM 
        "adm2vtti", # rep_train(SAME)+ RDM 
        "aefm16na", # rep_train(SAME)+ RDM 
        "co93f2g7", # rep_train(SAME)+ RDM 
        "i221poen" # rep_train(SAME)+ RDM 
       ]


## Creación Archivos CSV (Corren Regresión en Dataset)

In [27]:
for exp_id in tqdm(exps):
    args = get_args(exp_id)
    args.encoder['pretrain_method'] = None
    print(args)
    df = pd.DataFrame()
    for split in tqdm(['train','id','ood']):
        encoder, modulator, regressor = create_model(args)
        if args.train_method == "regression":
            model = LightningRegression.load_from_checkpoint(checkpoint_path=f"results/{args.dataset}/{exp_id}/last.ckpt", 
                                            args=args, 
                                            encoder=encoder, 
                                            modulator=modulator,
                                            regressor=regressor)
        else:
            model = LightningVersatile.load_from_checkpoint(checkpoint_path=f"results/{args.dataset}/{exp_id}/last.ckpt", 
                                            args=args, 
                                            encoder=encoder, 
                                            modulator=modulator,
                                            regressor=regressor)
        if split in ['train','id']:
            indices = torch.load(f"{args.dataset}/{args.dataset}_{args.sub_dataset}_train_indices.pth")
            train_indices, val_indices = train_test_split(indices, test_size = 0.1, random_state=42)
            indices = train_indices if split == "train" else val_indices
        elif split == "ood":
            indices = torch.load(f"{args.dataset}/{args.dataset}_{args.sub_dataset}_test_indices.pth")
        else:
            print("Split not recognized!")
            indices = []
        dl = get_dataloader(args, indices)
        results = evaluate(model, dl)
        r2 = results['r2_reg']
        mse = results['mse_reg']
        r2_same = results['r2_same']
        mse_same = results['mse_same']
        acc = results['acc_reg']
        # Store metadata
        meta = {
            'split': split,
            'dataset': args.dataset,
            'sub_dataset': args.sub_dataset,
            'model': get_name_from_args(args)
        }
        
        # Create a long-format DataFrame where each row is a (task, r2) pair
        rows = []
        for i, fov in enumerate(args.FOVS_PER_DATASET):
            rows.append({
                **meta,
                'task': fov,
                'r2': r2[i].item(),
                'mse': mse[i].item(),
                "acc": acc[i].item()*100
            })
        rows.append({**meta, 'task': "same", "r2": r2_same ,'mse': mse_same})
        # Append to df
        result_df = pd.DataFrame(rows)
        df = pd.concat([df, result_df], ignore_index=True)
    print(df)
    df.to_csv(f"results/{args.dataset}/{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv")

  0%|          | 0/9 [00:00<?, ?it/s]

{'lr': 0.001, 'wd': 0.04, 'arch': 'none', 'fovs': ['object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'h_axis', 'v_axis'], 'seed': 333, 'test': False, 'frozen': True, 'n_fovs': {'h_axis': 40, 'v_axis': 40, 'object_size': 2, 'object_color': 6, 'object_shape': 6, 'camera_height': 3, 'background_color': 3}, 'warmup': 6.666666666666667, 'dataset': 'mpi3d', 'encoder': {'arch': 'none', 'frozen': True, 'enc_dims': 16, 'pretrained': None, 'pretrain_method': None}, 'data_dir': '/mnt/nas2/GrimaRepo/araymond', 'enc_dims': 16, 'final_lr': 1e-06, 'final_wd': 0.4, 'fovs_ids': [0, 1, 2, 3, 4, 5, 6], 'mod_arch': 'mlp', 'mod_dims': 128, 'start_lr': 0.0002, 'train_bs': 256, 'ema_start': 0.996, 'ipe_scale': 1, 'modulator': {'arch': 'mlp', 'hidden_dim': 128}, 'num_steps': 200000, 'resume_id': None, 'fovs_tasks': ['object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'h_axis', 'v_axis'], 'num_epochs': 50, 'save_every': 10, 'fovs_levels': {'3dsh

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/253 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/29 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/732 [00:00<?, ?it/s]

    split dataset  sub_dataset                           model  \
0   train   mpi3d  composition  rep_train_same_vit_l_32_random   
1   train   mpi3d  composition  rep_train_same_vit_l_32_random   
2   train   mpi3d  composition  rep_train_same_vit_l_32_random   
3   train   mpi3d  composition  rep_train_same_vit_l_32_random   
4   train   mpi3d  composition  rep_train_same_vit_l_32_random   
5   train   mpi3d  composition  rep_train_same_vit_l_32_random   
6   train   mpi3d  composition  rep_train_same_vit_l_32_random   
7   train   mpi3d  composition  rep_train_same_vit_l_32_random   
8      id   mpi3d  composition  rep_train_same_vit_l_32_random   
9      id   mpi3d  composition  rep_train_same_vit_l_32_random   
10     id   mpi3d  composition  rep_train_same_vit_l_32_random   
11     id   mpi3d  composition  rep_train_same_vit_l_32_random   
12     id   mpi3d  composition  rep_train_same_vit_l_32_random   
13     id   mpi3d  composition  rep_train_same_vit_l_32_random   
14     id 

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  extrapolation  rep_train_same

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  interpolation  rep_train_same

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/253 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/29 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/732 [00:00<?, ?it/s]

    split dataset  sub_dataset                           model  \
0   train   mpi3d  composition  rep_train_same_vit_l_32_random   
1   train   mpi3d  composition  rep_train_same_vit_l_32_random   
2   train   mpi3d  composition  rep_train_same_vit_l_32_random   
3   train   mpi3d  composition  rep_train_same_vit_l_32_random   
4   train   mpi3d  composition  rep_train_same_vit_l_32_random   
5   train   mpi3d  composition  rep_train_same_vit_l_32_random   
6   train   mpi3d  composition  rep_train_same_vit_l_32_random   
7   train   mpi3d  composition  rep_train_same_vit_l_32_random   
8      id   mpi3d  composition  rep_train_same_vit_l_32_random   
9      id   mpi3d  composition  rep_train_same_vit_l_32_random   
10     id   mpi3d  composition  rep_train_same_vit_l_32_random   
11     id   mpi3d  composition  rep_train_same_vit_l_32_random   
12     id   mpi3d  composition  rep_train_same_vit_l_32_random   
13     id   mpi3d  composition  rep_train_same_vit_l_32_random   
14     id 

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  extrapolation  rep_train_same

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  interpolation  rep_train_same

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/253 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/29 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/732 [00:00<?, ?it/s]

    split dataset  sub_dataset                           model  \
0   train   mpi3d  composition  rep_train_same_vit_l_32_random   
1   train   mpi3d  composition  rep_train_same_vit_l_32_random   
2   train   mpi3d  composition  rep_train_same_vit_l_32_random   
3   train   mpi3d  composition  rep_train_same_vit_l_32_random   
4   train   mpi3d  composition  rep_train_same_vit_l_32_random   
5   train   mpi3d  composition  rep_train_same_vit_l_32_random   
6   train   mpi3d  composition  rep_train_same_vit_l_32_random   
7   train   mpi3d  composition  rep_train_same_vit_l_32_random   
8      id   mpi3d  composition  rep_train_same_vit_l_32_random   
9      id   mpi3d  composition  rep_train_same_vit_l_32_random   
10     id   mpi3d  composition  rep_train_same_vit_l_32_random   
11     id   mpi3d  composition  rep_train_same_vit_l_32_random   
12     id   mpi3d  composition  rep_train_same_vit_l_32_random   
13     id   mpi3d  composition  rep_train_same_vit_l_32_random   
14     id 

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  interpolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  interpolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  interpolation  rep_train_same

  0%|          | 0/3 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/274 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/31 [00:00<?, ?it/s]

using pretrained reps...


  0%|          | 0/709 [00:00<?, ?it/s]

    split dataset    sub_dataset                           model  \
0   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
1   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
2   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
3   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
4   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
5   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
6   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
7   train   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
8      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
9      id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
10     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
11     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
12     id   mpi3d  extrapolation  rep_train_same_vit_l_32_random   
13     id   mpi3d  extrapolation  rep_train_same

## Metadata 

In [3]:
# 
exps = [
    "qp96yw83", # regression baseline
    "tgacyk73", # regression baseline
    "wfjnpd9k", # regression baseline
    "4pa7icuy", # regression baseline
    "3g8h1nfq", # regression baseline
    "aphsrlwg", # regression baseline
    "clcf7oin", # regression baseline
    "t4nfx11h", # regression baseline
    "t9w5qz8g",  # regression baseline
    'udnuj8i3', # rep_train(SAME, LINOP) + RDM
        '3if4e740',  # rep_train(SAME, LINOP) + RDM
        '8i0422lj',  # rep_train(SAME, LINOP) + RDM
        'ddc3hg7w',  # rep_train(SAME, LINOP) + RDM
        'e5b4alno',  # rep_train(SAME, LINOP) + RDM
        'g8kvz2d8', # rep_train(SAME, LINOP) + RDM
        'mcuz1c1m', # rep_train(SAME, LINOP) + RDM
        'woda8yw6', # rep_train(SAME, LINOP) + RDM
        'xka5ng1n', # rep_train(SAME, LINOP) + RDM
        "s9dcfkpf", # rep_train(SAME, LINOP) + RDM +UN
        "o0eqrqin", # rep_train(SAME, LINOP) + RDM +UN
        "o0uqa97p", # rep_train(SAME, LINOP) + RDM +UN
        "0wsipw6c", # rep_train(SAME, LINOP) + RDM +UN
        "onxxjg6c", # rep_train(SAME, LINOP) + RDM +UN
        "7e0qfq42", # rep_train(SAME, LINOP) + RDM +UN
        "0pyeiovd", # rep_train(SAME, LINOP) + RDM +UN
        "bjdyhddm", # rep_train(SAME, LINOP) + RDM +UN
        "btqvlusd", # rep_train(SAME, LINOP) + RDM +UN
        "ze2poky9", # rep_train(SAME)+ RDM 
        "b6l0zvfq", # rep_train(SAME)+ RDM 
        "xf2a1c3k", # rep_train(SAME)+ RDM 
        "0ziv0je0", # rep_train(SAME)+ RDM 
        "6hwr35jz", # rep_train(SAME)+ RDM 
        "adm2vtti", # rep_train(SAME)+ RDM 
        "aefm16na", # rep_train(SAME)+ RDM 
        "co93f2g7", # rep_train(SAME)+ RDM 
        "i221poen", # rep_train(SAME)+ RDM 
        "5dfeqofg", # MOD REGRESSION + RDM
        "l78rkpl6", # MOD REGRESSION + RDM
        "5e6ut80u", # MOD REGRESSION + RDM
        "nvw4hweu", # MOD REGRESSION + RDM
        "2bntzmyt", # MOD REGRESSION + RDM
        "5glq8j5y", # MOD REGRESSION + RDM
        "sw308xe5", # MOD REGRESSION + RDM
        "u9otp2tk", # MOD REGRESSION + RDM
        "sfiiqnyw" # MOD REGRESSION + RDM
       ]

exps = [ # 3dshapes
    "ntlfv56t", # rep_train same + REG + RDM + FILM + UN_MOD
    "ai37d6yl", # rep_train same + REG + RDM + FILM + UN_MOD
    "cvmqh7zk", # rep_train same + REG + RDM + FILM + UN_MOD
    "gqif1rpl", # rep_train same + REG + RDM + FILM + UN_MOD
    "4qb6mbue", # rep_train same + REG + RDM + FILM + UN_MOD
    "8c5rpv6v", # rep_train same + REG + RDM + FILM + UN_MOD
    "91lbgcry", # rep_train same + REG + RDM + FILM + UN_MOD
    "jznsk0ar", # rep_train same + REG + RDM + FILM + UN_MOD
    "pz175tk9", # rep_train same + REG + RDM + FILM + UN_MOD
    "m3jzscq5", # rep_train same + REG + RDM + LINOP + UN_MOD
    "e99wn9i9", # rep_train same + REG + RDM + LINOP + UN_MOD
    "i1h2ptub", # rep_train same + REG + RDM + LINOP + UN_MOD
    "hs3twhwv", # rep_train same + REG + RDM + LINOP + UN_MOD
    "bzv1fds7", # rep_train same + REG + RDM + LINOP + UN_MOD
    "2xuk4rml", # rep_train same + REG + RDM + LINOP + UN_MOD
    "tgjidf5z", # rep_train same + REG + RDM + LINOP + UN_MOD
    "i35hbd07", # rep_train same + REG + RDM + LINOP + UN_MOD
    "20vqzi7o", # rep_train same + REG + RDM + LINOP + UN_MOD
    "96yex98j", # MODREG + RDM + LINOP + UN_MOD
    "ufplly0f", # MODREG + RDM + LINOP + UN_MOD
    "tv4m97cy", # MODREG + RDM + LINOP + UN_MOD
    "sl7cdgtr", # MODREG + RDM + LINOP + UN_MOD
    "rqibwv3l", # MODREG + RDM + LINOP + UN_MOD
    "nc422yyl", # MODREG + RDM + LINOP + UN_MOD
    "bm1mg0hh", # MODREG + RDM + LINOP + UN_MOD
    "62t4uxfa", # MODREG + RDM + LINOP + UN_MOD
    "zhjiji9k", # MODREG + RDM + LINOP + UN_MOD
    "fqd7lb6c", # rep_train same + REG + RDM + LINOP
    "07rd68b4", # rep_train same + REG + RDM + LINOP
    "0fusfr05", # rep_train same + REG + RDM + LINOP
    "46swehpw", # rep_train same + REG + RDM + LINOP
    "p9ox7fcm", # rep_train same + REG + RDM + LINOP
    "u51k0c6s", # rep_train same + REG + RDM + LINOP
    "vjtx7e81", # rep_train same + REG + RDM + LINOP
    "4qzis3s0", # rep_train same + REG + RDM + LINOP
    "ujlxzsb9", # rep_train same + REG + RDM + LINOP
    "n49kztk8", # mod regression + RDM + LATDIR
    "6g833bpr", # mod regression + RDM + LATDIR
    "42wx79f8", # mod regression + RDM + LATDIR
    "awdihwps", # mod regression + RDM + LATDIR
    "m4sf4hkk", # mod regression + RDM + LATDIR
    "6ja2ednk", # mod regression + RDM + LATDIR
    "omi3q31m", # mod regression + RDM + LATDIR
    "3f7highk", # mod regression + RDM + LATDIR
    "5e2h1l0e", # mod regression + RDM + LATDIR
    "nbtscppe", # rep_train same + REG + RDM + LATDIR
    "g1wkqai8", # rep_train same + REG + RDM + LATDIR
    "gp51jqdv", # rep_train same + REG + RDM + LATDIR
    "f73kjyvt", # rep_train same + REG + RDM + LATDIR
    "hc2fqhgt", # rep_train same + REG + RDM + LATDIR
    "j2kp78sn", # rep_train same + REG + RDM + LATDIR
    "nbtscppe", # rep_train same + REG + RDM + LATDIR
    "onvvto74", # rep_train same + REG + RDM + LATDIR
    "xj5bewsj", # rep_train same + REG + RDM + LATDIR
    
    "ceeun18x", # mod regression + RDM + LINOP
    "0tb8d0of", # mod regression + RDM + LINOP
    "0cf0y4op", # mod regression + RDM + LINOP
    "v9ku7xke", # mod regression + RDM + LINOP
    "g9fvdqs6", # mod regression + RDM + LINOP
    "i0gfw7sg", # mod regression + RDM + LINOP
    "o5vtujdk", # mod regression + RDM + LINOP
    "rx6nzxkx", # mod regression + RDM + LINOP
    "soub0i38", # mod regression + RDM + LINOP
    "i6msf1v8", # cnn modregression + random
    "rvei4qfz", # cnn modregression + random
    "s21bpu3a", # cnn modregression + random
    "1s89r7jg", # cnn modregression + random
    "uf7so6b2", # cnn modregression + random
    "onj3hzza", # cnn modregression + random
    "z0ijmy7x", # baseline regression scratch
    "yj9ixz8m", # baseline regression scratch
    "wet3xlav", # baseline regression scratch
    "rj8kmlgh", # baseline regression scratch
    "d5y0fcks", # baseline regression scratch
    "1ft0hp2e", # baseline regression scratch
    "dbe02hpj", # baseline regression scratch
    "cb0tgzzn", # baseline regression scratch
    "y1q6brv7", # baseline regression scratch
    "ec23sm1i", # CNN + rep_train(same) + reg + random
    "a7ws7cmq", # CNN + rep_train(same) + reg + random
    "zil44kpc", # CRASHED
    "5tv259yn", # CNN + rep_train(same) + reg + random
    "yr1dv4v6", # CNN + rep_train(same) + reg + random
    "zhqvr0lz", # CNN + rep_train(same) + reg + random 
    "t3gy9jk0", # CNN + rep_train(same) + reg + random
    "etui8as9", # CNN + rep_train(same) + reg + random
    "khxm7swv", # CNN + rep_train(same) + reg + random
    "c01smllr", # rep_train (residual) + same loss + regresion
    "2gk0zv3g",  # rep_train (residual) + same loss + regresion 
    "gzf0su5r",  # rep_train (residual) + same loss + regresion
    "ee2s99re",  # rep_train (residual) + same loss + regresion 
    "9gsdaqxd",  # rep_train (residual) + same loss + regresion 
    "cgrsl3av",  # rep_train (residual) + same loss + regresion 
    "crlqukzd",  # rep_train (residual) + same loss + regresion + random
    "qs7tezcp",  # rep_train (residual) + same loss + regresion + random
    "teh0lkys",  # rep_train (residual) + same loss + regresion + random
    "hz324tl1",  # rep_train (residual) + same loss + regresion + random
    "ijgamtnm",  # rep_train (residual) + same loss + regresion + random
    "wsd9ajz4", # rep_train (residual) + same loss + regresion + random
    "c3i8m2li", # rep_train + regression + random
    "le9tapxr",  # rep_train + regression + random
    "m3s9rawj",  # rep_train + regression + random
    "oykmci7n",  # rep_train + regression + random
    "13hgsljh", # rep train same + random
    "6kgq972p",  # rep train same + random
    "qr3o8g18", # rep_train same + random
    "ggr210j7", # rep_train same + random
    "btlbgiyw", # rep_train same + random
    "4qwcrwdo", # rep_train same + random
    "jhlkudqc", # mod regression + random
    "dy0g90yf", # mod regression + random
    "l7287kha", # mod regression + random
    "pdvuzeik", # mod regression + random
    "5nneg1v2", # mod regression + random
    "slhukb7t", # mod regression + random
    "6vkh3xiw", # rep_train same 
    "75t0lezc",  # rep_train same 
    "aj0sty26",  # rep_train same 
    "vwmjtne3",  # rep_train same 
    "i7crcxpp",  # rep_train same 
    "nrc8fsoq",  # rep_train same 
    "gee639pm", # baseline regression
    "kcpnt121", # baseline regression
    "z1fo3hf6", # baseline regression
    "bq4qhtxu", # baseline regression
    "f7yr1a78", # baseline regression
    "v8f31pyn", # baseline regression
    
    "c3i8m2li", # rep_train_plus extrapolation + random
    "le9tapxr", # rep_train_plus extrapolation + random
    "m3s9rawj", # rep_train_plus correct interpolation + ramdom
    "oykmci7n", # rep_train_plus correct interpolation + random
    "7l4txq1v", # rep_train_plus extrapolation
    "xg98v4xi", # rep_train_plus extrapolation
    "2r381uwd", # rep_train_plus extrapolation
    "tfa3zh2n", # rep_train_plus correct
    
    "sfl4cr53", # mod_regression extrapolation
    "91ktdndl", # mod regression
    "b36si90c", # mod regression
    "08vym6tg", # mod regression
    "irlsradb",# mod regression
    "esef85yi", # non_mod_regression composition
    "kpui77ip", # non_mod_regression composition
    "ihw7cd8h", # mod_regression composition
     "bqn7ytvr", # non mod regression interpolation
    "50jpen48", # mod regression interpolation
    "w9epqele", # non mod regression interpolation
    "np596yq5", # mod regression interpolation
     "co7i15y6", # rep_train (film) + SAME + REG + random
    "aa4xwwhq", # rep_train (film) + SAME + REG + random
    "oczdrid9", # rep_train (film) + SAME + REG + random
    "wbjo218x", # rep_train (film) + SAME + REG + random
    "2jua32og", # rep_train (film) + SAME + REG + random
    "53e9n83t", # rep_train (film) + SAME + REG + random
    "aeqigwxc", # rep_train (film) + SAME + REG + random
    "kskpufht", # rep_train (film) + SAME + REG + random
    "un7b247p", # rep_train (film) + SAME + REG + random
    "e0cn4e4f", # rep_train (transformer) + SAME + REG + random
    "fyncud9n", # rep_train (transformer) + SAME + REG + random
    "r32wp7e3", # rep_train (transformer) + SAME + REG + random
    "nmy0v96r", # rep_train (transformer) + SAME + REG + random
    "fqzsr1vf", # rep_train (transformer) + SAME + REG + random
    "mot5a4ku", # rep_train (transformer) + SAME + REG + random
    "q7fkbt0g", # rep_train (transformer) + SAME + REG + random
    "qj42vq3i", # rep_train (transformer) + SAME + REG + random
    "sxigun9n", # rep_train (transformer) + SAME + REG + random
    "7k8mxkwj", # mod_regression (transformer) + random
    "d2assvyp", # mod_regression (transformer) + random
    "bmb752uz", # mod_regression (transformer) + random
    "dax881de", # mod_regression (transformer) + random
    "87cd47fx", # mod_regression (transformer) + random
    "a2waa2j0", # mod_regression (transformer) + random
    "nsbutuwv", # mod_regression (transformer) + random
    "ntxtml07", # mod_regression (transformer) + random
    "zyzjokml", # mod_regression (transformer) + random
    "7k8mxkwj", # mod_regression (film) + random,
    "d2assvyp", # mod_regression (film) + random,
    "bmb752uz", # mod_regression (film) + random,
    "dax881de", # mod_regression (film) + random,
    "87cd47fx", # mod_regression (film) + random,
    "a2waa2j0", # mod_regression (film) + random,
    "nsbutuwv", # mod_regression (film) + random,
    "ntxtml07", # mod_regression (film) + random,
    "zyzjokml", # mod_regression (film) + random,
    "co7i15y6", # rep_train (film) + SAME + REG + random
    "aa4xwwhq", # rep_train (film) + SAME + REG + random
    "oczdrid9", # rep_train (film) + SAME + REG + random
    "wbjo218x", # rep_train (film) + SAME + REG + random
    "2jua32og", # rep_train (film) + SAME + REG + random
    "53e9n83t", # rep_train (film) + SAME + REG + random
    "aeqigwxc", # rep_train (film) + SAME + REG + random
    "kskpufht", # rep_train (film) + SAME + REG + random
    "un7b247p", # rep_train (film) + SAME + REG + random
    "e0cn4e4f", # rep_train (transformer) + SAME + REG + random
    "fyncud9n", # rep_train (transformer) + SAME + REG + random
    "r32wp7e3", # rep_train (transformer) + SAME + REG + random
    "nmy0v96r", # rep_train (transformer) + SAME + REG + random
    "fqzsr1vf", # rep_train (transformer) + SAME + REG + random
    "mot5a4ku", # rep_train (transformer) + SAME + REG + random
    "q7fkbt0g", # rep_train (transformer) + SAME + REG + random
    "qj42vq3i", # rep_train (transformer) + SAME + REG + random
    "sxigun9n", # rep_train (transformer) + SAME + REG + random
    "7k8mxkwj", # mod_regression (transformer) + random
    "d2assvyp", # mod_regression (transformer) + random
    "bmb752uz", # mod_regression (transformer) + random
    "dax881de", # mod_regression (transformer) + random
    "87cd47fx", # mod_regression (transformer) + random
    "a2waa2j0", # mod_regression (transformer) + random
    "nsbutuwv", # mod_regression (transformer) + random
    "ntxtml07", # mod_regression (transformer) + random
    "zyzjokml" # mod_regression (transformer) + random '''
#     ]
]

names = {  
           'mod_regression_latdir_vit_l_32_random': "MODREG(LATDIR)+RDM",
           'rep_train_same_linop_vit_l_32_random': "LSM(SAME,LINOP)+REG+RDM",
           'rep_train_same_linop_vit_l_32_random_un': "LSM(SAME,LINOP)+REG+RDM+UN",
           'rep_train_same_latdir_vit_l_32_random': "LSM(SAME,LATDIR)+REG+RDM",
           'mod_regression_linop_vit_l_32_random':  "MODREG(LINOP)+RDM",
           'mod_regression_linop_vit_l_32_random_un':  "MODREG(LINOP)+RDM+UN",
           'mod_regression_vit_l_32': "MODREG",
           "mod_regression_mod_regressioncnn_random":"CNN-MODREG+RDM",
           'mod_regression_vit_l_32_random': "MODREG+RDM" ,
           'non_mod_regression_vit_l_32': "NON-MODREG+RDM", 
           'regressioncnn': "CNN-REG (BASELINE)",
           'rep_train_plus_vit_l_32': "LSM(CLASS)+REG",
           'rep_train_plus_vit_b_32': "LSM(CLASS)+REG (b32)",
           'rep_train_plus_vit_l_32_random': "LSM(CLASS)+REG+RDM",
           'rep_train_same_res_vit_l_32': "LSM(SAME,RES)+REG",  
           'rep_train_same_res_vit_l_32_random': "LSM(SAME,RES)+REG+RDM",
           'rep_train_same_vit_l_32':"LSM(SAME)+REG",  
           'rep_train_same_vit_l_32_random': "LSM(SAME)+REG+RDM",
           'rep_train_same_rep_train_samecnn_random': "CNN-LSM(SAME)+REG+RDM",
           'vit_l_32': "REG (BASELINE)",
           'mod_regression_vit_b_32': "MODREG",
           'non_mod_regression_vit_b_32': "NON-MODREG",
           "rep_train_same_trans_vit_l_32_random": "LSM(SAME,TRANS)+REG+RDM",
           "rep_train_same_film_vit_l_32_random": "LSM(SAME,FILM)+REG+RDM",
           "rep_train_same_film_vit_l_32_random_un": "LSM(SAME,FILM)+REG+RDM+UN",
           "mod_regression_trans_vit_l_32_random": "MODREG(TRANS)+RDM",
           "mod_regression_film_vit_l_32_random": "MODREG(FILM)+RDM",
           "mod_regression_film_vit_l_32_random_un": "MODREG(FILM)+RDM+UN"
        }
custom_order = ['CNN-REG (BASELINE)', 
                'REG (BASELINE)',
               # 'NON-MODREG',
                #'NON-MODREG+RDM',
                'MODREG',
                "CNN-MODREG+RDM",
                'MODREG+RDM',
                "MODREG(TRANS)+RDM",
                "MODREG(FILM)+RDM",
                "MODREG(FILM)+RDM+UN",
                "MODREG(LATDIR)+RDM",
                "MODREG(LINOP)+RDM",
                "MODREG(LINOP)+RDM+UN",
                'LSM(CLASS)+REG',
                #"LSM(CLASS)+REG (b32)",
                "CNN-LSM(SAME)+REG+RDM",
                'LSM(CLASS)+REG+RDM',
                'LSM(SAME)+REG',
                'LSM(SAME)+REG+RDM',
                'LSM(SAME,RES)+REG', 
                'LSM(SAME,RES)+REG+RDM',
                "LSM(SAME,TRANS)+REG+RDM",
                "LSM(SAME,FILM)+REG+RDM",
                "LSM(SAME,FILM)+REG+RDM+UN",
                "LSM(SAME,LINOP)+REG+RDM",
                "LSM(SAME,LINOP)+REG+RDM+UN",
                "LSM(SAME,LATDIR)+REG+RDM"
     ]
exclude_models = {"NON-MODREG", "LSM(CLASS)+REG (b32)","NON-MODREG+RDM"}



## Procesamiento Archivos CSV

In [36]:
import pandas as pd
from tqdm.notebook import tqdm

dataset = "mpi3d"
final_result = pd.DataFrame()
for exp_id in tqdm(exps):
    args = get_args(exp_id)
    filename = f"results/{args.dataset}/{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv"
    exp_id, dataset, sub_dataset = filename.split("_")[:3]
    pretrained_reps = "-".join(filename.split("_")[3:]).replace(".csv","")
    df= pd.read_csv(filename)
    df.rename(columns={'correct': 'accuracy'}, inplace=True)
    df = df.pivot_table(
        index=["split","dataset",'sub_dataset', 'model'],
        columns='task',
        values=['acc',"r2","mse"]
    ).reset_index()
    final_result = pd.concat([final_result, df])


# Save as file
factors = {
        '3dshapes': ['floor_hue', 'object_hue', 'orientation', 'scale', 'shape', 'wall_hue'],
          'mpi3d': ['background_color', 'camera_height','h_axis', 'object_color', 'object_shape', 'object_size','v_axis'],
        'dsprites':  ["shape","scale","orientation","x","y"]
          }


print(final_result)
# Get average column value for high level evaluation
final_result['acc_mean'] = final_result['acc'][factors[dataset]].mean(axis=1)
final_result['r2_mean']  = final_result['r2'][factors[dataset]].mean(axis=1)
final_result['mse_mean'] = final_result['mse'][factors[dataset]].mean(axis=1)
final_result.to_csv(f"results/{dataset}/results_{dataset}.csv")

  0%|          | 0/45 [00:00<?, ?it/s]

      split dataset    sub_dataset                           model  \
task                                                                 
0        id   mpi3d    composition                        vit_l_32   
1       ood   mpi3d    composition                        vit_l_32   
2     train   mpi3d    composition                        vit_l_32   
0        id   mpi3d    composition                        vit_l_32   
1       ood   mpi3d    composition                        vit_l_32   
..      ...     ...            ...                             ...   
1       ood   mpi3d    composition  mod_regression_vit_l_32_random   
2     train   mpi3d    composition  mod_regression_vit_l_32_random   
0        id   mpi3d  interpolation  mod_regression_vit_l_32_random   
1       ood   mpi3d  interpolation  mod_regression_vit_l_32_random   
2     train   mpi3d  interpolation  mod_regression_vit_l_32_random   

                  acc                                                     \
task backgrou

In [31]:
final_result.columns

MultiIndex([(      'split',                 ''),
            (    'dataset',                 ''),
            ('sub_dataset',                 ''),
            (      'model',                 ''),
            (        'acc', 'background_color'),
            (        'acc',    'camera_height'),
            (        'acc',           'h_axis'),
            (        'acc',     'object_color'),
            (        'acc',     'object_shape'),
            (        'acc',      'object_size'),
            (        'acc',           'v_axis'),
            (        'mse', 'background_color'),
            (        'mse',    'camera_height'),
            (        'mse',           'h_axis'),
            (        'mse',     'object_color'),
            (        'mse',     'object_shape'),
            (        'mse',      'object_size'),
            (        'mse',             'same'),
            (        'mse',           'v_axis'),
            (         'r2', 'background_color'),
            (       

## Resultados Generales

In [4]:
import pandas as pd
dataset = "3dshapes"
df = pd.read_csv(f"results/{dataset}/results_{dataset}.csv")
df = df.drop(0)
df = df[['split','dataset','sub_dataset','model','acc_mean','r2_mean','mse_mean']]
# Ordenar y limpiar nombres de modelo
df['model'] = df['model'].apply(lambda x: names[x])
df['model'] = pd.Categorical(df['model'], categories=custom_order, ordered=True)
df = df[~df["model"].isin(exclude_models)]
df = df.sort_values('model')
grouped = df.groupby(['split', 'dataset','sub_dataset', 'model']).agg(['mean', 'std'])
task_cols =['r2_mean','acc_mean','mse_mean']
# Flatten MultiIndex columns: ('floor_hue', 'mean') → 'floor_hue_mean'
grouped.columns = ['_'.join(col).strip() for col in grouped.columns.values]
grouped = grouped.reset_index()

# Format each task column as "mean ± std"
for col in task_cols:
    
    mean_col = f'{col}_mean'
    std_col = f'{col}_std'
    grouped[col] = grouped.apply(
        lambda row: f"{row[mean_col]:.3f} ± {row[std_col]:.3f}", axis=1
    )

# Drop the now-redundant mean and std columns
grouped = grouped.drop(columns=[f'{col}_mean' for col in task_cols] + [f'{col}_std' for col in task_cols])

df = grouped.pivot_table(
    index=["split","dataset", 'model'],
    columns='sub_dataset',
    values=['acc_mean',"r2_mean","mse_mean"],
    aggfunc="first"   # ← Just take the first, don't try to average
)
# 2. Swap levels so that sub_dataset comes first, then metric
df.columns = df.columns.swaplevel(0, 1)

# 3. Sort columns so sub_dataset1 → [acc, r2, mse], then sub_dataset2 → ...
df = df.sort_index(axis=1, level=0)

# 4. Reset index as usual
df = df.reset_index()

  grouped = df.groupby(['split', 'dataset','sub_dataset', 'model']).agg(['mean', 'std'])
  df = grouped.pivot_table(


In [5]:
import numpy as np
df[df['split']=='ood'].replace(r"nan ± nan", np.nan, regex=True).dropna(how="any")

sub_dataset,split,dataset,model,composition,composition,composition,extrapolation,extrapolation,extrapolation,interpolation,interpolation,interpolation
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,acc_mean,mse_mean,r2_mean,acc_mean,mse_mean,r2_mean,acc_mean,mse_mean,r2_mean
24,ood,3dshapes,CNN-REG (BASELINE),99.999 ± 0.002,0.001 ± 0.000,1.000 ± 0.000,71.979 ± 0.881,10.679 ± 0.239,-0.183 ± 0.019,79.170 ± 0.092,0.191 ± 0.007,0.974 ± 0.001
26,ood,3dshapes,MODREG,82.600 ± 2.748,0.480 ± 0.046,0.939 ± 0.002,67.183 ± 0.590,5.738 ± 0.740,0.530 ± 0.036,79.479 ± 3.890,0.343 ± 0.098,0.962 ± 0.010
29,ood,3dshapes,MODREG(TRANS)+RDM,50.831 ± 25.023,3.244 ± 3.029,0.543 ± 0.452,66.143 ± 1.551,4.873 ± 0.195,0.591 ± 0.022,73.629 ± 0.376,0.381 ± 0.038,0.954 ± 0.004
32,ood,3dshapes,MODREG(LATDIR)+RDM,89.856 ± 1.519,0.306 ± 0.204,0.967 ± 0.017,70.436 ± 0.802,5.057 ± 0.097,0.575 ± 0.011,77.905 ± 1.190,0.293 ± 0.024,0.964 ± 0.004
33,ood,3dshapes,MODREG(LINOP)+RDM,90.416 ± 3.243,0.142 ± 0.028,0.983 ± 0.002,69.947 ± 0.477,4.012 ± 0.149,0.634 ± 0.018,85.694 ± 0.232,0.204 ± 0.017,0.976 ± 0.002
34,ood,3dshapes,MODREG(LINOP)+RDM+UN,88.126 ± 6.735,0.351 ± 0.344,0.968 ± 0.022,70.351 ± 0.054,4.315 ± 0.282,0.629 ± 0.029,79.779 ± 0.466,0.250 ± 0.006,0.969 ± 0.001
35,ood,3dshapes,LSM(CLASS)+REG,53.175 ± nan,0.735 ± nan,0.892 ± nan,61.492 ± nan,5.212 ± nan,0.544 ± nan,79.319 ± nan,0.275 ± nan,0.964 ± nan
36,ood,3dshapes,CNN-LSM(SAME)+REG+RDM,86.080 ± 15.128,0.186 ± 0.236,0.974 ± 0.033,70.865 ± 0.536,8.864 ± 1.355,0.036 ± 0.157,79.084 ± 1.965,0.185 ± 0.020,0.974 ± 0.002
42,ood,3dshapes,"LSM(SAME,TRANS)+REG+RDM",63.722 ± 6.259,1.313 ± 0.429,0.835 ± 0.048,67.635 ± 0.595,5.214 ± 1.027,0.570 ± 0.055,72.280 ± 2.319,0.448 ± 0.015,0.949 ± 0.001
43,ood,3dshapes,"LSM(SAME,FILM)+REG+RDM",93.855 ± 0.421,0.183 ± 0.018,0.980 ± 0.002,70.049 ± 0.224,4.945 ± 0.498,0.587 ± 0.022,81.081 ± 1.015,0.234 ± 0.014,0.972 ± 0.002


## Análisis más detallado por columna

In [140]:
df = pd.read_csv("results/3dshapes/results_3dshapes.csv")
metric = "acc"
print(df)
df = df.drop(0)

task_cols = ['floor_hue', 'object_hue', 'orientation', 'scale', 'shape', 'wall_hue']
new_task_cols = [f"{metric}.{i}" if i > 0 else metric for i, x in enumerate(task_cols)]
df[new_task_cols] = df[new_task_cols].apply(pd.to_numeric, errors='coerce')
# Ordenar y limpiar nombres de modelo
df['model'] = df['model'].apply(lambda x: names[x])
df['model'] = pd.Categorical(df['model'], categories=custom_order, ordered=True)
df = df[~df["model"].isin(exclude_models)]
df = df.sort_values('model')

models = ['train','id','ood']

# Group and compute mean and std
grouped = df.groupby(['split','dataset','sub_dataset', 'model'])[new_task_cols].agg(['mean', 'std'])

# Flatten MultiIndex columns: ('floor_hue', 'mean') → 'floor_hue_mean'
grouped.columns = ['_'.join(col).strip() for col in grouped.columns.values]
grouped = grouped.reset_index()

# Format each task column as "mean ± std"
for col in new_task_cols:
    
    mean_col = f'{col}_mean'
    std_col = f'{col}_std'
    grouped[col] = grouped.apply(
        lambda row: f"{row[mean_col]:.3f} ± {row[std_col]:.3f}", axis=1
    )

# Drop the now-redundant mean and std columns
grouped = grouped.drop(columns=[f'{col}_mean' for col in new_task_cols] + [f'{col}_std' for col in new_task_cols])
grouped.rename(columns={a:b for a, b in zip(new_task_cols, task_cols)}, inplace=True)
#result = df_model.groupby(['sub_dataset', 'model'])[['floor_hue', 'object_hue', 'orientation','scale','shape','wall_hue','r2']].mean().reset_index()
#  tables[model] = grouped



    Unnamed: 0  split   dataset    sub_dataset  \
0         task    NaN       NaN            NaN   
1            0     id  3dshapes    composition   
2            1    ood  3dshapes    composition   
3            2  train  3dshapes    composition   
4            0     id  3dshapes    composition   
..         ...    ...       ...            ...   
542          1    ood  3dshapes  interpolation   
543          2  train  3dshapes  interpolation   
544          0     id  3dshapes  extrapolation   
545          1    ood  3dshapes  extrapolation   
546          2  train  3dshapes  extrapolation   

                                    model                acc  \
0                                     NaN          floor_hue   
1    rep_train_same_linop_vit_l_32_random   99.9869704246521   
2    rep_train_same_linop_vit_l_32_random  95.54184079170228   
3    rep_train_same_linop_vit_l_32_random   99.9978244304657   
4    rep_train_same_linop_vit_l_32_random   99.9869704246521   
..             

  grouped = df.groupby(['split','dataset','sub_dataset', 'model'])[new_task_cols].agg(['mean', 'std'])


In [141]:
grouped[grouped['split']=="ood"]

Unnamed: 0,split,dataset,sub_dataset,model,floor_hue,object_hue,orientation,scale,shape,wall_hue
57,ood,3dshapes,composition,CNN-REG (BASELINE),100.000 ± 0.000,100.000 ± 0.000,99.994 ± 0.010,100.000 ± 0.000,99.999 ± 0.001,100.000 ± 0.000
58,ood,3dshapes,composition,REG (BASELINE),97.970 ± 0.608,89.530 ± 2.425,85.478 ± 0.231,91.664 ± 1.532,99.787 ± 0.041,97.107 ± 0.707
59,ood,3dshapes,composition,MODREG,70.610 ± 13.056,74.222 ± 8.971,77.933 ± 1.718,90.193 ± 0.571,99.853 ± 0.037,82.790 ± 7.107
60,ood,3dshapes,composition,CNN-MODREG+RDM,nan ± nan,nan ± nan,nan ± nan,nan ± nan,nan ± nan,nan ± nan
61,ood,3dshapes,composition,MODREG+RDM,98.004 ± 0.658,86.102 ± 3.812,78.428 ± 12.176,87.393 ± 0.420,99.828 ± 0.023,97.820 ± 0.193
62,ood,3dshapes,composition,MODREG(TRANS)+RDM,49.992 ± 29.759,37.777 ± 18.140,42.799 ± 20.118,49.755 ± 27.691,75.188 ± 36.171,49.475 ± 20.706
63,ood,3dshapes,composition,MODREG(LATDIR)+RDM,96.523 ± 4.097,82.312 ± 6.363,72.723 ± 10.345,91.975 ± 0.040,99.839 ± 0.037,95.766 ± 5.070
64,ood,3dshapes,composition,MODREG(LINOP)+RDM,98.075 ± 0.801,91.047 ± 1.376,73.094 ± 14.052,83.555 ± 5.318,99.853 ± 0.080,96.868 ± 0.877
65,ood,3dshapes,composition,LSM(CLASS)+REG,26.393 ± nan,47.386 ± nan,72.294 ± nan,34.711 ± nan,99.028 ± nan,39.240 ± nan
66,ood,3dshapes,composition,CNN-LSM(SAME)+REG+RDM,81.310 ± 20.508,77.914 ± 21.129,79.244 ± 20.926,94.293 ± 9.884,99.493 ± 0.879,84.226 ± 18.220


In [143]:
import re
import pandas as pd

def dataframe_to_tabularx(df, metric_cols=None, maximize=True):
    """
    Converts a DataFrame into LaTeX tabularx code with nice formatting:
    - Tiny stddev
    - Small font for numbers
    - Bold best value
    - Capitalized sub_dataset
    - Headers: _ replaced with space
    """

    # Infer metric columns if not provided
    if metric_cols is None:
        metric_cols = [col for col in df.columns if col not in ['sub_dataset', 'model', 'split', 'dataset']]

    # Capitalize headers and replace underscores with spaces
    header_cols = ['\\textbf{' + col.replace('_', ' ').upper() + '}' for col in ['SUB_DATASET', 'MODEL'] + metric_cols]

    # Begin LaTeX tabularx
    latex_code = []
    latex_code.append('\\begin{tabularx}{\\textwidth}{ll' + 'X' * len(metric_cols) + '}')
    latex_code.append('\\toprule')

    # Header row
    latex_code.append(' & '.join(header_cols) + ' \\\\')
    latex_code.append('\\midrule')

    last_sub_dataset = None

    # Find best per sub_dataset and metric
    best_values = {}
    for sub_dataset, group in df.groupby('sub_dataset'):
        best_values[sub_dataset] = {}
        for metric in metric_cols:
            values = group[metric].apply(lambda x: extract_main_value(x))
            if maximize:
                best_val = values.max()
            else:
                best_val = values.min()
            best_values[sub_dataset][metric] = best_val

    # For each row
    for idx, row in df.iterrows():
        current_sub_dataset = row['sub_dataset']
        if last_sub_dataset is not None and current_sub_dataset != last_sub_dataset:
            latex_code.append('\\midrule')
        last_sub_dataset = current_sub_dataset

        # Capitalize sub_dataset name
        sub_dataset_display = escape_latex(str(row['sub_dataset']).capitalize())
        model_display = escape_latex(str(row['model']))

        row_entries = [sub_dataset_display, model_display]

        for metric in metric_cols:
            value = row[metric]
            if pd.isna(value):
                formatted = ''
            else:
                main_val, std_val = extract_values(value)
                if main_val is None:
                    formatted = ''
                else:
                    # Check if this value is the best
                    is_best = (abs(main_val - best_values[current_sub_dataset][metric]) < 1e-3)
                    if std_val is not None:
                        formatted = f"\\scriptsize{{{main_val:.2f} \\tiny{{$\\pm$ {std_val:.2f}}}}}"
                    else:
                        formatted = f"\\scriptsize{{{main_val:.2f}}}"
                    if is_best:
                        formatted = f"\\textbf{{{formatted}}}"
            row_entries.append(escape_latex(formatted))

        latex_code.append(' & '.join(row_entries) + ' \\\\')

    # End LaTeX table
    latex_code.append('\\bottomrule')
    latex_code.append('\\end{tabularx}')

    return '\n'.join(latex_code)

# Helper function to extract mean and std from a string
def extract_values(value):
    if isinstance(value, str):
        match = re.match(r"([\d\.\-eE]+)\s*±\s*([\d\.\-eE]+)", value)
        if match:
            mean_val = float(match.group(1))
            std_val = float(match.group(2))
            if std_val != std_val:  # NaN check
                std_val = None
            return mean_val, std_val
        else:
            try:
                return float(value), None
            except:
                return None, None
    return None, None

# Helper function to extract only mean value
def extract_main_value(value):
    mean, _ = extract_values(value)
    return mean

# Helper function to escape LaTeX special characters
def escape_latex(s):
    if not isinstance(s, str):
        return s
    return (s.replace('\\', '\\textbackslash{}')
             .replace('&', '\\&')
             .replace('%', '\\%')
             .replace('$', '\\$')
             .replace('#', '\\#')
             .replace('_', '\\_')
             .replace('{', '\\{')
             .replace('}', '\\}')
             .replace('~', '\\textasciitilde{}')
             .replace('^', '\\textasciicircum{}'))
latex_code = dataframe_to_tabularx(
    grouped[grouped['split']=="ood"],
    metric_cols=['floor_hue', 'object_hue', 'orientation', 'scale', 'shape', 'wall_hue'],
    maximize=True
)

print(latex_code)

\begin{tabularx}{\textwidth}{llXXXXXX}
\toprule
\textbf{SUB DATASET} & \textbf{MODEL} & \textbf{FLOOR HUE} & \textbf{OBJECT HUE} & \textbf{ORIENTATION} & \textbf{SCALE} & \textbf{SHAPE} & \textbf{WALL HUE} \\
\midrule
Composition & CNN-REG (BASELINE) & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{100.00 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.00\}\}\} & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{100.00 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.00\}\}\} & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{99.99 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.01\}\}\} & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{100.00 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.00\}\}\} & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{100.00 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.00\}\}\} & \textbackslash\{\}textbf\{\textbackslash\{\}scriptsize\{100.00 \textbackslash\{\}tiny\{\$\textbackslash\{\}pm\$ 0.00\}\}\} 

In [28]:
def df_to_tabularx(df, label, caption, column_width='\\textwidth'):
    import io

    # Copy and format DataFrame
    df_fmt = df.copy()
    df_fmt.iloc[:, 0] = df_fmt.iloc[:, 0].str.capitalize()  # sub_dataset
    df_fmt.iloc[:, 1] = df_fmt.iloc[:, 1].str.upper()       # group

    # Format numeric columns as percentages
    for col in df.columns[2:]:
        df_fmt[col] = (df[col]).map(lambda x: f"{x:.2f}")

    # Build tabularx column format
    num_task_columns = df_fmt.shape[1] - 2
    column_format = 'll' + 'X' * num_task_columns

    # Prepare bold column headers
    bold_headers = [f"\\textbf{{{col.replace('_', ' ').capitalize()}}}" for col in df_fmt.columns]

    # Use to_latex to get the content, skipping header
    buf = io.StringIO()
    df_fmt.to_latex(
        buf,
        index=False,
        header=False,
        escape=False,
        column_format=column_format
    )
    lines = buf.getvalue().splitlines()

    # Insert bold header manually
    header_line = ' & '.join(bold_headers) + ' \\\\'
    table_body = '\n'.join(lines[3:-2])  # skip to_latex's \toprule, etc.

    # Final LaTeX table
    latex = (
        f"\\begin{{table}}[ht]\n"
        f"\\centering\n"
        f"\\caption{{{caption}}}\n"
        f"\\label{{{label}}}\n"
        f"\\begin{{tabularx}}{{{column_width}}}{{{column_format}}}\n"
        f"{header_line}\n"
        f"{table_body}\n"
        f"\\end{{tabularx}}\n"
        f"\\end{{table}}"
    )

    return latex




label = "tab:asda"
caption = "Hola"
print(df_to_tabularx(tables['ood'], label, caption, column_width='\\textwidth'))

\begin{table}[ht]
\centering
\caption{Hola}
\label{tab:asda}
\begin{tabularx}{\textwidth}{llXXXXXXX}
\textbf{Sub dataset} & \textbf{Model} & \textbf{Floor hue} & \textbf{Object hue} & \textbf{Orientation} & \textbf{Scale} & \textbf{Shape} & \textbf{Wall hue} & \textbf{R2} \\
Composition & MOD_REGRESSION_VIT_L_32 & 0.91 & 0.86 & 0.95 & 0.98 & 1.00 & 0.94 & 0.94 \\
Composition & MOD_REGRESSION_VIT_L_32_RANDOM & 0.99 & 0.93 & 0.97 & 0.98 & 1.00 & 0.99 & 0.98 \\
Composition & NON_MOD_REGRESSION_VIT_L_32 & 0.91 & 0.85 & 0.90 & 0.98 & 1.00 & 0.89 & 0.92 \\
Composition & REGRESSIONCNN & 1.00 & 1.00 & 1.00 & 1.00 & 1.00 & 1.00 & 1.00 \\
Composition & REP_TRAIN_PLUS_VIT_L_32 & 0.85 & 0.77 & 0.96 & 0.87 & 0.98 & 0.92 & 0.89 \\
Composition & REP_TRAIN_PLUS_VIT_L_32_RANDOM & 0.99 & 0.84 & 0.95 & 0.98 & 1.00 & 0.98 & 0.96 \\
Composition & REP_TRAIN_SAME_RES_VIT_L_32 & 0.95 & 0.84 & 0.96 & 0.99 & 1.00 & 0.95 & 0.95 \\
Composition & REP_TRAIN_SAME_RES_VIT_L_32_RANDOM & 1.00 & 0.90 & 0.96 & 0.98 & 1.0

In [None]:
    def split_step(self, batch):    
        # Batch is simple
        imgs, gt_reps, latents = batch
        mid_reps = gt_reps if self.use_reps else self.encoder(imgs.float(), gt_reps)     # Image encoding
        mid_reps = torch.nn.functional.normalize(mid_reps, p=2.0, dim=1, eps=1e-12)
        logits = self.regressor(mid_reps)
        data = dict()
        data['logits'] = logits
        data['targets'] = latents
        return data

    def predict_regression(self, batch):
        imgs, reps, latents = batch
        reps = self.encode(imgs, reps)
        zero_deltas = torch.zeros_like(latents).float()
        if self.modulator is not None:
            reps = self.modulate(reps, zero_deltas)
        preds = self.regressor(reps)
        data = dict()
        data['logits'] = preds
        data['targets'] = latents
        return data
