In [1]:
from models import LightningRepClassification
from model_info import encoders, modulators, model_output_dims
from utils import set_seed, get_args
import torch
from datasets import IdSpritesEval
import pandas as pd
from tqdm.notebook import tqdm

def get_dataloader(args):
    indices = list(range(480000))
    data = torch.load(f"{args.dataset}/{args.dataset}.pth", map_location="cpu")
    if args.pretrained_reps:
        data['reps'] = torch.load(f"{args.dataset}/{args.dataset}_images_feats_{args.pretrained_reps}.pth", map_location="cpu")
    ds = IdSpritesEval(args, data, indices, max_delta=14, num_samples=20, p_skip=0, test=False, return_indices=True)
    dl = torch.utils.data.DataLoader(ds, batch_size=1024, shuffle=False)
    return dl
def get_model(args):
    encoder = encoders[args.encoder.arch](args) if args.encoder.arch != "none" else None
    input_dims = model_output_dims[args.encoder.arch] if args.encoder.arch != "none" else  model_output_dims[args.pretrained_reps]
    
    print("input_dim",input_dims,
          "hidden_dim",args.modulator.hidden_dim)
    modulator = modulators[args.train_method](input_dim=input_dims,
                                              hidden_dim=args.modulator.hidden_dim,
                                              latent_dim = 5 if args.dataset == "idsprites" else 6
                                              )
    # Load from checkpoint:
    model = LightningRepClassification.load_from_checkpoint(checkpoint_path=f"results/{args.dataset}/{exp_id}/last.ckpt", 
                                                            args=args, 
                                                            encoder=encoder, 
                                                            modulator=modulator)
    return model

def evaluate_and_save(model, dataloader, groups_list, output_file):
    model.eval()
    results = []
    split = ['train','id','ood']
    with torch.no_grad():
        for n_batch, batch in enumerate(tqdm(dataloader)):
            # Unpack index + batch
            idxs, src_img, src_rep, imgs, gt_reps, latents = batch
            data = model.split_step((src_img.cuda(), src_rep.cuda(), imgs.cuda(), gt_reps.cuda(), latents.cuda()))
            # We want to store per-sample results:
            # The logits are bs*n_classes x n_classes — reshape if needed.
            bs = src_img.size(0)
            preds = data['logits'].argmax(dim=-1).cpu().numpy()
            targets = data['class_tgt'].cpu().numpy()
            tasks = data['tasks'].view(-1).cpu().numpy()
            for i, sample_idx in enumerate(idxs.cpu().numpy()):
                result = {
                    'idx': sample_idx,
                    'pred': preds[i],
                    'target': targets[i],
                    'task': tasks[i],
                    # Add other metrics or latent info if needed
                }
                results.append(result)
                
    df = pd.DataFrame(results)
    df['group'] = df['idx'].map(lambda x: split[groups_list[x].long().item()])
    df.to_csv(output_file, index=False)
    
    return df


There was a problem when trying to write in your cache folder (/storage/cache). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


In [9]:
from sklearn.model_selection import train_test_split
# Define experiment ids for checkpoint

exps = ['j97giv73', # composition
        '99fuyinh', # composition
        '8evsuasz', # interpolation
        'w8c3385v', # interpolation
        'nsfoq455', # extrapolation
        'pt7snmb1' # extrapolation
       ] # interpolation
# experimentos mejores

for exp_id in tqdm(exps):
    args = get_args(exp_id)
    args.encoder['pretrain_method'] = None
    model = get_model(args)
    dl = get_dataloader(args)
    train_indices = torch.load(f"3dshapes/shapes3d_{args.sub_dataset}_train_indices.pth")
    train_indices, val_indices = train_test_split(train_indices, test_size = 0.1, random_state=42)

    groups_list = 2*torch.ones(480000)
    groups_list[train_indices] = 0
    groups_list[val_indices] = 1
    evaluate_and_save(model, dl, groups_list, f"{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv")

  0%|          | 0/6 [00:00<?, ?it/s]

input_dim 768 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

input_dim 1024 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

input_dim 768 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

input_dim 768 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

input_dim 1024 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

input_dim 1024 hidden_dim 128
Recentering representations!
Normalizing reps!


  0%|          | 0/469 [00:00<?, ?it/s]

In [4]:
exps = ['j97giv73',
        '99fuyinh',
        '8evsuasz',
        'nsfoq455',
        'pt7snmb1',
        'w8c3385v'
       # "nn3235c9",
       # "8p8nqlqm",
       # "3j1sltd9",
       # "ycdtf0ng",
       # "57gubcbl",
       # "uu8nayjd"
        ] # experimentos mejores
#exps = ["ga9v6jrr", "q77tx64m", "4i1m0a4x", "aauu4gxw","6l733ceo","u8ql80ja"] # experimentos al reves

final_result = pd.DataFrame()
for exp_id in exps:
    args = get_args(exp_id)
    filename = f"{exp_id}_{args.dataset}_{args.sub_dataset}_{args.pretrained_reps}.csv"
    exp_id, dataset, sub_dataset = filename.split("_")[:3]
    pretrained_reps = "-".join(filename.split("_")[3:]).replace(".csv","")
    df= pd.read_csv(filename)
    df['dataset'] = dataset
    df['sub_dataset'] = sub_dataset
    df['model'] = pretrained_reps
    df['correct'] = df['pred'] == df['target']
    df['task_name'] = df['task'].apply(lambda x:  ["floor_hue", "wall_hue", "object_hue", "scale", "shape", "orientation"][x])
    result = df.groupby(['dataset','sub_dataset',"model",'group', 'task_name'])['correct'].mean().reset_index()
    result.rename(columns={'correct': 'accuracy'}, inplace=True)
    print(result)
    final_result = pd.concat([final_result, result])


     dataset  sub_dataset     model  group    task_name  accuracy
0   3dshapes  composition  vit-b-32     id    floor_hue  0.986631
1   3dshapes  composition  vit-b-32     id   object_hue  0.952473
2   3dshapes  composition  vit-b-32     id  orientation  0.871923
3   3dshapes  composition  vit-b-32     id        scale  0.918740
4   3dshapes  composition  vit-b-32     id        shape  0.993531
5   3dshapes  composition  vit-b-32     id     wall_hue  0.983175
6   3dshapes  composition  vit-b-32    ood    floor_hue  0.958850
7   3dshapes  composition  vit-b-32    ood   object_hue  0.887724
8   3dshapes  composition  vit-b-32    ood  orientation  0.764908
9   3dshapes  composition  vit-b-32    ood        scale  0.899644
10  3dshapes  composition  vit-b-32    ood        shape  0.989840
11  3dshapes  composition  vit-b-32    ood     wall_hue  0.959663
12  3dshapes  composition  vit-b-32  train    floor_hue  0.986366
13  3dshapes  composition  vit-b-32  train   object_hue  0.952742
14  3dshap

In [5]:
final_result['group'] = pd.Categorical(final_result['group'], categories=['train', 'id','ood'], ordered=True)
final_result = final_result.sort_values(by=['dataset', 'sub_dataset','model','group'], ascending=[True, True,True,True])


In [6]:
models = ['vit-b-32', 'vit-l-32']
df = final_result
tables = {}
for model in models:
    df_model = df[df['model'] == model]
    pivot = df_model.pivot_table(
        index=['sub_dataset', 'group'],
        columns='task_name',
        values='accuracy'
    ).reset_index()
    tables[model] = pivot

  pivot = df_model.pivot_table(
  pivot = df_model.pivot_table(


In [7]:
tables['vit-b-32']

task_name,sub_dataset,group,floor_hue,object_hue,orientation,scale,shape,wall_hue
0,composition,train,0.986366,0.952742,0.887531,0.92211,0.995251,0.981804
1,composition,id,0.986631,0.952473,0.871923,0.91874,0.993531,0.983175
2,composition,ood,0.95885,0.887724,0.764908,0.899644,0.98984,0.959663
3,extrapolation,train,0.888746,0.85527,0.650432,0.688759,0.98859,0.875718
4,extrapolation,id,0.884514,0.849767,0.659195,0.697283,0.986517,0.873458
5,extrapolation,ood,0.678899,0.692237,0.583349,0.662754,0.980025,0.704941
6,interpolation,train,0.746398,0.617402,0.677533,0.478523,0.999689,0.709076
7,interpolation,id,0.75301,0.610314,0.689942,0.476974,0.999446,0.708716
8,interpolation,ood,0.536196,0.586212,0.642316,0.451403,0.999194,0.52472


In [7]:
def df_to_tabularx(df, label, caption, column_width='\\textwidth'):
    import io

    # Copy and format DataFrame
    df_fmt = df.copy()
    df_fmt.iloc[:, 0] = df_fmt.iloc[:, 0].str.capitalize()  # sub_dataset
    df_fmt.iloc[:, 1] = df_fmt.iloc[:, 1].str.upper()       # group

    # Format numeric columns as percentages
    for col in df.columns[2:]:
        df_fmt[col] = (df[col] * 100).map(lambda x: f"{x:.2f}\\%")

    # Build tabularx column format
    num_task_columns = df_fmt.shape[1] - 2
    column_format = 'll' + 'X' * num_task_columns

    # Prepare bold column headers
    bold_headers = [f"\\textbf{{{col.replace('_', ' ').capitalize()}}}" for col in df_fmt.columns]

    # Use to_latex to get the content, skipping header
    buf = io.StringIO()
    df_fmt.to_latex(
        buf,
        index=False,
        header=False,
        escape=False,
        column_format=column_format
    )
    lines = buf.getvalue().splitlines()

    # Insert bold header manually
    header_line = ' & '.join(bold_headers) + ' \\\\'
    table_body = '\n'.join(lines[3:-2])  # skip to_latex's \toprule, etc.

    # Final LaTeX table
    latex = (
        f"\\begin{{table}}[ht]\n"
        f"\\centering\n"
        f"\\caption{{{caption}}}\n"
        f"\\label{{{label}}}\n"
        f"\\begin{{tabularx}}{{{column_width}}}{{{column_format}}}\n"
        f"{header_line}\n"
        f"{table_body}\n"
        f"\\end{{tabularx}}\n"
        f"\\end{{table}}"
    )

    return latex




label = "tab:asda"
caption = "Hola"
print(df_to_tabularx(tables['vit-b-32'], label, caption, column_width='\\textwidth'))

\begin{table}[ht]
\centering
\caption{Hola}
\label{tab:asda}
\begin{tabularx}{\textwidth}{llXXXXXX}
\textbf{Sub dataset} & \textbf{Group} & \textbf{Floor hue} & \textbf{Object hue} & \textbf{Orientation} & \textbf{Scale} & \textbf{Shape} & \textbf{Wall hue} \\
Composition & TRAIN & 98.12\% & 94.74\% & 82.05\% & 93.23\% & 99.62\% & 97.93\% \\
Composition & ID & 98.23\% & 95.13\% & 80.92\% & 92.77\% & 99.35\% & 97.88\% \\
Composition & OOD & 96.38\% & 88.27\% & 79.04\% & 90.19\% & 99.49\% & 94.81\% \\
Extrapolation & TRAIN & 88.38\% & 85.61\% & 65.26\% & 64.67\% & 97.49\% & 90.12\% \\
Extrapolation & ID & 87.89\% & 85.36\% & 65.42\% & 64.60\% & 97.25\% & 89.32\% \\
Extrapolation & OOD & 68.86\% & 70.67\% & 60.02\% & 63.01\% & 95.69\% & 71.33\% \\
Interpolation & TRAIN & 76.03\% & 61.38\% & 68.46\% & 46.04\% & 100.00\% & 76.66\% \\
Interpolation & ID & 75.95\% & 60.60\% & 70.49\% & 44.65\% & 100.00\% & 77.22\% \\
Interpolation & OOD & 52.75\% & 58.59\% & 65.44\% & 44.66\% & 100.00\% & 56.

1       ID
2      OOD
3    TRAIN
4       ID
5      OOD
6    TRAIN
7       ID
8      OOD
Name: group, dtype: object' has dtype incompatible with category, please explicitly cast to a compatible dtype first.
  df_fmt.iloc[:, 1] = df_fmt.iloc[:, 1].str.upper()       # group


In [None]:
result

In [36]:
 from torch.utils.data import DataLoader

dl = DataLoader(ds, batch_size=16)
images, reps, delta = next(iter(dl))

In [14]:
bs = 16
n_classes = 10
torch.tensor(bs*list(range(n_classes))).view(-1, n_classes)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [67]:
import torch.nn.functional as F
import pandas as pd


def get_metrics(data):
    criterion = F.cosine_similarity
    metrics = dict()
    loss = 0
    if True:
        same_loss = 1 - criterion(data['mid_reps'], data['rep_tgt']).mean()
        loss += same_loss
        metrics['same_loss'] = same_loss

    if True:
        class_loss = F.cross_entropy(data['logits'], data['class_tgt'].view(-1), reduction="none").mean()
        loss +=  class_loss
        metrics['class_loss'] = class_loss

        preds = logits.argmax(dim=-1).view(-1)
        correct = (preds == data['class_tgt']).view(-1).float()
        accuracy = correct.sum()/correct.numel()
        metrics['class_acc'] = accuracy
        dtype=correct.dtype
        device=correct.device
        print(correct.shape, tasks.shape)
        mean_per_group = torch.zeros(5, dtype=dtype, device=device).scatter_reduce(0, data['tasks'], correct, reduce="mean")
        for i, task in enumerate(['shape','scale','orientation','x','y']):
            metrics[f'class_{task}'] = mean_per_group[i]
            
    metrics['loss'] = loss

    for k, v in metrics.items():
        metrics[k] = v.item()
    return metrics 
    
imgs, gt_reps, latents = images, reps, delta 
zero_latents = torch.zeros_like(latents)
deltas = latents.sum(dim=-1)
bs, n_classes, *_ = imgs.shape

mid_reps = gt_reps if True else self.encoder(imgs)     # Image encoding
reps = torch.randn((bs,n_classes,128))                       # predicted reps given latents
tgt_reps = torch.randn((bs,n_classes,128))               # reps we are trying to achieve

logits = torch.matmul(reps, tgt_reps.transpose(1,2)).view(-1, n_classes) # bs x 10 x 10 --> 10bs x 10
#reps, valid_indices = expand_reps(reps, ranges) # valid indices means which lines to keep on the loss
#tgt_reps, _ = expand_reps(tgt_reps, ranges)
data['mid_reps'] = mid_reps
data['rep_tgt'] = gt_reps
data['logits'] = logits
targets = torch.tensor(bs*list(range(n_classes))).view(-1, n_classes)
tasks = latents.abs().argmax(dim=-1)
data['class_tgt'] = targets.view(-1)
data['tasks'] = tasks.view(-1)
metrics = get_metrics(data)
# altered encoded rep must be equal to rep of original image
loss = 0

torch.Size([160]) torch.Size([16, 10])


In [68]:
metrics

{'same_loss': 0.0,
 'class_loss': 18.06698226928711,
 'class_acc': 0.09375,
 'class_shape': 0.06060606241226196,
 'class_scale': 0.09090909361839294,
 'class_orientation': 0.09090909361839294,
 'class_x': 0.0882352963089943,
 'class_y': 0.125,
 'loss': 18.06698226928711}

In [42]:
delta[1]

tensor([[ 1,  0,  0,  0,  0],
        [ 0,  1,  0,  0,  0],
        [ 0,  0,  1,  0,  0],
        [ 0,  0,  0,  1,  0],
        [ 0,  0,  0,  0,  1],
        [ 0,  0,  0,  0, -1],
        [ 2,  0,  0,  0,  0],
        [ 0,  2,  0,  0,  0],
        [ 0,  0,  2,  0,  0],
        [ 0,  0,  0,  2,  0]])

In [92]:
from tqdm.notebook import tqdm
min_ = float('inf')
max_ = float('-inf')
for idx, item in tqdm(enumerate(ds)):
    min_ = min(item[-1].size(0), min_)
    max_ = max(item[-1].size(0), max_)
    if min_ < 10: 
        print('min_', idx)
        break
    if max_ > 10: 
        print('max_', idx)
        break
    print(f'min_={min_} max_={max_}', end='\r')

0it [00:00, ?it/s]

min_=10 max_=10

In [6]:
a, b = ds[0]

In [7]:
a.shape

torch.Size([10, 1, 64, 64])

In [81]:
import torch.nn as nn
class MLP(nn.Module):
    def __init__(self, input_dim=100, latent_dim=5, hidden_dim=128, n_blocks=3):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.proj = nn.Linear(latent_dim,input_dim)
        modules = [nn.Linear(2*self.input_dim, self.hidden_dim)]
        for i in range(n_blocks-1):
            modules.append(nn.ReLU())
            modules.append(nn.Linear(self.hidden_dim, self.hidden_dim))
        self.model = nn.Sequential(*modules)

    def forward(self, x, l):

        l = self.proj(l)
        x = torch.cat((x,l), dim=1)
        x = self.model(x)
        return x



model = MLP(input_dim=100, hidden_dim=32, latent_dim=5, n_blocks=4)
data = torch.randn(64,100)
l = torch.randn(64,5)
model(data, l).shape

torch.Size([64, 32])

In [77]:
encoders = {
    None: None,
    "vit": 384,
    "vit_b_16": 1,    
    "vit_b_32": 22,
    "vit_l_16": 16,
    "vit_l_32": 32
}

print(encoders["vit"])

384


In [4]:
import torch
root = 'idsprites'

data  = torch.load(f"{root}/idsprites.pth")
data['reps'] = torch.load(f"{root}/idsprites_images_feats_vit_b_16.pth")
train_indices = torch.tensor([i for i in range(len(data['images']))])


In [13]:
assert False

AssertionError: 

In [None]:
def map_latents_to_values(new_latents_idxs):
    new_latents_idxs = new_latents_idxs.long()
    new_latents_idxs[:,1:] = torch.clamp(new_latents_idxs[:,1:], min=0, max=13)
    latent_indices = f.one_hot(new_latents_idxs[:,1:],num_classes=14)
    new_latents = torch.cat((new_latents_idxs[:,0].unsqueeze(1),(latent_indices*values).sum(dim=-1)),dim=1)
    return new_latents
    
def map_detail(x):
    l = ['shape+','scale+','orientation+','x+','y+',
        'shape-','scale-','orientation-','x-','y-']
    return  l[x]
    
def index_to_latent_id(idx):
    shape = (idx // (14**4)) % 54
    scale =  (idx // (14**3)) % 14
    orientation =  (idx // (14**2)) % 14
    x =  (idx // 14) % 14
    y =  idx % 14
    return (shape,scale,orientation,x,y)

# latent_id = tensor with original starting latent_ids for all latent attributes (size = (1,5))
# delta = how many steps to move in all latent attributes

def get_delta_latents(latent_id, delta):
    
    pred_delta_latents = delta*torch.eye(5)
    pred_delta_latents = torch.cat((pred_delta_latents,-pred_delta_latents), dim=0)
    pred_delta_latents = pred_delta_latents.repeat(n,1,1) # CPU
    
    new_latents_idxs = (latent_id.repeat(1,10,1) + pred_delta_latents).to(torch.int8) # CPU
    out_of_min_range = new_latents_idxs >= 0
    out_of_max_range = new_latents_idxs < 14
    out_of_max_range[:,:,0] = new_latents_idxs[:,:,0] < 54
    out_of_range = out_of_max_range*out_of_min_range
    viable_latents = torch.all(out_of_range,dim=2).view(-1)

def latent_id_to_split(latent, ood):
    latent = {k: v for k, v in zip(['shape','scale','orientation','x','y'], latent)}
    for k, v in ood.items():
        if latent[k] in v:
            return "ood"
    else:
        return "iid"