In [None]:
!pip install shap


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting shap
  Downloading shap-0.41.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (572 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m572.6/572.6 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting slicer==0.0.7 (from shap)
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.41.0 slicer-0.0.7


In [None]:
from google.colab import drive, files
import pandas as pd
import numpy as np
import pickle
from IPython.display import display
import shap

In [None]:
drive.mount('/content/drive/', force_remount=True)


Mounted at /content/drive/


# Load required files

In [None]:
# you can change the path of the directory based on where you saved the repo.
dataset_dir = '/peak/dataset_dir'
pickle_dir = '/peak/pickle_dir'


In [None]:
df_train = pd.read_csv(dataset_dir +"/df_train.csv")
df_test = pd.read_csv(dataset_dir +"/df_test.csv")

train_input = pickle.load(open(pickle_dir+"/train_input.pickle", "rb"))
test_input = pickle.load(open(pickle_dir+"/test_input.pickle", "rb"))
train_label = pickle.load(open(pickle_dir+"/train_label.pickle", "rb"))
test_label = pickle.load(open(pickle_dir+"/test_label.pickle", "rb"))
df_train_shapley = pd.read_csv(dataset_dir +"/df_train_shapley.csv")
df_test_shapley = pd.read_csv(dataset_dir +"/df_test_shapley.csv")


In [None]:
topics = 20
my_list = [str(i) for i in np.arange(topics)]
string = 'topic '
input_columns = list(map(lambda orig_string: string + orig_string, my_list))

df_train_input = pd.DataFrame(train_input, columns = input_columns)
df_train_label = pd.DataFrame(train_label, columns = ['label'])

df_test_input = pd.DataFrame(test_input, columns = input_columns)
df_test_label = pd.DataFrame(test_label, columns = ['label'])


In [None]:
def prep_df(df_shapley, df_input):
    df_2 = df_shapley.copy().iloc[:, :20]

    columns_dict = {0: "topic 0", 1: "topic 1",
                2: "topic 2", 3: "topic 3",
                4: "topic 4", 5: "topic 5",
                6: "topic 6", 7: "topic 7",
                8: "topic 8", 9: "topic 9",
                10: "topic 10", 11: "topic 11",
                12: "topic 12", 13: "topic 13",
                14: "topic 14", 15: "topic 15",
                16: "topic 16", 17: "topic 17",
                18: "topic 18", 19: "topic 19",
                }

    pd_idx = pd.Index(list(range(20)))
    boolean_idx_matrix = np.zeros((len(df_shapley), 20), dtype=bool)
    topic_list = list(df_input.apply(lambda x: x > 0, raw=True).apply(lambda x: list(df_input.columns[x.values]), axis=1))

    for idx in range(len(df_2)):
        boolean_idx_matrix[idx] = pd_idx.isin([k for k, v in columns_dict.items() if v in topic_list[idx]])

    for i in range(boolean_idx_matrix.shape[0]):
        for j in range(boolean_idx_matrix.shape[1]):
            if (boolean_idx_matrix[i][j] == False):
                df_2.at[i, columns_dict[j]] = 0

    df_3 = df_2.abs()
    df_4 = df_3.div(df_3.sum(axis=1), axis=0)

    return df_2, df_3, df_4


In [None]:
train_df2, train_df3, train_df4 = prep_df(df_train_shapley, df_train_input)
test_df2, test_df3, test_df4 = prep_df(df_test_shapley, df_test_input)


In [None]:
indexes_public = list(df_test[df_test['normalizedpublic'] == 1.0].index)
indexes_private = list(df_test[df_test['normalizedpublic'] != 1.0].index)
len(indexes_public), len(indexes_private)


(2500, 2500)

# Dominant category

In [None]:
def cat_dominant(dominant_private_ub, dominant_public_ub, df_4):
    df_dominant_private_base = df_4.loc[indexes_private]
    df_dominant_public_base = df_4.loc[indexes_public]

    df_dominant_private = df_dominant_private_base[df_dominant_private_base.max(axis=1)>=dominant_private_ub]
    df_dominant_public = df_dominant_public_base[df_dominant_public_base.max(axis=1)>=dominant_public_ub]

    df_dominant = pd.concat([df_dominant_private, df_dominant_public])

    return df_dominant


# Opponent category

In [None]:
def cat_opponent(opponent_private_ub, opponent_public_ub, df_4, df_2, opponent_ub_2, df_dominant, paired_ub):
    df_5_private_base = df_4.loc[indexes_private]
    df_5_public_base = df_4.loc[indexes_public]

    df_5_private = df_5_private_base.apply(lambda x: x >= opponent_private_ub)
    df_5_public = df_5_public_base.apply(lambda x: x >= opponent_public_ub)

    df_5 = pd.concat([df_5_private, df_5_public])

    df_6 = df_5.loc[df_5[(df_5.sum(axis = 1) >= 2)].index] #the number of valid elements per image >=2
    df_6 = df_2[df_6].dropna(how='all')

    df_n_n = df_6[df_6.apply(lambda x: x<0)] < 0
    df_n_p = df_6[df_6.apply(lambda x: x>0)] > 0
    df_n = df_n_n.sum(axis=1).to_frame('negatives')
    df_n["positives"] = df_n_p.sum(axis=1)

    df_opponent = df_n[(df_n["negatives"] >= opponent_ub_2)&(df_n["positives"] >= opponent_ub_2)]
    idx_intersect = list(set(df_opponent.index) & set(df_dominant.index))
    df_opponent = df_opponent.drop(idx_intersect)
    df_opponent = df_4.iloc[df_opponent.index]

    temp_idx_list = [idx if (np.abs(df_2.iloc[idx].min()+df_2.iloc[idx].max())<paired_ub)==True else None for idx in df_opponent.index]
    indexes_filtered_opponent = [x for x in temp_idx_list if x is not None]

    return indexes_filtered_opponent, df_opponent


# Collaborative category

In [None]:
def cat_collab(df_2, collaborative_private_ub, collaborative_public_ub, df_dominant, df_conflict):
    df_7 = df_2.copy()
    df_8 = df_7[df_7.apply(lambda x: x<0)].sum(axis=1).to_frame('negatives')
    df_8["positives"] = df_7[df_7.apply(lambda x: x>0)].sum(axis=1)
    df_8 = df_8.abs()
    df_8["summa"] = df_8.sum(axis=1)
    df_8["res_neg"] = df_8["negatives"]/df_8["summa"]
    df_8["res_pos"] = df_8["positives"]/df_8["summa"]

    df_8_private_base = df_8.loc[indexes_private]
    df_8_public_base = df_8.loc[indexes_public]

    l_neg_private = list(df_8_private_base[df_8_private_base["res_neg"] >= collaborative_private_ub].index)
    l_pos_private = list(df_8_private_base[df_8_private_base["res_pos"] >= collaborative_private_ub].index)

    l_neg_public = list(df_8_public_base[df_8_public_base["res_neg"] >= collaborative_public_ub].index)
    l_pos_public = list(df_8_public_base[df_8_public_base["res_pos"] >= collaborative_public_ub].index)

    df_collaborative_private = df_7.loc[[*l_neg_private, *l_pos_private]].sort_index()
    df_collaborative_public = df_7.loc[[*l_neg_public, *l_pos_public]].sort_index()

    df_collaborative = pd.concat([df_collaborative_private, df_collaborative_public])

    idx_intersect_dom = list(set(df_collaborative.index) & set(df_dominant.index))
    idx_intersect_con = list(set(df_collaborative.index) & set(df_conflict.index))
    df_collaborative = df_collaborative.drop(idx_intersect_dom+idx_intersect_con)

    return df_collaborative


# Observing misclassified images

In [None]:
# Category Dominant
test_df_dominant = cat_dominant(0.7, 0.7, test_df4)
indexes_cat_dominant = np.array(test_df_dominant.index)
print("dominant: ", len(indexes_cat_dominant))

# Category Opponent
indexes_filtered_opponent, test_df_opponent = cat_opponent(0.2, 0.2, test_df4, test_df2, 1, test_df_dominant, 0.1)
indexes_cat_opponent = np.array(test_df_opponent.index)
print("opponent: ", len(test_df_opponent.index), len(indexes_filtered_opponent))

# # Category Collaborative
test_df_collaborative = cat_collab(test_df2, 0.8, 0.8, test_df_dominant, test_df_opponent)
indexes_cat_collaborative = np.array(test_df_collaborative.index)
print("collab: ", len(indexes_cat_collaborative))

# # Category Weak
test_df_weak = test_df2[~test_df2.index.isin(list(test_df_dominant.index)+list(test_df_opponent.index)+list(test_df_collaborative.index))]
indexes_cat_weak = np.array(test_df_weak.index)
print("weak: ", len(indexes_cat_weak))


dominant:  563
conflicting:  848 750
collab:  2776
vague:  813


# Generate Explanations

### Preprocessing

In [None]:
# transform extracted_tags column into comma separated str list
def str_to_comma_list(df):
    tag_list_all = []
    index_list = list(df.index)

    for index in index_list:
        tag_list_all.append([i.split(':', 1)[0] for i in df['extracted_tags'][index]][:20])

    tag_list_all = [[x.lower().strip().replace('(', '').replace(')', '').replace('-', ' ') for x in y] for y in tag_list_all]
    tag_list_all2 = [[x.replace(' ', '_') for x in y] for y in tag_list_all]
    df['cleaned_tags_w_comma'] = tag_list_all2


### Get the tags that are in the intersection set of a selected image and topic <br>
Warning: df_train_topic_tag: You should save the topic-tag dataframe after running NMF topic modelling. You can save the file in the pickle_dir

In [None]:
import ast

def tag_intersect_list(image_id, topic_id, df, df_train_topic_tag):
    x = df[df.image == image_id].cleaned_tags_w_comma.item()

    intersect_list = sorted(set(list(df_train_topic_tag[topic_id])) & set(ast.literal_eval(x)), key = list(df_train_topic_tag[topic_id]).index)
    return intersect_list


# Create a dictionary in which we will store the explanation category name and topics to display in the explanation of each image.

In [None]:
dominant_dict = test_df_dominant.idxmax(axis=1).to_dict()

opposing_dict = {index: (value1, value2) for index, value1, value2 in zip(test_df2.loc[list(test_df_conflict.index)].idxmin(axis=1).index, test_df2.loc[list(test_df_conflict.index)].idxmin(axis=1).values, test_df2.loc[list(test_df_conflict.index)].idxmax(axis=1).values)}

df = train_df4.loc[list(test_df_collaborative.index)]
collaborative_dict = df.apply(lambda row: row.nlargest(3).index.tolist(), axis=1).to_dict()

top_weak = test_df4.loc[test_df_others.index].apply(lambda row: row.nlargest(3).index.tolist(), axis=1)

top_negative = test_df_others.apply(lambda row: row.nlargest(3).index.tolist(), axis=1)
top_positive = test_df_others.apply(lambda row: row.nsmallest(3).index.tolist(), axis=1)

intersection_negative = {}
intersection_positive = {}
for idx in top_weak.index:
    intersection_neg = set(top_weak[idx]).intersection(top_negative[idx])
    intersection_pos = set(top_weak[idx]).intersection(top_positive[idx])

    intersection_negative[idx] = list(intersection_neg)
    intersection_positive[idx] = list(intersection_pos)


weak_dict = {key: [intersection_negative.get(key, []), intersection_positive.get(key, [])] for key in set(intersection_negative) | set(intersection_positive)}

def merge_dict_category_topic(dominant_dict, opposing_dict, collaborative_dict, weak_dict):
    cat_top_dict = {}

    for key, values in dominant_dict.items():
        cat_top_dict[key] = ['dominant', values]

    for key, values in opposing_dict.items():
        cat_top_dict[key] = ["opposing", values]

    for key, values in collaborative_dict.items():
        cat_top_dict[key] = ['collaborative', values]

    for key, values in weak_dict.items():
        cat_top_dict[key] = ["weak", values]

    return dict(sorted(cat_top_dict.items(), key=lambda item: item[0]))

cat_top_dict = merge_dict_category_topic(dominant_dict, opposing_dict, collaborative_dict, weak_dict)


# Plot explanations

In [None]:
from wordcloud import WordCloud
import numpy as np
import matplotlib.pyplot as plt
import ast

def plot_explanations(idx, cat_top_dict, df):
    category_name = cat_top_dict.get(idx)[0]
    truth_label = ['public' if df["normalizedpublic"].get(idx) == 1.0 else "private"][0]

    image_id = df[df.index == idx]["image"][0]
    # image_id = 4398397540

    # Predictions
    random_state, max_depth = 333, 13
    classifier_rf = RandomForestClassifier(random_state = random_state, max_depth=max_depth)
    classifier_rf_fit = classifier_rf.fit(train_input, train_label)
    predictions = classifier_rf_fit.predict(test_input)
    predicted_label = ['private' if predictions[idx] == 0 else "public"][0]

    # Function to format text
    def format_text(text, max_width):
        words = text.split()
        lines = []
        current_line = []

        for word in words:
            if len(' '.join(current_line + [word])) <= max_width:
                current_line.append(word)
            else:
                lines.append(' '.join(current_line))
                current_line = [word]

        if current_line:
            lines.append(' '.join(current_line))

        return '\n'.join(lines)

    # Generate text based on category
    if category_name == "dominant":
        topic = cat_top_dict.get(idx)[1]
        # feel free to rename topics based on your topic modelling result
        text = f"The generated explanation for this image being assigned to the {predicted_label} class \
                is that it is related to the topic {topic} with these specific tags."

        word_cloud_topics = topic
        num_circles = len(word_cloud_topics)

        if truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet"]
        elif truth_label == predicted_label and truth_label == "public":
            contour_colors = ["darkorange"]
        else: # misclassification
            contour_colors = ["black"]

    elif category_name == "opposing":
        topic_negative = cat_top_dict.get(idx)[1][0]
        topic_positive = cat_top_dict.get(idx)[1][1]
        # feel free to rename topics based on your topic modelling result
        text = f"Even though the image is related to the topic {topic_negative} with the specific tags below \
                which signals the {truth_label} class), it is also related to the topic {topic_positive} and \
                for that reason, it is classified as {predicted_label}."

        word_cloud_topics = [topic_positive, topic_negative]
        num_circles = len(word_cloud_topics)
        print("opposing", num_circles)
        if truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet", "darkorange"]
        elif truth_label == predicted_label and truth_label == "public":
            contour_colors = ["darkorange", "darkviolet"]
        else: # misclassification
            contour_colors = ["black", "black"]

    elif category_name == "collaborative":
        topics = cat_top_dict.get(idx)[1]
        # feel free to rename topics based on your topic modelling result
        text = f"The generated explanation for this image being assigned to the {predicted_label} class \
                is that it is related to the topics {', '.join(topics)} with these specific tags."

        word_cloud_topics = topics
        num_circles = len(word_cloud_topics)
        print("collab", num_circles)

        if truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet", "darkviolet", "darkviolet"]
        elif truth_label == predicted_label and truth_label == "public":
            contour_colors = ["darkorange", "darkorange", "darkorange"]
        else: # misclassification
            contour_colors = ["black", "black", "black"]

    else:
        topic_negative = cat_top_dict.get(idx)[1][0]
        topic_positive = cat_top_dict.get(idx)[1][1]
        # feel free to rename topics based on your topic modelling result
        text = f"Even though the image is related to the topic {', '.join(topic_negative)} with the specific tags below \
                which signals the {truth_label} class), it is also related to the topic {', '.join(topic_positive)} and \
                for that reason, it is classified as {predicted_label}."

        word_cloud_topics = [topic_positive, topic_negative]
        num_circles = len(word_cloud_topics)

        if len(topic_positive)==1 and truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet", "darkorange", "darkorange"]
        elif len(topic_positive)==2 and truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet", "darkviolet", "darkorange"]
        elif len(topic_positive)==3 and truth_label == predicted_label and truth_label == "private":
            contour_colors = ["darkviolet", "darkviolet", "darkviolet"]
        else: # misclassification
            contour_colors = ["black", "black", "black"]

    formatted_text = format_text(text, max_width=80)

    x, y = np.ogrid[:300, :300]
    mask = (x - 150) ** 2 + (y - 150) ** 2 > 130 ** 2
    mask = 255 * mask.astype(int)

    contour_color = "darkviolet"  # or darkorange

    # Calculate the number of columns based on the number of circles
    num_circles = len(word_cloud_topics)
    # Create the subplots grid
    num_cols = num_circles + 1  # One extra column for the text
    fig, axs = plt.subplots(1, num_cols, figsize=(5 * num_cols, 10))

    # Plot the text at the top
    axs[0].text(0.5, 0.5, formatted_text, ha='center', va='center', fontsize=12)
    axs[0].axis("off")

    # Create WordClouds and plot circles in the remaining columns
    wordclouds = []  # List to hold the generated WordCloud objects

    for i in range(num_circles):
        wordcloud = WordCloud(width=500, height=300, margin=3, prefer_horizontal=0.7, scale=1,
                              background_color="white", mask=mask, contour_width=0.1,
                              contour_color=contour_colors[i], relative_scaling=0)

        x = df[df.image == image_id].cleaned_tags_w_comma.item()
        print(x)
        topic_id = int(word_cloud_topics[i].split()[-1])
        tags = sorted(set(list(df_train_topic_tag[topic_id])) & set(ast.literal_eval(x)), key = list(df_train_topic_tag[topic_id]).index)
        tags = " ".join(tags)
        print(tags)

        wordcloud.generate(tags)

        axs[i + 1].imshow(wordcloud)
        axs[i + 1].title.set_text(word_cloud_topics[i])
        axs[i + 1].axis("off")

    plt.tight_layout()
    plt.show()
