In [5]:
from models import LightningRepClassification, LightningTransformRegression, LightningRegression,create_model
from model_info import encoders, modulators, model_output_dims
from utils import set_seed, get_args
import torch
from datasets import IdSpritesEval
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split

def get_dataset(args):
    indices = list(range(480000))
    data = torch.load(f"{args.dataset}/{args.dataset}.pth", map_location="cpu")
    if args.pretrained_reps:
        data['reps'] = torch.load(f"{args.dataset}/{args.dataset}_images_feats_{args.pretrained_reps}.pth", map_location="cpu")
    ds = IdSpritesEval(args, data, indices, max_delta=14, num_samples=20, p_skip=0, test=False, return_indices=True)
    return ds
    
def get_dataloader(args, ds, indices):
    sampler = SubsetRandomSampler(indices=indices)
    dl = DataLoader(ds, batch_size=1024, sampler=sampler)
    return dl

def evaluate(model, dataloader):
    model.eval()
    results = []
    split = ['train','id','ood']

    sss = []
    tasks = []
    _ , _ , src_rep, _, _, latents = next(iter(dataloader))
    n_attrs = latents.shape[-1]
    dims = src_rep.shape[-1]

    device = "cuda"
    y_squared = torch.zeros(n_attrs, dims).to(device)
    ys = torch.zeros(n_attrs, dims).to(device)
    
    with torch.no_grad():
        n_batches = 0
        for n_batch, batch in enumerate(tqdm(dataloader)):
            # Unpack index + batch
            idxs, src_img, src_rep, imgs, gt_reps, latents = batch
            idxs, src_img, src_rep, imgs, gt_reps, latents = idxs.cuda(),src_img.cuda(), src_rep.cuda(), imgs.cuda(), gt_reps.cuda(), latents.cuda()

            n_batches+=1
            data = model.split_step((src_img, src_rep, imgs, gt_reps, latents))
            
            sss.append(torch.sum((data['targets'] - data['logits']) ** 2, dim=1))
            tasks.append(data['tasks'])

            # rolling stats
            for i in range(n_attrs):
                y_squared[i] += (data['targets'][data['tasks'] == i]**2).sum(dim=0)
                ys[i] += data['targets'][data['tasks'] == i].sum(dim=0)

        ss_res = torch.cat(sss, dim=0).cuda()
        tasks = torch.cat(tasks, dim=0).long().cuda()
        
        
        n_attrs = 6
        dtype = ss_res.dtype
        device = ss_res.device
        ss_res = torch.zeros(n_attrs, dtype=dtype, device=device).scatter_reduce(0,
                                                                            tasks,
                                                                            ss_res,
                                                                            reduce="sum")

        counts = torch.zeros(n_attrs, dtype=tasks.dtype, device=device).scatter_reduce(0,tasks, torch.ones_like(tasks).cuda(), reduce="sum").to(device)
        mus = torch.empty(n_attrs, dims).to(device)
        ss_tot = torch.empty(n_attrs).to(device)
        for i in range(n_attrs):
            mus[i] = ys[i]/counts[i]
            ss_tot[i] = (y_squared[i] - 2*ys[i]*mus[i] + mus[i]**2).sum()  # ==> sum_{i=1 in T} (y_i - mu_T)^2

        r2 = 1- ss_res/ss_tot                            # TODO: get this per task
    return r2

In [6]:

# Define experiment ids for checkpoint

exps = [#'0hdoi4lw', # composition
       #'6dlybv9s', # composition
       # "dbnxlnfv", # interpolation
       # "tt466w4y", # interpolation
        "te4t8cr4", # extrapolation
        "n0vdpha1" # extrapolation
       ]

for exp_id in tqdm(exps):
    args = get_args(exp_id)
    args.encoder['pretrain_method'] = None
    print(args)

    ds = get_dataset(args)
    df = pd.DataFrame()
    for split in tqdm(['train','id','ood']):
        encoder, modulator = create_model(args)
        model = LightningTransformRegression.load_from_checkpoint(checkpoint_path=f"results/{args.dataset}/{exp_id}/last.ckpt", 
                                            args=args, 
                                            encoder=encoder, 
                                            modulator=modulator)
        if split in ['train','id']:
            indices = torch.load(f"3dshapes/shapes3d_{args.sub_dataset}_train_indices.pth")
            train_indices, val_indices = train_test_split(indices, test_size = 0.1, random_state=42)
            indices = train_indices if split == "train" else val_indices
        elif split == "ood":
            indices = torch.load(f"3dshapes/shapes3d_{args.sub_dataset}_test_indices.pth")
        else:
            print("Split not recognized!")
            indices = torch.tensor([]).long()
            
        dl = get_dataloader(args, ds, indices)
        r2 = evaluate(model, dl)
        
        # Store metadata
        model_name = args.pretrained_reps
        if args.pretrained_reps is None:
            if args.pretrained_encoder is not None:
                enc_args = get_args(args.pretrained_encoder)
                model_name = enc_args.pretrained_reps 
        meta = {
            'split': split,
            'dataset': args.dataset,
            'sub_dataset': args.sub_dataset,
            'model': f"{model_name}"
        }
        
        # Create a long-format DataFrame where each row is a (task, r2) pair
        rows = []
        for i, fov in enumerate(args.FOVS_PER_DATASET):
            rows.append({
                **meta,
                'task': fov,
                'r2': r2[i].item()
            })

        # Append to df
        result_df = pd.DataFrame(rows)
        df = pd.concat([df, result_df], ignore_index=True)

    df.to_csv(f"transform_{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv")

  0%|          | 0/2 [00:00<?, ?it/s]

{'lr': 0.001, 'wd': 0.04, 'arch': 'none', 'fovs': ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation'], 'seed': 111, 'test': False, 'frozen': True, 'losses': 'class', 'n_fovs': {'scale': 8, 'shape': 4, 'wall_hue': 10, 'floor_hue': 10, 'object_hue': 10, 'orientation': 15}, 'warmup': 6.666666666666667, 'dataset': '3dshapes', 'encoder': {'arch': 'none', 'frozen': True, 'enc_dims': 16, 'pretrained': None, 'pretrain_method': None}, 'data_dir': '/mnt/nas2/GrimaRepo/araymond', 'enc_dims': 16, 'final_lr': 1e-06, 'final_wd': 0.4, 'fovs_ids': [0, 1, 2, 3, 4, 5], 'mod_arch': 'mlp', 'mod_dims': 16, 'start_lr': 0.0002, 'train_bs': 256, 'ema_start': 0.996, 'ipe_scale': 1, 'modulator': {'arch': 'mlp', 'hidden_dim': 16}, 'num_steps': 400000, 'resume_id': None, 'fovs_tasks': ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation'], 'num_epochs': 50, 'save_every': 10, 'fovs_levels': {'3dshapes': {'scale': 3, 'shape': 2, 'wall_hue': 2, 'floor_hue': 2, 'object_hue': 2, 

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/130 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/325 [00:00<?, ?it/s]

{'lr': 0.001, 'wd': 0.04, 'arch': 'none', 'fovs': ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation'], 'seed': 111, 'test': False, 'frozen': True, 'losses': 'class', 'n_fovs': {'scale': 8, 'shape': 4, 'wall_hue': 10, 'floor_hue': 10, 'object_hue': 10, 'orientation': 15}, 'warmup': 6.666666666666667, 'dataset': '3dshapes', 'encoder': {'arch': 'none', 'frozen': True, 'enc_dims': 16, 'pretrained': None, 'pretrain_method': None}, 'data_dir': '/mnt/nas2/GrimaRepo/araymond', 'enc_dims': 16, 'final_lr': 1e-06, 'final_wd': 0.4, 'fovs_ids': [0, 1, 2, 3, 4, 5], 'mod_arch': 'mlp', 'mod_dims': 16, 'start_lr': 0.0002, 'train_bs': 256, 'ema_start': 0.996, 'ipe_scale': 1, 'modulator': {'arch': 'mlp', 'hidden_dim': 16}, 'num_steps': 400000, 'resume_id': None, 'fovs_tasks': ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation'], 'num_epochs': 50, 'save_every': 10, 'fovs_levels': {'3dshapes': {'scale': 3, 'shape': 2, 'wall_hue': 2, 'floor_hue': 2, 'object_hue': 2, 

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/130 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/325 [00:00<?, ?it/s]

In [7]:
exps = ['0hdoi4lw', # composition
       '6dlybv9s', # composition
        "dbnxlnfv", # interpolation
        "tt466w4y", # interpolation
        "te4t8cr4", # extrapolation
        "n0vdpha1" # extrapolation
       ]


In [8]:
final_result = pd.DataFrame()
for exp_id in exps:
    args = get_args(exp_id)
    filename = f"transform_{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv"
    exp_id, dataset, sub_dataset = filename.split("_")[:3]
    pretrained_reps = "-".join(filename.split("_")[3:]).replace(".csv","")
    df= pd.read_csv(filename)
    result = df.groupby(['dataset','sub_dataset',"model", 'split', 'task'])['r2'].mean().reset_index()
    result.rename(columns={'correct': 'accuracy'}, inplace=True)
    print(result)
    final_result = pd.concat([final_result, result])

     dataset  sub_dataset     model  split         task        r2
0   3dshapes  composition  vit_l_32     id    floor_hue  0.820189
1   3dshapes  composition  vit_l_32     id   object_hue  0.817173
2   3dshapes  composition  vit_l_32     id  orientation  0.844980
3   3dshapes  composition  vit_l_32     id        scale  0.856303
4   3dshapes  composition  vit_l_32     id        shape  0.837392
5   3dshapes  composition  vit_l_32     id     wall_hue  0.810326
6   3dshapes  composition  vit_l_32    ood    floor_hue  0.755349
7   3dshapes  composition  vit_l_32    ood   object_hue  0.748533
8   3dshapes  composition  vit_l_32    ood  orientation  0.782774
9   3dshapes  composition  vit_l_32    ood        scale  0.789272
10  3dshapes  composition  vit_l_32    ood        shape  0.755803
11  3dshapes  composition  vit_l_32    ood     wall_hue  0.746183
12  3dshapes  composition  vit_l_32  train    floor_hue  0.821755
13  3dshapes  composition  vit_l_32  train   object_hue  0.819229
14  3dshap

In [9]:
final_result['split'] = pd.Categorical(final_result['split'], categories=['train', 'id','ood'], ordered=True)
final_result = final_result.sort_values(by=['dataset', 'sub_dataset','model','split'], ascending=[True, True,True,True])

models = ["train","id","ood"]
df = final_result
tables = {}
for model in models:
    df_model = df[df['split'] == model]
    pivot = df_model.pivot_table(
        index=['sub_dataset', 'model'],
        columns='task',
        values='r2'
    ).reset_index()
    tables[model] = pivot

In [15]:
def df_to_tabularx(df, label, caption, column_width='\\textwidth'):
    import io

    # Copy and format DataFrame
    df_fmt = df.copy()
    df_fmt.iloc[:, 0] = df_fmt.iloc[:, 0].str.capitalize()  # sub_dataset
    df_fmt.iloc[:, 1] = df_fmt.iloc[:, 1].str.upper()       # group

    # Format numeric columns as percentages
    for col in df.columns[2:]:
        df_fmt[col] = (df[col]).map(lambda x: f"{x:.2f}")

    # Build tabularx column format
    num_task_columns = df_fmt.shape[1] - 2
    column_format = 'll' + 'X' * num_task_columns

    # Prepare bold column headers
    bold_headers = [f"\\textbf{{{col.replace('_', ' ').capitalize()}}}" for col in df_fmt.columns]

    # Use to_latex to get the content, skipping header
    buf = io.StringIO()
    df_fmt.to_latex(
        buf,
        index=False,
        header=False,
        escape=False,
        column_format=column_format
    )
    lines = buf.getvalue().splitlines()

    # Insert bold header manually
    header_line = ' & '.join(bold_headers) + ' \\\\'
    table_body = '\n'.join(lines[3:-2])  # skip to_latex's \toprule, etc.

    # Final LaTeX table
    latex = (
        f"\\begin{{table}}[ht]\n"
        f"\\centering\n"
        f"\\caption{{{caption}}}\n"
        f"\\label{{{label}}}\n"
        f"\\begin{{tabularx}}{{{column_width}}}{{{column_format}}}\n"
        f"{header_line}\n"
        f"{table_body}\n"
        f"\\end{{tabularx}}\n"
        f"\\end{{table}}"
    )

    return latex




label = "tab:asda"
caption = "Hola"
print(df_to_tabularx(tables['id'], label, caption, column_width='\\textwidth'))

\begin{table}[ht]
\centering
\caption{Hola}
\label{tab:asda}
\begin{tabularx}{\textwidth}{llXXXXXX}
\textbf{Sub dataset} & \textbf{Model} & \textbf{Floor hue} & \textbf{Object hue} & \textbf{Orientation} & \textbf{Scale} & \textbf{Shape} & \textbf{Wall hue} \\
Composition & VIT_B_32 & 0.83 & 0.83 & 0.87 & 0.87 & 0.85 & 0.83 \\
Composition & VIT_L_32 & 0.82 & 0.82 & 0.84 & 0.86 & 0.84 & 0.81 \\
Extrapolation & VIT_B_32 & 0.82 & 0.83 & 0.85 & 0.86 & 0.85 & 0.82 \\
Extrapolation & VIT_L_32 & 0.80 & 0.80 & 0.82 & 0.83 & 0.82 & 0.80 \\
Interpolation & VIT_B_32 & 0.84 & 0.83 & 0.87 & 0.87 & 0.86 & 0.83 \\
Interpolation & VIT_L_32 & 0.82 & 0.81 & 0.85 & 0.85 & 0.84 & 0.80 \\
\end{tabularx}
\end{table}


In [32]:
sum_per_group

tensor([59549.7578, 60457.2773, 74756.4141, 50006.6914, 74390.8359, 71710.8906])

In [14]:
tables['id']

task,sub_dataset,model,floor_hue,object_hue,orientation,scale,shape,wall_hue
0,composition,vit_b_32,0.830173,0.83331,0.866473,0.874353,0.851716,0.827378
1,composition,vit_l_32,0.820189,0.817173,0.84498,0.856303,0.837392,0.810326
2,extrapolation,vit_b_32,0.823776,0.830696,0.84665,0.859351,0.849552,0.822614
3,extrapolation,vit_l_32,0.802029,0.798602,0.821899,0.832293,0.819893,0.795141
4,interpolation,vit_b_32,0.836922,0.834801,0.874824,0.872758,0.864299,0.829349
5,interpolation,vit_l_32,0.819529,0.813959,0.849522,0.847024,0.842896,0.802319
