Research question 1: Do concept-based explanations produce more faithful explanations than feature attribution methods?

In [2]:
pip install -r requirements.txt




[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


Step 0:Setup

In [53]:
import numpy as np
import pickle
import hashlib
import pandas as pd

np.random.seed(42)

base_path = "/Users/karlgustav/Documents/GitHub/study/master-thesis/server/src/research/"
# base_path = "/Users/karl-gustav.kallasmaa/Documents/Projects/master-thesis/server/src/"
all_labels_path = f"{base_path}all_classes.txt"
masks_path = f"{base_path}data/masks.pkl"
img_path = f"{base_path}data/resized_imgs.pkl"
labels_path = f"{base_path}data/classes.pkl"
ade_path = f"{base_path}data/objectInfo150.csv"

labels = []
images = []
masks = []
unique_labels = []
with open(masks_path, 'rb') as f:
    masks = pickle.load(f)
with open(img_path, 'rb') as f:
    images = pickle.load(f)
with open(labels_path, 'rb') as f:
    labels = np.array(pickle.load(f))
with open(all_labels_path) as f:
    lines = f.read().splitlines()
    lines = [l.replace(' ', '_') for l in lines]
    unique_labels = np.array(list(set(lines)))


ade_classes = pd.read_csv(ade_path)

image_hex_index_map = {hashlib.sha1(np.array(img).view(np.uint8)).hexdigest(): i for i,img in enumerate(images)}

index_img_map = {i:img for i,img in enumerate(images)}
index_label_map = {i:label for i,label in enumerate(labels)}
index_mask_map = {i:mask for i,mask in enumerate(masks)}
index_ade_map = {i:ade for i,ade in enumerate(ade_classes)}

random_indexes = np.random.choice(list(index_img_map.keys()), int(0.1*len(index_img_map.keys())), replace=False)

random_images = [index_img_map[index] for index in random_indexes]
random_labels = np.array([index_img_map[index] for index in random_indexes])
random_masks = [index_mask_map[index] for index in random_indexes]

In [34]:
"""
def unique_labels(image_array):
    predictions = []
    preprocessed_images = np.array([preprocess_image(i) for i in image_array])
    batch_prediction = blackbox_model.predict(preprocessed_images)
    prediction_label = decode_predictions(batch_prediction, top=1)
    predictions = [p[0][1] for p in prediction_label]
    return np.array(list(set(predictions)))
uniq = unique_labels(images)
"""





In [54]:
from typing import List
from sklearn import preprocessing

def encode_categorical_values(values: List[str]):
    unique_values = sorted(list(set(values)))
    le = preprocessing.LabelEncoder()
    le.fit(unique_values)
    return le

to_be_encoded = np.concatenate((labels, unique_labels), axis=0)
label_encoder = encode_categorical_values(to_be_encoded)

Step 1: Get lime predictions

In [55]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import decode_predictions

blackbox_model = ResNet50(weights='imagenet', include_top=True, input_shape=(224, 224, 3))

def preprocess_image(image):
    image = tf.image.resize(image, (224, 224))
    image = tf.keras.applications.resnet50.preprocess_input(np.array(image))
    return image

def black_box_classify(img_array):
    predictions = []
    for i in img_array:
        img = preprocess_image(i)
        prediction = blackbox_model.predict(img[np.newaxis, ...])
        prediction_label = decode_predictions(prediction, top = 1)
        prediction_label = prediction_label[0][0][1]
        label_as_nr = label_encoder.transform([prediction_label])
        predictions.append(label_as_nr)
    return predictions

In [59]:
from lime import lime_image

def explain_with_lime(images, num_samples=50, num_features=10, hide_color=None):
    explainer = lime_image.LimeImageExplainer()
    explanations = []
    for image in images:
        explanation = explainer.explain_instance(np.array(image),
                                                 classifier_fn=black_box_classify,
                                                 top_labels=1,
                                                 hide_color=hide_color,
                                                 num_samples=num_samples,
                                                 num_features=num_features)
        #print(explanation)
        most_probable_label = explanation.top_labels[0]
        explanations.append(most_probable_label)
    return explanations

lime_predictions = explain_with_lime(random_images[0:2])

  0%|                                                                                                                                                                               | 0/50 [00:00<?, ?it/s]



 38%|███████████████████████████████████████████████████████████████                                                                                                       | 19/50 [00:00<00:01, 25.08it/s]



 50%|███████████████████████████████████████████████████████████████████████████████████                                                                                   | 25/50 [00:01<00:01, 14.75it/s]



 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                    | 39/50 [00:02<00:00, 17.59it/s]



 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                    | 44/50 [00:03<00:00, 12.83it/s]



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 13.02it/s]
  0%|                                                                                                                                                                               | 0/50 [00:00<?, ?it/s]



 20%|█████████████████████████████████▏                                                                                                                                    | 10/50 [00:00<00:02, 13.87it/s]



 40%|██████████████████████████████████████████████████████████████████▍                                                                                                   | 20/50 [00:01<00:02, 13.86it/s]



 60%|███████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                  | 30/50 [00:02<00:01, 13.93it/s]



 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                 | 40/50 [00:02<00:00, 13.88it/s]



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 13.91it/s]

[0, 0]





Step 2: Get concept-based desision tree explanations

In [None]:
from operator import itemgetter
from typing import Dict, List

def get_segments(img, mask, threshold=0.05):
    segs = np.unique(mask)
    segments = []
    total = mask.shape[0] * mask.shape[1]
    segments_classes = []
    for seg in segs:
        idxs = mask == seg
        sz = np.sum(idxs)
        if sz < threshold * total:
            continue
        segment = img * idxs[..., None]
        w, h, _ = np.nonzero(segment)
        segment = segment[np.min(w):np.max(w), np.min(h):np.max(h), :]
        segments.append(segment)
        segments_classes.append(ade_classes['Name'].loc[ade_classes['Idx'] == seg].iloc[0])
    return segments, segments_classes




def sort_dictionary(source: Dict[any, any], by_value=True, reverse=True) -> List[any]:
    if by_value:
        return sorted(source.items(), key=itemgetter(1), reverse=reverse)
    return sorted(source.items(), key=itemgetter(0), reverse=reverse)

In [None]:
from typing import Dict, List
from mpire import WorkerPool
from functools import reduce

class MostPopularConcepts:
    BATCH_SIZE = 10
    MAX_WORKER_COUNT = 8

    def __init__(self,l_labels,i_images,m_maks):
        all_labels = np.array(labels)
        chunk_size = max(1, int(all_labels.size / self.BATCH_SIZE))
        self.labels_in_chunks = np.array_split(all_labels, chunk_size)
        self.nr_of_jobs = min(self.MAX_WORKER_COUNT, len(self.labels_in_chunks))

        self.label_images_map = {}
        self.label_masks_map = {}

        self.image_most_popular_concepts = self.static_most_popular_concepts(l_labels,i_images,m_maks)

    def static_most_popular_concepts(self,l_labels,i_images,m_maks) -> Dict[str, List[any]]:
        for label, image, mask in zip(l_labels,i_images,m_maks):
            current_images = self.label_images_map.get(label, [])
            current_maks = self.label_masks_map.get(label, [])

            current_images.append(image)
            current_maks.append(mask)

            self.label_images_map[label] = current_images
            self.label_masks_map[label] = current_maks

        with WorkerPool(n_jobs=self.nr_of_jobs) as pool:
            return reduce(lambda a, b: {**a, **b},
                          pool.map(self.__extract_most_popular_concepts, self.labels_in_chunks))

    def __extract_most_popular_concepts(self, l_labels: List[str]) -> Dict[str, List[any]]:
        partial_results = {}
        for label in  l_labels:
            i_images = self.label_images_map[label]
            m_masks = self.label_masks_map[label]
            nr_of_images = len(i_images)
            partial_results[label] = self.most_popular_concepts(images,m_masks, nr_of_images)
        return partial_results

    @staticmethod
    def most_popular_concepts(i_images, m_masks, k) -> List[str]:
        segment_count = {}
        for pic, mask in zip(i_images, m_masks):
            _, seg_class = get_segments(np.array(pic), mask, threshold=0.005)
            for s in seg_class:
                segment_count[s] = segment_count.get(s, 0) + 1
        segment_count = sort_dictionary(segment_count)
        return [s for s, _ in segment_count[:k]]


MOST_POPULAR_CONCEPTS = MostPopularConcepts(random_labels,random_images,random_masks).image_most_popular_concepts

In [61]:
from typing import List,Tuple
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier


def get_segment_relative_size(segment: np.array, picture: np.array) -> float:
    segment_area = float(segment.shape[0] * segment.shape[1])
    picture_area = float(picture.shape[0] * picture.shape[1])
    return round(segment_area / picture_area, 2)



def get_training_row(user_selected_concepts: List[str], pic, mask) -> np.array:
    row = np.zeros(len(user_selected_concepts))
    pic_as_array = np.array(pic)
    segss, seg_class = get_segments(pic_as_array, mask, threshold=0.005)
    for index, el in enumerate(user_selected_concepts):
        if el in seg_class:
            segment = segss[seg_class.index(el)]
            row[index] = get_segment_relative_size(segment, pic_as_array)
    return row

def train_and_test_decision_tree(x, y) -> Tuple[DecisionTreeClassifier, float]:
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)
    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    return clf, clf.score(X_test, y_test)

In [None]:
def train_concept_explainer(all_labels,all_images,all_masks):
    X, y = [], []
    for label, pic, mask in zip(all_labels,all_images, all_masks):
        most_popular_concepts_for_label = MOST_POPULAR_CONCEPTS[label]
        row = get_training_row(most_popular_concepts_for_label, pic, mask)
        label_as_nr = label_encoder.transform([label])
        X.append(row)
        y.append(label_as_nr)
    clf, accuracy = train_and_test_decision_tree(np.array(X), np.array(y))
    return clf

In [None]:
def explain_with_concepts(images,model):
    predictions = []
    for img in images:
        img_key = hashlib.sha1(np.array(img).view(np.uint8)).hexdigest()
        image_index = image_hex_index_map[img_key]
        image_label = index_label_map[image_index]

        most_popular_concepts_for_label = MOST_POPULAR_CONCEPTS[image_label]
        mask = index_mask_map[image_index]
        
        row = get_training_row(most_popular_concepts_for_label, img, mask)
        prediction_as_nr = model.predict([row])
        prediction_as_label = label_encoder.inverse_transform(prediction_as_nr)
        predictions.append(prediction_as_label)
    return predictions

concept_model = train_concept_explainer(labels,images,masks)
concept_predictions = explain_with_concepts(random_images,concept_model)

Step 3: calculate fidelity

In [None]:
def black_box_models_predictions(images):
    predictions = []
    for img in images:
        img_key = hashlib.sha1(np.array(img).view(np.uint8)).hexdigest()
        image_index = image_hex_index_map[img_key]
        image_label = index_label_map[image_index]

        predictions.append(image_label)
    return predictions

def fidelity(pred1,pred2):
    same = 0
    not_same = 0
    for p1 in pred1:
        for p2 in pred2:
            if p1 == p2:
                same += 1
            else:
                not_same += 1
    return same / not_same


black_box_pred = black_box_models_predictions(random_images)

lime_fidelity = fidelity(pred1=lime_predictions,pred2=black_box_pred)
concept_fidelity = fidelity(pred1=concept_predictions,pred2=black_box_pred)


print("LIME fidelity "+lime_fidelity)
print("Concept fidelity"+concept_fidelity)
if lime_fidelity > concept_fidelity:
    diff = lime_fidelity - concept_fidelity
    print("LIME fidelity is greater than concept fidelity by "+diff)
else:
    diff = concept_fidelity -lime_fidelity
    print("Concept fidelity is creater than LIME fidelity by "+diff)