# Metrics

### Import

#### Modules

In [1]:
%load_ext autoreload
%autoreload 2

import os
PROJECT_PATH = "/projects/compures/alexandre/disdiff_adapters"
os.chdir(PROJECT_PATH)
print(os.getcwd())

from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance
from sklearn.metrics import r2_score
from xgboost import XGBRegressor
from scipy.special import digamma, gammaln
import numpy as np
from collections import Counter, defaultdict
from math import log
import seaborn as sns
from glob import glob
import json
from pathlib import Path

from torch.utils.data import DataLoader, TensorDataset
from lightning import Trainer, LightningDataModule

from tqdm import tqdm

#DataModule
from disdiff_adapters.data_module import *
#Dataset
from disdiff_adapters.dataset import *
#Module
from disdiff_adapters.arch.multi_distillme import *
#utils
from disdiff_adapters.utils import *
#loss   
from disdiff_adapters.loss import *
#metric
from disdiff_adapters.metric import FactorVAEScore

from disdiff_adapters.arch.multi_distillme.xfactors import Xfactors
BATCH_SIZE = 2**19
LATENT_DIM_S = 126
LATENT_DIM_T = 2
is_pca = False

torch.set_float32_matmul_precision('medium')

/import/pr_compures/alexandre/disdiff_adapters


  from .autonotebook import tqdm as notebook_tqdm


#### Load metrics

In [2]:
def load_json(path: Path) -> dict:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)
metrics_x = load_json(Path("metrics_final.json"))
ckpt_path_x = load_json(Path("ckpt_path_x.json"))

## Compute metrics

### FactorVAE X

In [3]:
import torch
import numpy as np
from sklearn.decomposition import PCA
from os.path import join
from os import mkdir
from collections import Counter, defaultdict
from tqdm import tqdm


from disdiff_adapters.data_module import LatentDataModule

class FactorVAEScore :

    def __init__(self, ckpt_path: str, 
                is_pca: bool=False, 
                n_iter: int=153600, 
                batch_size: int=64, 
                pref_gpu: int=0, 
                verbose: bool=False,
                only_factors: list[int]=[]) :
        
        self.ckpt_path = ckpt_path
        self.data_name =  self.get_data_name()
        self.is_pca = is_pca
        self.n_iter = n_iter
        self.batch_size = batch_size
        self.pref_gpu = pref_gpu
        self.only_factors = only_factors

        z_s_te, z_t_te, label_te = self.load_latent(stage="test", verbose=verbose)
        z_s_val, z_t_val, label_val = self.load_latent(stage="val", verbose=verbose)                                 

        Z_te = torch.cat([z_s_te, z_t_te], dim=1).cpu().numpy()        
        Z_te = (Z_te - Z_te.mean(axis=0, keepdims=True)) / (Z_te.std(axis=0, keepdims=True) + 1e-8)
        Y_te = label_te.cpu().numpy().astype(np.int64)   

        Z_val = torch.cat([z_s_val, z_t_val], dim=1).cpu().numpy()        
        Z_val = (Z_val - Z_val.mean(axis=0, keepdims=True)) / (Z_val.std(axis=0, keepdims=True) + 1e-8)
        Y_val = label_val.cpu().numpy().astype(np.int64) 

        if only_factors:   
            print(Y_val.shape, Y_te.shape)
            Y_val = Y_val[:, only_factors]
            Y_te = Y_te[:, only_factors]       
            print(Y_val.shape, Y_te.shape)

        self.mus_test = Z_te.T                                   
        self.ys_test  = Y_te.T   
        if verbose: print("Test data formated.") 

        self.mus_val = Z_val.T                                   
        self.ys_val  = Y_val.T   
        if verbose: print("Val data formated.") 

        self.rng = np.random.default_rng(0)
    
    def set_batch_size(self, batch_size: int) :
        self.batch_size = batch_size

    def set_n_iter(self, n_iter: int):
        self.n_iter = n_iter
        
    def load_latent(self, stage: str="test", verbose: bool=False):
        latent = LatentDataModule(standard=True, 
                                batch_size=2**19,
                                pref_gpu=self.pref_gpu,
                                Model_class=Xfactors,
                                data_name=self.data_name,
                                ckpt_path=self.ckpt_path,
                                verbose=verbose)
        
        if verbose : print("Prepare data: test if .npz files exist.")
        latent.prepare_data()

        if verbose: print(f"Start loading {stage} batch.")
        latent.setup(stage)
        if verbose: print(f"Start loading {stage} batch. - end setup")
        latent_loader = latent.test_dataloader() if stage == "test" else latent.val_dataloader()
        if verbose: print(f"Start loading {stage} batch. -end dataloader")
        batch = next(iter(latent_loader))
        z_s_te, z_t_te, label_te = batch
        if verbose: print(f"Test batch shape: {z_s_te.shape, z_t_te.shape, label_te.shape}")

        self.FACTOR_NAMES = latent.Data_class.Params.FACTORS_IN_ORDER
        if self.only_factors : self.FACTOR_NAMES = list(np.asarray(self.FACTOR_NAMES, dtype=str)[self.only_factors])

        if self.is_pca:
            print("Start PCA.")
            #test
            pca_t_te = PCA(n_components=1) 
            pca_s_te = PCA(n_components=1)
            z_t_te = pca_t_te.fit_transform(z_t_te)
            z_s_te = pca_s_te.fit_transform(z_s_te)
            print(z_s_te.shape, z_t_te.shape)
            if not isinstance(z_t_te, torch.Tensor) : z_t_te = torch.tensor(z_t_te)
            if not isinstance(z_s_te, torch.Tensor) : z_s_te = torch.tensor(z_s_te)
            print("End PCA.")
        
        return z_s_te, z_t_te, label_te

    def value_index(self, ys):
        out=[]
        for k in range(ys.shape[0]):
            d={}
            for v in np.unique(ys[k]):
                d[int(v)]=np.flatnonzero(ys[k]==v)
            out.append(d)
        return out

    def collect(self, mus, ys, n_iter, batch_size=64, verbose=False):
        z_std = mus.std(axis=1, keepdims=True); z_std[z_std==0]=1.0
        v2i = self.value_index(ys)
        argmins, labels = [], []
        if verbose: print("Starting computing FactorVAE metric.")
        for _ in tqdm(range(n_iter)):
            k = self.rng.integers(0, ys.shape[0]) #Choose a factor f_k
            v = self.rng.choice(list(v2i[k].keys())) #Choose a value for f_k
            pool = v2i[k][v]
            idx = self.rng.choice(pool, size=batch_size, replace=(len(pool)<batch_size)) #Batch with f_k=v

            Z = mus[:, idx]/z_std
            d = int(Z.var(axis=1).argmin()) #get the argmin variance for this batch
            argmins.append(d); labels.append(k)
        return np.array(argmins), np.array(labels)

    def save(self):
        paths = self.ckpt_path.split("/")
        folder_path=""
        for k in range(len(paths)-2) :
            folder_path+= paths[k]+"/"
        folder_path+="metric"
        print(f"Saving at {folder_path}.")

        try: mkdir(folder_path)
        except FileExistsError: pass

        scores = {"dim_factor_score": self.dim_factor_score, "factor_dim_score": self.factor_dim_score}
        torch.save(scores, join(folder_path, "metric.pt"))
        
    def get_dicts(self, mu_s, ys, verbose: bool=True) :
        argmins, labels = self.collect(mu_s, ys, self.n_iter, self.batch_size, verbose=verbose)
        self.argmins = argmins
        self.labels = labels
        dim_factor_score = defaultdict(list)
        # Taux d'association dim->facteur
        for d in np.unique(argmins):
            dim_factor_score[str(d)] = defaultdict(float)
            cnt = Counter(labels[argmins==d]) #labels[argmins==d], How many times f_k is assigned to dimension d?
            total = sum(cnt.values()) #Number of labels assigned to dimension d
            if verbose: print(f"\nDimension {d}:")
            for k,n in cnt.most_common():
                dim_factor_score[str(d)][self.FACTOR_NAMES[k]] = n/total
                if verbose: print(f"  {self.FACTOR_NAMES[k]:12s} : {n/total:5.1%}  ({n}/{total})")

        # """Dimension 0:
        #   scale        : 73.8%  (135/183)
        #   shape        : 26.2%  (48/183)
        # """ means 73.8% of labels assigned to dimension 0 are scale.

        factor_dim_score = defaultdict(list)
        for k in np.unique(labels):
            mask = (labels == k)
            cnt = Counter(argmins[mask])                # Combien de fois la dim d "gagne" pour le facteur k ?
            total = sum(cnt.values())
            name  = self.FACTOR_NAMES[k]
            factor_dim_score[name] = defaultdict(float)
            if verbose: print(f"\nFacteur {name}:")
            for d, n in cnt.most_common():             # tri décroissant
                if verbose: print(f"  dim {d:>3} : {n/total:5.1%}  ({n}/{total})")
                factor_dim_score[name][str(d)]=n/total

        self.factor_dim_score = factor_dim_score
        self.dim_factor_score = dim_factor_score

        self.save()
        return factor_dim_score, dim_factor_score
        
    def get_data_name(self):
        root_path = self.ckpt_path.split("/")[:-2]
        data_name = str(root_path[-6])
        return data_name
    
    def compute_map(self):
        dim_factor_score = self.dim_factor_score
        #map_dim_factor = defaultdict(float)
        map_dim_factor = {}

        for dim in range(int(list(self.dim_factor_score.keys())[-1])+1) :
            factors = dim_factor_score[str(dim)]
            if type(factors) == list : first_factor = "s"
            #elif self.only_factors and (list(factors.keys())[0] not in np.asarray(self.FACTOR_NAMES)[self.only_factors]) : first_factor = "s"
            else : first_factor = list(factors.keys())[0]
            map_dim_factor[str(dim)] = first_factor
        self.map_dim_factor = map_dim_factor
        return map_dim_factor

    def compute_score(self, verbose=False):
        self.get_dicts(self.mus_val, self.ys_val, verbose=verbose)
        map = self.compute_map()
        if verbose: print(map)

        argmins, labels = self.collect(self.mus_test, self.ys_test, self.n_iter, self.batch_size)
        predictions = []
        for argmin in argmins:
            pred_str = self.map_dim_factor[str(argmin)]
            if pred_str == "s" : pred_int = -1
            else :
                try : pred_int = self.FACTOR_NAMES.index(pred_str)
                except ValueError : print(f"Factor {pred_str} does not exist.")
            predictions.append(pred_int)
        
        predictions = np.asarray(predictions)
        if verbose : 
            for prediction, label in zip(predictions, labels):
                print(prediction, label)

        return np.sum(predictions == labels)/self.n_iter

class FactorVAEScoreLight :

    def __init__(self,  
                 buff: dict, 
                 mode: str, 
                 dim_t: int, 
                 dim_s: int, 
                 select_factor: int,
                 n_iter: int=153600,
                 batch_size: int=64) :
        
        self.get_buff(mode)
        self.dim_t = dim_t
        self.dim_s = dim_s
        self.select_factor = select_factor
        self.rng = np.random.default_rng(0)
        self.n_iter = n_iter
        self.batch_size = batch_size
        
    def get_buff(self, mode):
        z_s, z_t, label = self.buff[mode]

        Z = torch.cat([z_s, z_t], dim=1).cpu().numpy()        
        Z = (Z - Z.mean(axis=0, keepdims=True)) / (Z.std(axis=0, keepdims=True) + 1e-8)
        Y = label.cpu().numpy().astype(np.int64)              

        self.mus_train = Z.T                                   
        self.ys_train  = Y.T
        print(f"{mode} data formated.")

    def value_index(self, ys):
        out=[]
        for k in range(ys.shape[0]):
            d={}
            for v in np.unique(ys[k]):
                d[int(v)]=np.flatnonzero(ys[k]==v)
            out.append(d)
        return out

    def collect(self, mus, ys, n_iter=20000, batch_size=64):
        z_std = mus.std(axis=1, keepdims=True); z_std[z_std==0]=1.0
        v2i = self.value_index(ys)
        argmins, labels = [], []
        print("Starting computing FactorVAE metric.")
        for _ in tqdm(range(n_iter)):
            k = self.rng.integers(0, ys.shape[0]) #Choose a factor f_k
            v = self.rng.choice(list(v2i[k].keys())) #Choose a value for f_k
            pool = v2i[k][v]
            idx = self.rng.choice(pool, size=batch_size, replace=(len(pool)<batch_size)) #Batch with f_k=v

            Z = mus[:, idx]/z_std
            d = int(Z.var(axis=1).argmin()) #get the argmin variance for this batch
            argmins.append(d); labels.append(k)
        return np.array(argmins), np.array(labels)
    
    def get_argmins(self, verbose: bool=True) :
        argmins, labels = self.collect(self.mus_test, self.ys_test, n_iter=153600, batch_size=64)
        self.argmins = argmins
        self.labels = labels

    def get_score(self):
        self.get_argmins()
        N = len(self.argmins)
        tp = 0
        dims_t = [self.dim_s+k for k in range(self.dim_t)]

        for dim, factor in zip(self.argmins, self.labels):

            if dim in dims_t :
                if factor == self.select_factor : tp+=1
            if dim not in dims_t :
                if factor != self.select_factor : tp+=1
        score = tp/N
        print(f"FactorVAEScore: {score}")
        return score

In [None]:
import torch
import numpy as np
from sklearn.decomposition import PCA
from os.path import join
from os import mkdir
from collections import Counter, defaultdict
from tqdm import tqdm

class FactorVAEScore:

    def __init__(self, ckpt_path: str,
                is_pca: bool=False,
                n_iter: int=153600,
                batch_size: int=64,
                pref_gpu: int=0,
                verbose: bool=False,
                only_factors: list[int]=[],
                collapse_others_to_s: bool=True):

        self.ckpt_path = ckpt_path
        self.data_name = self.get_data_name()
        self.is_pca = is_pca
        self.n_iter = n_iter
        self.batch_size = batch_size
        self.pref_gpu = pref_gpu

        # -------------------------
        # NEW: init targets AVANT load_latent()
        # -------------------------
        self.only_factors = list(only_factors) if only_factors else []
        self.collapse_others_to_s = collapse_others_to_s

        # Dédoublonne en gardant l'ordre (pas besoin que ce soit trié)
        self._target_factors = list(dict.fromkeys(self.only_factors))
        # Pour que les "if self.only_factors" restent cohérents partout
        self.only_factors = self._target_factors

        # Load latents (load_latent peut maintenant utiliser _target_factors sans crash)
        z_s_te, z_t_te, label_te = self.load_latent(stage="test", verbose=verbose)
        z_s_val, z_t_val, label_val = self.load_latent(stage="val", verbose=verbose)

        Z_te = torch.cat([z_s_te, z_t_te], dim=1).cpu().numpy()
        Z_te = (Z_te - Z_te.mean(axis=0, keepdims=True)) / (Z_te.std(axis=0, keepdims=True) + 1e-8)
        Y_te = label_te.cpu().numpy().astype(np.int64)

        Z_val = torch.cat([z_s_val, z_t_val], dim=1).cpu().numpy()
        Z_val = (Z_val - Z_val.mean(axis=0, keepdims=True)) / (Z_val.std(axis=0, keepdims=True) + 1e-8)
        Y_val = label_val.cpu().numpy().astype(np.int64)

        # -------------------------
        # NEW: configuration "targets + s"
        # -------------------------
        self.n_factors_full = Y_val.shape[1]

        if self.only_factors:
            assert max(self.only_factors) < self.n_factors_full, "Index de facteur hors-borne."

            if self.collapse_others_to_s:
                target_set = set(self._target_factors)
                self._other_factors = [i for i in range(self.n_factors_full) if i not in target_set]

                if len(self._other_factors) == 0:
                    raise ValueError(
                        "collapse_others_to_s=True mais aucun 'other factor' (tu as sélectionné tous les facteurs)."
                    )

                self._s_label = len(self._target_factors)       # ex: 6
                self._n_eval_factors = self._s_label + 1         # ex: 7
            else:
                self._other_factors = []
                self._s_label = None
                self._n_eval_factors = len(self._target_factors)
        else:
            self._other_factors = None
            self._s_label = None
            self._n_eval_factors = self.n_factors_full

        # Backward compat: ancien mode "slice" (si collapse_others_to_s=False)
        if self.only_factors and (not self.collapse_others_to_s):
            Y_val = Y_val[:, self.only_factors]
            Y_te  = Y_te[:, self.only_factors]

        self.mus_test = Z_te.T
        self.ys_test = Y_te.T
        if verbose: print("Test data formated.")

        self.mus_val = Z_val.T
        self.ys_val = Y_val.T
        if verbose: print("Val data formated.")

        self.rng = np.random.default_rng(0)

    def get_data_name(self):
        root_path = self.ckpt_path.split("/")[:-2]
        data_name = str(root_path[-6])
        return data_name

    def load_latent(self, stage: str="test", verbose: bool=False):
        latent = LatentDataModule(
            standard=True,
            batch_size=2**19,
            pref_gpu=self.pref_gpu,
            Model_class=Xfactors,
            data_name=self.data_name,
            ckpt_path=self.ckpt_path,
            verbose=verbose
        )

        latent.prepare_data()
        latent.setup(stage)
        latent_loader = latent.test_dataloader() if stage == "test" else latent.val_dataloader()
        batch = next(iter(latent_loader))
        z_s, z_t, label = batch

        # Full names (40)
        self.FACTOR_NAMES_FULL = latent.Data_class.Params.FACTORS_IN_ORDER

        # Names used by the metric output
        if self.only_factors and self.collapse_others_to_s:
            target_names = [self.FACTOR_NAMES_FULL[i] for i in self._target_factors]
            self.FACTOR_NAMES = target_names + ["s"]  # length = K+1
            self._label_to_name = {i: self.FACTOR_NAMES[i] for i in range(len(self.FACTOR_NAMES))}
        elif self.only_factors:
            self.FACTOR_NAMES = list(np.asarray(self.FACTOR_NAMES_FULL, dtype=str)[self.only_factors])
            self._label_to_name = {i: self.FACTOR_NAMES[i] for i in range(len(self.FACTOR_NAMES))}
        else:
            self.FACTOR_NAMES = self.FACTOR_NAMES_FULL
            self._label_to_name = {i: self.FACTOR_NAMES[i] for i in range(len(self.FACTOR_NAMES))}

        if self.is_pca:
            pca_t = PCA(n_components=1)
            pca_s = PCA(n_components=1)
            z_t = pca_t.fit_transform(z_t)
            z_s = pca_s.fit_transform(z_s)
            if not isinstance(z_t, torch.Tensor): z_t = torch.tensor(z_t)
            if not isinstance(z_s, torch.Tensor): z_s = torch.tensor(z_s)

        return z_s, z_t, label

    def value_index(self, ys):
        out = []
        for k in range(ys.shape[0]):
            d = {}
            for v in np.unique(ys[k]):
                d[int(v)] = np.flatnonzero(ys[k] == v)
            out.append(d)
        return out

    def collect(self, mus, ys, n_iter, batch_size=64, verbose=False):
        z_std = mus.std(axis=1, keepdims=True)
        z_std[z_std == 0] = 1.0

        v2i = self.value_index(ys)
        argmins, labels = [], []

        if verbose: print("Starting computing FactorVAE metric.")

        for _ in tqdm(range(n_iter)):
            if self.only_factors and self.collapse_others_to_s:
                # k_eval in {0..K} where K is "s"
                k_eval = self.rng.integers(0, self._n_eval_factors)

                if k_eval == self._s_label:
                    # "s" -> choose one non-target factor
                    k_src = int(self.rng.choice(self._other_factors))
                    label_k = self._s_label
                else:
                    # target k_eval -> map to the original factor index
                    k_src = int(self._target_factors[k_eval])
                    label_k = int(k_eval)

                v = self.rng.choice(list(v2i[k_src].keys()))
                pool = v2i[k_src][v]
                idx = self.rng.choice(pool, size=batch_size, replace=(len(pool) < batch_size))

                Z = mus[:, idx] / z_std
                d = int(Z.var(axis=1).argmin())
                argmins.append(d)
                labels.append(label_k)

            else:
                # original behaviour
                k = self.rng.integers(0, ys.shape[0])
                v = self.rng.choice(list(v2i[k].keys()))
                pool = v2i[k][v]
                idx = self.rng.choice(pool, size=batch_size, replace=(len(pool) < batch_size))

                Z = mus[:, idx] / z_std
                d = int(Z.var(axis=1).argmin())
                argmins.append(d)
                labels.append(k)

        return np.array(argmins), np.array(labels)

    def save(self):
        paths = self.ckpt_path.split("/")
        folder_path = ""
        for k in range(len(paths) - 2):
            folder_path += paths[k] + "/"
        folder_path += "metric"
        print(f"Saving at {folder_path}.")

        try:
            mkdir(folder_path)
        except FileExistsError:
            pass

        scores = {"dim_factor_score": self.dim_factor_score,
                "factor_dim_score": self.factor_dim_score}
        torch.save(scores, join(folder_path, "metric.pt"))

    def get_dicts(self, mu_s, ys, verbose: bool=True):
        argmins, labels = self.collect(mu_s, ys, self.n_iter, self.batch_size, verbose=verbose)
        self.argmins = argmins
        self.labels = labels

        dim_factor_score = defaultdict(list)
        for d in np.unique(argmins):
            dim_factor_score[str(d)] = defaultdict(float)
            cnt = Counter(labels[argmins == d])
            total = sum(cnt.values())
            if verbose: print(f"\nDimension {d}:")
            for k, n in cnt.most_common():
                name = self._label_to_name[int(k)]
                dim_factor_score[str(d)][name] = n / total
                if verbose: print(f"  {name:12s} : {n/total:5.1%}  ({n}/{total})")

        factor_dim_score = defaultdict(list)
        for k in np.unique(labels):
            mask = (labels == k)
            cnt = Counter(argmins[mask])
            total = sum(cnt.values())
            name = self._label_to_name[int(k)]
            factor_dim_score[name] = defaultdict(float)
            if verbose: print(f"\nFacteur {name}:")
            for d, n in cnt.most_common():
                if verbose: print(f"  dim {d:>3} : {n/total:5.1%}  ({n}/{total})")
                factor_dim_score[name][str(d)] = n / total

        self.factor_dim_score = factor_dim_score
        self.dim_factor_score = dim_factor_score

        self.save()
        return factor_dim_score, dim_factor_score

    def compute_map(self):
        dim_factor_score = self.dim_factor_score
        map_dim_factor = {}

        # plus robuste que "last key"
        if len(dim_factor_score) == 0:
            self.map_dim_factor = {}
            return self.map_dim_factor

        max_dim = max(int(k) for k in dim_factor_score.keys() if k.isdigit())
        for dim in range(max_dim + 1):
            factors = dim_factor_score[str(dim)]
            # Si aucune dim n'a été assignée -> default list
            if isinstance(factors, list) or len(factors) == 0:
                first_factor = "s"
            else:
                first_factor = list(factors.keys())[0]
            map_dim_factor[str(dim)] = first_factor

        self.map_dim_factor = map_dim_factor
        return map_dim_factor

    def compute_score(self, verbose=False):
        self.get_dicts(self.mus_val, self.ys_val, verbose=verbose)
        self.compute_map()

        argmins, labels = self.collect(self.mus_test, self.ys_test, self.n_iter, self.batch_size)

        predictions = []
        for argmin in argmins:
            pred_str = self.map_dim_factor[str(argmin)]

            if pred_str == "s":
                if self.only_factors and self.collapse_others_to_s:
                    pred_int = self._s_label
                else:
                    pred_int = -1
            else:
                try:
                    pred_int = self.FACTOR_NAMES.index(pred_str)
                except ValueError:
                    pred_int = -1

            predictions.append(pred_int)

        predictions = np.asarray(predictions)
        return np.sum(predictions == labels) / self.n_iter


##### Compute FactorVAEScore sur un dummy classifier

In [18]:
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["shapes"]["b1"]["-1"], verbose=False, only_factors=[1,2,3])
argmins, labels = xscore.collect(xscore.mus_test, xscore.ys_test, xscore.n_iter, xscore.batch_size, verbose=False)
predictions = []
for argmin in argmins:
    pred_int = np.random.randint(0,3)
    predictions.append(pred_int)

predictions = np.asarray(predictions)
np.sum(predictions == labels)/xscore.n_iter

current device is 0


100%|██████████| 47/47 [00:09<00:00,  5.20it/s]


current device is 0


100%|██████████| 38/38 [00:07<00:00,  5.25it/s]


(76800, 6) (96000, 6)
(76800, 3) (96000, 3)


100%|██████████| 153600/153600 [00:10<00:00, 14025.64it/s]


np.float64(0.3330208333333333)

##### CelebA

In [None]:
celeba_scores = []
celeba_path = ckpt_path_x["126"]["celeba"]
for bt in celeba_path.keys():
    for bs in celeba_path[bt].keys():
        for config in celeba_path[bt][bs].keys():
            path = celeba_path[bt][bs][config]
            if path != "" :
                xscore = FactorVAEScore(ckpt_path=path)
                celeba_scores.append(xscore.compute_score())
print(celeba_scores)

current device is 0


100%|██████████| 10/10 [00:01<00:00,  6.83it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.48it/s]
100%|██████████| 153600/153600 [00:08<00:00, 17319.89it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(1.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:08<00:00, 18874.25it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.82it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  6.25it/s]
100%|██████████| 153600/153600 [00:07<00:00, 19613.11it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:07<00:00, 19737.39it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.45it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.76it/s]
100%|██████████| 153600/153600 [00:07<00:00, 20897.48it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(500.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:07<00:00, 20012.21it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.85it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.96it/s]
100%|██████████| 153600/153600 [00:07<00:00, 20133.27it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(1.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:07<00:00, 20429.28it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.17it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.84it/s]
100%|██████████| 153600/153600 [00:07<00:00, 20956.60it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:06<00:00, 22381.59it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.01it/s]


current device is 0


100%|██████████| 10/10 [00:01<00:00,  7.44it/s]
100%|██████████| 153600/153600 [00:06<00:00, 23967.35it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/celeba/loss_vae_nce/factor=final/batch32/test_dim_s126/x_epoch=65_beta=(500.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=15,20,26,31,35,36/metric.


100%|██████████| 153600/153600 [00:07<00:00, 21755.60it/s]


[np.float64(0.09822916666666667), np.float64(0.09540364583333333), np.float64(0.113671875), np.float64(0.10962239583333333), np.float64(0.11260416666666667), np.float64(0.09844401041666667)]


##### Shapes

In [22]:
shapes_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x_126["shapes"]["bt100"]["bs100"]["n-1"], n_iter=10000)
for _ in range(1) : shapes_scores.append(xscore.compute_score())
np.mean(shapes_scores), np.std(shapes_scores)

current device is 0


100%|██████████| 47/47 [00:12<00:00,  3.78it/s]


current device is 0


100%|██████████| 38/38 [00:09<00:00,  3.91it/s]
100%|██████████| 10000/10000 [00:00<00:00, 11530.22it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/shapes/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 10000/10000 [00:00<00:00, 11652.04it/s]


(np.float64(1.0), np.float64(0.0))

In [4]:
shapes_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["shapes"]["b1_t3"]["s-1"])
for _ in range(5) : shapes_scores.append(xscore.compute_score())
np.mean(shapes_scores), np.std(shapes_scores)

current device is 0


100%|██████████| 47/47 [00:17<00:00,  2.62it/s]


current device is 0


100%|██████████| 38/38 [00:07<00:00,  5.04it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13516.34it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/shapes/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12143.76it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12822.51it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/shapes/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13405.64it/s]
100%|██████████| 153600/153600 [00:10<00:00, 14079.35it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/shapes/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13334.32it/s]
100%|██████████| 153600/153600 [00:10<00:00, 14194.03it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/shapes/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:10<00:00, 14113.96it/s]
100%|██████████| 153600/153600 [00:10<00:00, 14180.46it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/shapes/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13317.66it/s]


(np.float64(0.9999895833333333), np.float64(1.3405768412731297e-05))

##### MPI3D

###### **Beta 1, 1 facteur dans S**

In [9]:
mpi3d_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["mpi3d"]["b1"]["s-1"])
for _ in range(10) : mpi3d_scores.append(xscore.compute_score())
np.mean(mpi3d_scores), np.std(mpi3d_scores)

current device is 0


100%|██████████| 165/165 [00:22<00:00,  7.42it/s]


current device is 0


100%|██████████| 49/49 [00:05<00:00,  9.13it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12738.54it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12243.83it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12876.04it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12564.34it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12496.98it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12878.67it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12590.78it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12416.85it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12741.19it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12446.39it/s]
100%|██████████| 153600/153600 [00:12<00:00, 11988.36it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12559.65it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12084.87it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12696.64it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12212.33it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12097.67it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12354.35it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 11988.57it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12459.35it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12817.08it/s]


(np.float64(0.5419576822916666), np.float64(0.0012758591294291892))

###### **beta 1, 0 facteur dans s**

In [10]:
mpi3d_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["mpi3d"]["b1"]["-1"])
for _ in range(10) : mpi3d_scores.append(xscore.compute_score())
np.mean(mpi3d_scores), np.std(mpi3d_scores)

current device is 0


100%|██████████| 165/165 [00:21<00:00,  7.63it/s]


current device is 0


100%|██████████| 49/49 [00:05<00:00,  8.96it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12407.35it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12727.39it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13162.27it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12648.95it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12857.49it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12785.62it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12284.39it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12787.51it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12053.64it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12762.48it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12859.16it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12435.03it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12643.21it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13007.66it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12892.77it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12360.48it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12848.77it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12285.68it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12620.95it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/mpi3d/loss_vae_nce/factor0,1,2,3,4,5,6/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12645.27it/s]


(np.float64(0.9411783854166667), np.float64(0.00046931931572990836))

###### **beta 100, 1 facteur dans S**

In [11]:
mpi3d_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["mpi3d"]["b100"]["s-1"])
for _ in range(10) : mpi3d_scores.append(xscore.compute_score())
np.mean(mpi3d_scores), np.std(mpi3d_scores)

current device is 0


100%|██████████| 165/165 [00:22<00:00,  7.41it/s]


current device is 0


100%|██████████| 49/49 [00:05<00:00,  8.90it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12495.56it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12416.95it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12526.63it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12608.56it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12882.07it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12612.96it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12506.51it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12236.76it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12847.99it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12385.97it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12291.37it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12686.96it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12625.59it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12309.94it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12427.89it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12921.19it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12588.81it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12759.66it/s]
100%|██████████| 153600/153600 [00:12<00:00, 12735.69it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/mpi3d/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 12464.31it/s]


(np.float64(0.9737011718749999), np.float64(0.00032445463411304447))

In [5]:
mpi3d_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["mpi3d"]["b1_t3"]["s-1"])
for _ in range(5) : mpi3d_scores.append(xscore.compute_score())
np.mean(mpi3d_scores), np.std(mpi3d_scores)

current device is 0


100%|██████████| 165/165 [00:40<00:00,  4.07it/s]


current device is 0


100%|██████████| 49/49 [00:10<00:00,  4.89it/s]
100%|██████████| 153600/153600 [00:12<00:00, 11853.30it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/mpi3d/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:14<00:00, 10327.40it/s]
100%|██████████| 153600/153600 [00:19<00:00, 7955.33it/s] 


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/mpi3d/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:20<00:00, 7467.61it/s]
100%|██████████| 153600/153600 [00:15<00:00, 9679.87it/s] 


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/mpi3d/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12922.28it/s]
100%|██████████| 153600/153600 [00:13<00:00, 11332.45it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/mpi3d/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:14<00:00, 10882.86it/s]
100%|██████████| 153600/153600 [00:16<00:00, 9194.98it/s] 


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1_dim_t3/mpi3d/loss_vae_nce/factor-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,3)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:12<00:00, 11883.66it/s]


(np.float64(0.7970794270833333), np.float64(0.0005261866807132937))

##### DSprites

In [4]:
dsprites_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["dsprites"]["b1"]["-1"])
for _ in range(10): dsprites_scores.append(xscore.compute_score())
np.mean(dsprites_scores), np.std(dsprites_scores)

current device is 0


100%|██████████| 72/72 [00:05<00:00, 14.38it/s]


current device is 0


100%|██████████| 58/58 [00:03<00:00, 14.61it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13034.82it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13312.71it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13264.32it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13560.25it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13251.61it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13005.22it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13610.11it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13215.91it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12914.19it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13858.74it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13368.41it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13769.87it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13448.77it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13714.44it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13593.96it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13429.92it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13623.83it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13447.25it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13724.67it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor0,1,2,3,4/batch32/test_dim_s126/x_epoch=100_beta=(100.0,1.0)_latent=(126,2,2,2,2,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1,0.1,0.1,0.1,0.1+l_anti_nce=0.0_factor=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12976.57it/s]


(np.float64(0.9999166666666666), np.float64(2.2326078384750045e-05))

In [5]:
dsprites_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["dsprites"]["b1"]["s-1"])
for _ in range(10): dsprites_scores.append(xscore.compute_score())
np.mean(dsprites_scores), np.std(dsprites_scores)

current device is 0


100%|██████████| 72/72 [00:04<00:00, 15.26it/s]


current device is 0


100%|██████████| 58/58 [00:03<00:00, 15.24it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13578.04it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 12919.57it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13014.62it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13476.13it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13308.84it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13428.63it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13587.47it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13064.57it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13353.96it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13301.03it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13601.23it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13162.30it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13621.96it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13107.92it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13220.94it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13077.81it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13549.40it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13251.05it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13054.16it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t1/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,1.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13472.51it/s]


(np.float64(0.9963118489583334), np.float64(0.00011400667214899524))

In [6]:
dsprites_scores = []
xscore = FactorVAEScore(ckpt_path=ckpt_path_x["dsprites"]["b100"]["s-1"])
for _ in range(10): dsprites_scores.append(xscore.compute_score())
np.mean(dsprites_scores), np.std(dsprites_scores)

current device is 0


100%|██████████| 72/72 [00:04<00:00, 15.40it/s]


current device is 0


100%|██████████| 58/58 [00:03<00:00, 15.72it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13808.51it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13485.61it/s]
100%|██████████| 153600/153600 [00:11<00:00, 12873.45it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13821.86it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13361.43it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:10<00:00, 13985.87it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13221.74it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13590.84it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13863.80it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13521.63it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13592.29it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13737.88it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13338.73it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13836.12it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13299.07it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13536.43it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13613.80it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13447.53it/s]
100%|██████████| 153600/153600 [00:11<00:00, 13354.82it/s]


Saving at /projects/compures/alexandre/disdiff_adapters/disdiff_adapters/logs/x_with_beta_t100/dsprites/loss_vae_nce/factor_s=-1/batch32/test_dim_s126/x_epoch=70_beta=(100.0,100.0)_latent=(126,2)_batch=32_warm_up=False_lr=1e-05_arch=res+l_cov=0.0+l_nce=0.1+l_anti_nce=0.0_factor=s=-1/metric.


100%|██████████| 153600/153600 [00:11<00:00, 13651.63it/s]


(np.float64(0.9571783854166667), np.float64(0.00044777470194535656))

##### Fill factorVAE metrics

In [3]:
datamodules:dict[str, LightningDataModule] = {"cars3d": Cars3DDataModule, "mpi3d": MPI3DDataModule, "shapes": Shapes3DDataModule, "dsprites": DSpritesDataModule,
                "celeba": CelebADataModule}

In [4]:

def _to_bchw(images: torch.Tensor) -> torch.Tensor:
    """Accept CHW, HWC, BCHW, BHWC -> return BCHW."""
    if images.ndim == 3:
        # CHW
        if images.shape[0] in (1, 3):
            return images.unsqueeze(0)
        # HWC
        if images.shape[-1] in (1, 3):
            return images.permute(2, 0, 1).unsqueeze(0)
        raise ValueError(f"Ambiguous image shape {tuple(images.shape)} (not CHW/HWC).")

    if images.ndim == 4:
        # BCHW
        if images.shape[1] in (1, 3):
            return images
        # BHWC
        if images.shape[-1] in (1, 3):
            return images.permute(0, 3, 1, 2)
        raise ValueError(f"Ambiguous batch shape {tuple(images.shape)} (not BCHW/BHWC).")

    raise ValueError(f"Expected 3D or 4D images, got {images.ndim}D with shape {tuple(images.shape)}")

def _minmax01_per_sample(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    Min-max scale each sample independently to [0, 1] over (C,H,W).
    Expects NCHW float tensor.
    """
    x = x.float()
    mn = x.amin(dim=(1, 2, 3), keepdim=True)
    mx = x.amax(dim=(1, 2, 3), keepdim=True)
    denom = (mx - mn).clamp_min(eps)
    return (x - mn) / denom


In [5]:
dataset = "cars3d"
bt = "bt100"
config = "n-1"

for dim in metrics_x.keys():
    datamodule = datamodules[dataset](batch_size=512)
    datamodule.prepare_data()
    datamodule.setup("test")
    images, _ = next(iter(datamodule.test_dataloader()))
    images = torch.as_tensor(images)
    images = _to_bchw(images)
    images = _minmax01_per_sample(images)

    for bs in metrics_x[dim][dataset][bt].keys():
        path = ckpt_path_x[dim][dataset][bt][bs][config]
        try:
            xfactors = Xfactors.load_from_checkpoint(path, map_location=set_device(0)[0])
            images = images.to(device=xfactors.device, dtype=torch.float32)
            recos = []
            for _ in range(5): recos.append(mse(xfactors(images)[2], images).item())
            recos = np.asarray(recos)
            metrics_x[dim][dataset][bt][bs][config]["reco"] = (float(recos.mean()), float(recos.std()))
        except RuntimeError as e: print(e)

tensors loaded.
current device is 0
current device is 0
current device is 0
tensors loaded.
current device is 0
current device is 0
current device is 0
tensors loaded.
current device is 0
current device is 0
current device is 0


In [6]:
metrics_x["126"]["cars3d"]["bt100"]

{'bs1': {'all': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
  'n-1': {'d': [0.7528651251490377, 0.005809209758220596],
   'c': [0.45040158579747597, 0.0036529475067087333],
   'i': [0.6199209546131046, 0.00447241092237733],
   'score': [0.9587, 0.004162210950925002],
   'reco': (0.0013923314400017262, 3.024846584284574e-06)}},
 'bs100': {'all': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
  'n-1': {'d': [0.8024727225742202, 0.004026886487319363],
   'c': [0.5048565594911556, 0.0024893874505523337],
   'i': [0.6947804317866468, 0.005239697805867635],
   'score': [0.9482799999999999, 0.0022337412562783636],
   'reco': (0.004014034382998943, 2.6999283364682603e-05)}},
 'bs500': {'all': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
  'n-1': {'d': [0.7439692773141531, 0.006368656107821697],
   'c': [0.5167055553186648, 0.002183875940072099],
   'i': [0.7072750138906334, 0.003524006324679166],
   'score': [0.9376200000000001, 0.0025701361831622915],
   'reco': (0.005830490402877331, 5.334818791475317

In [None]:
for dim in metrics_x.keys():
    if dim == "126" : continue
    print(dim)
    dataset = "celeba"
    for bt in metrics_x[dim][dataset].keys():
        for bs in metrics_x[dim][dataset][bt].keys():
            for config in metrics_x[dim][dataset][bt][bs].keys():
                path = ckpt_path_x[dim][dataset][bt][bs][config]
                if path == "" : continue
                score = FactorVAEScoreCelebA(path, only_factors=CelebA.Params.REPRESENTANT_IDX, collapse_others_to_s=True,
                                            n_iter=10000, pref_gpu=1)
                acc = score.compute_score()
                print(acc)
                metrics_x[dim][dataset][bt][bs][config]["score"] = acc

{'2': {'shapes': {'bt1': {'bs1': {'n-1': {'d': -1,
      'c': -1,
      'i': -1,
      'score': -1},
     'all': {'d': 0.9999999999216111,
      'c': 0.7979388329220564,
      'i': 0.9999995620256041,
      'score': 0.9999,
      'reco': [6.752953922841697e-05, 3.749390290243857e-07]}},
    'bs100': {'n-1': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
     'all': {'d': 0.9999999999216111,
      'c': 0.7982376982048585,
      'i': 0.9999918959642002,
      'score': 1.0,
      'reco': [8.259454916696996e-05, 1.2211433208978908e-07]}},
    'bs500': {'n-1': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
     'all': {'d': 0.9999999999216113,
      'c': 0.7700243694833294,
      'i': 0.9999996801270391,
      'score': 1.0,
      'reco': [9.813048382056877e-05, 1.1454008040782146e-07]}}},
   'bt100': {'bs1': {'all': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
     'n-1': {'d': -1, 'c': -1, 'i': -1, 'score': -1}},
    'bs100': {'all': {'d': -1, 'c': -1, 'i': -1, 'score': -1},
     'n-1': {'d': -1, 'c': 

In [8]:
with open("metrics_final_0.json", "w") as f:
    json.dump(metrics_x, f, ensure_ascii=False, indent=2)

### MI Methods

In [11]:
# --- Ross MI estimator (discrete Y vs continuous X) applied to blocks S, T, [S,T] ---
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from scipy.special import digamma, gammaln

def to_np(x): return x.detach().cpu().numpy() if isinstance(x, torch.Tensor) else np.asarray(x)

# ------- Differential entropy via Kozachenko–Leonenko (k-NN) -------
def kl_entropy_knn(X, k=5, metric='chebyshev'):
    """
    H(X) ≈ psi(N) - psi(k) + ln(c_d) + (d/N)*sum ln(eps_i)
    - X: (N, d), k>=1
    - metric: 'chebyshev' (∞-norm) ou 'euclidean' (au choix, constant factor cancels in MI)
    Retourne H en nats.
    """
    X = np.asarray(X, dtype=np.float64)
    N, d = X.shape
    # k-NN (k+1 car le plus proche est lui-même)
    nbrs = NearestNeighbors(n_neighbors=k+1, metric=metric)
    nbrs.fit(X)
    # distances au (k+1)-ème voisin (index 0 = soi-même)
    dists, _ = nbrs.kneighbors(X)
    eps = dists[:, -1]  # rayon au k-ième voisin
    # volume de la boule unité (∞-norm ou 2-norm) -> constante; s'annule dans MI, mais on le laisse pour propreté
    if metric == 'chebyshev':
        # volume de la boule L∞ unité = (2)^d
        log_c_d = d * np.log(2.0)
    else:
        # L2: V_d = pi^{d/2} / Gamma(d/2+1)
        log_c_d = (d/2.0)*np.log(np.pi) - gammaln(d/2.0 + 1.0)
    H = digamma(N) - digamma(k) + log_c_d + (d / N) * np.sum(np.log(eps + 1e-12))
    return float(H)

# ------- Ross MI for discrete Y and continuous X -------
def ross_mi_continuous_discrete(X, y, k=5, metric='chebyshev'):
    """
    I(X;Y) = H(X) - sum_y p(y) H(X|Y=y)
    - X: (N,d), y: (N,) entiers
    Retourne (MI_nats, H_y_nats, NMI = MI/H_y).
    """
    X = np.asarray(X, dtype=np.float64)
    y = np.asarray(y, dtype=np.int64).ravel()
    N = X.shape[0]
    Hx = kl_entropy_knn(X, k=k, metric=metric)
    # H(X|Y)
    Hx_given = 0.0
    Hy = 0.0  # H(Y) pour NMI
    vals, counts = np.unique(y, return_counts=True)
    for v, cnt in zip(vals, counts):
        p = cnt / N
        Xv = X[y == v]
        if len(Xv) <= k:
            # pas assez de points dans cette classe -> fallback k' = max(1, len(Xv)-1)
            kv = max(1, len(Xv) - 1)
        else:
            kv = k
        Hx_given += p * kl_entropy_knn(Xv, k=kv, metric=metric)
        Hy += -p * np.log(p + 1e-12)
    I = max(0.0, Hx - Hx_given)  # clamp numérique
    NMI = I / max(Hy, 1e-12)
    return I, Hy, NMI  # nats

# ================== APPLY TO YOUR BLOCKS ==================
Zs = to_np(z_s)        # (N,126)
Zt = to_np(z_t)        # (N,2)

Y  = to_np(label)      # (N,6)  entiers 0..card-1
N, Ds, Dt, K = Zs.shape[0], Zs.shape[1], Zt.shape[1], Y.shape[1]

# Standardisation par bloc (recommandé pour k-NN)
def standardize(X):
    X = np.asarray(X, dtype=np.float64)
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True) + 1e-12
    return (X - mu) / sd

Xs  = standardize(Zs)
Xt  = standardize(Zt)
Xst = standardize(np.concatenate([Zs, Zt], axis=1))

FACTOR_NAMES = ["floor_hue","wall_hue","object_hue","scale","shape","orientation"]  # adapte l'ordre si besoin

for k_nn in [3, 5, 10] :
#k_nn = 5
    metric = 'chebyshev'   # classique pour KL (∞-norm). 'euclidean' marche aussi (constante annulée).

    rows = []
    for k_idx, name in enumerate(FACTOR_NAMES):
        yk = Y[:, k_idx]
        Is, Hy, Ns = ross_mi_continuous_discrete(Xs,  yk, k=k_nn, metric=metric)
        It, _,  Nt = ross_mi_continuous_discrete(Xt,  yk, k=k_nn, metric=metric)
        Ist,_, Nst = ross_mi_continuous_discrete(Xst, yk, k=k_nn, metric=metric)
        # conversion optionnelle en bits: / np.log(2)
        rows.append((name, int(yk.max()+1), Is, It, Ist, Hy, Ns, Nt, Nst))

    # Affichage
    w = max(len(n) for n, *_ in rows)
    #print(f"{'factor':{w}s}  #cls   I(S;Y)   I(T;Y)  I([S,T];Y)   H(Y)   NMI_S  NMI_T  NMI_ST   (nats)")
    print(f"{'factor':{w}s}  NMI_S  NMI_T (nats)")
    for name, ncls, Is, It, Ist, Hy, Ns, Nt, Nst in rows:
        #print(f"{name:{w}s}  {ncls:>4d}  {Is:7.3f}  {It:7.3f}     {Ist:7.3f}   {Hy:5.3f}   {Ns:6.3f} {Nt:6.3f}  {Nst:6.3f}")
        print(f"{name:{w}s}  {Ns:6.3f} {Nt:6.3f}  ")
    print("\n\n")


factor       NMI_S  NMI_T (nats)
floor_hue     0.000  1.000  
wall_hue      0.661  0.224  
object_hue    0.591  0.079  
scale         0.771  0.129  
shape         0.992  0.109  
orientation   0.308  0.263  



factor       NMI_S  NMI_T (nats)
floor_hue     0.000  1.000  
wall_hue      0.525  0.219  
object_hue    0.461  0.073  
scale         0.665  0.128  
shape         0.986  0.103  
orientation   0.200  0.256  



factor       NMI_S  NMI_T (nats)
floor_hue     0.000  1.000  
wall_hue      0.364  0.208  
object_hue    0.322  0.061  
scale         0.529  0.119  
shape         0.965  0.093  
orientation   0.054  0.244  





In [7]:
def whiten_pca(X, n_components=None, zca=False):
    """
    PCA / ZCA whitening.
    - X : (N,d)
    - n_components : None -> garde tout
    - zca=False : PCA; True : ZCA
    Retourne X_whitened de même shape.
    """
    X = np.asarray(X, np.float64)
    mu = X.mean(0, keepdims=True)
    Xc = X - mu
    pca = PCA(n_components=n_components).fit(Xc)
    Xw = pca.transform(Xc)
    var_ratio = pca.explained_variance_ratio_[: n_components]
    cum_var = np.sum(var_ratio)
    print(cum_var)
    if zca:
        Xw = (Xw @ pca.components_)  # revient dans l’espace original
    return Xw
Zs_white = whiten_pca(Zs, n_components=2)   # ou zca=True si tu veux garder la base originale
print(Zs_white.shape)

0.2514179840753171
(32768, 2)


In [8]:
for k_factor in range(6):
    Is, Hy, NMs = ross_mi_continuous_discrete(Zs_white, label[:, k_factor], k=10, metric='chebyshev')
    print(NMs)

0.19698376967542
0.25606322803533016
0.35356016589290995
0.012561051335177377
0.030724644430047774
0.010450275756023971


In [6]:
for k_nn in [5,15,30,50, 500, 1000]:
    Is, Hy, NMs = ross_mi_continuous_discrete(Zs_white, label[:,k_factor], k=k_nn)
    print(f"k={k_nn:2d} -> NMI={NMs:.3f}")


k= 5 -> NMI=0.008
k=15 -> NMI=0.011
k=30 -> NMI=0.013
k=50 -> NMI=0.013
k=500 -> NMI=0.008
k=1000 -> NMI=0.000


### DCI

#### Automatisation

In [12]:
class DCIscore :

    def __init__(self, ckpt_path: str, is_pca: bool=False, n_samples_tr=None, n_samples_te=None) :
        self.ckpt_path = ckpt_path
        self.data_name =  self.get_data_name()
        self.is_pca = is_pca

        z_s_tr, z_t_tr, label_tr, z_s_te, z_t_te, label_te = self.load_latent()   

        z_tr = torch.cat([z_s_tr, z_t_tr], dim=1).cpu().numpy()        
        self.z_tr = (z_tr - z_tr.mean(axis=0, keepdims=True)) / (z_tr.std(axis=0, keepdims=True) + 1e-8)
        self.y_tr = label_tr.cpu().numpy().astype(np.int64)       
        self.n_samples_tr = self.y_tr.shape[0] if n_samples_te is None else n_samples_tr                                        

        z_te = torch.cat([z_s_te, z_t_te], dim=1).cpu().numpy()        
        self.z_te = (z_te - z_te.mean(axis=0, keepdims=True)) / (z_te.std(axis=0, keepdims=True) + 1e-8)
        self.y_te = label_te.cpu().numpy().astype(np.int64)
        self.n_samples_te = self.y_te.shape[0] if n_samples_te is None else n_samples_te    
        
    def load_latent(self):
        latent = LatentDataModule(standard=True, 
                                batch_size=2**19,
                                Model_class=Xfactors,
                                pref_gpu=0,
                                data_name=self.data_name,
                                ckpt_path=self.ckpt_path)
        
        latent.prepare_data()
        latent.setup("val")
        latent_loader = latent.val_dataloader()
        batch = next(iter(latent_loader))
        z_s_tr, z_t_tr, label_tr = batch
        print(z_s_tr.shape, z_t_tr.shape, label_tr.shape)

        latent.setup("test")
        latent_test_loader = latent.test_dataloader()
        batch = next(iter(latent_test_loader))
        z_s_te, z_t_te, label_te = batch
        print(z_s_te.shape, z_t_te.shape, label_te.shape)

        self.FACTOR_NAMES = latent.Data_class.Params.FACTORS_IN_ORDER

        if self.is_pca:
            #train
            pca_t_tr = PCA(n_components=1) 
            pca_s_tr = PCA(n_components=1)
            z_t_tr = pca_t_tr.fit_transform(z_t_tr)
            z_s_tr = pca_s_tr.fit_transform(z_s_tr)
            print(z_s_tr.shape, z_t_tr.shape)
            if not isinstance(z_t_tr, torch.Tensor) : z_t_tr = torch.tensor(z_t_tr)
            if not isinstance(z_s_tr, torch.Tensor) : z_s_tr = torch.tensor(z_s_tr)

            #test
            pca_t_te = PCA(n_components=1) 
            pca_s_te = PCA(n_components=1)
            z_t_te = pca_t_tr.fit_transform(z_t_te)
            z_s_te = pca_s_tr.fit_transform(z_s_te)
            print(z_s_te.shape, z_t_te.shape)
            if not isinstance(z_t_te, torch.Tensor) : z_t_te = torch.tensor(z_t_te)
            if not isinstance(z_s_te, torch.Tensor) : z_s_te = torch.tensor(z_s_te)
        
        return z_s_tr, z_t_tr, label_tr, z_s_te, z_t_te, label_te

    def get_data_name(self):
        root_path = self.ckpt_path.split("/")[:-2]
        data_name = str(root_path[-6])
        return data_name

    def train_reg(self) :
        regressors = {str(k):{} for k in range(self.y_tr.shape[1])}
        for k in tqdm(range(self.y_tr.shape[1])):

            perm = torch.randperm(len(self.z_tr))
            reg_k = RandomForestRegressor(n_estimators=20, 
                                max_depth=20,
                                n_jobs=-1)
            reg_k.fit(self.z_tr[perm][:self.n_samples_tr], self.y_tr[:, k][perm][:self.n_samples_tr])

            perm = torch.randperm(len(self.z_tr))
            score_tr = reg_k.score(self.z_tr[perm][:self.n_samples_tr], self.y_tr[:, k][perm][:self.n_samples_tr])
            perm = torch.randperm(len(self.z_te))
            score_te = reg_k.score(self.z_te[perm][:self.n_samples_te], self.y_te[:, k][perm][:self.n_samples_te])

            print(f"Reg_{k} score={score_tr}, {score_te}")
            regressors[str(k)]["model"] = reg_k
            regressors[str(k)]["score_tr"] = score_tr
            regressors[str(k)]["score_te"] = score_te
        self.regressors = regressors
    
    def compute_weights(self):
        # 1) Récupérer la matrice d'importances R (D, K) à partir de tes régressions
        D = self.z_tr.shape[1]
        K = self.y_tr.shape[1]

        R = np.zeros((D, K), dtype=float)  # feature importances pour chaque facteur k
        for k in range(K):
            model = self.regressors[str(k)]["model"]
            imp = getattr(model, "feature_importances_", None)
            if imp is None:
                raise ValueError(f"Aucune feature_importances_ pour k={k}")
            if len(imp) != D:
                raise ValueError(f"Dim mismatch: len(imp)={len(imp)} vs D={D}")
            R[:, k] = imp

        # 2) Normaliser par colonne -> P(d | k): "où vit le facteur k ?"
        col_sum = R.sum(axis=0, keepdims=True)               # (1, K)
        col_sum[col_sum == 0] = 1.0                          # éviter /0 si colonne nulle
        self.P_d_given_k = R / col_sum                             # (D, K)

        # 3) (optionnel) Normaliser par ligne -> P(k | d): "quel facteur porte la dimension d ?"
        row_sum = R.sum(axis=1, keepdims=True)               # (D, 1)
        row_sum[row_sum == 0] = 1.0
        self.P_k_given_d = R / row_sum                             # (D, K)

    def print_top_dims_per_factor(self, P, top=5, factor_names=None):
        D, K = P.shape
        if factor_names is None:
            factor_names = [f"f{k}" for k in range(K)]
        for k in range(K):
            order = np.argsort(-P[:, k])
            print(f"\nFacteur {factor_names[k]}:")
            for d in order[:top]:
                print(f"  dim {d:3d} : {P[d, k]:.3f}")

    def dci_scores(self, eps=1e-12):
        """
        P_kd : array (K, D) d'importances non-négatives (p.ex. permutation importance, gain, ou P(d|k)).
            Pas besoin d'être normalisé: on renormalisera correctement pour D et C.
        r2_per_factor : iterable de longueur K avec les R^2 (test) pour l'informativeness (optionnel).

        Retourne: D, C, I (I peut être None si r2_per_factor est None).
        """
        r2_per_factor=[reg["score_te"] for reg in self.regressors.values()]

        # self.P_d_given_k est (D, K) => on transpose pour avoir (K, D)
        P_kd = self.P_d_given_k.T        # (K, D)
        R = np.asarray(P_kd, dtype=float)
        K, D = R.shape

        R = np.clip(R, 0.0, None)
        if R.sum() == 0:
            raise ValueError("La matrice d'importances est nulle.")

        # Poids des dimensions et des facteurs (somme des importances)
        w_d = R.sum(axis=0)  # (D,)
        w_k = R.sum(axis=1)  # (K,)
        w_d = w_d / (w_d.sum() + eps)
        w_k = w_k / (w_k.sum() + eps)

        # p(k|d) : normalisation par colonnes ; p(d|k) : normalisation par lignes
        P_k_given_d = R / (R.sum(axis=0, keepdims=True) + eps)   # (K, D)
        P_d_given_k = R / (R.sum(axis=1, keepdims=True) + eps)   # (K, D)

        # Entropies
        def entropy(p, axis):
            p = np.clip(p, eps, 1.0)
            return -(p * np.log(p)).sum(axis=axis)

        H_k_given_d = entropy(P_k_given_d, axis=0)               # (D,)
        H_d_given_k = entropy(P_d_given_k, axis=1)               # (K,)

        # Disentanglement et completeness (entropies normalisées)
        D_score = float(((1.0 - H_k_given_d / (np.log(K) + eps)) * w_d).sum())
        C_score = float(((1.0 - H_d_given_k / (np.log(D) + eps)) * w_k).sum())

        # Informativeness = moyenne des R^2 test (tronqués à [0,1])
        I_score = None
        if r2_per_factor is not None:
            r2 = np.asarray(r2_per_factor, dtype=float)
            I_score = float(np.clip(r2, 0.0, 1.0).mean())

        return D_score, C_score, I_score

    def compute(self):
        self.train_reg()
        self.compute_weights()
        D, C, I = self.dci_scores()
        print(f"D:{D}")
        print(f"C={C}")
        print(f"I={I}")
        return D,C,I

In [6]:
class DCIscore:

    def __init__(self,
                ckpt_path: str,
                is_pca: bool = False,
                n_samples_tr=None,
                n_samples_te=None,
                only_factors: list[int] = [],
                collapse_others_to_s: bool = True,
                pref_gpu: int = 0,
                seed: int = 0):

        self.ckpt_path = ckpt_path
        self.data_name = self.get_data_name()
        self.is_pca = is_pca
        self.pref_gpu = pref_gpu

        self.only_factors = list(only_factors) if only_factors else []
        self.collapse_others_to_s = collapse_others_to_s

        # dedupe en gardant l'ordre
        self._target_factors = list(dict.fromkeys(self.only_factors))
        self.only_factors = self._target_factors

        self.rng = np.random.default_rng(seed)

        z_s_tr, z_t_tr, label_tr, z_s_te, z_t_te, label_te = self.load_latent()

        z_tr = torch.cat([z_s_tr, z_t_tr], dim=1).cpu().numpy()
        self.z_tr = (z_tr - z_tr.mean(axis=0, keepdims=True)) / (z_tr.std(axis=0, keepdims=True) + 1e-8)
        self.y_tr_full = label_tr.cpu().numpy().astype(np.int64)   # (N, K_full)

        z_te = torch.cat([z_s_te, z_t_te], dim=1).cpu().numpy()
        self.z_te = (z_te - z_te.mean(axis=0, keepdims=True)) / (z_te.std(axis=0, keepdims=True) + 1e-8)
        self.y_te_full = label_te.cpu().numpy().astype(np.int64)   # (N, K_full)

        # Samples
        self.n_samples_tr = self.y_tr_full.shape[0] if n_samples_tr is None else int(n_samples_tr)
        self.n_samples_te = self.y_te_full.shape[0] if n_samples_te is None else int(n_samples_te)

        # --- config facteurs évalués ---
        self.n_factors_full = self.y_tr_full.shape[1]

        if self.only_factors:
            assert max(self.only_factors) < self.n_factors_full, "Index de facteur hors-borne."

            if self.collapse_others_to_s:
                target_set = set(self._target_factors)
                self._other_factors = [i for i in range(self.n_factors_full) if i not in target_set]
                if len(self._other_factors) == 0:
                    raise ValueError("collapse_others_to_s=True mais aucun autre facteur (tu as tout sélectionné).")

                self._s_label = len(self._target_factors)          # ex: 6
                self._n_eval_factors = self._s_label + 1            # ex: 7

                # noms: 6 cibles + s
                target_names = [self.FACTOR_NAMES_FULL[i] for i in self._target_factors]
                self.FACTOR_NAMES = target_names + ["s"]
            else:
                self._other_factors = []
                self._s_label = None
                self._n_eval_factors = len(self._target_factors)

                self.FACTOR_NAMES = [self.FACTOR_NAMES_FULL[i] for i in self._target_factors]
        else:
            self._other_factors = None
            self._s_label = None
            self._n_eval_factors = self.n_factors_full
            self.FACTOR_NAMES = self.FACTOR_NAMES_FULL

    def load_latent(self):
        latent = LatentDataModule(
            standard=True,
            batch_size=2**19,
            Model_class=Xfactors,
            pref_gpu=self.pref_gpu,
            data_name=self.data_name,
            ckpt_path=self.ckpt_path
        )

        latent.prepare_data()

        latent.setup("val")
        batch = next(iter(latent.val_dataloader()))
        z_s_tr, z_t_tr, label_tr = batch

        latent.setup("test")
        batch = next(iter(latent.test_dataloader()))
        z_s_te, z_t_te, label_te = batch

        self.FACTOR_NAMES_FULL = latent.Data_class.Params.FACTORS_IN_ORDER

        if self.is_pca:
            # train fit
            pca_t = PCA(n_components=1)
            pca_s = PCA(n_components=1)
            z_t_tr_np = pca_t.fit_transform(z_t_tr)
            z_s_tr_np = pca_s.fit_transform(z_s_tr)

            # test transform (PAS fit_transform)
            z_t_te_np = pca_t.transform(z_t_te)
            z_s_te_np = pca_s.transform(z_s_te)

            z_t_tr = torch.tensor(z_t_tr_np)
            z_s_tr = torch.tensor(z_s_tr_np)
            z_t_te = torch.tensor(z_t_te_np)
            z_s_te = torch.tensor(z_s_te_np)

        return z_s_tr, z_t_tr, label_tr, z_s_te, z_t_te, label_te

    def get_data_name(self):
        root_path = self.ckpt_path.split("/")[:-2]
        return str(root_path[-6])

    def train_reg(self):
        """
        En mode collapse_others_to_s=True :
        - k_eval=0..K-1 : régression 1D sur facteur cible
        - k_eval=K      : régression multi-output sur tous les autres facteurs (le bucket s)
        """
        regressors = {str(k): {} for k in range(self._n_eval_factors)}

        for k_eval in tqdm(range(self._n_eval_factors)):
            reg = RandomForestRegressor(
                n_estimators=20,
                max_depth=20,
                n_jobs=-1
            )

            # --------- définir y_tr / y_te pour ce "facteur évalué" ----------
            if self.only_factors and self.collapse_others_to_s:
                if k_eval == self._s_label:
                    # s = tous les autres facteurs (multi-output)
                    y_tr = self.y_tr_full[:, self._other_factors]  # (N, K_other)
                    y_te = self.y_te_full[:, self._other_factors]
                else:
                    k_src = self._target_factors[k_eval]
                    y_tr = self.y_tr_full[:, k_src]               # (N,)
                    y_te = self.y_te_full[:, k_src]
            elif self.only_factors:
                # ancien mode : uniquement les facteurs sélectionnés, pas de "s"
                k_src = self._target_factors[k_eval]
                y_tr = self.y_tr_full[:, k_src]
                y_te = self.y_te_full[:, k_src]
            else:
                # full factors
                y_tr = self.y_tr_full[:, k_eval]
                y_te = self.y_te_full[:, k_eval]

            # --------- subsample ----------
            perm_tr = self.rng.permutation(len(self.z_tr))[:self.n_samples_tr]
            perm_te = self.rng.permutation(len(self.z_te))[:self.n_samples_te]

            reg.fit(self.z_tr[perm_tr], y_tr[perm_tr])

            score_tr = reg.score(self.z_tr[perm_tr], y_tr[perm_tr])
            score_te = reg.score(self.z_te[perm_te], y_te[perm_te])

            name = self.FACTOR_NAMES[k_eval] if k_eval < len(self.FACTOR_NAMES) else str(k_eval)
            print(f"Reg[{k_eval}] ({name}) score_tr={score_tr:.4f}, score_te={score_te:.4f}")

            regressors[str(k_eval)]["model"] = reg
            regressors[str(k_eval)]["score_tr"] = float(score_tr)
            regressors[str(k_eval)]["score_te"] = float(score_te)

        self.regressors = regressors

    def compute_weights(self):
        # D dims latentes, K facteurs évalués (incluant s si activé)
        D = self.z_tr.shape[1]
        K = self._n_eval_factors

        R = np.zeros((D, K), dtype=float)
        for k in range(K):
            model = self.regressors[str(k)]["model"]
            imp = getattr(model, "feature_importances_", None)
            if imp is None:
                raise ValueError(f"Aucune feature_importances_ pour k={k}")
            if len(imp) != D:
                raise ValueError(f"Dim mismatch: len(imp)={len(imp)} vs D={D}")
            R[:, k] = imp

        col_sum = R.sum(axis=0, keepdims=True)
        col_sum[col_sum == 0] = 1.0
        self.P_d_given_k = R / col_sum            # (D, K)

        row_sum = R.sum(axis=1, keepdims=True)
        row_sum[row_sum == 0] = 1.0
        self.P_k_given_d = R / row_sum            # (D, K)

    def dci_scores(self, eps=1e-12):
        r2_per_factor = [reg["score_te"] for reg in self.regressors.values()]

        # self.P_d_given_k est (D, K) => (K, D)
        R = np.asarray(self.P_d_given_k.T, dtype=float)   # (K, D)
        K, D = R.shape

        R = np.clip(R, 0.0, None)
        if R.sum() == 0:
            raise ValueError("La matrice d'importances est nulle.")

        w_d = R.sum(axis=0)
        w_k = R.sum(axis=1)
        w_d = w_d / (w_d.sum() + eps)
        w_k = w_k / (w_k.sum() + eps)

        P_k_given_d = R / (R.sum(axis=0, keepdims=True) + eps)   # (K, D)
        P_d_given_k = R / (R.sum(axis=1, keepdims=True) + eps)   # (K, D)

        def entropy(p, axis):
            p = np.clip(p, eps, 1.0)
            return -(p * np.log(p)).sum(axis=axis)

        H_k_given_d = entropy(P_k_given_d, axis=0)  # (D,)
        H_d_given_k = entropy(P_d_given_k, axis=1)  # (K,)

        D_score = float(((1.0 - H_k_given_d / (np.log(K) + eps)) * w_d).sum())
        C_score = float(((1.0 - H_d_given_k / (np.log(D) + eps)) * w_k).sum())

        I_score = None
        if r2_per_factor is not None:
            r2 = np.asarray(r2_per_factor, dtype=float)
            I_score = float(np.clip(r2, 0.0, 1.0).mean())

        return D_score, C_score, I_score

    def compute(self):
        self.train_reg()
        self.compute_weights()
        D, C, I = self.dci_scores()
        print(f"D={D}")
        print(f"C={C}")
        print(f"I={I}")
        return D, C, I


In [7]:
for dim in metrics_x.keys():
    dataset = "celeba"
    if dim == "126" : continue
    for bt in metrics_x[dim][dataset].keys():
        for bs in metrics_x[dim][dataset][bt].keys():
            for config in metrics_x[dim][dataset][bt][bs].keys():
                path = ckpt_path_x[dim][dataset][bt][bs][config]
                if path == "" : continue
                score = DCIscore(path, only_factors=CelebA.Params.REPRESENTANT_IDX, pref_gpu=1)
                d,c,i = score.compute()
                print(d,c,i)

                metrics_x[dim][dataset][bt][bs][config]["d"] = d
                metrics_x[dim][dataset][bt][bs][config]["c"] = c
                metrics_x[dim][dataset][bt][bs][config]["i"] = i

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.23it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.25it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.58it/s]

Reg[0] (Eyeglasses) score_tr=0.9872, score_te=0.9474


 29%|██▊       | 2/7 [00:01<00:02,  1.75it/s]

Reg[1] (Male) score_tr=0.9871, score_te=0.9506


 43%|████▎     | 3/7 [00:01<00:02,  1.84it/s]

Reg[2] (Pale_Skin) score_tr=0.9019, score_te=0.6544


 57%|█████▋    | 4/7 [00:02<00:01,  1.80it/s]

Reg[3] (Smiling) score_tr=0.9543, score_te=0.8457


 71%|███████▏  | 5/7 [00:02<00:01,  1.86it/s]

Reg[4] (Wearing_Hat) score_tr=0.9611, score_te=0.8529


 86%|████████▌ | 6/7 [00:03<00:00,  1.89it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9643, score_te=0.8787


100%|██████████| 7/7 [00:04<00:00,  1.55it/s]

Reg[6] (s) score_tr=0.8327, score_te=0.5080





D=0.5544760969787063
C=0.5209895933802732
I=0.8053841595179174
0.5544760969787063 0.5209895933802732 0.8053841595179174
current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.19it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.24it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.99it/s]

Reg[0] (Eyeglasses) score_tr=0.9857, score_te=0.9456


 29%|██▊       | 2/7 [00:00<00:02,  2.07it/s]

Reg[1] (Male) score_tr=0.9869, score_te=0.9478


 43%|████▎     | 3/7 [00:01<00:01,  2.22it/s]

Reg[2] (Pale_Skin) score_tr=0.9025, score_te=0.6452


 57%|█████▋    | 4/7 [00:01<00:01,  2.15it/s]

Reg[3] (Smiling) score_tr=0.9570, score_te=0.8484


 71%|███████▏  | 5/7 [00:02<00:01,  1.91it/s]

Reg[4] (Wearing_Hat) score_tr=0.9611, score_te=0.8533


 86%|████████▌ | 6/7 [00:03<00:00,  1.85it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9652, score_te=0.8789


100%|██████████| 7/7 [00:04<00:00,  1.60it/s]

Reg[6] (s) score_tr=0.8386, score_te=0.5147
D=0.5481122301506208
C=0.5129237665706362
I=0.8048558139810675
0.5481122301506208 0.5129237665706362 0.8048558139810675
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.25it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.24it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.88it/s]

Reg[0] (Eyeglasses) score_tr=0.9863, score_te=0.9510


 29%|██▊       | 2/7 [00:01<00:02,  1.81it/s]

Reg[1] (Male) score_tr=0.9877, score_te=0.9484


 43%|████▎     | 3/7 [00:01<00:02,  1.93it/s]

Reg[2] (Pale_Skin) score_tr=0.9031, score_te=0.6445


 57%|█████▋    | 4/7 [00:02<00:01,  1.90it/s]

Reg[3] (Smiling) score_tr=0.9540, score_te=0.8467


 71%|███████▏  | 5/7 [00:02<00:01,  1.94it/s]

Reg[4] (Wearing_Hat) score_tr=0.9640, score_te=0.8618


 86%|████████▌ | 6/7 [00:03<00:00,  1.97it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9671, score_te=0.8775


100%|██████████| 7/7 [00:04<00:00,  1.70it/s]

Reg[6] (s) score_tr=0.8370, score_te=0.5092
D=0.5351338199543588
C=0.48471060467134297
I=0.8055802183405696
0.5351338199543588 0.48471060467134297 0.8055802183405696
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.26it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.22it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.75it/s]

Reg[0] (Eyeglasses) score_tr=0.9867, score_te=0.9455


 29%|██▊       | 2/7 [00:01<00:02,  2.00it/s]

Reg[1] (Male) score_tr=0.9850, score_te=0.9479


 43%|████▎     | 3/7 [00:01<00:01,  2.03it/s]

Reg[2] (Pale_Skin) score_tr=0.9015, score_te=0.6394


 57%|█████▋    | 4/7 [00:02<00:01,  1.93it/s]

Reg[3] (Smiling) score_tr=0.9576, score_te=0.8550


 71%|███████▏  | 5/7 [00:02<00:01,  1.93it/s]

Reg[4] (Wearing_Hat) score_tr=0.9633, score_te=0.8585


 86%|████████▌ | 6/7 [00:03<00:00,  1.93it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9632, score_te=0.8824


100%|██████████| 7/7 [00:04<00:00,  1.61it/s]

Reg[6] (s) score_tr=0.8434, score_te=0.5249
D=0.5541727372608725
C=0.5268571266025783
I=0.8076743960216507
0.5541727372608725 0.5268571266025783 0.8076743960216507
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.22it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.20it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:02,  2.22it/s]

Reg[0] (Eyeglasses) score_tr=0.9863, score_te=0.9498


 29%|██▊       | 2/7 [00:00<00:02,  2.28it/s]

Reg[1] (Male) score_tr=0.9870, score_te=0.9523


 43%|████▎     | 3/7 [00:01<00:01,  2.11it/s]

Reg[2] (Pale_Skin) score_tr=0.9031, score_te=0.6468


 57%|█████▋    | 4/7 [00:01<00:01,  1.92it/s]

Reg[3] (Smiling) score_tr=0.9575, score_te=0.8558


 71%|███████▏  | 5/7 [00:02<00:01,  1.94it/s]

Reg[4] (Wearing_Hat) score_tr=0.9608, score_te=0.8470


 86%|████████▌ | 6/7 [00:03<00:00,  1.94it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9645, score_te=0.8883


100%|██████████| 7/7 [00:04<00:00,  1.63it/s]

Reg[6] (s) score_tr=0.8385, score_te=0.5206
D=0.548381624350247
C=0.5084466887442868
I=0.8086724311963515
0.548381624350247 0.5084466887442868 0.8086724311963515
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.23it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.20it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.78it/s]

Reg[0] (Eyeglasses) score_tr=0.9883, score_te=0.9511


 29%|██▊       | 2/7 [00:01<00:02,  1.85it/s]

Reg[1] (Male) score_tr=0.9872, score_te=0.9530


 43%|████▎     | 3/7 [00:01<00:02,  1.95it/s]

Reg[2] (Pale_Skin) score_tr=0.9085, score_te=0.6494


 57%|█████▋    | 4/7 [00:02<00:01,  1.86it/s]

Reg[3] (Smiling) score_tr=0.9566, score_te=0.8504


 71%|███████▏  | 5/7 [00:02<00:01,  1.86it/s]

Reg[4] (Wearing_Hat) score_tr=0.9632, score_te=0.8542


 86%|████████▌ | 6/7 [00:03<00:00,  1.86it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9638, score_te=0.8821


100%|██████████| 7/7 [00:04<00:00,  1.55it/s]

Reg[6] (s) score_tr=0.8319, score_te=0.5136
D=0.5513242685461772
C=0.5106709298176936
I=0.8076760133427398
0.5513242685461772 0.5106709298176936 0.8076760133427398
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.25it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.24it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.66it/s]

Reg[0] (Eyeglasses) score_tr=0.9876, score_te=0.9474


 29%|██▊       | 2/7 [00:01<00:02,  1.74it/s]

Reg[1] (Male) score_tr=0.9876, score_te=0.9532


 43%|████▎     | 3/7 [00:01<00:02,  1.67it/s]

Reg[2] (Pale_Skin) score_tr=0.9044, score_te=0.6557


 57%|█████▋    | 4/7 [00:02<00:01,  1.64it/s]

Reg[3] (Smiling) score_tr=0.9511, score_te=0.8454


 71%|███████▏  | 5/7 [00:02<00:01,  1.72it/s]

Reg[4] (Wearing_Hat) score_tr=0.9607, score_te=0.8493


 86%|████████▌ | 6/7 [00:03<00:00,  1.60it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9639, score_te=0.8844


100%|██████████| 7/7 [00:04<00:00,  1.43it/s]

Reg[6] (s) score_tr=0.8417, score_te=0.5256





D=0.5939937948928128
C=0.561841190186535
I=0.8087054445147294
0.5939937948928128 0.561841190186535 0.8087054445147294
current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.22it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.21it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.73it/s]

Reg[0] (Eyeglasses) score_tr=0.9864, score_te=0.9492


 29%|██▊       | 2/7 [00:01<00:03,  1.62it/s]

Reg[1] (Male) score_tr=0.9875, score_te=0.9511


 43%|████▎     | 3/7 [00:01<00:02,  1.51it/s]

Reg[2] (Pale_Skin) score_tr=0.9054, score_te=0.6625


 57%|█████▋    | 4/7 [00:02<00:02,  1.48it/s]

Reg[3] (Smiling) score_tr=0.9531, score_te=0.8419


 71%|███████▏  | 5/7 [00:03<00:01,  1.44it/s]

Reg[4] (Wearing_Hat) score_tr=0.9620, score_te=0.8514


 86%|████████▌ | 6/7 [00:04<00:00,  1.44it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9632, score_te=0.8803


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]

Reg[6] (s) score_tr=0.8425, score_te=0.5230
D=0.5802375711581382
C=0.5349794495724988
I=0.8084808117597281
0.5802375711581382 0.5349794495724988 0.8084808117597281
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.23it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.26it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.65it/s]

Reg[0] (Eyeglasses) score_tr=0.9885, score_te=0.9482


 29%|██▊       | 2/7 [00:01<00:03,  1.61it/s]

Reg[1] (Male) score_tr=0.9881, score_te=0.9543


 43%|████▎     | 3/7 [00:01<00:02,  1.62it/s]

Reg[2] (Pale_Skin) score_tr=0.8950, score_te=0.6372


 57%|█████▋    | 4/7 [00:02<00:01,  1.50it/s]

Reg[3] (Smiling) score_tr=0.9578, score_te=0.8505


 71%|███████▏  | 5/7 [00:03<00:01,  1.48it/s]

Reg[4] (Wearing_Hat) score_tr=0.9630, score_te=0.8624


 86%|████████▌ | 6/7 [00:03<00:00,  1.52it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9622, score_te=0.8782


100%|██████████| 7/7 [00:05<00:00,  1.29it/s]

Reg[6] (s) score_tr=0.8409, score_te=0.5173
D=0.579249437807563
C=0.536967157696029
I=0.8068693547877611
0.579249437807563 0.536967157696029 0.8068693547877611
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.21it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.25it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:03,  1.73it/s]

Reg[0] (Eyeglasses) score_tr=0.9887, score_te=0.9560


 29%|██▊       | 2/7 [00:01<00:02,  1.77it/s]

Reg[1] (Male) score_tr=0.9883, score_te=0.9569


 43%|████▎     | 3/7 [00:01<00:02,  1.69it/s]

Reg[2] (Pale_Skin) score_tr=0.9017, score_te=0.6472


 57%|█████▋    | 4/7 [00:02<00:01,  1.62it/s]

Reg[3] (Smiling) score_tr=0.9575, score_te=0.8546


 71%|███████▏  | 5/7 [00:03<00:01,  1.64it/s]

Reg[4] (Wearing_Hat) score_tr=0.9641, score_te=0.8744


 86%|████████▌ | 6/7 [00:03<00:00,  1.59it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9650, score_te=0.8888


100%|██████████| 7/7 [00:05<00:00,  1.34it/s]

Reg[6] (s) score_tr=0.8429, score_te=0.5362
D=0.6013172182636825
C=0.5683496080742747
I=0.8163048997325866
0.6013172182636825 0.5683496080742747 0.8163048997325866
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.25it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.20it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:04,  1.28it/s]

Reg[0] (Eyeglasses) score_tr=0.9869, score_te=0.9420


 29%|██▊       | 2/7 [00:01<00:03,  1.50it/s]

Reg[1] (Male) score_tr=0.9870, score_te=0.9503


 43%|████▎     | 3/7 [00:02<00:02,  1.47it/s]

Reg[2] (Pale_Skin) score_tr=0.8966, score_te=0.6668


 57%|█████▋    | 4/7 [00:02<00:02,  1.38it/s]

Reg[3] (Smiling) score_tr=0.9575, score_te=0.8574


 71%|███████▏  | 5/7 [00:03<00:01,  1.39it/s]

Reg[4] (Wearing_Hat) score_tr=0.9667, score_te=0.8639


 86%|████████▌ | 6/7 [00:04<00:00,  1.45it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9655, score_te=0.8853


100%|██████████| 7/7 [00:05<00:00,  1.18it/s]

Reg[6] (s) score_tr=0.8440, score_te=0.5386
D=0.5927225967769715
C=0.5558137564914295
I=0.8149006955130147
0.5927225967769715 0.5558137564914295 0.8149006955130147
current device is 0





load dataset val
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.20it/s]


tensors loaded.
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 10/10 [00:02<00:00,  4.17it/s]


tensors loaded.


 14%|█▍        | 1/7 [00:00<00:05,  1.16it/s]

Reg[0] (Eyeglasses) score_tr=0.9871, score_te=0.9504


 29%|██▊       | 2/7 [00:01<00:03,  1.45it/s]

Reg[1] (Male) score_tr=0.9880, score_te=0.9538


 43%|████▎     | 3/7 [00:02<00:02,  1.58it/s]

Reg[2] (Pale_Skin) score_tr=0.9041, score_te=0.6241


 57%|█████▋    | 4/7 [00:02<00:01,  1.54it/s]

Reg[3] (Smiling) score_tr=0.9567, score_te=0.8576


 71%|███████▏  | 5/7 [00:03<00:01,  1.57it/s]

Reg[4] (Wearing_Hat) score_tr=0.9596, score_te=0.8565


 86%|████████▌ | 6/7 [00:04<00:00,  1.48it/s]

Reg[5] (Wearing_Lipstick) score_tr=0.9669, score_te=0.8862


100%|██████████| 7/7 [00:05<00:00,  1.32it/s]

Reg[6] (s) score_tr=0.8448, score_te=0.5279
D=0.5740025368458662
C=0.52462348825569
I=0.8080545469102791
0.5740025368458662 0.52462348825569 0.8080545469102791





In [8]:
with open("metrics_complete.json", "w") as f:
    json.dump(metrics_x, f, ensure_ascii=False, indent=2)

##### MPI3D

In [10]:
mpi3d_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["mpi3d"]["b1"]["-1"], data_name="mpi3d")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    mpi3d_dci['d'].append(d)
    mpi3d_dci['c'].append(c)
    mpi3d_dci['i'].append(i)
print(f"D:{np.mean(mpi3d_dci['d'])}{ np.std(mpi3d_dci['d'])}")
print(f"C:{np.mean(mpi3d_dci['c'])}{ np.std(mpi3d_dci['c'])}")
print(f"I:{np.mean(mpi3d_dci['i'])}{ np.std(mpi3d_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 49/49 [00:03<00:00, 12.34it/s]


tensors loaded.
torch.Size([100000, 126]) torch.Size([100000, 14]) torch.Size([100000, 7])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 165/165 [00:18<00:00,  8.90it/s]


tensors loaded.
torch.Size([336800, 126]) torch.Size([336800, 14]) torch.Size([336800, 7])


 14%|█▍        | 1/7 [00:10<01:03, 10.58s/it]

Reg_0 score=0.9998882434500572, 0.9994361438404149


 29%|██▊       | 2/7 [00:32<01:26, 17.34s/it]

Reg_1 score=0.993489630798081, 0.9683818278148009


 43%|████▎     | 3/7 [00:40<00:52, 13.11s/it]

Reg_2 score=0.9995624997058248, 0.9967846711914224


 57%|█████▋    | 4/7 [00:43<00:27,  9.02s/it]

Reg_3 score=0.9999971959572628, 0.9999785120668458


 71%|███████▏  | 5/7 [00:49<00:16,  8.06s/it]

Reg_4 score=0.9999937401516007, 0.9999783954800145


 86%|████████▌ | 6/7 [01:09<00:11, 11.88s/it]

Reg_5 score=0.9993705208621313, 0.9962043883377152


100%|██████████| 7/7 [01:25<00:00, 12.20s/it]


Reg_6 score=0.9999122364347746, 0.999532363527368
D:0.976104425577659
C=0.8820723717579579
I=0.9943280431797975


 14%|█▍        | 1/7 [00:07<00:47,  7.89s/it]

Reg_0 score=0.9999057917877086, 0.9994929428970557


 29%|██▊       | 2/7 [00:34<01:34, 18.83s/it]

Reg_1 score=0.9935018779615042, 0.9685807696012292


 43%|████▎     | 3/7 [00:43<00:57, 14.38s/it]

Reg_2 score=0.9995268996818873, 0.9968047425011646


 57%|█████▋    | 4/7 [00:46<00:29,  9.86s/it]

Reg_3 score=0.9999950648847825, 0.9999797033667669


 71%|███████▏  | 5/7 [00:50<00:15,  7.71s/it]

Reg_4 score=0.9999981632780146, 0.9999836268681614


 86%|████████▌ | 6/7 [01:25<00:17, 17.13s/it]

Reg_5 score=0.9993511291599511, 0.9961688540770175


100%|██████████| 7/7 [01:41<00:00, 14.45s/it]


Reg_6 score=0.9999136641735792, 0.9994715536934712
D:0.9762011913356097
C=0.8823017208700228
I=0.9943545990006951


 14%|█▍        | 1/7 [00:08<00:48,  8.08s/it]

Reg_0 score=0.9998959043189006, 0.9994777502893594


 29%|██▊       | 2/7 [00:35<01:36, 19.29s/it]

Reg_1 score=0.9938580032682526, 0.9687470565008048


 43%|████▎     | 3/7 [00:43<00:57, 14.31s/it]

Reg_2 score=0.9995182996761047, 0.9968993092489882


 57%|█████▋    | 4/7 [00:46<00:29,  9.74s/it]

Reg_3 score=0.9999986914467226, 0.9999770646931101


 71%|███████▏  | 5/7 [00:51<00:16,  8.06s/it]

Reg_4 score=0.9999950146117539, 0.9999810668271534


 86%|████████▌ | 6/7 [01:15<00:13, 13.42s/it]

Reg_5 score=0.9993353756505989, 0.9960279873928509


100%|██████████| 7/7 [01:35<00:00, 13.58s/it]


Reg_6 score=0.9999263712752369, 0.9995250679239854
D:0.9762445996533037
C=0.8821896690350576
I=0.9943764718394646


 14%|█▍        | 1/7 [00:10<01:03, 10.59s/it]

Reg_0 score=0.9998882178569543, 0.9994585430019632


 29%|██▊       | 2/7 [00:31<01:24, 16.95s/it]

Reg_1 score=0.9939404159080791, 0.9687975238690361


 43%|████▎     | 3/7 [00:41<00:55, 13.75s/it]

Reg_2 score=0.9995423996923095, 0.9967706865954334


 57%|█████▋    | 4/7 [00:45<00:28,  9.54s/it]

Reg_3 score=0.9999986914467226, 0.9999769756239572


 71%|███████▏  | 5/7 [00:49<00:15,  7.59s/it]

Reg_4 score=0.9999940400245779, 0.99997888522699


 86%|████████▌ | 6/7 [01:21<00:15, 15.90s/it]

Reg_5 score=0.9993825333806292, 0.9962138530982656


100%|██████████| 7/7 [01:35<00:00, 13.71s/it]


Reg_6 score=0.9999165014459992, 0.9993969425406753
D:0.9761020131073269
C=0.8821666131148571
I=0.9943704871366174


 14%|█▍        | 1/7 [00:08<00:48,  8.01s/it]

Reg_0 score=0.9998970218844001, 0.9994738043890706


 29%|██▊       | 2/7 [00:34<01:34, 18.83s/it]

Reg_1 score=0.9935906960444322, 0.9687844043026792


 43%|████▎     | 3/7 [00:41<00:54, 13.60s/it]

Reg_2 score=0.9995324996856528, 0.9968342556400458


 57%|█████▋    | 4/7 [00:45<00:29,  9.84s/it]

Reg_3 score=0.9999892324758891, 0.9999813400124528


 71%|███████▏  | 5/7 [00:51<00:16,  8.16s/it]

Reg_4 score=0.9999956893259526, 0.9999801095944286


 86%|████████▌ | 6/7 [01:17<00:14, 14.44s/it]

Reg_5 score=0.9993785959630828, 0.9962166508204446


100%|██████████| 7/7 [01:34<00:00, 13.55s/it]

Reg_6 score=0.9999079247375299, 0.9994738571378006
D:0.975960614569697
C=0.8822240846597718
I=0.9943920602709889
D:0.97612256884871929.80346040895184e-05
C:0.88219089188753347.486360767947083e-05
I:0.99436433228551272.175191103232541e-05





In [11]:
mpi3d_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["mpi3d"]["b1"]["s-1"], data_name="mpi3d")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    mpi3d_dci['d'].append(d)
    mpi3d_dci['c'].append(c)
    mpi3d_dci['i'].append(i)
print(f"D:{np.mean(mpi3d_dci['d']), np.std(mpi3d_dci['d'])}")
print(f"C:{np.mean(mpi3d_dci['c']), np.std(mpi3d_dci['c'])}")
print(f"I:{np.mean(mpi3d_dci['i']), np.std(mpi3d_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 49/49 [00:04<00:00, 12.07it/s]


tensors loaded.
torch.Size([100000, 126]) torch.Size([100000, 12]) torch.Size([100000, 7])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 165/165 [00:18<00:00,  9.06it/s]


tensors loaded.
torch.Size([336800, 126]) torch.Size([336800, 12]) torch.Size([336800, 7])


 14%|█▍        | 1/7 [00:06<00:39,  6.53s/it]

Reg_0 score=0.999965901455716, 0.9997952742617818


 29%|██▊       | 2/7 [00:28<01:17, 15.55s/it]

Reg_1 score=0.9967404165752917, 0.9825682682615697


 43%|████▎     | 3/7 [00:47<01:08, 17.02s/it]

Reg_2 score=0.9995069996685065, 0.9970540008521785


 57%|█████▋    | 4/7 [00:55<00:40, 13.40s/it]

Reg_3 score=0.999941227264228, 0.9997503725651564


 71%|███████▏  | 5/7 [00:57<00:19,  9.60s/it]

Reg_4 score=1.0, 0.9999944346934607


 86%|████████▌ | 6/7 [01:20<00:14, 14.04s/it]

Reg_5 score=0.999034907081368, 0.9946460149808726


100%|██████████| 7/7 [01:51<00:00, 15.98s/it]


Reg_6 score=0.9880816161321825, 0.9515586160401538
D:0.8737681399313281
C=0.8210137404096478
I=0.9893381402364534


 14%|█▍        | 1/7 [00:06<00:38,  6.41s/it]

Reg_0 score=0.999957711662743, 0.999804019436534


 29%|██▊       | 2/7 [00:26<01:12, 14.51s/it]

Reg_1 score=0.9968002241452052, 0.9826883462747269


 43%|████▎     | 3/7 [00:39<00:54, 13.56s/it]

Reg_2 score=0.9995466996952008, 0.9971208062470305


 57%|█████▋    | 4/7 [00:46<00:33, 11.07s/it]

Reg_3 score=0.9999575654865768, 0.9997608604579187


 71%|███████▏  | 5/7 [00:50<00:16,  8.50s/it]

Reg_4 score=1.0, 0.9999931212811175


 86%|████████▌ | 6/7 [01:12<00:13, 13.20s/it]

Reg_5 score=0.9990564379988679, 0.9948908524852643


100%|██████████| 7/7 [01:33<00:00, 13.30s/it]


Reg_6 score=0.9880689960992374, 0.9515830179765428
D:0.8738569299465849
C=0.821062043245889
I=0.9894058605941621


 14%|█▍        | 1/7 [00:06<00:38,  6.47s/it]

Reg_0 score=0.999951117173193, 0.9997938757470894


 29%|██▊       | 2/7 [00:24<01:05, 13.02s/it]

Reg_1 score=0.9969612345369603, 0.9823098314764391


 43%|████▎     | 3/7 [00:35<00:49, 12.48s/it]

Reg_2 score=0.9995390996900906, 0.9971014178368801


 57%|█████▋    | 4/7 [00:43<00:31, 10.56s/it]

Reg_3 score=0.9999378250257068, 0.9997389717135762


 71%|███████▏  | 5/7 [00:46<00:15,  7.73s/it]

Reg_4 score=1.0, 0.9999931212811175


 86%|████████▌ | 6/7 [01:07<00:12, 12.46s/it]

Reg_5 score=0.999138966803182, 0.9947911248528849


100%|██████████| 7/7 [01:42<00:00, 14.59s/it]

Reg_6 score=0.9883294757522351, 0.9531908139686025
D:0.8727893871106622
C=0.8216525343360427
I=0.9895598795537985
D:(np.float64(0.8734714856628584), np.float64(0.0004836767121976006))
C:(np.float64(0.8212427726638598), np.float64(0.00029041552069381877))
I:(np.float64(0.9894346267948047), np.float64(9.278182989762906e-05))





In [12]:
mpi3d_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["mpi3d"]["b100"]["s-1"], data_name="mpi3d")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    mpi3d_dci['d'].append(d)
    mpi3d_dci['c'].append(c)
    mpi3d_dci['i'].append(i)
print(f"D:{np.mean(mpi3d_dci['d']), np.std(mpi3d_dci['d'])}")
print(f"C:{np.mean(mpi3d_dci['c']), np.std(mpi3d_dci['c'])}")
print(f"I:{np.mean(mpi3d_dci['i']), np.std(mpi3d_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 49/49 [00:04<00:00, 12.21it/s]


tensors loaded.
torch.Size([100000, 126]) torch.Size([100000, 12]) torch.Size([100000, 7])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 165/165 [00:17<00:00,  9.22it/s]


tensors loaded.
torch.Size([336800, 126]) torch.Size([336800, 12]) torch.Size([336800, 7])


 14%|█▍        | 1/7 [00:08<00:51,  8.52s/it]

Reg_0 score=0.9999973212552151, 0.9999432187750584


 29%|██▊       | 2/7 [00:22<01:00, 12.00s/it]

Reg_1 score=0.998920991706393, 0.9932297084563507


 43%|████▎     | 3/7 [00:32<00:43, 10.79s/it]

Reg_2 score=0.9998134998745973, 0.9989070637402233


 57%|█████▋    | 4/7 [00:36<00:24,  8.03s/it]

Reg_3 score=0.9999908401270584, 0.9999963926993047


 71%|███████▏  | 5/7 [00:42<00:15,  7.56s/it]

Reg_4 score=0.9999632655602917, 0.9998015634300355


 86%|████████▌ | 6/7 [01:05<00:12, 12.77s/it]

Reg_5 score=0.9996658940226631, 0.9983424060499847


100%|██████████| 7/7 [01:43<00:00, 14.81s/it]

Reg_6 score=0.9974040352205523, 0.9890641836571499





D:0.9645339510314215
C=0.7913538196360524
I=0.9970406481154439


 14%|█▍        | 1/7 [00:06<00:38,  6.47s/it]

Reg_0 score=0.9999860773519461, 0.9999370795757713


 29%|██▊       | 2/7 [00:21<00:58, 11.72s/it]

Reg_1 score=0.9988498715397339, 0.9929192379718151


 43%|████▎     | 3/7 [00:29<00:38,  9.73s/it]

Reg_2 score=0.999828799884885, 0.9988247298024792


 57%|█████▋    | 4/7 [00:32<00:21,  7.19s/it]

Reg_3 score=0.9999932702974307, 0.9999919726425885


 71%|███████▏  | 5/7 [00:39<00:14,  7.03s/it]

Reg_4 score=0.999986618168392, 0.9998667220389974


 86%|████████▌ | 6/7 [01:04<00:13, 13.07s/it]

Reg_5 score=0.9996210723325677, 0.9979805661948964


100%|██████████| 7/7 [01:29<00:00, 12.75s/it]


Reg_6 score=0.9972739481342414, 0.988915331091268
D:0.9664644051195797
C=0.7893200703312544
I=0.9969193770454022


 14%|█▍        | 1/7 [00:05<00:32,  5.40s/it]

Reg_0 score=0.9999866148071099, 0.9999395225185581


 29%|██▊       | 2/7 [00:22<01:01, 12.36s/it]

Reg_1 score=0.998902255927843, 0.9930859289503864


 43%|████▎     | 3/7 [00:30<00:40, 10.24s/it]

Reg_2 score=0.999823799881523, 0.9987921287697914


 57%|█████▋    | 4/7 [00:33<00:22,  7.49s/it]

Reg_3 score=0.9999892324758891, 0.9999944220442952


 71%|███████▏  | 5/7 [00:40<00:14,  7.12s/it]

Reg_4 score=0.9999830196926655, 0.9998842972770484


 86%|████████▌ | 6/7 [01:00<00:11, 11.54s/it]

Reg_5 score=0.9996599584054616, 0.9982154511855272


100%|██████████| 7/7 [01:30<00:00, 12.91s/it]

Reg_6 score=0.9971177911412121, 0.9882837999253277
D:0.9652544404278365
C=0.7908703913596283
I=0.9968850786672763
D:(np.float64(0.9654175988596125), np.float64(0.0007965043427445401))
C:(np.float64(0.7905147604423117), np.float64(0.000867521016346857))
I:(np.float64(0.9969483679427075), np.float64(6.673738280612944e-05))





##### Shapes

In [13]:
shapes_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["shapes"]["b1"]["-1"], data_name="shapes")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    shapes_dci['d'].append(d)
    shapes_dci['c'].append(c)
    shapes_dci['i'].append(i)
print(f"D:{np.mean(shapes_dci['d']), np.std(shapes_dci['d'])}")
print(f"C:{np.mean(shapes_dci['c']), np.std(shapes_dci['c'])}")
print(f"I:{np.mean(shapes_dci['i']), np.std(shapes_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 38/38 [00:03<00:00, 12.40it/s]


tensors loaded.
torch.Size([76800, 126]) torch.Size([76800, 12]) torch.Size([76800, 6])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 47/47 [00:04<00:00, 10.39it/s]


tensors loaded.
torch.Size([96000, 126]) torch.Size([96000, 12]) torch.Size([96000, 6])


 17%|█▋        | 1/6 [00:03<00:19,  3.99s/it]

Reg_0 score=0.9999999803369257, 0.9999999463642706


 33%|███▎      | 2/6 [00:08<00:16,  4.13s/it]

Reg_1 score=0.9999988590531649, 0.999996970585617


 50%|█████     | 3/6 [00:12<00:12,  4.02s/it]

Reg_2 score=0.9999946276282836, 0.9999939702622742


 67%|██████▋   | 4/6 [00:16<00:08,  4.01s/it]

Reg_3 score=0.9999945388308288, 0.9999839550056734


 83%|████████▎ | 5/6 [00:18<00:03,  3.46s/it]

Reg_4 score=0.9999999477238921, 0.9999999583964321


100%|██████████| 6/6 [00:23<00:00,  3.93s/it]

Reg_5 score=0.9999985273860362, 0.999999291690325





D:0.9999816018826406
C=0.887001454483623
I=0.9999956820507654


 17%|█▋        | 1/6 [00:03<00:17,  3.49s/it]

Reg_0 score=1.0, 1.0


 33%|███▎      | 2/6 [00:07<00:15,  3.86s/it]

Reg_1 score=0.9999975834112174, 0.9999937108095153


 50%|█████     | 3/6 [00:12<00:12,  4.17s/it]

Reg_2 score=0.999998833292716, 0.9999993465153355


 67%|██████▋   | 4/6 [00:16<00:08,  4.38s/it]

Reg_3 score=0.9999763349335912, 0.9999574564544369


 83%|████████▎ | 5/6 [00:19<00:03,  3.58s/it]

Reg_4 score=0.9999998170336225, 0.9999996671714565


100%|██████████| 6/6 [00:24<00:00,  4.11s/it]

Reg_5 score=0.9999989282158726, 0.9999984827922938





D:0.9999909179975396
C=0.8863827002175012
I=0.999991443957173


 17%|█▋        | 1/6 [00:04<00:22,  4.58s/it]

Reg_0 score=0.9999999842695405, 0.9999999179688845


 33%|███▎      | 2/6 [00:08<00:17,  4.28s/it]

Reg_1 score=0.9999956065623606, 0.9999932059071182


 50%|█████     | 3/6 [00:13<00:13,  4.42s/it]

Reg_2 score=0.9999995506600325, 0.9999987593262166


 67%|██████▋   | 4/6 [00:16<00:08,  4.06s/it]

Reg_3 score=0.9999669295866851, 0.9999516219110454


 83%|████████▎ | 5/6 [00:19<00:03,  3.46s/it]

Reg_4 score=0.9999993726867057, 0.9999992511357771


100%|██████████| 6/6 [00:23<00:00,  3.88s/it]

Reg_5 score=1.0, 1.0
D:0.9999957348330677
C=0.8874949255928567
I=0.9999904593748403
D:(np.float64(0.9999894182377492), np.float64(5.866403299860495e-06))
C:(np.float64(0.886959693431327), np.float64(0.00045502330380821864))
I:(np.float64(0.9999925284609262), np.float64(2.265862197584973e-06))





In [14]:
shapes_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["shapes"]["b1"]["s-1"], data_name="shapes")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    shapes_dci['d'].append(d)
    shapes_dci['c'].append(c)
    shapes_dci['i'].append(i)
print(f"D:{np.mean(shapes_dci['d']), np.std(shapes_dci['d'])}")
print(f"C:{np.mean(shapes_dci['c']), np.std(shapes_dci['c'])}")
print(f"I:{np.mean(shapes_dci['i']), np.std(shapes_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 38/38 [00:03<00:00, 11.46it/s]


tensors loaded.
torch.Size([76800, 126]) torch.Size([76800, 10]) torch.Size([76800, 6])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 47/47 [00:04<00:00, 11.68it/s]


tensors loaded.
torch.Size([96000, 126]) torch.Size([96000, 10]) torch.Size([96000, 6])


 17%|█▋        | 1/6 [00:04<00:20,  4.09s/it]

Reg_0 score=0.9999999882021554, 0.999998624401294


 33%|███▎      | 2/6 [00:08<00:17,  4.37s/it]

Reg_1 score=0.9999999762302743, 0.9999997254593216


 50%|█████     | 3/6 [00:12<00:12,  4.05s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:15<00:07,  3.64s/it]

Reg_3 score=0.9999993003263987, 0.9999889014680554


 83%|████████▎ | 5/6 [00:17<00:03,  3.20s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:30<00:00,  5.04s/it]


Reg_5 score=0.9995328144586737, 0.9976409419248279
D:0.9531749514938983
C=0.8502941216846032
I=0.999604698875583


 17%|█▋        | 1/6 [00:04<00:20,  4.19s/it]

Reg_0 score=0.9999999646064662, 0.999998854719426


 33%|███▎      | 2/6 [00:07<00:15,  3.90s/it]

Reg_1 score=1.0, 1.0


 50%|█████     | 3/6 [00:11<00:11,  3.89s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:14<00:07,  3.53s/it]

Reg_3 score=0.9999995479984699, 0.9999892834917299


 83%|████████▎ | 5/6 [00:17<00:03,  3.11s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:32<00:00,  5.39s/it]


Reg_5 score=0.9994592409668648, 0.9975472561036046
D:0.9534030222887127
C=0.8499270653039935
I=0.9995892323857936


 17%|█▋        | 1/6 [00:03<00:18,  3.65s/it]

Reg_0 score=0.9999999528086216, 0.9999980312532281


 33%|███▎      | 2/6 [00:07<00:14,  3.62s/it]

Reg_1 score=0.9999999841535162, 0.9999997380818815


 50%|█████     | 3/6 [00:11<00:11,  3.69s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:14<00:07,  3.65s/it]

Reg_3 score=0.9999996656427038, 0.9999888716999769


 83%|████████▎ | 5/6 [00:16<00:03,  3.11s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:32<00:00,  5.39s/it]

Reg_5 score=0.9995322945498666, 0.9975812618693948
D:0.9531907139998939
C=0.8505948362383108
I=0.9995946504840801
D:(np.float64(0.9532562292608349), np.float64(0.00010399762472911196))
C:(np.float64(0.8502720077423024), np.float64(0.00027306443075190176))
I:(np.float64(0.9995961939151522), np.float64(6.407792714475672e-06))





In [15]:
shapes_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["shapes"]["b100"]["s-1"], data_name="shapes")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    shapes_dci['d'].append(d)
    shapes_dci['c'].append(c)
    shapes_dci['i'].append(i)
print(f"D:{np.mean(shapes_dci['d']), np.std(shapes_dci['d'])}")
print(f"C:{np.mean(shapes_dci['c']), np.std(shapes_dci['c'])}")
print(f"I:{np.mean(shapes_dci['i']), np.std(shapes_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 38/38 [00:03<00:00, 12.57it/s]


tensors loaded.
torch.Size([76800, 126]) torch.Size([76800, 10]) torch.Size([76800, 6])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 47/47 [00:03<00:00, 12.22it/s]


tensors loaded.
torch.Size([96000, 126]) torch.Size([96000, 10]) torch.Size([96000, 6])


 17%|█▋        | 1/6 [00:03<00:17,  3.58s/it]

Reg_0 score=1.0, 1.0


 33%|███▎      | 2/6 [00:07<00:15,  3.89s/it]

Reg_1 score=0.9999999207675809, 0.9999999368872003


 50%|█████     | 3/6 [00:13<00:13,  4.66s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:16<00:08,  4.15s/it]

Reg_3 score=1.0, 0.9999844412176226


 83%|████████▎ | 5/6 [00:19<00:03,  3.67s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:25<00:00,  4.21s/it]


Reg_5 score=0.9999999320332017, 0.9999995138229449
D:0.9999999999165881
C=0.8835486058040425
I=0.9999973153212948


 17%|█▋        | 1/6 [00:03<00:19,  3.90s/it]

Reg_0 score=1.0, 1.0


 33%|███▎      | 2/6 [00:08<00:17,  4.26s/it]

Reg_1 score=0.9999994929125178, 0.9999997223036815


 50%|█████     | 3/6 [00:12<00:12,  4.04s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:18<00:09,  4.88s/it]

Reg_3 score=1.0, 0.9999844412176226


 83%|████████▎ | 5/6 [00:21<00:04,  4.19s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:25<00:00,  4.27s/it]

Reg_5 score=0.9999998762655723, 0.9999994202198912





D:0.9999999999158442
C=0.8826681781894924
I=0.9999972639568661


 17%|█▋        | 1/6 [00:04<00:20,  4.08s/it]

Reg_0 score=1.0, 1.0


 33%|███▎      | 2/6 [00:08<00:16,  4.09s/it]

Reg_1 score=0.9999998692665085, 0.9999998737744007


 50%|█████     | 3/6 [00:12<00:12,  4.19s/it]

Reg_2 score=1.0, 1.0


 67%|██████▋   | 4/6 [00:16<00:08,  4.06s/it]

Reg_3 score=0.9999993932034255, 0.9999749600846114


 83%|████████▎ | 5/6 [00:20<00:04,  4.08s/it]

Reg_4 score=1.0, 1.0


100%|██████████| 6/6 [00:25<00:00,  4.19s/it]

Reg_5 score=0.9999994649793055, 0.9999983263215174
D:0.9999999999162164
C=0.8853528825100186
I=0.9999955266967548
D:(np.float64(0.9999999999162162), np.float64(3.03675277025308e-13))
C:(np.float64(0.8838565555011845), np.float64(0.001117447712166877))
I:(np.float64(0.9999967019916385), np.float64(8.313234925478435e-07))





##### DSprites

In [16]:
dsprites_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["dsprites"]["b1"]["-1"], data_name="dsprites")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    dsprites_dci['d'].append(d)
    dsprites_dci['c'].append(c)
    dsprites_dci['i'].append(i)
print(f"D:{np.mean(dsprites_dci['d']), np.std(dsprites_dci['d'])}")
print(f"C:{np.mean(dsprites_dci['c']), np.std(dsprites_dci['c'])}")
print(f"I:{np.mean(dsprites_dci['i']), np.std(dsprites_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 58/58 [00:03<00:00, 16.21it/s]


tensors loaded.
torch.Size([117965, 126]) torch.Size([117965, 10]) torch.Size([117965, 5])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 72/72 [00:04<00:00, 15.66it/s]


tensors loaded.
torch.Size([147456, 126]) torch.Size([147456, 10]) torch.Size([147456, 5])


 20%|██        | 1/5 [00:12<00:49, 12.48s/it]

Reg_0 score=0.9999594274648704, 0.9997870040575041


 40%|████      | 2/5 [00:20<00:29,  9.73s/it]

Reg_1 score=0.999995054833391, 0.9999639362266423


 60%|██████    | 3/5 [00:41<00:30, 15.12s/it]

Reg_2 score=0.903979360316073, 0.6721293558419356


 80%|████████  | 4/5 [01:03<00:17, 17.53s/it]

Reg_3 score=0.9989873944109475, 0.9968687370990922


100%|██████████| 5/5 [01:24<00:00, 16.81s/it]

Reg_4 score=0.9997229603463547, 0.9986048180914914





D:0.9102506747076836
C=0.7400377686620438
I=0.9334707702633331


 20%|██        | 1/5 [00:15<01:03, 15.77s/it]

Reg_0 score=0.9999695467251204, 0.9998258766514473


 40%|████      | 2/5 [00:22<00:31, 10.59s/it]

Reg_1 score=0.9999944675042189, 0.9999609861890346


 60%|██████    | 3/5 [00:53<00:39, 19.93s/it]

Reg_2 score=0.9071071831624726, 0.6722839314313354


 80%|████████  | 4/5 [01:17<00:21, 21.33s/it]

Reg_3 score=0.999256978342518, 0.9971746823147404


100%|██████████| 5/5 [01:34<00:00, 18.82s/it]


Reg_4 score=0.9997443925540175, 0.9988138179314753
D:0.9090203563641189
C=0.7415836474761446
I=0.9336118589036067


 20%|██        | 1/5 [00:18<01:14, 18.68s/it]

Reg_0 score=0.9999686557277626, 0.9998032667922232


 40%|████      | 2/5 [00:28<00:39, 13.31s/it]

Reg_1 score=0.9999951273431653, 0.9999558890826356


 60%|██████    | 3/5 [00:56<00:40, 20.08s/it]

Reg_2 score=0.906954171817856, 0.6721721967984242


 80%|████████  | 4/5 [01:16<00:20, 20.07s/it]

Reg_3 score=0.9991380142060396, 0.9966300414708809


100%|██████████| 5/5 [01:33<00:00, 18.76s/it]

Reg_4 score=0.9997873965864753, 0.998759614719269
D:0.9086087841696945
C=0.7409990212547684
I=0.9334642017726866
D:(np.float64(0.9092932717471657), np.float64(0.0006975257411795996))
C:(np.float64(0.7408734791309856), np.float64(0.0006373151735231611))
I:(np.float64(0.9335156103132087), np.float64(6.81108392378301e-05))





In [17]:
dsprites_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["dsprites"]["b1"]["s-1"], data_name="dsprites")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    dsprites_dci['d'].append(d)
    dsprites_dci['c'].append(c)
    dsprites_dci['i'].append(i)
print(f"D:{np.mean(dsprites_dci['d']), np.std(dsprites_dci['d'])}")
print(f"C:{np.mean(dsprites_dci['c']), np.std(dsprites_dci['c'])}")
print(f"I:{np.mean(dsprites_dci['i']), np.std(dsprites_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 58/58 [00:03<00:00, 15.72it/s]


tensors loaded.
torch.Size([117965, 126]) torch.Size([117965, 8]) torch.Size([117965, 5])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 72/72 [00:04<00:00, 16.17it/s]


tensors loaded.
torch.Size([147456, 126]) torch.Size([147456, 8]) torch.Size([147456, 5])


 20%|██        | 1/5 [00:15<01:00, 15.10s/it]

Reg_0 score=0.9999486717653615, 0.9997269491938707


 40%|████      | 2/5 [00:25<00:36, 12.23s/it]

Reg_1 score=0.9999742227752274, 0.9998346407913759


 60%|██████    | 3/5 [01:04<00:49, 24.59s/it]

Reg_2 score=0.8923440597836818, 0.6756606863117509


 80%|████████  | 4/5 [01:22<00:21, 21.89s/it]

Reg_3 score=0.9999280590780435, 0.9997477087579626


100%|██████████| 5/5 [01:56<00:00, 23.29s/it]


Reg_4 score=0.9992002927927491, 0.9966653334033363
D:0.8324284318997845
C=0.7042293552843987
I=0.9343270636916593


 20%|██        | 1/5 [00:15<01:02, 15.68s/it]

Reg_0 score=0.9999459032864939, 0.9997005413675872


 40%|████      | 2/5 [00:28<00:41, 13.92s/it]

Reg_1 score=0.9999724607877113, 0.9998351644666907


 60%|██████    | 3/5 [00:52<00:36, 18.42s/it]

Reg_2 score=0.8935064177084673, 0.6780573759054453


 80%|████████  | 4/5 [01:15<00:20, 20.36s/it]

Reg_3 score=0.9999401910232215, 0.9997074153981168


100%|██████████| 5/5 [01:44<00:00, 20.90s/it]


Reg_4 score=0.999204499476818, 0.9968520313985623
D:0.8316430058261791
C=0.7033234500744844
I=0.9348305057072805


 20%|██        | 1/5 [00:14<00:59, 14.90s/it]

Reg_0 score=0.9999460305728786, 0.9997258786063188


 40%|████      | 2/5 [00:25<00:37, 12.61s/it]

Reg_1 score=0.9999743967986857, 0.9998400870146515


 60%|██████    | 3/5 [00:56<00:42, 21.03s/it]

Reg_2 score=0.890768226897568, 0.6765043427110637


 80%|████████  | 4/5 [01:17<00:20, 20.83s/it]

Reg_3 score=0.9999581348000267, 0.9997349110728746


100%|██████████| 5/5 [01:39<00:00, 19.95s/it]

Reg_4 score=0.9992029044671144, 0.996658731413595
D:0.8304273482182953
C=0.7027305095446426
I=0.9344927901637007
D:(np.float64(0.831499595314753), np.float64(0.0008232087245128988))
C:(np.float64(0.7034277716345084), np.float64(0.0006163315551501318))
I:(np.float64(0.9345501198542135), np.float64(0.00020948903843068324))





In [18]:
dsprites_dci = {'d':[], 'c':[], 'i': []}
estim_dci = DCIscore(ckpt_path=ckpt_path_x["dsprites"]["b100"]["s-1"], data_name="dsprites")
for _ in range(3): 
    d,c,i = estim_dci.compute()
    dsprites_dci['d'].append(d)
    dsprites_dci['c'].append(c)
    dsprites_dci['i'].append(i)
print(f"D:{np.mean(dsprites_dci['d']), np.std(dsprites_dci['d'])}")
print(f"C:{np.mean(dsprites_dci['c']), np.std(dsprites_dci['c'])}")
print(f"I:{np.mean(dsprites_dci['i']), np.std(dsprites_dci['i'])}")

current device is 0
load dataset val
From images to latent vectors.


100%|██████████| 58/58 [00:03<00:00, 15.62it/s]


tensors loaded.
torch.Size([117965, 126]) torch.Size([117965, 8]) torch.Size([117965, 5])
load test numpy.
loaded test numpy.
From images to latent vectors.


100%|██████████| 72/72 [00:04<00:00, 16.22it/s]


tensors loaded.
torch.Size([147456, 126]) torch.Size([147456, 8]) torch.Size([147456, 5])


 20%|██        | 1/5 [00:03<00:14,  3.55s/it]

Reg_0 score=0.9999980270610368, 0.9999944686309812


 40%|████      | 2/5 [00:10<00:15,  5.32s/it]

Reg_1 score=0.999999296655189, 0.9999940301014093


 60%|██████    | 3/5 [00:39<00:32, 16.48s/it]

Reg_2 score=0.6983749757994251, 0.2868401983391975


 80%|████████  | 4/5 [00:54<00:15, 15.64s/it]

Reg_3 score=0.9999404757563384, 0.9994039596086498


100%|██████████| 5/5 [01:21<00:00, 16.31s/it]


Reg_4 score=0.9996701933105684, 0.9983204489833037
D:0.8656052569532431
C=0.6960114785180361
I=0.8569106211327083


 20%|██        | 1/5 [00:03<00:13,  3.45s/it]

Reg_0 score=0.9999985680281719, 0.9999924294165963


 40%|████      | 2/5 [00:08<00:13,  4.54s/it]

Reg_1 score=0.9999992821532342, 0.999994483953349


 60%|██████    | 3/5 [00:47<00:40, 20.17s/it]

Reg_2 score=0.6853882915624059, 0.2887847352682641


 80%|████████  | 4/5 [01:02<00:17, 17.95s/it]

Reg_3 score=0.999942941254599, 0.9995066469160659


100%|██████████| 5/5 [01:32<00:00, 18.45s/it]


Reg_4 score=0.999665692947799, 0.998339338872178
D:0.8635249831444302
C=0.6958674912346186
I=0.8573235268852907


 20%|██        | 1/5 [00:04<00:17,  4.26s/it]

Reg_0 score=0.9999989498873261, 0.9999959470614102


 40%|████      | 2/5 [00:10<00:15,  5.18s/it]

Reg_1 score=0.999999514184512, 0.9999942512087645


 60%|██████    | 3/5 [00:35<00:28, 14.46s/it]

Reg_2 score=0.6955749532202677, 0.28756621618532774


 80%|████████  | 4/5 [00:50<00:14, 14.50s/it]

Reg_3 score=0.9999193739965715, 0.999458487381971


100%|██████████| 5/5 [01:19<00:00, 15.93s/it]

Reg_4 score=0.9996652483497896, 0.9984228801684147





D:0.8646795116391723
C=0.6977188738330753
I=0.8570875564011775
D:(np.float64(0.8646032505789485), np.float64(0.0008509784924380111))
C:(np.float64(0.6965326145285767), np.float64(0.00084086916841194))
I:(np.float64(0.8571072348063922), np.float64(0.0001691414001217981))


##### Celeba

In [57]:
estim_dci = DCIscore(ckpt_path_celeba_100, data_name="celeba")
estim_dci.compute()

Nombre de GPU : 4

[ GPU 0 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.03 Go
Mémoire réservée : 0.13 Go

[ GPU 1 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go

[ GPU 2 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.17 Go
Mémoire réservée : 2.36 Go

[ GPU 3 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go
current device is 0
No prepare data for celebA - Please ensure all path exists
load dataset - train
load dataset val
tensors loaded.
torch.Size([162770, 126]) torch.Size([162770, 2]) torch.Size([162770, 40])
tensors loaded.
torch.Size([19962, 126]) torch.Size([19962, 2]) torch.Size([19962, 40])


  2%|▎         | 1/40 [00:35<22:56, 35.29s/it]

Reg_0 score=0.7381348423570493, 0.10310709872616541


  5%|▌         | 2/40 [01:17<24:52, 39.27s/it]

Reg_1 score=0.7824787576944373, 0.151508927469876


  8%|▊         | 3/40 [01:49<22:15, 36.10s/it]

Reg_2 score=0.8346813158631319, 0.28343145269935976


 10%|█         | 4/40 [02:30<22:46, 37.95s/it]

Reg_3 score=0.7541970904684662, 0.12044964678667658


 12%|█▎        | 5/40 [03:18<24:21, 41.75s/it]

Reg_4 score=0.7438675274384656, 0.09756490705913734


 15%|█▌        | 6/40 [03:55<22:40, 40.02s/it]

Reg_5 score=0.8374291424934902, 0.39244553463230103


 18%|█▊        | 7/40 [04:30<21:10, 38.50s/it]

Reg_6 score=0.6745111133095314, -0.010233573401608131


 20%|██        | 8/40 [05:15<21:36, 40.51s/it]

Reg_7 score=0.7672319306629222, 0.10870123559336953


 22%|██▎       | 9/40 [05:47<19:34, 37.88s/it]

Reg_8 score=0.8337791824517043, 0.27822892591934967


 25%|██▌       | 10/40 [06:25<18:57, 37.91s/it]

Reg_9 score=0.8799930531183526, 0.43922205189985497


 28%|██▊       | 11/40 [07:12<19:34, 40.50s/it]

Reg_10 score=0.3748476176898249, -0.010013377307166005


 30%|███       | 12/40 [07:38<16:49, 36.05s/it]

Reg_11 score=0.768200390813683, 0.1162812959534516


 32%|███▎      | 13/40 [08:17<16:36, 36.92s/it]

Reg_12 score=0.7172987590716081, 0.08702691467001178


 35%|███▌      | 14/40 [09:02<17:10, 39.63s/it]

Reg_13 score=0.6600034599873229, 0.0736556909144106


 38%|███▊      | 15/40 [09:38<15:59, 38.38s/it]

Reg_14 score=0.6957570106401498, 0.09601391945593674


 40%|████      | 16/40 [10:12<14:53, 37.23s/it]

Reg_15 score=0.9892249947312558, 0.9262410891952451


 42%|████▎     | 17/40 [10:51<14:26, 37.66s/it]

Reg_16 score=0.685257279307034, 0.0839348281679081


 45%|████▌     | 18/40 [11:26<13:28, 36.77s/it]

Reg_17 score=0.7981962147158007, 0.11917545444751965


 48%|████▊     | 19/40 [12:02<12:50, 36.71s/it]

Reg_18 score=0.8648966743984414, 0.40287030816591074


 50%|█████     | 20/40 [12:39<12:11, 36.58s/it]

Reg_19 score=0.8520117500546847, 0.34243504419455695


 52%|█████▎    | 21/40 [13:03<10:27, 33.02s/it]

Reg_20 score=0.8899670016604324, 0.4747435499266218


 55%|█████▌    | 22/40 [13:35<09:46, 32.60s/it]

Reg_21 score=0.8362446937340708, 0.3052965488604893


 57%|█████▊    | 23/40 [14:21<10:24, 36.74s/it]

Reg_22 score=0.5888421407787319, 0.04904874746655019


 60%|██████    | 24/40 [14:52<09:19, 34.97s/it]

Reg_23 score=0.4991986299641781, -0.0033508741921401963


 62%|██████▎   | 25/40 [15:24<08:31, 34.13s/it]

Reg_24 score=0.797017050836687, 0.21353136744047574


 65%|██████▌   | 26/40 [16:03<08:16, 35.44s/it]

Reg_25 score=0.722309876517194, 0.03452021146369766


 68%|██████▊   | 27/40 [16:44<08:04, 37.24s/it]

Reg_26 score=0.728722989205749, 0.1719086753698853


 70%|███████   | 28/40 [17:31<07:59, 39.93s/it]

Reg_27 score=0.7398425419538586, 0.051920503709557586


 72%|███████▎  | 29/40 [17:58<06:38, 36.25s/it]

Reg_28 score=0.7213624581517883, 0.11137666853943617


 75%|███████▌  | 30/40 [18:41<06:21, 38.15s/it]

Reg_29 score=0.7456366712424431, 0.11836970042744366


 78%|███████▊  | 31/40 [19:20<05:47, 38.58s/it]

Reg_30 score=0.6766375128229498, 0.09809747832052473


 80%|████████  | 32/40 [20:02<05:16, 39.53s/it]

Reg_31 score=0.8892256422588408, 0.4811776254098348


 82%|████████▎ | 33/40 [20:38<04:29, 38.57s/it]

Reg_32 score=0.6283659886707824, 0.017947632363459864


 85%|████████▌ | 34/40 [21:14<03:45, 37.60s/it]

Reg_33 score=0.7803901350342033, 0.16935879392351982


 88%|████████▊ | 35/40 [21:46<02:59, 35.87s/it]

Reg_34 score=0.7149001321208424, 0.07873464005120667


 90%|█████████ | 36/40 [22:33<02:37, 39.46s/it]

Reg_35 score=0.7381746118135286, 0.33260728440111653


 92%|█████████▎| 37/40 [23:16<02:00, 40.31s/it]

Reg_36 score=0.8754294904006256, 0.45827653456251816


 95%|█████████▌| 38/40 [23:55<01:20, 40.07s/it]

Reg_37 score=0.59290028297456, 0.014043581164747021


 98%|█████████▊| 39/40 [24:34<00:39, 39.77s/it]

Reg_38 score=0.7102865706429036, 0.10591200549947888


100%|██████████| 40/40 [25:11<00:00, 37.78s/it]

Reg_39 score=0.7719885907166775, 0.18078765496674087





D:0.09310489817496176
C=0.05542524492233808
I=0.19224883824545866


(0.09310489817496176, 0.05542524492233808, 0.19224883824545866)

In [58]:
estim_dci = DCIscore(ckpt_path_celeba, data_name="celeba")
estim_dci.compute()

Nombre de GPU : 4

[ GPU 0 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.03 Go
Mémoire réservée : 0.13 Go

[ GPU 1 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go

[ GPU 2 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.17 Go
Mémoire réservée : 2.36 Go

[ GPU 3 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go
current device is 0
No prepare data for celebA - Please ensure all path exists
load dataset - train
load dataset val
tensors loaded.
torch.Size([162770, 126]) torch.Size([162770, 2]) torch.Size([162770, 40])
tensors loaded.
torch.Size([19962, 126]) torch.Size([19962, 2]) torch.Size([19962, 40])


  2%|▎         | 1/40 [00:50<32:53, 50.59s/it]

Reg_0 score=0.7061105118698918, 0.06966440880196267


  5%|▌         | 2/40 [01:28<27:29, 43.42s/it]

Reg_1 score=0.7498756818009854, 0.14269273256378712


  8%|▊         | 3/40 [02:02<23:58, 38.88s/it]

Reg_2 score=0.8256232754320627, 0.25332212556369105


 10%|█         | 4/40 [02:42<23:42, 39.51s/it]

Reg_3 score=0.7157164909731633, 0.05678672911682281


 12%|█▎        | 5/40 [03:34<25:35, 43.86s/it]

Reg_4 score=0.748207572977527, 0.11022610063744043


 15%|█▌        | 6/40 [04:12<23:47, 41.99s/it]

Reg_5 score=0.8416490795167436, 0.3854177182240843


 18%|█▊        | 7/40 [04:42<20:49, 37.87s/it]

Reg_6 score=0.6289108913607329, -0.014153602784118702


 20%|██        | 8/40 [05:15<19:24, 36.41s/it]

Reg_7 score=0.7315642177257091, 0.056005563344257325


 22%|██▎       | 9/40 [05:48<18:15, 35.34s/it]

Reg_8 score=0.8200730721934004, 0.26520970552104306


 25%|██▌       | 10/40 [06:18<16:52, 33.74s/it]

Reg_9 score=0.8796349531221733, 0.4434871806120375


 28%|██▊       | 11/40 [07:03<17:56, 37.12s/it]

Reg_10 score=0.41304181180796185, -0.023946487245803727


 30%|███       | 12/40 [07:31<15:58, 34.24s/it]

Reg_11 score=0.7289546944963663, 0.10649817724627819


 32%|███▎      | 13/40 [08:10<16:04, 35.71s/it]

Reg_12 score=0.6715525171885453, 0.05167696724989013


 35%|███▌      | 14/40 [09:14<19:14, 44.40s/it]

Reg_13 score=0.6320760078626775, 0.03739868510870192


 38%|███▊      | 15/40 [10:07<19:36, 47.06s/it]

Reg_14 score=0.6291467018369465, 0.03926751895307545


 40%|████      | 16/40 [10:47<17:56, 44.84s/it]

Reg_15 score=0.9922442591669528, 0.926047238866578


 42%|████▎     | 17/40 [11:34<17:22, 45.33s/it]

Reg_16 score=0.6563405910038919, 0.015150359970973581


 45%|████▌     | 18/40 [12:22<16:55, 46.15s/it]

Reg_17 score=0.8075249213177219, 0.09523977451766996


 48%|████▊     | 19/40 [13:05<15:54, 45.45s/it]

Reg_18 score=0.8506992033167278, 0.378998079761382


 50%|█████     | 20/40 [13:51<15:11, 45.57s/it]

Reg_19 score=0.8244442484255594, 0.3125299927970129


 52%|█████▎    | 21/40 [14:25<13:20, 42.15s/it]

Reg_20 score=0.8769400541464039, 0.4389484783293538


 55%|█████▌    | 22/40 [15:05<12:23, 41.33s/it]

Reg_21 score=0.8159215676654656, 0.2359108808672058


 57%|█████▊    | 23/40 [15:52<12:11, 43.04s/it]

Reg_22 score=0.5582029282960603, 0.020663677292613425


 60%|██████    | 24/40 [16:44<12:14, 45.90s/it]

Reg_23 score=0.43029340179274045, -0.012368983627509822


 62%|██████▎   | 25/40 [17:24<10:58, 43.92s/it]

Reg_24 score=0.7910021740147922, 0.14866904098409173


 65%|██████▌   | 26/40 [18:06<10:06, 43.31s/it]

Reg_25 score=0.6567104873727185, 0.03992151386061382


 68%|██████▊   | 27/40 [18:45<09:05, 41.96s/it]

Reg_26 score=0.7396418831748652, 0.18902234210773228


 70%|███████   | 28/40 [19:26<08:22, 41.87s/it]

Reg_27 score=0.6574249961915068, 0.04567473253177701


 72%|███████▎  | 29/40 [20:03<07:22, 40.24s/it]

Reg_28 score=0.7209363049761262, 0.10638268088873526


 75%|███████▌  | 30/40 [20:48<06:58, 41.81s/it]

Reg_29 score=0.7313190380311173, 0.12932397136801055


 78%|███████▊  | 31/40 [21:27<06:09, 41.02s/it]

Reg_30 score=0.653787121904823, 0.03904547056928109


 80%|████████  | 32/40 [21:59<05:06, 38.34s/it]

Reg_31 score=0.8545787563228763, 0.3903624456660003


 82%|████████▎ | 33/40 [22:44<04:42, 40.33s/it]

Reg_32 score=0.594173734951321, 0.010000007119086107


 85%|████████▌ | 34/40 [23:30<04:10, 41.81s/it]

Reg_33 score=0.783474475489255, 0.17685002732362354


 88%|████████▊ | 35/40 [24:09<03:26, 41.24s/it]

Reg_34 score=0.6643626141622292, 0.07219938197314824


 90%|█████████ | 36/40 [24:58<02:53, 43.45s/it]

Reg_35 score=0.7498234813990312, 0.3218886316592974


 92%|█████████▎| 37/40 [25:24<01:54, 38.29s/it]

Reg_36 score=0.8701823986392205, 0.4475102116475873


 95%|█████████▌| 38/40 [26:18<01:25, 42.98s/it]

Reg_37 score=0.6162674490411523, 0.012088836352726728


 98%|█████████▊| 39/40 [27:06<00:44, 44.46s/it]

Reg_38 score=0.7231308014964435, 0.07456542406231259


100%|██████████| 40/40 [27:51<00:00, 41.78s/it]

Reg_39 score=0.7383582253200234, 0.14928921002777296





D:0.09757300385864183
C=0.03976439143716474
I=0.16984840058719147


(0.09757300385864183, 0.03976439143716474, 0.16984840058719147)

#### Dev DCI

In [49]:
latent = LatentDataModule(standard=True, 
                          batch_size=2**19,
                          data_name="shapes",
                          ckpt_path=ckpt_path_shapes_100,
                          pref_gpu=2)
latent.prepare_data()
latent.setup("fit")
latent_loader = latent.train_dataloader()
batch = next(iter(latent_loader))
z_s_0, z_t_0, label_0 = batch
z_s_0.shape, z_t_0.shape, label_0.shape

latent.setup("test")
latent_test_loader = latent.test_dataloader()
batch = next(iter(latent_test_loader))
z_s_1, z_t_1, label_1 = batch
z_s_1.shape, z_t_1.shape, label_1.shape

z_tr = torch.cat([z_s_0, z_t_0], dim=1)
z_tr = (z_tr - z_tr.mean(axis=0, keepdims=True)) / (z_tr.std(axis=0, keepdims=True) + 1e-8)
z_te = torch.cat([z_s_1, z_t_1], dim=1)
z_te = (z_te - z_te.mean(axis=0, keepdims=True)) / (z_te.std(axis=0, keepdims=True) + 1e-8)
print(z_tr.shape, z_te.shape, label_0.shape, label_1.shape)

Nombre de GPU : 4

[ GPU 0 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.03 Go
Mémoire réservée : 0.13 Go

[ GPU 1 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go

[ GPU 2 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.04 Go
Mémoire réservée : 2.36 Go

[ GPU 3 ]
Nom : NVIDIA L40S
Mémoire totale : 47.8 Go
Mémoire utilisée : 0.0 Go
Mémoire réservée : 0.0 Go
current device is 0
load dataset - train
load dataset val
tensors loaded.
tensors loaded.
torch.Size([307200, 128]) torch.Size([96000, 128]) torch.Size([307200, 6]) torch.Size([96000, 6])


In [50]:
regressors = {str(k):{} for k in range(label_0.shape[1])}
for k in tqdm(range(label_0.shape[1])):

    perm = torch.randperm(len(z_tr))
    reg_k = RandomForestRegressor(n_estimators=20, 
                        max_depth=20,
                        n_jobs=-1)
    reg_k.fit(z_tr[perm], label_0[:, k][perm])

    perm = torch.randperm(len(z_tr))
    score_tr = reg_k.score(z_tr[perm], label_0[:, k][perm])
    perm = torch.randperm(len(z_te))
    score_te = reg_k.score(z_te[perm], label_1[:, k][perm])

    print(f"Reg_{k} score={score_tr}, {score_te}")
    regressors[str(k)]["model"] = reg_k
    regressors[str(k)]["score_tr"] = score_tr
    regressors[str(k)]["score_te"] = score_te

  0%|          | 0/6 [00:00<?, ?it/s]

 17%|█▋        | 1/6 [00:22<01:52, 22.53s/it]

Reg_0 score=1.0, 1.0


 33%|███▎      | 2/6 [01:20<02:53, 43.26s/it]

Reg_1 score=0.9998020379423094, 0.999043715427895


 50%|█████     | 3/6 [02:34<02:51, 57.28s/it]

Reg_2 score=0.9999070732971801, 0.9996646055954586


 50%|█████     | 3/6 [03:03<03:03, 61.07s/it]


KeyboardInterrupt: 

In [None]:
# 1) Récupérer la matrice d'importances R (D, K) à partir de tes régressions
D = z_tr.shape[1]
K = label_0.shape[1]

R = np.zeros((D, K), dtype=float)  # feature importances pour chaque facteur k
for k in range(K):
    model = regressors[str(k)]["model"]
    imp = getattr(model, "feature_importances_", None)
    if imp is None:
        raise ValueError(f"Aucune feature_importances_ pour k={k}")
    if len(imp) != D:
        raise ValueError(f"Dim mismatch: len(imp)={len(imp)} vs D={D}")
    R[:, k] = imp

# 2) Normaliser par colonne -> P(d | k): "où vit le facteur k ?"
col_sum = R.sum(axis=0, keepdims=True)               # (1, K)
col_sum[col_sum == 0] = 1.0                          # éviter /0 si colonne nulle
P_d_given_k = R / col_sum                             # (D, K)

# 3) (optionnel) Normaliser par ligne -> P(k | d): "quel facteur porte la dimension d ?"
row_sum = R.sum(axis=1, keepdims=True)               # (D, 1)
row_sum[row_sum == 0] = 1.0
P_k_given_d = R / row_sum                             # (D, K)

# 4) (optionnel) petit affichage: top-dims pour chaque facteur (P(d|k) décroissant)
def print_top_dims_per_factor(P, top=5, factor_names=None):
    D, K = P.shape
    if factor_names is None:
        factor_names = [f"f{k}" for k in range(K)]
    for k in range(K):
        order = np.argsort(-P[:, k])
        print(f"\nFacteur {factor_names[k]}:")
        for d in order[:top]:
            print(f"  dim {d:3d} : {P[d, k]:.3f}")

In [None]:
print_top_dims_per_factor(P_d_given_k)


Facteur f0:
  dim  28 : 0.702
  dim 126 : 0.197
  dim 127 : 0.095
  dim  71 : 0.006
  dim   4 : 0.000

Facteur f1:
  dim 111 : 0.731
  dim 105 : 0.125
  dim 112 : 0.116
  dim  43 : 0.008
  dim 106 : 0.004

Facteur f2:
  dim  58 : 0.858
  dim   6 : 0.113
  dim  15 : 0.018
  dim   7 : 0.004
  dim  96 : 0.001

Facteur f3:
  dim  54 : 0.590
  dim  26 : 0.091
  dim  70 : 0.090
  dim  43 : 0.070
  dim  82 : 0.031

Facteur f4:
  dim  74 : 0.326
  dim  91 : 0.118
  dim  50 : 0.116
  dim  62 : 0.106
  dim  46 : 0.090

Facteur f5:
  dim  56 : 0.491
  dim  60 : 0.202
  dim  72 : 0.071
  dim  97 : 0.054
  dim  10 : 0.037


In [None]:
def dci_scores(P_kd: np.ndarray, r2_per_factor=[reg["score_te"] for reg in regressors.values()], eps=1e-12):
    """
    P_kd : array (K, D) d'importances non-négatives (p.ex. permutation importance, gain, ou P(d|k)).
           Pas besoin d'être normalisé: on renormalisera correctement pour D et C.
    r2_per_factor : iterable de longueur K avec les R^2 (test) pour l'informativeness (optionnel).

    Retourne: D, C, I (I peut être None si r2_per_factor est None).
    """
    R = np.asarray(P_kd, dtype=float)
    R = np.clip(R, 0.0, None)
    K, D = R.shape
    if R.sum() == 0:
        raise ValueError("La matrice d'importances est nulle.")

    # Poids des dimensions et des facteurs (somme des importances)
    w_d = R.sum(axis=0)  # (D,)
    w_k = R.sum(axis=1)  # (K,)
    w_d = w_d / (w_d.sum() + eps)
    w_k = w_k / (w_k.sum() + eps)

    # p(k|d) : normalisation par colonnes ; p(d|k) : normalisation par lignes
    P_k_given_d = R / (R.sum(axis=0, keepdims=True) + eps)   # (K, D)
    P_d_given_k = R / (R.sum(axis=1, keepdims=True) + eps)   # (K, D)

    # Entropies
    def entropy(p, axis):
        p = np.clip(p, eps, 1.0)
        return -(p * np.log(p)).sum(axis=axis)

    H_k_given_d = entropy(P_k_given_d, axis=0)               # (D,)
    H_d_given_k = entropy(P_d_given_k, axis=1)               # (K,)

    # Disentanglement et completeness (entropies normalisées)
    D_score = float(((1.0 - H_k_given_d / (np.log(K) + eps)) * w_d).sum())
    C_score = float(((1.0 - H_d_given_k / (np.log(D) + eps)) * w_k).sum())

    # Informativeness = moyenne des R^2 test (tronqués à [0,1])
    I_score = None
    if r2_per_factor is not None:
        r2 = np.asarray(r2_per_factor, dtype=float)
        I_score = float(np.clip(r2, 0.0, 1.0).mean())

    return D_score, C_score, I_score
dci_scores(P_d_given_k)

(0.7138876947209459, 0.911502977744196, 0.9980409023406279)