Research question 3: How meaningful are the extracted concepts?


In [1]:
pip install -r -q requirements.txt 


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing.image import img_to_array
from skimage.transform import resize

blackbox_model = ResNet50(weights='imagenet', include_top=True, input_shape=(224, 224, 3))

def preprocess_images(img_array):
    img_array = np.array([tf.image.resize(img_to_array(img), (224, 224)) for img in img_array])
    return preprocess_input(img_array)

def black_box_classify(img_array,convert_to_nr=True):
    preprocessed_imgs = preprocess_images(img_array)
    predictions = blackbox_model.predict(preprocessed_imgs)
    prediction_labels = decode_predictions(predictions, top = 1)
    labels_as_str = [row[0][1] for row in prediction_labels]
    if convert_to_nr:
        label_as_nr = label_encoder.transform(labels_as_str)
        return [[l]for l in label_as_nr]
    return [[l]for l in labels_as_str]

def black_box_lime(temp):
    resized_temp = resize(temp, (224, 224), mode='reflect', preserve_range=True).astype(np.uint8)
    resized_temp = np.expand_dims(resized_temp, axis=0)
    predictions = blackbox_model.predict(resized_temp)
    prediction_labels = decode_predictions(predictions, top = 1)
    labels_as_str = [row[0][1] for row in prediction_labels]
    label_as_nr = label_encoder.transform(labels_as_str)
    return [[l]for l in label_as_nr]

In [None]:
import numpy as np
import pickle
import hashlib
import pandas as pd

np.random.seed(42)

base_path = "/Users/karlgustav/Documents/GitHub/study/master-thesis/server/src/research/"
# base_path = "/Users/karl-gustav.kallasmaa/Documents/Projects/master-thesis/server/src/"
all_labels_path = f"{base_path}all_classes.txt"
all_concepts_path = f"{base_path}all_concepts.txt"
masks_path = f"{base_path}data/masks.pkl"
img_path = f"{base_path}data/resized_imgs.pkl"
labels_path = f"{base_path}data/classes.pkl"
ade_path = f"{base_path}data/objectInfo150.csv"

ade_classes = pd.read_csv(ade_path)

images = []
masks = []
unique_labels = []
with open(masks_path, 'rb') as f:
    masks = pickle.load(f)
with open(img_path, 'rb') as f:
    images = pickle.load(f)
with open(all_labels_path) as f:
    lines = f.read().splitlines()
    lines = [l.replace(' ', '_') for l in lines]
    unique_labels = np.array(list(set(lines)))

labels = black_box_classify(images,False)
labels = [l[0] for l in labels]

all_concept_values = ade_classes['Name'].tolist()
UNIQUE_CONCEPT_VALUES = sorted(list(set(all_concept_values)))
NR_OF_UNIQUE_CONCEPTS = len(UNIQUE_CONCEPT_VALUES)


image_hex_index_map = {hashlib.sha1(np.array(img).view(np.uint8)).hexdigest(): i for i,img in enumerate(images)}

index_img_map = {i:img for i,img in enumerate(images)}
index_label_map = {i:label for i,label in enumerate(labels)}
index_mask_map = {i:mask for i,mask in enumerate(masks)}
index_ade_map = {i:ade for i,ade in enumerate(ade_classes)}

test_image_count = 10

random_indexes = np.random.choice(list(index_img_map.keys()), test_image_count, replace=False)

random_images = [index_img_map[index] for index in random_indexes]
random_labels = np.array([index_label_map[index] for index in random_indexes])
random_masks = [index_mask_map[index] for index in random_indexes]

print("Total number of images "+str(len(images)))
print("Number of images used "+str(len(random_images)))

 8/50 [===>..........................] - ETA: 23s

Propose concepts

In [None]:
from operator import itemgetter
from typing import Dict, List

def get_segments(img, mask, threshold=0.05):
    segs = np.unique(mask)
    segments = []
    total = mask.shape[0] * mask.shape[1]
    segments_classes = []
    
    for seg in segs:
        idxs = mask == seg
        sz = np.sum(idxs)
        
        if sz < threshold * total:
            continue
        
        coords = np.argwhere(idxs)
        x_min, y_min = coords.min(axis=0)
        x_max, y_max = coords.max(axis=0)
        
        segment_img = img[x_min:x_max+1, y_min:y_max+1, :]
        
        segments.append(segment_img)
        segments_classes.append(ade_classes['Name'].loc[ade_classes['Idx'] == seg].iloc[0])
    
    return segments, segments_classes

def sort_dictionary(source: Dict[any, any], by_value=True, reverse=True) -> List[any]:
    if by_value:
        return sorted(source.items(), key=itemgetter(1), reverse=reverse)
    return sorted(source.items(), key=itemgetter(0), reverse=reverse)

In [None]:
from typing import Dict, List
from mpire import WorkerPool
from functools import reduce

class MostPopularConcepts:
    BATCH_SIZE = 10
    MAX_WORKER_COUNT = 8

    def __init__(self,l_labels,i_images,m_maks):
        all_labels = np.array(l_labels)
        chunk_size = max(1, int(all_labels.size / self.BATCH_SIZE))
        self.labels_in_chunks = np.array_split(all_labels, chunk_size)
        self.nr_of_jobs = min(self.MAX_WORKER_COUNT, len(self.labels_in_chunks))

        self.label_images_map = {}
        self.label_masks_map = {}

        self.image_most_popular_concepts = self.static_most_popular_concepts(l_labels,i_images,m_maks)

    def static_most_popular_concepts(self,l_labels,i_images,m_maks) -> Dict[str, List[any]]:
        for label, image, mask in zip(l_labels,i_images,m_maks):
            current_images = self.label_images_map.get(label, [])
            current_maks = self.label_masks_map.get(label, [])

            current_images.append(image)
            current_maks.append(mask)

            self.label_images_map[label] = current_images
            self.label_masks_map[label] = current_maks

        with WorkerPool(n_jobs=self.nr_of_jobs) as pool:
            return reduce(lambda a, b: {**a, **b},
                          pool.map(self.__extract_most_popular_concepts, self.labels_in_chunks))

    def __extract_most_popular_concepts(self, l_labels: List[str]) -> Dict[str, List[any]]:
        partial_results = {}
        for label in  l_labels:
            i_images = self.label_images_map[label]
            m_masks = self.label_masks_map[label]
            nr_of_images = len(i_images)
            partial_results[label] = self.most_popular_concepts(images,m_masks, nr_of_images)
        return partial_results

    @staticmethod
    def most_popular_concepts(i_images, m_masks, k) -> List[str]:
        segment_count = {}
        for pic, mask in zip(i_images, m_masks):
            _, seg_class = get_segments(np.array(pic), mask, threshold=0.005)
            for s in seg_class:
                segment_count[s] = segment_count.get(s, 0) + 1
        segment_count = sort_dictionary(segment_count)
        if len(segment_count) < k:
            return [s for s, _ in segment_count]
        return [s for s, _ in segment_count[:k]]

In [None]:
mpc_service = MostPopularConcepts(labels,images,masks)
MOST_POPULAR_CONCEPTS = mpc_service.image_most_popular_concepts

Calculate random concepts

In [None]:
import random

def random_concepts(correct_concepts,label,image,mask,random_concept_count=4):
    popular_concepts_in_img = mpc_service.most_popular_concepts([image],[mask],1000)
    popular_concepts_in_img = [c for c in popular_concepts_in_img if c not in correct_concepts]
    if len(popular_concepts_in_img) >= random_concept_count:
        return [c for c in popular_concepts_in_img[:random_concept_count]]
    
    all_labels = list(MOST_POPULAR_CONCEPTS.keys())
    all_labels = [l for l in all_labels if l != label]
    
    while True:
        if len(popular_concepts_in_img) >= random_concept_count:
            return [c for c in popular_concepts_in_img[:random_concept_count]]
        random_label = random.choice(all_labels)
        random_concepts = [c for c in MOST_POPULAR_CONCEPTS[random_label] if c not in correct_concepts]
        random_concepts = [c for c in random_concepts if c not in popular_concepts_in_img]
        popular_concepts_in_img = popular_concepts_in_img + random_concepts

Calculate not random concepts

In [None]:
from typing import List
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier


def get_segment_relative_size(segment: np.array, picture: np.array) -> float:
    segment_area = float(segment.shape[0] * segment.shape[1])
    picture_area = float(picture.shape[0] * picture.shape[1])
    return round(segment_area / picture_area, 2)


def get_training_row(top_concepts_for_label: List[str], pic, mask) -> np.array:
    row = np.zeros(NR_OF_UNIQUE_CONCEPTS)
    pic_as_array = np.array(pic)
    segss, seg_class = get_segments(pic_as_array, mask, threshold=0.005)
    for index,concept in enumerate(UNIQUE_CONCEPT_VALUES):
        if concept in top_concepts_for_label and concept in seg_class:
            segment = segss[seg_class.index(concept)]
            row[index] = get_segment_relative_size(segment, pic_as_array)            
    return row

def train_decision_tree(x, y) -> DecisionTreeClassifier:
    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42)
    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    return clf

def train_concept_explainer(all_labels,all_images,all_masks):
    X, y = [], []
    for label, pic, mask in zip(all_labels,all_images, all_masks):
        most_popular_concepts_for_label = MOST_POPULAR_CONCEPTS[label]
        row = get_training_row(most_popular_concepts_for_label, pic, mask)
        label_as_nr = label_encoder.transform([label])
        X.append(row)
        y.append(label_as_nr[0])
    return train_decision_tree(X,np.array(y))

In [None]:
def encode_categorical_values(values: List[str]):
    unique_values = sorted(list(set(values)))
    le = preprocessing.LabelEncoder()
    le.fit(unique_values)
    return le
def encode_categorical_features():
    with open(all_concepts_path) as f:
        all_concepts = f.read().splitlines()
    return encode_categorical_values(all_concepts)

In [None]:
feature_encoder = encode_categorical_features()
label_encoder = encode_categorical_values(unique_labels)
#estimator = train_concept_explainer(labels,images,masks)

In [None]:
def most_predictive_concepts(target_label):
    random_label = ""
    for l in labels:
        if l != target_label:
            random_label = l
            break
    reformated_labels = []
    for label in labels:
        if label == target_label:
            reformated_labels.append(target_label)
        else:
            reformated_labels.append(random_label)
    
    estimator = train_concept_explainer(reformated_labels,images,masks)
    
    feature_importance = {feature: {"featureName": feature, "local": importance} for feature, importance in
                              zip(feature_encoder.classes_, estimator.feature_importances_)}

    feature_importance = sorted(list(feature_importance.values()), key=lambda x: x["local"], reverse=True)
    return [feature["featureName"] for feature in feature_importance]
    

In [None]:
import itertools
from typing import List, Set

max_predictive_concept_count = 2
max_intuitive_concept_count = 2

def combine_concepts(most_predictive_concepts: List[str],most_intuitive_concepts: List[str],
                    concept_suggestion_limit=4) -> List[str]:
    proposed_concepts = []
    added_predictive_concepts, added_intuitive_concepts = set(), set()
    initially_proposed_concepts = []

    for combination in itertools.zip_longest(most_predictive_concepts, most_intuitive_concepts):
        if len(proposed_concepts) >= concept_suggestion_limit:
            return proposed_concepts[:concept_suggestion_limit]

        predictive_concept, intuitive_concept = combination

        if __should_append_intuitive_concept(predictive_concept, proposed_concepts, added_intuitive_concepts):
                proposed_concepts.append(predictive_concept)
                added_predictive_concepts.add(predictive_concept)
        if __should_append_intuitive_concept(intuitive_concept, proposed_concepts, added_intuitive_concepts):
                proposed_concepts.append(intuitive_concept)
                added_intuitive_concepts.add(intuitive_concept)

    return proposed_concepts[:concept_suggestion_limit] if len(
            proposed_concepts) > concept_suggestion_limit else proposed_concepts
    
def __should_append_intuitive_concept(intuitive_concept: str,proposed_concepts: List[str],added_intuitive_concepts: Set[str]):
    return intuitive_concept is not None and intuitive_concept not in proposed_concepts and len(added_intuitive_concepts) < max_intuitive_concept_count

  
def __should_append_predictive_concept(predictive_concept: str,proposed_concepts: List[str],added_predictive_concepts: Set[str]):
    return predictive_concept is not None and predictive_concept not in proposed_concepts and len(added_predictive_concepts) < TOP_K_PREDICTIVE_CONCEPTS


In [None]:
def experiment_solultions(l_labels,i_images,m_masks,correct_concept_count=4):
    solutions = []
    i = 0
    for label,mask,image in zip(l_labels,i_images,m_masks):
        most_predictive = most_predictive_concepts(label)
        most_popular = MOST_POPULAR_CONCEPTS[label]
        framework_proposed_concepts = combine_concepts(most_predictive,most_popular)  
        
        chosen_random_concepts = random_concepts(framework_proposed_concepts,label,mask,image)
          
        combined = framework_proposed_concepts + chosen_random_concepts
        random.shuffle(combined)
        
        exp = {
            "nr":i+1,
            "label":label,
            "correct":framework_proposed_concepts,
            "random":chosen_random_concepts,
            "combined":combined
        }
        solutions.append(exp)
        i += 1
    return solutions

In [None]:
solutions = experiment_solultions(random_labels,random_images,random_masks)

In [None]:
import uuid
import os
from PIL import Image
import matplotlib.pyplot as plt


def temp_img(index):
    image_pil = random_images[index]
    temp_image_path = str(uuid.uuid4())+".jpg"
    image_pil.save(temp_image_path)
    return temp_image_path

def vizualise_img(path):
    open_image = Image.open(path)
    plt.imshow(open_image)
    plt.axis('off')
    plt.show()
    os.remove(path)
    
def vizualise_explanation(index):
    correct_label = random_labels[index]
    
    temp_image_path = temp_img(index)
    vizualise_img(temp_image_path)
    
    solution = solutions[index]
    print("Label: "+solution["label"])
        
    print("Correct concepts")
    print(solution["correct"])
    
    
    print("Combined concepts")
    print(solution["combined"])
    

In [None]:
# Image 1
# 6 7 8 9
for i in range(10):
    print("Image "+str(i))
    vizualise_explanation(i)