In [2]:
import os, sys
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
import sklearn.preprocessing as Preprocessing
import pickle
from datetime import datetime

sys.path.insert(0, 'src')
from utils.places365_pred_utils import get_class_category_dict, get_category_class_dict
from utils.utils import ensure_dir, write_lists, informal_log
from utils.attribute_utils import get_one_hot_attributes, get_frequent_attributes, hyperparam_search, partition_paths_by_congruency, convert_sparse_to_dense_attributes


### Load features and attributes

In [3]:
# Load features
features_dir = os.path.join('saved', 'ADE20K', '0501_105640')
train_features_path = os.path.join(features_dir, 'train_features.pth')
val_features_path = os.path.join(features_dir, 'val_features.pth')
test_features_path = os.path.join(features_dir, 'test_features.pth')

train_features_dict = torch.load(train_features_path)
train_features = train_features_dict['features']
train_paths = train_features_dict['paths']

val_features_dict = torch.load(val_features_path)
val_features = val_features_dict['features']
val_paths = val_features_dict['paths']

test_features_dict = torch.load(test_features_path)
test_features = test_features_dict['features']
test_paths = test_features_dict['paths']

features = {
    'train': train_features,
    'val': val_features,
    'test': test_features
}
paths = {
    'train': train_paths,
    'val': val_paths,
    'test': test_paths
}
n_attributes = 1200
frequency_threshold = 0

# Load data and calculate attributes
data_path = os.path.join('data', 'ade20k', 'full_ade20k_imagelabels.pth')
data = torch.load(data_path)

print("Obtaining one hot encodings of attributes")
attributes = get_one_hot_attributes(
    data=data,
    paths=paths,
    n_attr=n_attributes
)
attribute_save_path = os.path.join(os.path.dirname(data_path), 'one_hot_attributes.pth')
if not os.path.exists(attribute_save_path):
    torch.save(attributes, attribute_save_path)
    print("Saved one hot attributes from train/val/test to {}".format(attribute_save_path))

print("Obtaining frequent attributes only")
sparse_freq_attributes, metadata = get_frequent_attributes(
    attributes=attributes,
    frequency_threshold=frequency_threshold
)

# Get indices of frequent attributes
sparse_freq_attributes_one_hot = metadata['freq_attr_one_hot']
frequent_attribute_idxs = metadata['freq_attr_idxs']

print("After filtering concepts that appear in {}+ images, we have {} concepts".format(
    frequency_threshold, len(frequent_attribute_idxs)))
      
# Get dense attributes
dense_attributes = {}
for split in sparse_freq_attributes.keys():
    dense_attributes[split] = convert_sparse_to_dense_attributes(
        sparse_attributes=sparse_freq_attributes[split],
        used_attributes_idxs=frequent_attribute_idxs)
    
# Save dir for this frequency threshold
save_dir = os.path.join(os.path.dirname(data_path), 'filter_attr_{}'.format(frequency_threshold))
ensure_dir(save_dir)
# Save sparse attributes from each split
sparse_freq_attribute_save_path = os.path.join(save_dir, 'splits_sparse_one_hot_attributes.pth')
if not os.path.exists(sparse_freq_attribute_save_path):
    torch.save(sparse_freq_attributes, sparse_freq_attribute_save_path)
    print("Saved frequency filtered one hot attributes from train/val/test to {}".format(sparse_freq_attribute_save_path))
else:
    print("Path {} already exists".format(sparse_freq_attribute_save_path))
    
# Save dense attributes from each split
dense_freq_attributes_save_path = os.path.join(save_dir, 'splits_dense_one_hot_attributes.pth')
if not os.path.exists(dense_freq_attributes_save_path):
    torch.save(dense_attributes, dense_freq_attributes_save_path)
    print("Saved dense frequency filtered one hot attributes from train/val/test to {}".format(dense_freq_attributes_save_path))
else:
    print("Path to {} already exists".format(dense_freq_attributes_save_path))
# Save the idxs of the attributes that were kept after filtering for frequency
frequent_idx_save_path = os.path.join(save_dir, 'frequent_attribute_idxs_n-{}.pth'.format(len(frequent_attribute_idxs)))
if not os.path.exists(frequent_idx_save_path):
    torch.save(frequent_attribute_idxs, frequent_idx_save_path)
    print("Saved indices of frequent attributes from training to {}".format(frequent_idx_save_path))
# Load names of attributes
labels_path = os.path.join('data', 'broden1_224', 'label.csv')
attribute_label_dict = pd.read_csv(labels_path, index_col=0)['name'].to_dict()
print("Loaded human-readable labels)")

Obtaining one hot encodings of attributes
Processing attributes for train split


100%|██████████████████████████████████| 13326/13326 [00:00<00:00, 94009.73it/s]


Processing attributes for val split


100%|████████████████████████████████████| 4442/4442 [00:00<00:00, 79593.55it/s]


Processing attributes for test split


100%|████████████████████████████████████| 4442/4442 [00:00<00:00, 62211.70it/s]

Obtaining frequent attributes only





0 examples have no more attributes
0/13326 examples affected
0 examples have no more attributes
0/4442 examples affected
0 examples have no more attributes
2/4442 examples affected
After filtering concepts that appear in 0+ images, we have 623 concepts
Path data/ade20k/filter_attr_0/splits_sparse_one_hot_attributes.pth already exists
Path to data/ade20k/filter_attr_0/splits_dense_one_hot_attributes.pth already exists
Loaded human-readable labels)


### For each attribute, create a linear classifier (hyperparameter search)

In [None]:
cavs = {}
train_attributes = dense_attributes['train']
val_attributes = dense_attributes['val']

resume = False
print(train_attributes.shape)

cavs_save_dir = os.path.join('saved', 'ADE20K', 'cav', 'weighted', datetime.now().strftime(r'%m%d_%H%M%S'))
if not resume:
    ensure_dir(cavs_save_dir)
    # if os.path.exists(cavs_save_dir):
    #     raise ValueError("Path {} already exists".format(cavs_save_dir))
else:
    timestamp = '0727_111655'
    cavs_save_dir = os.path.join(os.path.dirname(cavs_save_dir), timestamp)


cavs_save_path = os.path.join(cavs_save_dir, 'cavs.pickle')

if resume:
    with open(cavs_save_path, 'rb') as file:
        cavs = pickle.load(file)
    saved_attr_idxs = list(cavs.keys())
    last_saved_idx = max(saved_attr_idxs)
else:
    last_saved_idx = -1

log_path = os.path.join(cavs_save_dir, 'log.txt')
n_frequent_attributes = len(frequent_attribute_idxs)

for idx, attribute_idx in tqdm(enumerate(frequent_attribute_idxs)):
    if attribute_idx < last_saved_idx:
        continue
    informal_log("[{}] {}/{} Calculating CAV for {}".format(
        datetime.now().strftime(r'%m%d_%H%M%S'),
        idx+1,
        n_frequent_attributes,
        attribute_label_dict[attribute_idx]), log_path)
    scaler = None # Preprocessing.StandardScaler()
    
    logistic_regression_args = {
        'solver': 'liblinear',
        'penalty': 'l2',
        'class_weight': 'balanced'
    }
    cav = hyperparam_search(
        train_features=train_features,
        train_labels=train_attributes[:, idx],
        val_features=val_features,
        val_labels=val_attributes[:, idx],
        scaler=scaler,
        log_path=log_path,
        logistic_regression_args=logistic_regression_args)
    cavs[attribute_idx] = cav
    
    accuracy = cav.score(val_features, val_attributes[:, idx])
    informal_log("CAV accuracy for {} concept ({}): {:.4f}".format(
        attribute_label_dict[attribute_idx],
        attribute_idx,
        accuracy), log_path)
    # Save periodically
    if idx % 10 == 0:
        pickle.dump(cavs, open(cavs_save_path, 'wb'))
        
pickle.dump(cavs, open(cavs_save_path, 'wb'))


(13326, 623)


0it [00:00, ?it/s]

[0804_105359] 1/623 Calculating CAV for wall
Best accuracy: 0.8986942818550203 Regularization: 0.001
Best accuracy: 0.9004952723998199 Regularization: 0.01


1it [00:22, 22.68s/it]

CAV accuracy for wall concept (12): 0.9005
[0804_105422] 2/623 Calculating CAV for sky
Best accuracy: 0.9268347591175147 Regularization: 0.001
Best accuracy: 0.9342638451148132 Regularization: 0.005
Best accuracy: 0.9376407023863125 Regularization: 0.01
Best accuracy: 0.9398919405673121 Regularization: 0.05
Best accuracy: 0.9410175596578118 Regularization: 0.1


2it [00:43, 21.39s/it]

CAV accuracy for sky concept (13): 0.9410
[0804_105442] 3/623 Calculating CAV for floor
Best accuracy: 0.9207564160288159 Regularization: 0.001
Best accuracy: 0.9218820351193157 Regularization: 0.005
Best accuracy: 0.9225574065736155 Regularization: 0.01


3it [01:05, 21.80s/it]

CAV accuracy for floor concept (14): 0.9226
[0804_105504] 4/623 Calculating CAV for windowpane
Best accuracy: 0.7973885637100405 Regularization: 0.001
Best accuracy: 0.8057181449797388 Regularization: 0.005
Best accuracy: 0.8059432687978388 Regularization: 0.05


4it [01:30, 22.92s/it]

CAV accuracy for windowpane concept (15): 0.8059
[0804_105529] 5/623 Calculating CAV for tree
Best accuracy: 0.8818099954975236 Regularization: 0.001
Best accuracy: 0.895092300765421 Regularization: 0.005
Best accuracy: 0.8980189104007203 Regularization: 0.01
Best accuracy: 0.9022962629446195 Regularization: 0.05
Best accuracy: 0.9029716343989194 Regularization: 0.1
Best accuracy: 0.9036470058532192 Regularization: 0.5
Best accuracy: 0.9038721296713192 Regularization: 1


5it [01:53, 23.09s/it]

CAV accuracy for tree concept (16): 0.9039
[0804_105553] 6/623 Calculating CAV for wood
Best accuracy: 0.9529491220171095 Regularization: 0.001
Best accuracy: 0.975461503827105 Regularization: 0.005
Best accuracy: 0.979963980189104 Regularization: 0.01
Best accuracy: 0.9909950472760019 Regularization: 0.05
Best accuracy: 0.9925709140027015 Regularization: 0.1
Best accuracy: 0.9959477712742009 Regularization: 0.5
Best accuracy: 0.9961728950923008 Regularization: 1
Best accuracy: 0.9968482665466006 Regularization: 3


6it [02:18, 23.72s/it]

Best accuracy: 0.9970733903647006 Regularization: 5
CAV accuracy for wood concept (17): 0.9971
[0804_105617] 7/623 Calculating CAV for building
Best accuracy: 0.9099504727600181 Regularization: 0.001
Best accuracy: 0.9104007203962179 Regularization: 0.01


7it [02:39, 22.80s/it]

CAV accuracy for building concept (18): 0.9104
[0804_105638] 8/623 Calculating CAV for person
Best accuracy: 0.8354344889689329 Regularization: 0.001
Best accuracy: 0.84984241332733 Regularization: 0.005
Best accuracy: 0.85029266096353 Regularization: 0.01
Best accuracy: 0.8520936515083296 Regularization: 0.05


8it [03:04, 23.64s/it]

CAV accuracy for person concept (19): 0.8521
[0804_105704] 9/623 Calculating CAV for head


### Save Concept present vectors from CAVs for each image in train/val/test

In [11]:
# For images in train/val/test splits, predict the presence/absence of each concept from features using CAVs
# Save as one-hot encoded concept-presence vectors
# cavs_path = 'saved/ADE20K/cav/weighted/0517_151725/cavs.pickle'
cavs_path = 'saved/ADE20K/cav/weighted/all_cavs.pickle'
frequency_threshold = 1200
n_concepts = 27
frequent_attribute_idxs_path = 'data/ade20k/filter_attr_{}/frequent_attribute_idxs_n-{}.pth'.format(
    frequency_threshold, n_concepts)

use_dense = True

cavs_save_dir = os.path.dirname(cavs_path)
if 'scaled' in cavs_path:
    scale = True
else:
    scale = False
cavs = pickle.load(open(cavs_path, 'rb'))
splits = ['train', 'val', 'test']
frequent_attribute_idxs = torch.load(frequent_attribute_idxs_path)
assert len(frequent_attribute_idxs) == len(cavs)

if scale:
    import sklearn.preprocessing as Preprocessing
    scaler = Preprocessing.StandardScaler()
    scaler.fit(features['train'])
    scaled_features = {
        'train': scaler.transform(features['train']),
        'val': scaler.transform(features['val']),
        'test': scaler.transform(features['test'])
    }
concept_vectors_dict = {}
concept_vectors_save_path = os.path.join(cavs_save_dir, '{}_cav_attributes.pth'.format('dense' if use_dense else 'sparse'))
if os.path.exists(concept_vectors_save_path):
    print("Concept presence vectors already exist at {}".format(concept_vectors_save_path))
else:
    for split in splits:
        if scale:
            split_features = scaled_features[split]
        else:
            split_features = features[split]
        n_samples = len(split_features)
        print("Obtaining concept presence vectors for {} split".format(split))

        concept_presence_vectors = []
        
        # If use_dense, save dense concept_presence_vectors
        if use_dense:
            for attr_idx in frequent_attribute_idxs:
                cav = cavs[attr_idx]
                concept_present = cav.predict(split_features)
                print(concept_present)
                assert len(concept_present) == n_samples
                concept_presence_vectors.append(concept_present)
                
        else:
            for attr_idx in tqdm(range(n_attributes)):
                if attr_idx in cavs:
                    cav = cavs[attr_idx]
                    concept_present = cav.predict(split_features)
                    assert len(concept_present) == n_samples
                    concept_presence_vectors.append(concept_present)
                    
                else:
                    concept_presence_vectors.append(np.zeros(n_samples))
        concept_presence_vectors = np.stack(concept_presence_vectors, axis=1)
        print(concept_presence_vectors.shape, n_samples)
        # Concept vectors only for frequent vectors; Turn it into one hot vector

        concept_vectors_dict[split] = concept_presence_vectors
    
    torch.save(concept_vectors_dict, concept_vectors_save_path)
    print("Saved concept present vectors from CAVs to {}".format(concept_vectors_save_path))


Obtaining concept presence vectors for train split
[0. 1. 0. ... 0. 0. 1.]
[1. 0. 1. ... 1. 0. 1.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 1. 0. 0.]
[1. 0. 0. ... 0. 1. 1.]
[1. 0. 1. ... 1. 1. 1.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 1. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 1. ... 0. 0. 1.]
[0. 0. 0. ... 1. 1. 1.]
[0. 0. 0. ... 1. 0. 1.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 1.]
[0. 0. 0. ... 0. 0. 1.]
[0. 1. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ...

### Learn Explainer from CAVs

In [6]:
# Load attributes from CAVs
cav_attributes_path = os.path.join(cavs_save_dir, 'cav_attributes.pth')
cav_attributes = torch.load(cav_attributes_path)

# Load model's predictions
prediction_path = os.path.join(
    'saved', 
    'PlacesCategoryClassification',
    '0510_102912',
    'ADE20K_predictions', 
    'saga', 
    '{}_outputs_predictions.pth')
splits = ['train', 'val', 'test']
predictions = {}
for split in splits:
    predictions[split] = torch.load(prediction_path.format(split))['predictions']




In [8]:
# Train explainer
logistic_regression_args = {
    'solver': 'saga',
    'penalty': 'l1'
}

max_iter = 200
# explainer = hyperparam_search(
#     train_features=cav_attributes['train'],
#     train_labels=predictions['train'], 
#     val_features=cav_attributes['val'], 
#     val_labels=predictions['val'], 
#     scaler=None,
#     Cs = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 3, 5],
#     log_path=None,
#     logistic_regression_args=logistic_regression_args)
explainer = LogisticRegression(
    solver='saga',
    penalty='l1',
    C=0.1,
    multi_class='multinomial',
    max_iter=max_iter
)
explainer.fit(cav_attributes['train'], predictions['train'])
accuracy = explainer.score(cav_attributes['val'], predictions['val'])
print("Explainer {} accuracy: {}".format(explainer, accuracy))

Explainer LogisticRegression(C=0.1, max_iter=200, multi_class='multinomial', penalty='l1',
                   solver='saga') accuracy: 0.6960828455650608


In [10]:
c = 0.1
solver = 'saga'
penalty = 'l1'
save_dir = os.path.dirname(prediction_path)

# Save CAV explainer
cav_explainer_dir = os.path.join(save_dir, 'weighted_cav_explainer')
ensure_dir(cav_explainer_dir)
cav_explainer_save_path = os.path.join(cav_explainer_dir, '{}_explainer_{}_{}.pickle'.format(solver, penalty, c))
if not os.path.exists(cav_explainer_save_path):
    pickle.dump(explainer, open(cav_explainer_save_path, 'wb'))
    print("Saved explainer trained on CAVs to {}".format(cav_explainer_save_path))
else:
    print("Explainer already exists at '{}'".format(cav_explainer_save_path))

# Save CAV explainer predictions
accuracy = explainer.score(cav_attributes['val'], predictions['val'])
print(accuracy)

explainer_outputs = explainer.decision_function(cav_attributes['val'])
explainer_probabilities = explainer.predict_proba(cav_attributes['val'])
explainer_predictions = explainer.predict(cav_attributes['val'])

validation_output = {
    'outputs': explainer_outputs,
    'probabilities': explainer_probabilities,
    'predictions': explainer_predictions
}
validation_output_path = os.path.join(cav_explainer_dir, '{}_explainer_{}_{}_validation.pth'.format(solver, penalty, c))
if not os.path.exists(validation_output_path):
    torch.save(validation_output, validation_output_path)
    print("Saved outputs from validation set to {}".format(validation_output_path))
else:
    print("Validation set outputs already saved to {}".format(validation_output_path))

# Save congruent/incongruent paths
congruency_paths = partition_paths_by_congruency(
    explainer_predictions=explainer_predictions,
    model_predictions=predictions['val'],
    paths=paths['val']
)
congruent_paths = congruency_paths['congruent']
incongruent_paths = congruency_paths['incongruent']
congruent_paths_path = os.path.join(cav_explainer_dir, 'congruent_paths.txt')
incongruent_paths_path = os.path.join(cav_explainer_dir, 'incongruent_paths.txt')
if not os.path.exists(congruent_paths_path) or not os.path.exists(incongruent_paths_path):
    write_lists(congruent_paths, congruent_paths_path)
    write_lists(incongruent_paths, incongruent_paths_path)
    print("Saved {} congruent paths to {} and {} incongruent paths to {}".format(
        len(congruent_paths),
        congruent_paths_path,
        len(incongruent_paths),
        incongruent_paths_path
    ))
else:
    print("Congruent paths already saved to {} and incongruent paths already saved to {}".format(
        congruent_paths_path, incongruent_paths_path
    ))

Saved explainer trained on CAVs to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/weighted_cav_explainer/saga_explainer_l1_0.1.pickle
0.6960828455650608
Saved outputs from validation set to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/weighted_cav_explainer/saga_explainer_l1_0.1_validation.pth


4442it [00:00, 781276.40it/s]

Saved 3092 congruent paths to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/weighted_cav_explainer/congruent_paths.txt and 1350 incongruent paths to saved/PlacesCategoryClassification/0510_102912/ADE20K_predictions/saga/weighted_cav_explainer/incongruent_paths.txt



