## Measuring disentanglement

Specify path to an experiment in a format `outputs/YYYY-MM-DD/HH-MM-SS`.

In [1]:
EXPERIMENT_PATH = "outputs/2023-06-01/12-16-30"

In [2]:
import os 
import math
import torch

import numpy as np
import torch.nn as nn 
from tqdm import tqdm
import matplotlib.pyplot as plt

from helpers import load_experiment
from sklearn.linear_model import Lasso
from sklearn.ensemble import RandomForestRegressor

In [3]:
PATH_PREFIX = "/home/danis/Projects/AlphaCaption/AutoConceptBottleneck/autoconcept"
dm, model = load_experiment(os.path.join(PATH_PREFIX, EXPERIMENT_PATH))
train_loader = dm.train_dataloader()
train_set = train_loader.dataset

Global seed set to 42


Fetching configuration...
Loading datamodule...


[nltk_data] Downloading package wordnet to /home/danis/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
100%|██████████| 15000/15000 [00:03<00:00, 4013.86it/s]


Len of vocab:  8651
Max len of caption:  378
Index for <pad>: [0]
Loading model


  rank_zero_warn(
  rank_zero_warn(


## DCI

In [4]:
def prepare_data(loader, model):
    X_train, y_train = list(), list()
    is_framework = hasattr(model.main, "concept_extractor")
    
    for batch in tqdm(loader):
        images, attributes_all = batch["image"].cuda(), batch["attributes"]
        N = images.shape[0]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            attributes = np.array(attributes_all[sample_id])
            features = batch_features[sample_id]

            X_train.append(features)
            y_train.append(attributes)
    
    X_train, y_train = np.array(X_train), np.array(y_train)

    return X_train, y_train

X_train, y_train = prepare_data(train_loader, model)

100%|██████████| 165/165 [01:14<00:00,  2.22it/s]


In [5]:
TINY = 1e-12

def norm_entropy(p):
    n = p.shape[0]
    return - p.dot(np.log(p + TINY) / np.log(n + TINY))

def entropic_scores(r):
    r = np.abs(r)
    ps = r / np.sum(r, axis=0) # 'probabilities'
    hs = [1-norm_entropy(p) for p in ps.T]
    return hs

def fit_linear_model(X_train, y_train, fast=False, seed=42):
    n_attributes = y_train.shape[1]
    R = list()
    for regressor_idx in tqdm(range(n_attributes)):
        kwargs = {"random_state": seed}
        if fast:
            kwargs["n_estimators"] = 20
            kwargs["max_depth"] = 10
        regressor = RandomForestRegressor(**kwargs)
        regressor.fit(X_train, y_train[:, regressor_idx])
        R.append(regressor.feature_importances_)
    return np.array(R)

def compute_disentanglement(R):
    disent_scores = entropic_scores(R.T)
    c_rel_importance = np.sum(R,1) / np.sum(R)
    disent_w_avg = np.sum(np.array(disent_scores) * c_rel_importance)
    print(f"Disentanglement: {disent_w_avg:.4f}")
    return disent_w_avg

def compute_completeness(R):
    complete_scores = entropic_scores(R)
    complete_scores = [v for v in complete_scores if not math.isnan(v)]
    complete_avg = np.mean(complete_scores)
    print(f"Completeness: {complete_avg:.4f}")
    return complete_avg

In [6]:
R = fit_linear_model(X_train, y_train)

100%|██████████| 14/14 [02:49<00:00, 12.11s/it]


In [7]:
disentanglement = compute_disentanglement(R)

Disentanglement: 0.0164


In [8]:
completeness = compute_completeness(R)

Completeness: 0.0085


## Purity

In [25]:
def compute_purity(loader, model):
    is_framework = hasattr(model.main, "concept_extractor")
    n_features = model.main.feature_extractor.main.fc.out_features
    
    features_to_attributes = list()
    attribute_values = None
    
    for batch in tqdm(loader):
        images, attributes_all = batch["image"].cuda(), batch["attributes"]
        N = images.shape[0]
        n_attributes = np.array(attributes_all).shape[1]

        if attribute_values is None:
            attribute_values = [[[0, 0] for _ in range(n_attributes)] for f in range(n_features)]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            attributes = np.array(attributes_all[sample_id])
            features = batch_features[sample_id]

            for feature_id in range(n_features):

                feature = features[feature_id]

                for attribute_id, attribute in enumerate(attributes):
                    value_on = attribute * feature + (1 - attribute) * (1 - feature)
                    attribute_values[feature_id][attribute_id][0] += value_on

                    value_off = attribute * (1 - feature) + (1 - attribute) * feature
                    attribute_values[feature_id][attribute_id][1] += value_off
    
    for a in attribute_values:
        a_ = [max(p) / len(train_set) for p in a]
        features_to_attributes.append(a_)
    
    return features_to_attributes

f2a = compute_purity(train_loader, model)

100%|██████████| 29/29 [00:15<00:00,  1.84it/s]


In [29]:
def find_best_alignment(features_to_attributes, iter_converge=20.0):
    n_features, n_attributes = np.array(features_to_attributes).shape

    features_to_attributes_ = list()
    for feature_to_attributes in features_to_attributes:
        feature_to_attributes_ = sorted([(idx, fa) for idx, fa in enumerate(feature_to_attributes)], key=lambda x: x[1], reverse=True)
        features_to_attributes_.append(feature_to_attributes_)
    
    attributes_to_features = list(list() for _ in range(n_attributes))

    for idx_feat, feature_to_attributes in enumerate(features_to_attributes_):
        for idx_attr, score in feature_to_attributes:
            attributes_to_features[idx_attr].append((idx_feat, score))
    
    attributes_to_features_ = list()
    for attr2feature in attributes_to_features:
        attributes_to_features_.append(sorted(attr2feature, key=lambda x: x[1], reverse=True))
    
    best_idx = list(None for _ in range(n_features))
    best_scores = list(None for _ in range(n_features))

    patience_left = iter_converge

    while None in best_idx and patience_left > 0:
        prev_best = [_ for _ in best_idx]

        for feat_idx, f2a in enumerate(features_to_attributes_):
            
            if best_idx[feat_idx] is None:

                for att_idx, score in f2a:

                    if att_idx not in best_idx:
                        best_idx[feat_idx] = att_idx
                        best_scores[feat_idx] = score
                        break

                    else:
                        idx_other = best_idx.index(att_idx)
                        score_other = best_scores[idx_other]

                        if score > score_other:
                            best_idx[feat_idx] = att_idx
                            best_scores[feat_idx] = score

                            best_idx[idx_other] = None
                            best_scores[idx_other] = None
                            break
        
        if best_idx == prev_best:
            patience_left -= 1
        else:
            patience_left = iter_converge
        
    return list(zip(best_idx, best_scores))

    
result = find_best_alignment(f2a)

scores = [b for _, b in result if b is not None and b != 0]
print("Purity: ", np.array(scores).mean())

Purity:  0.7787873725699388


## Results

### 1. Shapes dataset

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | disentanglement | completeness | cluster | align | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-------:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.830247` | `0.767724` | `0.465324 `| `0.481560` |  `A`  | `-`  | `outputs/2023-05-22/08-37-36` |
| Baseline | `gumbel` |   `-`   |   `-`   |  `-`   | `0.404321`  | `0.828083` | `0.316362` | `0.276083` | `A` | `-` |  `outputs/2023-05-22/08-49-23`  |
| Framework | `sigmoid` | `softmax`   |  `false`   |     `false`   | `0.969136`  | `0.636225`  | `0.437534` | `0.408124` | `B` | `D` | `outputs/2023-05-22/08-18-17` |  
| Framework | `gumbel` |  `softmax`   |  `false`   |    `false`  |   `0.848765`  |  `0.763983`    | `0.651922` | `0.611969` |   `B` |  `C`   |  `outputs/2023-05-22/08-04-48`  |  
| Framework | `gumbel` |  `entmax`   |  `false`   |    `false`  |   `0.842593`  |  `0.748309`     | `0.742202` | `0.736384` |  `A`  |  `B`  |  `outputs/2023-05-22/09-13-40`  | 
| Framework | `gumbel` |  `softmax`   |  `true`   |    `false`  |   `0.731482`  |  `0.707190`     | `0.670618` | `0.637436` |   `A` |  `B`  |  `outputs/2023-05-22/09-38-41`  | 
| Framework | `gumbel` |  `entmax`   |  `true`   |    `false`  |   `0.586420`  |  `0.691018`     | `0.582086` | `0.564636` |  `A`  |  `B`  |  `outputs/2023-05-22/11-03-11`  | 
| Framework | `gumbel` |  `softmax`   |  `false`   |    `true`  |   `0.814815`  |  `0.726690`    | `0.708535` | `0.673539` |  `A` | `D` |  `outputs/2023-05-22/09-53-54`  | 

### 2. CUB-200 

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | disentanglement | completeness | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.805452` | `0.573071` | `0.189394`| `0.196021` | `outputs/2023-05-26/07-31-18` |
| Baseline | `gumbel (0.01)` |   `-`   |   `-`   |  `-`   | `0.765`  | `0.578` | `0.193078` | `0.201281` | `outputs/2023-05-28/09-21-34`  |
| Framework | `gumbel (0.5)` |  `entmax`   |  `false`   |    `false`  |   `0.773003`  |  `0.593836`    | `0.229795` | `0.255646` |   `outputs/2023-05-27/10-34-06`  | 
| Framework | `gumbel (0.01)` |  `entmax`   |  `false`   |    `false`  |   `0.726`  |  `0.657`    | `X` | `X` |   `outputs/2023-05-27/19-41-06`  | 

### 3. MIMIC-CXR

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | disentanglement | completeness | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.768` |  `0.0164`| `0.0085` | `outputs/2023-06-01/12-16-30` |
| Framework | `gumbel (0.01)` |   `-`   |   `-`   |  `-`   | `0.749` | `0.0222` | `0.0124` | `outputs/2023-06-01/13-18-08`  |
