In [1]:
import os 

import torch
import torch.nn as nn 
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from helpers import load_experiment

PATH_PREFIX = "/home/danis/Projects/AlphaCaption/AutoConceptBottleneck/autoconcept"

In [2]:
experiment_path = "outputs/2023-05-22/09-53-54"

dm, model = load_experiment(os.path.join(PATH_PREFIX, experiment_path))
train_loader = dm.train_dataloader()
train_set = train_loader.dataset

Global seed set to 42


Fetching configuration...
Loading datamodule...


[nltk_data] Downloading package wordnet to /home/danis/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
100%|██████████| 2700/2700 [00:00<00:00, 9915.98it/s]


Len of vocab:  53
Max len of caption:  12
Index for <pad>: [0]
Loading model


  rank_zero_warn(
  rank_zero_warn(


In [3]:
# Shapes dataset
# [red, green, blue, square, triangle, circle, ... each for class]

attribute_mapping = {
    0: [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0
        ],
    1: [0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0
        ],
    2: [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0
        ],
    3: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0
        ],
    4: [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0
        ],
    5: [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0
        ],
    6: [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0
        ],
    7: [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0
        ],
    8: [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1
        ],
}

In [4]:
def compute_purity(loader, model, attribute_mapping):
    is_framework = hasattr(model.main, "concept_extractor")
    n_features = model.main.feature_extractor.main.fc.out_features
    
    features_to_attributes = list()
    attribute_values = [[[0, 0] for _ in range(len(attribute_mapping[0]))] for f in range(n_features)]
    
    for batch in tqdm(loader):
        images, targets = batch["image"].cuda(), batch["target"]
        N = images.shape[0]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            target = targets[sample_id].item()
            attributes = np.array(attribute_mapping[target])
            features = batch_features[sample_id]

            for feature_id in range(n_features):

                feature = features[feature_id]

                for attribute_id, attribute in enumerate(attributes):
                    value_on = attribute * feature + (1 - attribute) * (1 - feature)
                    attribute_values[feature_id][attribute_id][0] += value_on

                    value_off = attribute * (1 - feature) + (1 - attribute) * feature
                    attribute_values[feature_id][attribute_id][1] += value_off
    
    for a in attribute_values:
        a_ = [max(p) / len(train_set) for p in a]
        features_to_attributes.append(a_)
    
    return features_to_attributes

f2a = compute_purity(train_loader, model, attribute_mapping)

100%|██████████| 29/29 [00:18<00:00,  1.58it/s]


In [5]:
def find_best_alignment(features_to_attributes, iter_converge=20.0):
    n_features, n_attributes = np.array(features_to_attributes).shape

    features_to_attributes_ = list()
    for feature_to_attributes in features_to_attributes:
        feature_to_attributes_ = sorted([(idx, fa) for idx, fa in enumerate(feature_to_attributes)], key=lambda x: x[1], reverse=True)
        features_to_attributes_.append(feature_to_attributes_)
    
    attributes_to_features = list(list() for _ in range(n_attributes))

    for idx_feat, feature_to_attributes in enumerate(features_to_attributes_):
        for idx_attr, score in feature_to_attributes:
            attributes_to_features[idx_attr].append((idx_feat, score))
    
    attributes_to_features_ = list()
    for attr2feature in attributes_to_features:
        attributes_to_features_.append(sorted(attr2feature, key=lambda x: x[1], reverse=True))
    
    best_idx = list(None for _ in range(n_features))
    best_scores = list(None for _ in range(n_features))

    patience_left = iter_converge

    while None in best_idx and patience_left > 0:
        prev_best = [_ for _ in best_idx]

        for feat_idx, f2a in enumerate(features_to_attributes_):
            
            if best_idx[feat_idx] is None:

                for att_idx, score in f2a:

                    if att_idx not in best_idx:
                        best_idx[feat_idx] = att_idx
                        best_scores[feat_idx] = score
                        break

                    else:
                        idx_other = best_idx.index(att_idx)
                        score_other = best_scores[idx_other]

                        if score > score_other:
                            best_idx[feat_idx] = att_idx
                            best_scores[feat_idx] = score

                            best_idx[idx_other] = None
                            best_scores[idx_other] = None
                            break
        
        if best_idx == prev_best:
            patience_left -= 1
        else:
            patience_left = iter_converge
        
    return list(zip(best_idx, best_scores))

    
result = find_best_alignment(f2a)

scores = [b for _, b in result if b is not None]
print("Purity: ", np.array(scores).mean())

Purity:  0.7266901688760646


### Results

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-----------|
| Baseline | `sigmoid` |   -   |  -   |   -   | `0.830247` | `0.767724`  | `outputs/2023-05-22/08-37-36` |
| Baseline | `gumbel` |   -   |   -   |  -   | `0.404321`  | `0.828083` | `outputs/2023-05-22/08-49-23`  |
| Framework | `sigmoid` | `softmax`   |  `false`   |     `false`   | `0.969136`  | `0.636225` | `outputs/2023-05-22/08-18-17` |  
| Framework | `gumbel` |  `softmax`   |  `false`   |    `false`  |   `0.848765`  |  `0.763983`        |  `outputs/2023-05-22/08-04-48`  |  
| Framework | `gumbel` |  `entmax`   |  `false`   |    `false`  |   `0.842593`  |  `0.748309`        |  `outputs/2023-05-22/09-13-40`  | 
| Framework | `gumbel` |  `softmax`   |  `true`   |    `false`  |   `0.731482`  |  `0.707190`        |  `outputs/2023-05-22/09-38-41`  | 
| Framework | `gumbel` |  `softmax`   |  `false`   |    `true`  |   `0.814815`  |  `0.726690`        |  `outputs/2023-05-22/09-53-54`  | 