## Measuring disentanglement

Specify path to an experiment in a format `outputs/YYYY-MM-DD/HH-MM-SS`.

In [13]:
EXPERIMENT_PATH = "outputs/2023-06-02/17-54-32"

In [14]:
import os 
import math
import torch

import numpy as np
import torch.nn as nn 
from tqdm import tqdm
import matplotlib.pyplot as plt

from helpers import load_experiment
from extract import prepare_data_dci, fit_linear_model, compute_completeness, compute_disentanglement, compute_informativeness

In [15]:
PATH_PREFIX = "/home/danis/Projects/AlphaCaption/AutoConceptBottleneck/autoconcept"
dm, model = load_experiment(os.path.join(PATH_PREFIX, EXPERIMENT_PATH))
train_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()
train_set = train_loader.dataset

Global seed set to 42


Fetching configuration...
Loading datamodule...


100%|██████████| 2700/2700 [00:00<00:00, 9265.14it/s]


Len of vocab:  53
Max len of caption:  12
Index for <pad>: [0]
Loading model


  rank_zero_warn(
  rank_zero_warn(


## DCI

In [16]:
X_train, y_train = prepare_data_dci(train_loader, model)
X_test, y_test = prepare_data_dci(test_loader, model)

100%|██████████| 29/29 [00:14<00:00,  2.07it/s]
100%|██████████| 9/9 [00:08<00:00,  1.11it/s]


In [17]:
R, errors = fit_linear_model(X_train, y_train, X_test, y_test)

100%|██████████| 6/6 [00:02<00:00,  2.05it/s]


In [18]:
disentanglement = compute_disentanglement(R)
print(f"Disentanglement: {disentanglement:.3f}")

Disentanglement: 0.539


In [19]:
completeness = compute_completeness(R)
print(f"Completeness: {completeness:.3f}")

Completeness: 0.513


In [20]:
informativeness = compute_informativeness(errors)
print(f"Informativeness (NRMSE): {informativeness:.3f}")

Informativeness (NRMSE): 0.180


## Purity

In [None]:
def compute_purity(loader, model):
    is_framework = hasattr(model.main, "concept_extractor")
    n_features = model.main.feature_extractor.main.fc.out_features
    
    features_to_attributes = list()
    attribute_values = None
    
    for batch in tqdm(loader):
        images, attributes_all = batch["image"].cuda(), batch["attributes"]
        N = images.shape[0]
        n_attributes = np.array(attributes_all).shape[1]

        if attribute_values is None:
            attribute_values = [[[0, 0] for _ in range(n_attributes)] for f in range(n_features)]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            attributes = np.array(attributes_all[sample_id])
            features = batch_features[sample_id]

            for feature_id in range(n_features):

                feature = features[feature_id]

                for attribute_id, attribute in enumerate(attributes):
                    value_on = attribute * feature + (1 - attribute) * (1 - feature)
                    attribute_values[feature_id][attribute_id][0] += value_on

                    value_off = attribute * (1 - feature) + (1 - attribute) * feature
                    attribute_values[feature_id][attribute_id][1] += value_off
    
    for a in attribute_values:
        a_ = [max(p) / len(train_set) for p in a]
        features_to_attributes.append(a_)
    
    return features_to_attributes

f2a = compute_purity(train_loader, model)

100%|██████████| 29/29 [00:15<00:00,  1.84it/s]


In [None]:
def find_best_alignment(features_to_attributes, iter_converge=20.0):
    n_features, n_attributes = np.array(features_to_attributes).shape

    features_to_attributes_ = list()
    for feature_to_attributes in features_to_attributes:
        feature_to_attributes_ = sorted([(idx, fa) for idx, fa in enumerate(feature_to_attributes)], key=lambda x: x[1], reverse=True)
        features_to_attributes_.append(feature_to_attributes_)
    
    attributes_to_features = list(list() for _ in range(n_attributes))

    for idx_feat, feature_to_attributes in enumerate(features_to_attributes_):
        for idx_attr, score in feature_to_attributes:
            attributes_to_features[idx_attr].append((idx_feat, score))
    
    attributes_to_features_ = list()
    for attr2feature in attributes_to_features:
        attributes_to_features_.append(sorted(attr2feature, key=lambda x: x[1], reverse=True))
    
    best_idx = list(None for _ in range(n_features))
    best_scores = list(None for _ in range(n_features))

    patience_left = iter_converge

    while None in best_idx and patience_left > 0:
        prev_best = [_ for _ in best_idx]

        for feat_idx, f2a in enumerate(features_to_attributes_):
            
            if best_idx[feat_idx] is None:

                for att_idx, score in f2a:

                    if att_idx not in best_idx:
                        best_idx[feat_idx] = att_idx
                        best_scores[feat_idx] = score
                        break

                    else:
                        idx_other = best_idx.index(att_idx)
                        score_other = best_scores[idx_other]

                        if score > score_other:
                            best_idx[feat_idx] = att_idx
                            best_scores[feat_idx] = score

                            best_idx[idx_other] = None
                            best_scores[idx_other] = None
                            break
        
        if best_idx == prev_best:
            patience_left -= 1
        else:
            patience_left = iter_converge
        
    return list(zip(best_idx, best_scores))

    
result = find_best_alignment(f2a)

scores = [b for _, b in result if b is not None and b != 0]
print("Purity: ", np.array(scores).mean())

Purity:  0.7787873725699388


## Results

### 1. Shapes dataset

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | disentanglement | completeness | cluster | align | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-------:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.830247` | `0.767724` | `0.465324 `| `0.481560` |  `A`  | `-`  | `outputs/2023-05-22/08-37-36` |
| Baseline | `gumbel` |   `-`   |   `-`   |  `-`   | `0.404321`  | `0.828083` | `0.316362` | `0.276083` | `A` | `-` |  `outputs/2023-05-22/08-49-23`  |
| Framework | `sigmoid` | `softmax`   |  `false`   |     `false`   | `0.969136`  | `0.636225`  | `0.437534` | `0.408124` | `B` | `D` | `outputs/2023-05-22/08-18-17` |  
| Framework | `gumbel` |  `softmax`   |  `false`   |    `false`  |   `0.848765`  |  `0.763983`    | `0.651922` | `0.611969` |   `B` |  `C`   |  `outputs/2023-05-22/08-04-48`  |  
| Framework | `gumbel` |  `entmax`   |  `false`   |    `false`  |   `0.842593`  |  `0.748309`     | `0.742202` | `0.736384` |  `A`  |  `B`  |  `outputs/2023-05-22/09-13-40`  | 
| Framework | `gumbel` |  `softmax`   |  `true`   |    `false`  |   `0.731482`  |  `0.707190`     | `0.670618` | `0.637436` |   `A` |  `B`  |  `outputs/2023-05-22/09-38-41`  | 
| Framework | `gumbel` |  `entmax`   |  `true`   |    `false`  |   `0.586420`  |  `0.691018`     | `0.582086` | `0.564636` |  `A`  |  `B`  |  `outputs/2023-05-22/11-03-11`  | 
| Framework | `gumbel` |  `softmax`   |  `false`   |    `true`  |   `0.814815`  |  `0.726690`    | `0.708535` | `0.673539` |  `A` | `D` |  `outputs/2023-05-22/09-53-54`  | 

### 2. CUB-200 

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | disentanglement | completeness | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.805452` | `0.573071` | `0.189394`| `0.196021` | `outputs/2023-05-26/07-31-18` |
| Baseline | `gumbel (0.01)` |   `-`   |   `-`   |  `-`   | `0.765`  | `0.578` | `0.193078` | `0.201281` | `outputs/2023-05-28/09-21-34`  |
| Framework | `gumbel (0.5)` |  `entmax`   |  `false`   |    `false`  |   `0.773003`  |  `0.593836`    | `0.229795` | `0.255646` |   `outputs/2023-05-27/10-34-06`  | 
| Framework | `gumbel (0.01)` |  `entmax`   |  `false`   |    `false`  |   `0.726`  |  `0.657`    | `X` | `X` |   `outputs/2023-05-27/19-41-06`  | 

### 3. MIMIC-CXR

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | disentanglement | completeness | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.768` |  `0.0164`| `0.0085` | `outputs/2023-06-01/12-16-30` |
| Framework | `gumbel (0.01)` |   `-`   |   `-`   |  `-`   | `0.749` | `0.0222` | `0.0124` | `outputs/2023-06-01/13-18-08`  |


In [1]:
results = {

"E49-SHP":
[[0.981131, 0.873248, 0.973656, 0.924189, 0.971781],
[0.663345, 0.394986, 0.387052, 0.561059, 0.575503],
[0.505478, 0.364848, 0.391295, 0.460447, 0.642119],
[0.172368, 0.280508, 0.169319, 0.192731, 0.182917]],

"E50-SHP":
[[0.981165, 0.971696, 0.990583, 0.963805, 0.983096],
[0.588585, 0.573661, 0.56222, 0.553471, 0.477187],
[0.46946, 0.632319, 0.498188, 0.485164, 0.380766],
[0.143945, 0.175667, 0.154161, 0.149143, 0.161818]],

"E51-SHP":
[[0.447716, 0.389919, 0.472793, 0.665221, 0.489999],
[0.432098, 0.37835, 0.430014, 0.495095, 0.528005],
[0.419269, 0.382079, 0.424006, 0.461387, 0.431549],
[0.477292, 0.476659, 0.459475, 0.364721, 0.448022]],

"E52-SHP":
[[0.942825, 0.968106, 0.988683, 0.983082, 0.9887],
[0.472531, 0.45885, 0.638394, 0.511719, 0.509037],
[0.447773, 0.461443, 0.549162, 0.533417, 0.422075],
[0.231318, 0.14777, 0.142614, 0.159164, 0.139041]],

"E53-SHP":
[[0.881701, 0.827895, 0.879056, 0.874072, 0.682161],
[0.627323, 0.667616, 0.595665, 0.850418, 0.534167],
[0.529322, 0.517727, 0.527888, 0.692079, 0.511404],
[0.139276, 0.145541, 0.137512, 0.162982, 0.205537]],

"E54-SHP":
[[0.575949, 0.35089, 0.582664, 0.482592, 0.376687],
[0.485671, 0.460129, 0.661162, 0.617536, 0.583347],
[0.519757, 0.406376, 0.518161, 0.522811, 0.61405],
[0.177332, 0.366635, 0.146314, 0.171029, 0.265272]],

"E55-SHP":
[[0.614845, 0.42983, 0.584451, 0.477093, 0.393848],
[0.540889, 0.504733, 0.678941, 0.546654, 0.577525],
[0.425633, 0.505349, 0.588776, 0.478584, 0.560115],
[0.226377, 0.231023, 0.146473, 0.176644, 0.241622]],

"E56-SHP":
[[0.862637, 0.829518, 0.862029, 0.850979, 0.740367],
[0.618041, 0.618922, 0.505485, 0.73972, 0.538934],
[0.460178, 0.541834, 0.441986, 0.625805, 0.513224],
[0.13346, 0.175929, 0.154805, 0.153977, 0.179752]],

"E57-SHP":
[[0.494562, 0.657113, 0.52312, 0.514724, 0.508858],
[0.586914, 0.810653, 0.365879, 0.488906, 0.59746],
[0.54572, 0.579612, 0.399097, 0.43978, 0.65056],
[0.348297, 0.145581, 0.318479, 0.254663, 0.204913]],

"E58-SHP":
[[0.439655, 0.605264, 0.416931, 0.468073, 0.512713],
[0.525453, 0.775373, 0.408299, 0.5834, 0.549561],
[0.486188, 0.379191, 0.408362, 0.547149, 0.581411],
[0.319103, 0.174612, 0.328444, 0.241011, 0.177954]],

"E59-SHP":
[[0.605599, 0.6777, 0.646318, 0.581926, 0.673149],
[0.731953, 0.801716, 0.782887, 0.706306, 0.739048],
[0.776975, 0.721015, 0.769809, 0.739891, 0.747892],
[0.086086, 0.078189, 0.08505, 0.050539, 0.105751]],

"E60-SHP":
[[0.455141, 0.523521, 0.497093, 0.494261, 0.47096],
[0.564572, 0.594083, 0.54585, 0.645181, 0.679455],
[0.503765, 0.468938, 0.542489, 0.407685, 0.668507],
[0.283051, 0.184142, 0.215551, 0.276161, 0.207509]],

"E61-SHP":
[[0.881701, 0.827895, 0.879056, 0.874072, 0.682161],
[0.627323, 0.667616, 0.595665, 0.850418, 0.534167],
[0.529322, 0.517727, 0.527888, 0.692079, 0.511404],
[0.139276, 0.145541, 0.137512, 0.162982, 0.205537]],

"E62-SHP":
[[0.862637, 0.829518, 0.862029, 0.850979, 0.740367],
[0.618041, 0.618922, 0.505485, 0.73972, 0.538934],
[0.460178, 0.541834, 0.441986, 0.625805, 0.513224],
[0.13346, 0.175929, 0.154805, 0.153977, 0.179752]]

}

In [3]:
import numpy as np

f1 = np.array([0.981131, 0.873248, 0.973656, 0.924189, 0.971781])
print(f"f1: {f1.mean():4f} ± {f1.std():.4f}")

D = np.array([0.663345, 0.394986, 0.387052, 0.561059, 0.575503])
print(f"disentanglement: {D.mean():4f} ± {D.std():.4f}")

C = np.array([0.505478, 0.364848, 0.391295, 0.460447, 0.642119])
print(f"completeness:{C.mean():4f} ± {C.std():.4f}")

I = np.array([0.172368, 0.280508, 0.169319,0.192731, 0.182917,])
print(f"informativeness: {I.mean():4f} ± {I.std():.4f}")

f1: 0.944801 ± 0.0410
disentanglement: 0.516389 ± 0.1082
completeness:0.472837 ± 0.0982
informativeness: 0.199569 ± 0.0413


In [7]:
for i in range(49, 63):
    exp = f"E{i}-SHP"
    res = results[exp]
    f1 = np.array(res[0])
    D = np.array(res[1])
    C = np.array(res[2])
    I = np.array(res[3])
    print(exp)
    print(f"f1: {f1.mean():.2f} ± {f1.std():.2f}")
    print(f"disentanglement: {D.mean():.2f} ± {D.std():.2f}")
    print(f"completeness:{C.mean():.2f} ± {C.std():.2f}")
    print(f"informativeness: {I.mean():.2f} ± {I.std():.2f}")
    print()

E49-SHP
f1: 0.94 ± 0.04
disentanglement: 0.52 ± 0.11
completeness:0.47 ± 0.10
informativeness: 0.20 ± 0.04

E50-SHP
f1: 0.98 ± 0.01
disentanglement: 0.55 ± 0.04
completeness:0.49 ± 0.08
informativeness: 0.16 ± 0.01

E51-SHP
f1: 0.49 ± 0.09
disentanglement: 0.45 ± 0.05
completeness:0.42 ± 0.03
informativeness: 0.45 ± 0.04

E52-SHP
f1: 0.97 ± 0.02
disentanglement: 0.52 ± 0.06
completeness:0.48 ± 0.05
informativeness: 0.16 ± 0.03

E53-SHP
f1: 0.83 ± 0.08
disentanglement: 0.66 ± 0.11
completeness:0.56 ± 0.07
informativeness: 0.16 ± 0.03

E54-SHP
f1: 0.47 ± 0.10
disentanglement: 0.56 ± 0.08
completeness:0.52 ± 0.07
informativeness: 0.23 ± 0.08

E55-SHP
f1: 0.50 ± 0.09
disentanglement: 0.57 ± 0.06
completeness:0.51 ± 0.06
informativeness: 0.20 ± 0.04

E56-SHP
f1: 0.83 ± 0.05
disentanglement: 0.60 ± 0.08
completeness:0.52 ± 0.07
informativeness: 0.16 ± 0.02

E57-SHP
f1: 0.54 ± 0.06
disentanglement: 0.57 ± 0.15
completeness:0.52 ± 0.09
informativeness: 0.25 ± 0.07

E58-SHP
f1: 0.49 ± 0.07
dise