In [10]:
import torch
import os
import random
import utils
import data_utils
import json

import cbm
import plots

In [11]:
# change this to the correct model dir, everything else should be taken care of
load_dir = "saved_models/cub_lf_cbm"
device = "cuda"

with open(os.path.join(load_dir, "args.txt"), "r") as f:
    args = json.load(f)
dataset = args["dataset"]
_, target_preprocess = data_utils.get_target_model(args["backbone"], device)
model = cbm.load_cbm(load_dir, device)

Downloading /home/gridsan/vyuan/.torch/models/resnet18_cub-2333-200d8b9c.pth.zip from https://github.com/osmr/imgclsmob/releases/download/v0.0.344/resnet18_cub-2333-200d8b9c.pth.zip...
download failed, retrying, 4 attempts left
Downloading /home/gridsan/vyuan/.torch/models/resnet18_cub-2333-200d8b9c.pth.zip from https://github.com/osmr/imgclsmob/releases/download/v0.0.344/resnet18_cub-2333-200d8b9c.pth.zip...
download failed, retrying, 3 attempts left
Downloading /home/gridsan/vyuan/.torch/models/resnet18_cub-2333-200d8b9c.pth.zip from https://github.com/osmr/imgclsmob/releases/download/v0.0.344/resnet18_cub-2333-200d8b9c.pth.zip...
download failed, retrying, 2 attempts left
Downloading /home/gridsan/vyuan/.torch/models/resnet18_cub-2333-200d8b9c.pth.zip from https://github.com/osmr/imgclsmob/releases/download/v0.0.344/resnet18_cub-2333-200d8b9c.pth.zip...
download failed, retrying, 1 attempt left
Downloading /home/gridsan/vyuan/.torch/models/resnet18_cub-2333-200d8b9c.pth.zip from htt

ConnectionError: HTTPSConnectionPool(host='github.com', port=443): Max retries exceeded with url: /osmr/imgclsmob/releases/download/v0.0.344/resnet18_cub-2333-200d8b9c.pth.zip (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f3b70271ee0>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

In [None]:
val_d_probe = dataset+"_val"
cls_file = data_utils.LABEL_FILES[dataset]

val_data_t = data_utils.get_data(val_d_probe, preprocess=target_preprocess)
val_pil_data = data_utils.get_data(val_d_probe)

In [None]:
with open(cls_file, "r") as f:
    classes = f.read().split("\n")

with open(os.path.join(load_dir, "concepts.txt"), "r") as f:
    concepts = f.read().split("\n")

## Measure accuracy

In [None]:
accuracy = utils.get_accuracy_cbm(model, val_data_t, device)
print("Accuracy: {:.2f}%".format(accuracy*100))

## Show final layer weights for some classes

You can build a Sankey diagram of weights by copying the incoming weights printed below into https://sankeymatic.com/build/

In [None]:
to_show = random.choices([i for i in range(len(classes))], k=1)

for i in to_show:
    print("Output class:{} - {}".format(i, classes[i]))
    print("Incoming weights:")
    for j in range(len(concepts)):
        if torch.abs(model.final.weight[i,j])>0.05:
            print("{} [{:.4f}] {}".format(concepts[j], model.final.weight[i,j], classes[i]))

In [None]:
to_show = random.choices([i for i in range(len(classes))], k=2)

top_weights, top_weight_ids = torch.topk(model.final.weight, k=5, dim=1)
bottom_weights, bottom_weight_ids = torch.topk(model.final.weight, k=5, dim=1, largest=False)

for i in to_show:
    print("Class {} - {}".format(i, classes[i]))
    out = "Highest weights: "
    for j in range(top_weights.shape[1]):
        idx = int(top_weight_ids[i, j].cpu())
        out += "{}:{:.3f}, ".format(concepts[idx], top_weights[i, j])
    print(out)
    out = "Lowest weights: "
    for j in range(bottom_weights.shape[1]):
        idx = int(bottom_weight_ids[i, j].cpu())
        out += "{}:{:.3f}, ".format(concepts[idx], bottom_weights[i, j])
    print(out + "\n")

In [None]:
# Some features may not have any non-zero outgoing weights, 
# i.e. these are not used by the model and should be deleted for better performance
weight_contribs = torch.sum(torch.abs(model.final.weight), dim=0)
print("Num concepts with outgoing weights:{}/{}".format(torch.sum(weight_contribs>1e-5), len(weight_contribs)))

## Explain model reasoning for random inputs

In [None]:
to_display = random.sample([i for i in range(len(val_pil_data))], k=4)

with torch.no_grad():
    for i in to_display:
        image, label = val_pil_data[i]
        x, _ = val_data_t[i]
        x = x.unsqueeze(0).to(device)
        display(image.resize([320,320]))
        
        outputs, concept_act = model(x)
        
        top_logit_vals, top_classes = torch.topk(outputs[0], dim=0, k=2)
        conf = torch.nn.functional.softmax(outputs[0], dim=0)
        print("Image:{} Gt:{}, 1st Pred:{}, {:.3f}, 2nd Pred:{}, {:.3f}".format(i, classes[int(label)], classes[top_classes[0]], top_logit_vals[0],
                                                                      classes[top_classes[1]], top_logit_vals[1]))
        
        for k in range(1):
            contributions = concept_act[0]*model.final.weight[top_classes[k], :]
            feature_names = [("NOT " if concept_act[0][i] < 0 else "") + concepts[i] for i in range(len(concepts))]
            values = contributions.cpu().numpy()
            max_display = min(int(sum(abs(values)>0.005))+1, 8)
            title = "Pred:{} - Conf: {:.3f} - Logit:{:.2f} - Bias:{:.2f}".format(classes[top_classes[k]],
                             conf[top_classes[k]], top_logit_vals[k], model.final.bias[top_classes[k]])
            plots.bar(values, feature_names, max_display=max_display, title=title, fontsize=16)