In [None]:
import json
import torch

from matplotlib import pyplot as plt
from torcheval.metrics.functional import binary_auprc

import utils
import metrics
import data_utils

## Test parameters

In [None]:
device = "cuda"
batch_size = 256
activations_dir = 'saved_activations'
epsilon = 0.001
n_samples = 1 #how many different samples for c_t^+/c_t^- we take

setting = 3 #which setting to run, following Appendix F.1

In [None]:
if setting == 1:
    dataset_name = "imagenet_val"
    target_name = "vit_b_16_imagenet"
    target_layer = "heads"
    activation_fn = "softmax"
    superclass_concepts = True
    superclass_neurons = True
    final_layer = True
    test_alpha = "best"

elif setting == 2:
    dataset_name = "imagenet_val"
    target_name = "resnet50_imagenet"
    target_layer = "layer4"
    activation_fn = None
    superclass_concepts = True
    superclass_neurons = False
    final_layer = False
    test_alpha = 0.005

elif setting == 3:
    dataset_name = "places365_val"
    target_name = "resnet18_places365"
    target_layer = "fc"
    activation_fn = "softmax"
    superclass_concepts = False
    superclass_neurons = False
    final_layer = True
    test_alpha = "best"

elif setting == 4:
    dataset_name = "places365_val"
    target_name = "resnet18_places365"
    target_layer = "layer4"
    activation_fn = None
    superclass_concepts = False
    superclass_neurons = False
    final_layer = False
    test_alpha = 0.005

elif setting == 5:
    dataset_name = "cub_test"
    target_name = "cub_cbm"
    activation_fn = "sigmoid"
    superclass_concepts = False
    superclass_neurons = False
    final_layer = True
    test_alpha = "best"

elif setting == 6:
    dataset_name = "cub_test"
    target_name = "cub_linear_probe"
    activation_fn = "sigmoid"
    superclass_concepts = False
    superclass_neurons = False
    final_layer = True
    test_alpha = "best"

elif setting == 7:
    dataset_name = "openwebtext_subset"
    target_name = "gpt2-small"
    activation_fn = None #running softmax inside the model since only including subset of output toks
    superclass_concepts = False
    superclass_neurons = False
    final_layer = True
    test_alpha = "best"
    batch_size = 16 #overriding since llm uses more memory

elif setting == 8:
    dataset_name = "openwebtext_subset"
    target_name = "gpt2-xl"
    activation_fn = None #running softmax inside the model since only including subset of output toks
    superclass_concepts = False
    superclass_neurons = False
    final_layer = True
    test_alpha = "best"
    batch_size = 16 #overriding since llm uses more memory

In [None]:
model, preprocess = data_utils.get_target_model(target_name, device=device)
dataset = data_utils.get_data(dataset_name, preprocess)
if dataset_name != "openwebtext_subset":
    pil_data = data_utils.get_data(dataset_name)

In [None]:
if dataset_name == "cub_test":
    concept_activations, text = utils.get_cub_concept_labels(dataset, device, batch_size)
    neuron_activations = utils.get_cub_concept_preds(model, dataset, device, batch_size)
elif dataset_name == "openwebtext_subset":
    concept_activations, neuron_activations = utils.get_llm_ct_ak(model, dataset, device, batch_size)
else:
    concept_activations, text = utils.get_onehot_labels(dataset_name, device, superclass_concepts)
    layer_save_path = '{}/{}_{}/{}/'.format(activations_dir, target_name, dataset_name, target_layer)
    neuron_activations = utils.save_summary_activations(model, dataset, device, target_layer, batch_size, layer_save_path)


if activation_fn == "softmax":
    neuron_activations = torch.nn.functional.softmax(neuron_activations, dim=1)
elif activation_fn == "sigmoid":
    neuron_activations = torch.nn.functional.sigmoid(neuron_activations)
torch.cuda.empty_cache()

print(concept_activations.shape, neuron_activations.shape)

In [None]:
#check prediction accuracy to make sure loading works, only works for some settings
#print(torch.mean((neuron_activations > 0.5).float() == concept_activations, dtype=float))

In [None]:
if dataset_name != "openwebtext_subset":
    #check to see inputs and concept labels loaded correctly
    img_id = 2500
    plt.imshow(pil_data[img_id][0])
    vals, ids = torch.sort(concept_activations[img_id], descending=True)
    print("Top concepts:")
    for id in ids[:5]:
        print(text[id], concept_activations[img_id, id].cpu())

In [None]:
#creating superclass neurons, only for imagenet final layer
if superclass_neurons:
    assert(final_layer==True)
    assert(dataset_name=="imagenet_val")
    with open('data/imagenet_superclass_to_ids.json', 'r') as f:
        superclass_to_id = json.load(f)
    
    new_activations = []
    for sclass in superclass_to_id.keys():
        subclasses = superclass_to_id[sclass]
        new_activations.append(torch.sum(torch.stack([neuron_activations[:, i] for i in subclasses], dim=0), dim=0))
    new_activations = torch.stack(new_activations, dim=1)
    print(neuron_activations.shape, new_activations.shape)
    neuron_activations = torch.cat([neuron_activations, new_activations], dim=1)

if final_layer:
    correct = torch.arange(neuron_activations.shape[1])
else:
    #explanation is the concept that maximizes IoU
    similarities = metrics.iou(neuron_activations, concept_activations, alpha=test_alpha)
    correct = torch.argmax(similarities, dim=1)

In [None]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

### Split neurons into 5% validation and 95% test set

In [None]:
num_classes = concept_activations.shape[1]
neurons = torch.randperm(neuron_activations.shape[1])
val_neurons = neurons[:int(0.05*len(neurons))].sort()[0]
val_correct = correct[val_neurons]
val_correct = torch.nn.functional.one_hot(val_correct, num_classes=num_classes).to(device)

test_neurons = neurons[int(0.05*len(neurons)):].sort()[0]
test_correct = correct[test_neurons]
test_correct = torch.nn.functional.one_hot(test_correct, num_classes=num_classes).to(device)

In [None]:
concepts_missing = []
concepts_extra = []
for _ in range(n_samples):
    mask = torch.rand(concept_activations.shape, device=device) > 0.5
    concepts_missing.append((concept_activations*mask).cpu())

    cutoff = torch.sum(concept_activations, dim=0, keepdims=True)/(concept_activations.shape[0]-torch.sum(concept_activations, dim=0, keepdims=True))
    extra = torch.rand(concept_activations.shape, device=device) < cutoff
    concepts_extra.append(torch.clamp(concept_activations + extra, max=1).cpu())

## Testing different methods

In [None]:
def print_results(missing_diff, extra_diff):
    missing_reduced = torch.mean((missing_diff<-epsilon).float())*100
    print("Missing Labels Test: Avg Score Diff:{:.4f}, Decrease Acc: {:.2f}%".format(missing_diff.mean(), missing_reduced))
    extra_reduced = torch.mean((extra_diff<-epsilon).float())*100
    print("Extra Labels Test: Avg Score Diff:{:.4f}, Decrease Acc: {:.2f}%".format(extra_diff.mean(), extra_reduced))

def run_test(explanation_fn, min_val=None, max_val=None):
    similarities = explanation_fn(neuron_activations[:, test_neurons], concept_activations)
    auc = binary_auprc(similarities.flatten(), test_correct.flatten())
    print("Test AUPRC: {:.7f}".format(auc))

    correct_sims = torch.sum(similarities*test_correct, dim=1)

    missing_c_sims = []
    corr_miss_c_sims = []
    for c_missing in concepts_missing:
        missing_c_sim = explanation_fn(neuron_activations[:, test_neurons], c_missing.to(device))
        missing_c_sims.append(missing_c_sim)
        corr_miss_c_sims.append(torch.sum(missing_c_sim*test_correct, dim=1))

    extra_c_sims = []
    corr_extra_c_sims = []
    for c_extra in concepts_extra:
        extra_c_sim = explanation_fn(neuron_activations[:, test_neurons], c_extra.to(device))
        extra_c_sims.append(extra_c_sim)
        corr_extra_c_sims.append(torch.sum(extra_c_sim*test_correct, dim=1))
    
    if min_val==None:
        min_val = torch.min(torch.cat([similarities]+missing_c_sims+extra_c_sims, dim=0))
    if max_val==None:
        max_val = torch.max(torch.cat([similarities]+missing_c_sims+extra_c_sims, dim=0))

    print("Original avg:{:.4f}".format(torch.mean(correct_sims)))
    #average across samples
    corr_miss_c_sims = torch.mean(torch.stack(corr_miss_c_sims, dim=0), dim=0)
    corr_extra_c_sims = torch.mean(torch.stack(corr_extra_c_sims, dim=0), dim=0)

    missing_diff = (corr_miss_c_sims-correct_sims)/(max_val-min_val)
    extra_diff = (corr_extra_c_sims-correct_sims)/(max_val-min_val)
    print_results(missing_diff, extra_diff)
    
def fast_sims(explanation_fn, concept_acts):
    correct_ids = torch.argmax(test_correct, dim=1)
    correct_sims = []
    for i in range(len(test_neurons)):
        sims = explanation_fn(neuron_activations[:, test_neurons[i]:test_neurons[i]+1],
                                            concept_acts[:, correct_ids[i]:correct_ids[i]+1])
        correct_sims.append(sims[0,0])
    correct_sims = torch.stack(correct_sims, dim=0)
    return correct_sims

def run_test_fast(explanation_fn, min_val, max_val):
    correct_sims = fast_sims(explanation_fn, concept_activations)
    
    corr_miss_c_sims = []
    for c_missing in concepts_missing:
        corr_miss_c_sims.append(fast_sims(explanation_fn, c_missing.to(device)))
    corr_miss_c_sims = torch.mean(torch.stack(corr_miss_c_sims, dim=0), dim=0)

    corr_extra_c_sims = []
    for c_extra in concepts_extra:
        corr_extra_c_sims.append(fast_sims(explanation_fn, c_extra.to(device)))
    corr_extra_c_sims = torch.mean(torch.stack(corr_extra_c_sims, dim=0), dim=0)

    missing_diff = (corr_miss_c_sims-correct_sims)/(max_val-min_val)
    extra_diff = (corr_extra_c_sims-correct_sims)/(max_val-min_val)
    print_results(missing_diff, extra_diff)

def find_best_alpha(explanation_fn, min_val=None, max_val=None, use_fast=False, test_alpha=test_alpha):
    """
    for fns with only an alpha parameter
    """
    
    if test_alpha == "best":
        best_auc = -1
        best_alpha = 0
        for alpha in [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]:
            similarities = explanation_fn(neuron_activations[:, val_neurons], concept_activations, alpha=alpha)
            auc = binary_auprc(similarities.flatten(), val_correct.flatten())
            if auc > best_auc:
                best_auc = auc
                best_alpha = alpha
                #print("Alpha: {}, Val AUC: {:.7f}".format(alpha, auc))
        print("Best Alpha: {}".format(best_alpha))
    else:
        best_alpha = test_alpha
        print("Using Alpha = {}".format(best_alpha))

    if use_fast:
        run_test_fast(explanation_fn=lambda x, y: explanation_fn(x, y, alpha=best_alpha),
                        min_val=min_val, max_val=max_val)
    else:
        run_test(explanation_fn=lambda x, y: explanation_fn(x, y, alpha=best_alpha),
                        min_val=min_val, max_val=max_val)

def find_best_alpha_lam(explanation_fn, min_val=None, max_val=None, test_alpha=test_alpha):
    best_auc = -1
    if test_alpha == "best":
        best_alpha = 0
        best_lam = 0
        for alpha in [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]:
            for lam in [0.01*i for i in range(101)]:
                similarities = explanation_fn(neuron_activations[:, val_neurons], concept_activations, alpha=alpha, lam=lam)
                auc = binary_auprc(similarities.flatten(), val_correct.flatten())
                if auc > best_auc:
                    best_auc = auc
                    best_alpha = alpha
                    best_lam = lam
                    #print("alpha={}, lam={}, Val AUC: {:.7f}".format(alpha, lam, auc))
        print("Best alpha={}, Best lam={}".format(best_alpha, best_lam))
    else:
        best_lam = 0
        best_alpha = test_alpha
        for lam in [0.01*i for i in range(101)]:
            similarities = explanation_fn(neuron_activations[:, val_neurons], concept_activations, alpha=test_alpha, lam=lam)
            auc = binary_auprc(similarities.flatten(), val_correct.flatten())
            if auc > best_auc:
                best_auc = auc
                best_lam = lam
        print("Best alpha={}, Best lam={}".format(best_alpha, best_lam))
    
    run_test(explanation_fn=lambda x, y: explanation_fn(x, y, alpha=best_alpha, lam=best_lam),
                       min_val=min_val, max_val=max_val)


### Recall

In [None]:
find_best_alpha(metrics.recall, min_val=0, max_val=1)

### Precision

In [None]:
find_best_alpha(metrics.precision, min_val=0, max_val=1)

### F1-score

In [None]:
find_best_alpha(metrics.f1_score, min_val=0, max_val=1)

### IoU

In [None]:
find_best_alpha(metrics.iou, min_val=0, max_val=1)

### Accuracy

In [None]:
find_best_alpha(metrics.accuracy, min_val=0, max_val=1)

### Balanced Accuracy

In [None]:
find_best_alpha(metrics.balanced_accuracy, min_val=0, max_val=1)

### Inverse Balanced Accuracy

In [None]:
find_best_alpha(metrics.inverse_balanced_accuracy, min_val=0, max_val=1)

### AUC

In [None]:
find_best_alpha(metrics.auc, min_val=0, max_val=1, use_fast=True)

### Inverse AUC (Classification)

In [None]:
run_test_fast(metrics.inverse_auc, min_val=0, max_val=1)

### Correlation

In [None]:
run_test(metrics.correlation, min_val=-1, max_val=1)

### Correlation top-and-random

In [None]:
run_test(metrics.correlation_top_and_random, min_val=-1, max_val=1)

### Spearman Correlation

In [None]:
run_test(metrics.spearman_correlation, min_val=-1, max_val=1)

### Spearman Correlation top-and-random

In [None]:
run_test(metrics.spearman_correlation_top_and_random, min_val=-1, max_val=1)

### Cosine

In [None]:
run_test(metrics.cos_sim, min_val=-1, max_val=1)

### WPMI

In [None]:
find_best_alpha_lam(metrics.wpmi)

### MAD (Mean Activation Difference)

In [None]:
run_test(metrics.mad)

### AUPRC

In [None]:
#slower than others
find_best_alpha(metrics.auprc, min_val=0, max_val=1, use_fast=True)

### Inverse AUPRC

In [None]:
run_test_fast(metrics.inverse_auprc, min_val=0, max_val=1)

## Appendix: Combination Metrics

In [None]:
find_best_alpha(metrics.combined_auc, min_val=0, max_val=1, use_fast=True)

In [None]:
find_best_alpha(metrics.combined_balanced_acc, min_val=0, max_val=1, use_fast=True)

In [None]:
find_best_alpha(metrics.recall_auc, min_val=0, max_val=1, use_fast=True)

In [None]:
find_best_alpha(metrics.recall_inv_auc, min_val=0, max_val=1, use_fast=True)

In [None]:
find_best_alpha(metrics.precision_bal_acc, min_val=0, max_val=1, use_fast=True)

In [None]:
find_best_alpha(metrics.precision_inverse_bal_acc, min_val=0, max_val=1, use_fast=True)