In [1]:
import os 

import torch
import torch.nn as nn 
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from helpers import load_experiment

PATH_PREFIX = "/home/danis/Projects/AlphaCaption/AutoConceptBottleneck/autoconcept"

In [2]:
experiment_path = "outputs/2023-05-26/07-31-18"

dm, model = load_experiment(os.path.join(PATH_PREFIX, experiment_path))
train_loader = dm.train_dataloader()
train_set = train_loader.dataset

Global seed set to 42


Fetching configuration...
Loading datamodule...


[nltk_data] Downloading package wordnet to /home/danis/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


                             filename  \
0  Black_Footed_Albatross_0046_18.jpg   
1  Black_Footed_Albatross_0009_34.jpg   
2  Black_Footed_Albatross_0002_55.jpg   
3  Black_Footed_Albatross_0074_59.jpg   
4  Black_Footed_Albatross_0014_89.jpg   

                                     source_captions  \
0  ['closeup bin food include broccoli bread', 'm...   
1  ['giraffe eating food top tree', 'giraffe stan...   
2  ['flower vase sitting porch stand', 'white vas...   
3  ['zebra grazing lush green grass field', 'zebr...   
4  ['woman swim suit holding parasol sunny day', ...   

                                mask_source_captions  \
0  ['coco', 'coco', 'coco', 'coco', 'coco', 'cub'...   
1  ['coco', 'coco', 'coco', 'coco', 'coco', 'cub'...   
2  ['coco', 'coco', 'coco', 'coco', 'coco', 'cub'...   
3  ['coco', 'coco', 'coco', 'coco', 'coco', 'cub'...   
4  ['coco', 'coco', 'coco', 'coco', 'coco', 'cub'...   

                                          attributes  
0  [0, 0, 0, 0, 1, 0, 0,

100%|██████████| 11788/11788 [00:07<00:00, 1539.26it/s]


Max length:  373
Index for <pad>: [0]
Loading model


  rank_zero_warn(
  rank_zero_warn(


In [3]:
train_set[0]

{'image': tensor([[[ 1.1700,  1.1700,  1.1700,  ..., -0.4739, -0.6281, -0.6965],
          [ 1.1700,  1.1700,  1.1700,  ..., -0.3541, -0.4911, -0.5596],
          [ 1.1700,  1.1700,  1.1529,  ..., -0.1828, -0.3027, -0.3541],
          ...,
          [ 1.8722,  1.8722,  1.8722,  ...,  1.8722,  1.8722,  1.8722],
          [ 0.8961,  0.8961,  0.8961,  ...,  0.8961,  0.8961,  0.8961],
          [ 0.2453,  0.2453,  0.2453,  ...,  0.2453,  0.2453,  0.2453]],
 
         [[ 1.2206,  1.2206,  1.2206,  ..., -0.6001, -0.7577, -0.8277],
          [ 1.2206,  1.2206,  1.2206,  ..., -0.5126, -0.6527, -0.7227],
          [ 1.2206,  1.2206,  1.2031,  ..., -0.3725, -0.4951, -0.5476],
          ...,
          [ 2.0434,  2.0434,  2.0434,  ...,  2.0434,  2.0434,  2.0434],
          [ 1.0455,  1.0455,  1.0455,  ...,  1.0455,  1.0455,  1.0455],
          [ 0.3803,  0.3803,  0.3803,  ...,  0.3803,  0.3803,  0.3803]],
 
         [[ 1.7685,  1.7685,  1.7685,  ..., -0.3055, -0.4624, -0.5321],
          [ 1.7685,

In [3]:
# Shapes dataset
# [red, green, blue, square, triangle, circle, ... each for class]

attribute_mapping = {
    0: [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0
        ],
    1: [0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0
        ],
    2: [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0
        ],
    3: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0
        ],
    4: [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0
        ],
    5: [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0
        ],
    6: [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0
        ],
    7: [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0
        ],
    8: [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1
        ],
}

In [4]:
def compute_purity(loader, model):
    is_framework = hasattr(model.main, "concept_extractor")
    n_features = model.main.feature_extractor.main.fc.out_features
    
    features_to_attributes = list()
    attribute_values = [[[0, 0] for _ in range(312)] for f in range(n_features)]
    
    for batch in tqdm(loader):
        images, targets, attributes_all = batch["image"].cuda(), batch["target"], batch["attributes"]
        N = images.shape[0]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            target = targets[sample_id].item()
            attributes = np.array(attributes_all[sample_id])
            features = batch_features[sample_id]

            for feature_id in range(n_features):

                feature = features[feature_id]

                for attribute_id, attribute in enumerate(attributes):
                    value_on = attribute * feature + (1 - attribute) * (1 - feature)
                    attribute_values[feature_id][attribute_id][0] += value_on

                    value_off = attribute * (1 - feature) + (1 - attribute) * feature
                    attribute_values[feature_id][attribute_id][1] += value_off
    
    for a in attribute_values:
        a_ = [max(p) / len(train_set) for p in a]
        features_to_attributes.append(a_)
    
    return features_to_attributes

f2a = compute_purity(train_loader, model)

100%|██████████| 75/75 [1:33:11<00:00, 74.55s/it]


[[0.5333321326156851,
  0.5175032345234292,
  0.533291174448071,
  0.5250748356577005,
  0.5407203081253298,
  0.5317071517640035,
  0.5076629421105598,
  0.5332108073363836,
  0.5261002664925284,
  0.5320663969540372,
  0.5271907677771028,
  0.5317306329081528,
  0.5326326533950034,
  0.5332791021394779,
  0.5042796146478096,
  0.5335700843456018,
  0.5289865350325488,
  0.5301229507290409,
  0.5333944570098356,
  0.5318578868016709,
  0.5022507955962102,
  0.5253758674400069,
  0.530538869761824,
  0.5263586434973666,
  0.5297260090227047,
  0.5312035478907904,
  0.5316772707125194,
  0.5327359905145962,
  0.5331830379923144,
  0.506105225445456,
  0.5349021361629955,
  0.528530151456812,
  0.5300971222247023,
  0.5333561121547309,
  0.5316782797744053,
  0.5023924489557059,
  0.5257248682386559,
  0.5311900892707677,
  0.5297741728542495,
  0.5356804026956976,
  0.5372620835232162,
  0.5326200738999841,
  0.5333906634756372,
  0.5325450487256175,
  0.5279836917880678,
  0.5206651229

In [4]:
def compute_purity(loader, model, attribute_mapping):
    is_framework = hasattr(model.main, "concept_extractor")
    n_features = model.main.feature_extractor.main.fc.out_features
    
    features_to_attributes = list()
    attribute_values = [[[0, 0] for _ in range(len(attribute_mapping[0]))] for f in range(n_features)]
    
    for batch in tqdm(loader):
        images, targets = batch["image"].cuda(), batch["target"]
        N = images.shape[0]
        
        if is_framework:
            batch_features = model.main.inference(images)[1].cpu().detach().numpy()
        else:
            batch_features = model(images)["concept_probs"].cpu().detach().numpy()
        
        for sample_id in range(N):
            target = targets[sample_id].item()
            attributes = np.array(attribute_mapping[target])
            features = batch_features[sample_id]

            for feature_id in range(n_features):

                feature = features[feature_id]

                for attribute_id, attribute in enumerate(attributes):
                    value_on = attribute * feature + (1 - attribute) * (1 - feature)
                    attribute_values[feature_id][attribute_id][0] += value_on

                    value_off = attribute * (1 - feature) + (1 - attribute) * feature
                    attribute_values[feature_id][attribute_id][1] += value_off
    
    for a in attribute_values:
        a_ = [max(p) / len(train_set) for p in a]
        features_to_attributes.append(a_)
    
    return features_to_attributes

f2a = compute_purity(train_loader, model, attribute_mapping)

100%|██████████| 29/29 [00:18<00:00,  1.61it/s]


In [6]:
def find_best_alignment(features_to_attributes, iter_converge=20.0):
    n_features, n_attributes = np.array(features_to_attributes).shape

    features_to_attributes_ = list()
    for feature_to_attributes in features_to_attributes:
        feature_to_attributes_ = sorted([(idx, fa) for idx, fa in enumerate(feature_to_attributes)], key=lambda x: x[1], reverse=True)
        features_to_attributes_.append(feature_to_attributes_)
    
    attributes_to_features = list(list() for _ in range(n_attributes))

    for idx_feat, feature_to_attributes in enumerate(features_to_attributes_):
        for idx_attr, score in feature_to_attributes:
            attributes_to_features[idx_attr].append((idx_feat, score))
    
    attributes_to_features_ = list()
    for attr2feature in attributes_to_features:
        attributes_to_features_.append(sorted(attr2feature, key=lambda x: x[1], reverse=True))
    
    best_idx = list(None for _ in range(n_features))
    best_scores = list(None for _ in range(n_features))

    patience_left = iter_converge

    while None in best_idx and patience_left > 0:
        prev_best = [_ for _ in best_idx]

        for feat_idx, f2a in enumerate(features_to_attributes_):
            
            if best_idx[feat_idx] is None:

                for att_idx, score in f2a:

                    if att_idx not in best_idx:
                        best_idx[feat_idx] = att_idx
                        best_scores[feat_idx] = score
                        break

                    else:
                        idx_other = best_idx.index(att_idx)
                        score_other = best_scores[idx_other]

                        if score > score_other:
                            best_idx[feat_idx] = att_idx
                            best_scores[feat_idx] = score

                            best_idx[idx_other] = None
                            best_scores[idx_other] = None
                            break
        
        if best_idx == prev_best:
            patience_left -= 1
        else:
            patience_left = iter_converge
        
    return list(zip(best_idx, best_scores))

    
result = find_best_alignment(f2a)

scores = [b for _, b in result if b is not None]
print("Purity: ", np.array(scores).mean())

Purity:  0.5732395308343736


### Results

| model | activation | norm_fn | slot_norm | reg_dist | f1-score | purity | cluster | align | directory |
|:-----------|:----:|:----:|:----:|:----:|:----:|:-------:|:-------:|:-------:|:-----------|
| Baseline | `sigmoid` |   `-`   |  `-`   |   `-`   | `0.830247` | `0.767724` | `A`  | `-`  | `outputs/2023-05-22/08-37-36` |
| Baseline | `gumbel` |   `-`   |   `-`   |  `-`   | `0.404321`  | `0.828083` | `A` | `-` |  `outputs/2023-05-22/08-49-23`  |
| Framework | `sigmoid` | `softmax`   |  `false`   |     `false`   | `0.969136`  | `0.636225` | `B` | `D` | `outputs/2023-05-22/08-18-17` |  
| Framework | `gumbel` |  `softmax`   |  `false`   |    `false`  |   `0.848765`  |  `0.763983`    |  `B` |  `C`   |  `outputs/2023-05-22/08-04-48`  |  
| Framework | `gumbel` |  `entmax`   |  `false`   |    `false`  |   `0.842593`  |  `0.748309`     | `A`  |  `B`  |  `outputs/2023-05-22/09-13-40`  | 
| Framework | `gumbel` |  `softmax`   |  `true`   |    `false`  |   `0.731482`  |  `0.707190`     |  `A` |  `B`  |  `outputs/2023-05-22/09-38-41`  | 
| Framework | `gumbel` |  `entmax`   |  `true`   |    `false`  |   `0.586420`  |  `0.691018`     | `A`  |  `B`  |  `outputs/2023-05-22/11-03-11`  | 
| Framework | `gumbel` |  `softmax`   |  `false`   |    `true`  |   `0.814815`  |  `0.726690`    | `A` | `D` |  `outputs/2023-05-22/09-53-54`  | 