<b>To show panel, in seperate anaconda window type:</b>

G:

cd "My Drive\Studie\Thesis\Viz"

panel serve Panel_interactive_viz-retry.ipynb --show

<b>Then in new browser tab, go to: </b>

http://localhost:5006/Panel_interactive_viz_retry

<b>Local visualisations</b>

In [1]:
import numpy as np
import json
import matplotlib.pyplot as plt
import os
from PIL import Image
from configparser import ConfigParser, ExtendedInterpolation

from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvas

import panel as pn


'''Set folders to be used. --> Can be turned into something more interactive'''
input_dir = "testing_blur2"


'''Global variables'''
config = ConfigParser(interpolation=ExtendedInterpolation())

classes_ini_file = os.path.join(input_dir, 'classes.ini')
images_ini_file = os.path.join(input_dir, 'images.ini')

if False:#os.path.exists(images_ini_file):
    input_file = images_ini_file
    config.read(input_file)
    class_files = config['images']
elif os.path.exists(classes_ini_file):
    input_file = classes_ini_file
    config.read(input_file)
    class_files = config['classification']
else:
    print('Error: no classes.ini of images.ini file found in input_dir')

    
img_paths = {i.split('/')[-1]: i.split('/')[-2] for i in class_files.get('img_paths').split('\n') if i}


start_val = list(img_paths.keys())[0]

w_dict_path = os.path.join(input_dir,'weights.txt')
class_dict_path = os.path.join(input_dir,'imagenet_classcode_to_index.json')
img_dir = os.path.join(input_dir, 'images')
concept_dir = os.path.join(input_dir, 'concepts')
concept_dict_path = os.path.join(input_dir, 'concept_files.txt')

with open(w_dict_path, 'r') as f:
    w_dict = json.load(f)
with open(class_dict_path, 'r') as f:
    class_dict = json.load(f)
with open(concept_dict_path, 'r') as f:
    concept_dict = json.load(f)

class_dict_clean = {k: v[1] for k, v in class_dict.items()}
inverted_dict = dict(map(reversed, class_dict_clean.items()))

class_list = [key for key in w_dict]
class_names = [class_dict_clean[x] for x in class_list]

'''Selecting image'''
inp = pn.widgets.Select(name='input file', options=list(img_paths.keys()))


'''Selecting image's variables'''
class Input_Data():
    def __init__(self, inp):
        self.img_folder = inp.split('.')[0]
        
        self.og_img_name = inp
        self.og_img_path = os.path.join(img_dir, self.og_img_name)
        self.og_img_class = img_paths[inp]
        self.class_scores_path = os.path.join(input_dir, f"class_scores/{self.img_folder}_class_scores.txt")
        self.sim_dict_path = os.path.join(input_dir, f"sim_dicts/{self.img_folder}_sim_dict.txt")
        self.segment_dir = os.path.join(img_dir, self.img_folder)

        with open(self.class_scores_path, 'r') as f:
            self.class_scores = json.load(f)
        with open(self.sim_dict_path, 'r') as f:
            self.sim_dict = json.load(f)

            
'''Local visualisations'''
@pn.depends(inp.param.value)
def og_img(inp): 
    '''Show original image.'''
    ip = Input_Data(inp)
    
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot(frame_on=False)
    ax.imshow(Image.open(ip.og_img_path))
    ax.set_xticks([])
    ax.set_xlabel(f"actual class: {class_dict[ip.og_img_class]}")
    ax.set_yticks([])
    ax.set_title(f"{ip.og_img_name}")
    
    return fig

@pn.depends(inp.param.value)
def class_overview(inp): 
    '''Show simple bar plot with total class scores.'''
    ip = Input_Data(inp)
    
    index = np.arange(len(class_list))
    
    scores = list(ip.class_scores.values())
    
    fig = Figure()
    FigureCanvas(fig)
    ax = fig.add_subplot()
    ax.bar(index, scores, align='center')
    ax.set_xticks(index)
    ax.set_xticklabels(class_names, rotation=30, ha='right')
    ax.set(ylabel='total class scores', title=f'predicted total scores for all the classes')
    ax.set_ylim(0,1)
    fig.set_tight_layout(True)
    
    return fig


@pn.depends(inp.param.value)
def prototype_heatmap(inp):
    '''Shows a heatmap where rows represent the classes, columns represent the prototypes with the 
    highest similarity scores to the input image, and the colour represents weight*similarity_score'''
    '''Remove?'''
    n = 10
    ip = Input_Data(inp)
    
    
    fig = Figure(figsize=(24,12))
    FigureCanvas(fig)
    ax = fig.add_subplot(2,1,1)
    
    cls_all_sim = [(k, v[0], v[1]) for k, v in ip.sim_dict.items()] # for a given class, get the weights of top 5 sim scores
    cls_all_sim.sort(key = lambda y: y[2], reverse=True)
    
    concept_names = [c[0] for c in cls_all_sim[:n]]
    
    cls_w = np.array([[w_dict[cls][c[0]]*c[2] for c in cls_all_sim[:n]] for cls in class_list])
    ax.imshow(cls_w)
    
    ax.set_xticks(np.arange(len(concept_names)))
    ax.set_xticklabels(concept_names, rotation=90)
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_yticklabels(class_names)
    
    return fig



'''Create interactive widgets to select the to-be-compared classes.'''
cls1 = pn.widgets.Select(name='class 1', options=class_names)
cls2 = pn.widgets.Select(name='class 2', options=class_names)



@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)##
def simxw_plot(c1, c2, inp):
    '''Shows a bar plot for the n highest total_prototype_score=weights*similarity_score for two chosen classes.
    (--> the most contributing prototypes to the prediction?)
    Also shows the corresponding images.'''
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    n = 10
    
    fig = Figure(figsize=(24,16))
    FigureCanvas(fig)
        
    bar_width = 0.7
    index = np.arange(n)

    # plot 1
    ax1 = fig.add_subplot(2,2,1)
    
    simxw_cls1 = [(c, w_dict[cls1][c], ip.sim_dict[c][1], w_dict[cls1][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls1.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_cls1 = [tup[3] for tup in simxw_cls1[:n]]
    stacked_part_cls1 = [tup[1] for tup in simxw_cls1[:n]]
    
    rects1b = ax1.bar(index, stacked_part_cls1, bar_width, color='w', edgecolor='black', label='weight')
    rects1 = ax1.bar(index, top_simxw_cls1, bar_width, color='b', label='sim*weight')
    
    
    ax1.set(xlabel=f'Prototypes with highest values for weight*similarity_score for class {c1}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c1}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_xlim([0-bar_width, (n-1)+bar_width])
    
    
    
    # plot 2
    ax2 = fig.add_subplot(2,2,2, sharey=ax1)
    
    simxw_cls2 = [(c, w_dict[cls2][c], ip.sim_dict[c][1], w_dict[cls2][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls2.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_cls2 = [tup[3] for tup in simxw_cls2[:n]]
    stacked_part_cls2 = [tup[1] for tup in simxw_cls2[:n]]
    
    rects2b = ax2.bar(index, stacked_part_cls2, bar_width, color='w', edgecolor='black', label='weight')
    rects2 = ax2.bar(index, top_simxw_cls2, bar_width, color='g', label='sim*weight')
    
    ax2.set(xlabel=f'Prototypes with highest values for weight*similarity_score for class {c2}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c2}'
          )
    ax2.legend()
    
    ax2.xaxis.set_tick_params(labelbottom=False)
    ax2.set_xticks([])
    ax2.set_xlim([0-bar_width, (n-1)+bar_width])
    
    y_min = min([-0.5, min(stacked_part_cls1), min(stacked_part_cls2)])
    y_max = max([0.5, max(stacked_part_cls1), max(stacked_part_cls2)])
    
    ax1.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    ax2.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    
    
    # images plot 1
    prototypes = []
    segments = []
    n1 = n
    n2 = n
    ppp = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls1[:n]]
    prototype_paths = [os.path.join(concept_dir, p) for p in ppp] # chosen class, max w, prototype path
    prototypes = [Image.open(p) for p in prototype_paths]
    
    sss = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls1[:n]]
    segment_paths = [os.path.join(ip.segment_dir, s) for s in sss]
    segments = [Image.open(s) for s in segment_paths]


    for i in range(n1):
        ax3 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1, frame_on=False)
        ax3.imshow(prototypes[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        ax4 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1, frame_on=False)
        ax4.imshow(segments[i])
        ax4.set_xticks([])
        ax4.set_yticks([])
        if i == 0:
            ax3.set_ylabel('prototypes', size='large')
            ax4.set_ylabel('closest input\nsegments', size='large')
        ax3.set_xlabel(f"w = {round(simxw_cls1[i][1], 3)}")
        ax4.set_xlabel(f"similarity\n = {round(simxw_cls1[i][2], 3)}")
    
    
    
    
    # images plot 2
    prototypes2 = []
    segments2 = []
    ppp2 = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls2[:n]]
    prototype_paths2 = [os.path.join(concept_dir, p) for p in ppp2] # chosen class, max w, prototype path
    prototypes2 = [Image.open(p) for p in prototype_paths2]
    
    sss2 = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls2[:n]]
    segment_paths2 = [os.path.join(ip.segment_dir, s) for s in sss2]
    segments2 = [Image.open(s) for s in segment_paths2]


    for i in range(n2):
        ax5 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1+(n1+int(n/4)), frame_on=False)
        ax5.imshow(prototypes2[i])
        ax5.set_xticks([])
        ax5.set_yticks([])
        ax6 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1+(n1+int(n/4)), frame_on=False)
        ax6.imshow(segments2[i])
        ax6.set_xticks([])
        ax6.set_yticks([])
        if i == 0:
            ax5.set_ylabel('prototypes', size='large')
            ax6.set_ylabel('closest input\nsegments', size='large')
        ax5.set_xlabel(f"w = {round(simxw_cls2[i][1], 3)}")
        ax6.set_xlabel(f"similarity\n = {round(simxw_cls2[i][2], 3)}")
    
    
    return fig





@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)##
def simxw_plot_combined(c1, c2, inp):
    '''Shows a bar plot for the n highest total_prototype_score=weights*similarity_score for two chosen classes.
    (--> the most contributing prototypes to the prediction?)
    Also shows the corresponding images.'''
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    n = 10
    
    fig = Figure(figsize=(24,16))
    FigureCanvas(fig)
        
    bar_width = 0.35
    index = np.arange(n)

    # plot 1
    ax1 = fig.add_subplot(2,2,1)
    
    
    simxw_cls1 = [(c, w_dict[cls1][c], ip.sim_dict[c][1], w_dict[cls1][c]*ip.sim_dict[c][1], w_dict[cls2][c], w_dict[cls2][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls1.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_w1_cls1 = [tup[1] for tup in simxw_cls1[:n]]
    top_simxw_sw1_cls1 = [tup[3] for tup in simxw_cls1[:n]]
    top_simxw_w2_cls1 = [tup[4] for tup in simxw_cls1[:n]]
    top_simxw_sw2_cls1 = [tup[5] for tup in simxw_cls1[:n]]
    
    rects1b = ax1.bar(index, top_simxw_w1_cls1, bar_width, color='w', edgecolor='black', label='weight')
    rects1a = ax1.bar(index, top_simxw_sw1_cls1, bar_width, color='b', label=f'sim*weight class {c1}')
    
    rects1d = ax1.bar(index+bar_width, top_simxw_w2_cls1, bar_width, color='w', edgecolor='black')#, label='weight')
    rects1c = ax1.bar(index+bar_width, top_simxw_sw2_cls1, bar_width, color='g', label=f'sim*weight class {c2}')
    
    
    ax1.set(xlabel=f'Prototypes with highest values for weight*similarity_score for class {c1}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c1}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_xlim([0-bar_width, (n-1)+2*bar_width])
    
    
    
    
    # plot 2
    ax2 = fig.add_subplot(2,2,2, sharey=ax1)
    
    simxw_cls2 = [(c, w_dict[cls2][c], ip.sim_dict[c][1], w_dict[cls2][c]*ip.sim_dict[c][1], w_dict[cls1][c], w_dict[cls1][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls2.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_w2_cls2 = [tup[1] for tup in simxw_cls2[:n]]
    top_simxw_sw2_cls2 = [tup[3] for tup in simxw_cls2[:n]]
    top_simxw_w1_cls2 = [tup[4] for tup in simxw_cls2[:n]]
    top_simxw_sw1_cls2 = [tup[5] for tup in simxw_cls2[:n]]
    
    rects2b = ax2.bar(index, top_simxw_w2_cls2, bar_width, color='w', edgecolor='black', label='weight')
    rects2a = ax2.bar(index, top_simxw_sw2_cls2, bar_width, color='g', label=f'sim*weight class {c2}')
    
    rects2d = ax2.bar(index+bar_width, top_simxw_w1_cls2, bar_width, color='w', edgecolor='black')#, label='weight')
    rects2c = ax2.bar(index+bar_width, top_simxw_sw1_cls2, bar_width, color='b', label=f'sim*weight class {c1}')
    
    ax2.set(xlabel=f'Prototypes with highest values for weight*similarity_score for class {c2}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c2}'
          )
    ax2.legend()
    
    ax2.xaxis.set_tick_params(labelbottom=False)
    ax2.set_xticks([])
    ax2.set_xlim([0-bar_width, (n-1)+2*bar_width])
    
    
    y_min = min([-0.5, min(top_simxw_w1_cls1), min(top_simxw_w2_cls2)])
    y_max = max([0.5, max(top_simxw_w1_cls1), max(top_simxw_w2_cls2)])
    
    ax1.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    ax2.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    

    
    
    # images plot 1
    prototypes = []
    segments = []
    n1 = n
    n2 = n
    ppp = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls1[:n]]
    prototype_paths = [os.path.join(concept_dir, p) for p in ppp] # chosen class, max w, prototype path
    prototypes = [Image.open(p) for p in prototype_paths]
    
    sss = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls1[:n]]
    segment_paths = [os.path.join(ip.segment_dir, s) for s in sss]
    segments = [Image.open(s) for s in segment_paths]


    for i in range(n1):
        ax3 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1, frame_on=False)
        ax3.imshow(prototypes[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        ax4 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1, frame_on=False)
        ax4.imshow(segments[i])
        ax4.set_xticks([])
        ax4.set_yticks([])
        if i == 0:
            ax3.set_ylabel('prototypes', size='large')
            ax4.set_ylabel('closest input\nsegments', size='large')
        ax3.set_xlabel(f"w = {round(simxw_cls1[i][1], 3)}")
        ax4.set_xlabel(f"similarity\n = {round(simxw_cls1[i][2], 3)}")
    
    
    
    
    # images plot 2
    prototypes2 = []
    segments2 = []
    ppp2 = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls2[:n]]
    prototype_paths2 = [os.path.join(concept_dir, p) for p in ppp2] # chosen class, max w, prototype path
    prototypes2 = [Image.open(p) for p in prototype_paths2]
    
    sss2 = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls2[:n]]
    segment_paths2 = [os.path.join(ip.segment_dir, s) for s in sss2]
    segments2 = [Image.open(s) for s in segment_paths2]


    for i in range(n2):
        ax5 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1+(n1+int(n/4)), frame_on=False)
        ax5.imshow(prototypes2[i])
        ax5.set_xticks([])
        ax5.set_yticks([])
        ax6 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1+(n1+int(n/4)), frame_on=False)
        ax6.imshow(segments2[i])
        ax6.set_xticks([])
        ax6.set_yticks([])
        if i == 0:
            ax5.set_ylabel('prototypes', size='large')
            ax6.set_ylabel('closest input\nsegments', size='large')
        ax5.set_xlabel(f"w = {round(simxw_cls2[i][1], 3)}")
        ax6.set_xlabel(f"similarity\n = {round(simxw_cls2[i][2], 3)}")
    
    
    return fig





@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)##
def simxw_plot_mirror(c1, c2, inp):
    '''Shows a bar plot for the n highest total_prototype_score=weights*similarity_score for two chosen classes.
    (--> the most contributing prototypes to the prediction?)
    Also shows the corresponding images.'''
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    n = 10
    
    fig = Figure(figsize=(24,16))
    FigureCanvas(fig)
        
    bar_width = 0.7
    index = np.arange(n)
    index1 = [i+1 for i in index]

    # plot 1
    ax1 = fig.add_subplot(2,1,1)

    simxw_cls1 = [(c, w_dict[cls1][c], ip.sim_dict[c][1], w_dict[cls1][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls1.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_w1_cls1 = [tup[1] for tup in simxw_cls1[:n]]
    top_simxw_sw1_cls1 = [tup[3] for tup in simxw_cls1[:n]]
    
    rects1b = ax1.bar(index1, top_simxw_w1_cls1, bar_width, color='w', edgecolor='black', label='weight')
    rects1a = ax1.bar(index1, top_simxw_sw1_cls1, bar_width, color='b', label=f'sim*weight class {c1}')
    
    
    simxw_cls2 = [(c, w_dict[cls2][c], ip.sim_dict[c][1], w_dict[cls2][c]*ip.sim_dict[c][1], w_dict[cls1][c], w_dict[cls1][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls2.sort(key=lambda y: y[3], reverse=False)
    
    top_simxw_w2_cls2 = [tup[1] for tup in simxw_cls2[-n:]]
    top_simxw_sw2_cls2 = [tup[3] for tup in simxw_cls2[-n:]]
    
    index2 = [-i for i in index1]
    index2.sort()
    rects2b = ax1.bar(index2, top_simxw_w2_cls2, bar_width, color='w', edgecolor='black')#, label='weight')
    rects2a = ax1.bar(index2, top_simxw_sw2_cls2, bar_width, color='g', label=f'sim*weight class {c2}')
    
    ax1.set(xlabel=f'Prototypes with highest values for weight*similarity_score for classes {c1} and {c2}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype scores for the input image for the given classes {c1} and {c2}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_xlim([-(n+bar_width), (n)+bar_width])    
    
    y_min = min([-0.5, min(top_simxw_w1_cls1), min(top_simxw_w2_cls2)])
    y_max = max([0.5, max(top_simxw_w1_cls1), max(top_simxw_w2_cls2)])
    
    ax1.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    
    
    # images plot 1
    prototypes = []
    segments = []
    n1 = n
    n2 = n
    ppp = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls2[-n:]]
    prototype_paths = [os.path.join(concept_dir, p) for p in ppp] # chosen class, max w, prototype path
    prototypes = [Image.open(p) for p in prototype_paths]

    
    sss = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls2[-n:]]
    segment_paths = [os.path.join(ip.segment_dir, s) for s in sss]
    segments = [Image.open(s) for s in segment_paths]



    for i in range(n1):
        ax3 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1, frame_on=False)
        ax3.imshow(prototypes[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        ax4 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1, frame_on=False)
        ax4.imshow(segments[i])
        ax4.set_xticks([])
        ax4.set_yticks([])
        if i == 0:
            ax3.set_ylabel('prototypes', size='large')
            ax4.set_ylabel('closest input\nsegments', size='large')
        ax3.set_xlabel(f"w = {round(simxw_cls2[-n:][i][1], 3)}")
        ax4.set_xlabel(f"similarity\n = {round(simxw_cls2[-n:][i][2], 3)}")
    
    
    
    
    # images plot 2
    prototypes2 = []
    segments2 = []
    ppp2 = [concept_dict[tup[0]].split('/')[-1] for tup in simxw_cls1[:n]]
    prototype_paths2 = [os.path.join(concept_dir, p) for p in ppp2] # chosen class, max w, prototype path
    prototypes2 = [Image.open(p) for p in prototype_paths2]
    
    sss2 = [ip.sim_dict[tup[0]][0].split('/')[-1] for tup in simxw_cls1[:n]]
    segment_paths2 = [os.path.join(ip.segment_dir, s) for s in sss2]
    segments2 = [Image.open(s) for s in segment_paths2]


    for i in range(n2):
        ax5 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1+(n1+int(n/4)), frame_on=False)
        ax5.imshow(prototypes2[i])
        ax5.set_xticks([])
        ax5.set_yticks([])
        ax6 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1+(n1+int(n/4)), frame_on=False)
        ax6.imshow(segments2[i])
        ax6.set_xticks([])
        ax6.set_yticks([])
        if i == 0:
            ax5.set_ylabel('prototypes', size='large')
            ax6.set_ylabel('closest input\nsegments', size='large')
        ax5.set_xlabel(f"w = {round(simxw_cls1[i][1], 3)}")
        ax6.set_xlabel(f"similarity\n = {round(simxw_cls1[i][2], 3)}")
    
    
    return fig



"""

def draw_brace(ax, xspan, ylocs=None, dirr='up', text=''):
    '''Draws an annotated brace on the axes.'''
    xmin, xmax = xspan
    xspan = xmax - xmin
    ax_xmin, ax_xmax = ax.get_xlim()
    xax_span = ax_xmax - ax_xmin
    ymin, ymax = ax.get_ylim()
    yspan = ymax - ymin
#     print(yspan)
    resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
    beta = 300./xax_span # the higher this is, the smaller the radius

    x = np.linspace(xmin, xmax, resolution)
    x_half = x[:resolution//2+1]
    y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
                    + 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
    y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
#     print(yspan)
#     print(y)
#     y = ymin + (.05*y - .01)*yspan # adjust vertical position
    if ylocs is not None:
        if dirr == 'up':
            yloc = ylocs
            y = yloc + (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] + .01*yspan
            valign = 'bottom'
        elif dirr == 'down':
            yloc = ylocs # find out height of bracket
            y = yloc - (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] - .01*yspan
            valign = 'top'
    else:
        if dirr == 'up':
            yloc = ymax
            y = yloc + (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] + .01*yspan
            valign = 'bottom'
        elif dirr == 'down':
            yloc = ymin
            y = yloc - (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] - .01*yspan
            valign = 'top'
     # adjust vertical position
#     print(y)
    
#     ax.autoscale(False)
    ax.plot(x, y, color='black', lw=1)

    ax.text((xmax+xmin)/2., texty, text, ha='center', va=valign)

# ax = plt.gca()
# ax.plot(range(10))
# draw_brace(ax, (0, 8), 'large brace')
# draw_brace(ax, (8, 9), 'small brace')


@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)##
def simxw_distr(c1, c2, inp):
    '''Shows a bar plot for the total_prototype_score=weights*similarity_score for two chosen classes.'''
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    fig = Figure(figsize=(24,12))
    FigureCanvas(fig)
        
    bar_width = 0.5
    index = np.arange(len(list(concept_dict.keys())))

    # plot 1
    ax1 = fig.add_subplot(1,2,1)
    
    simxw_cls1 = [(c, w_dict[cls1][c], ip.sim_dict[c][1], w_dict[cls1][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls1.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_cls1 = [tup[3] for tup in simxw_cls1]
    
    
    
    rects1 = ax1.bar(index, top_simxw_cls1, bar_width, color='b', label=c1)
    
    ax1.set(xlabel=f'Total value (weight*similarity_score) for each prototype for class {c1}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c1}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_ylim(min([-0.5, top_simxw_cls1[-1]]),max([0.5, top_simxw_cls1[0]]))
    
    
    # plot 2
    ax2 = fig.add_subplot(1,2,2, sharey=ax1)
    '''CHECK IF THIS HOLDS WHEN MANUALLY SETTING THE YLIMIT'''
    
    simxw_cls2 = [(c, w_dict[cls2][c], ip.sim_dict[c][1], w_dict[cls2][c]*ip.sim_dict[c][1]) for c in concept_dict.keys()]
    simxw_cls2.sort(key=lambda y: y[3], reverse=True)
    
    top_simxw_cls2 = [tup[3] for tup in simxw_cls2]
    
    rects2 = ax2.bar(index, top_simxw_cls2, bar_width, color='g', label=c2)
    
    ax2.set(xlabel=f'Total value (weight*similarity_score) for each prototype for class {c2}', 
           ylabel='Weight*similarity_score for given prototype and the input image segment closest to it', 
           title=f'Total prototype score for the input image for the given class {c2}'
          )
    ax2.legend()
    
    ax2.xaxis.set_tick_params(labelbottom=False)
    ax2.set_xticks([])
    ax2.set_ylim(min([-0.5, top_simxw_cls2[-1]]),max([0.5, top_simxw_cls2[0]]))
    
    # ------------------------------------------------
    # ------------------------------------------------
    # ------------------------------------------------
    # ------------------------------------------------
    # ------------------------------------------------
    
    ymin1, ymax1 = ax1.get_ylim()
    ymin2, ymax2 = ax2.get_ylim()
    ymin = min([ymin1, ymin2])
    ymax = max([ymax1, ymax2])
    ymin = ymin/2
    
    simxw_pos_cls1 = [x for x in top_simxw_cls1 if x >= 0]
    simxw_neg_cls1 = [x for x in top_simxw_cls1 if x < 0]
    sum_pos_cls1 = sum(simxw_pos_cls1)
    sum_neg_cls1 = sum(simxw_neg_cls1)
    thres_cls1 = len(simxw_pos_cls1)
    draw_brace(ax1, (0, thres_cls1), 0, 'down', f"{sum_pos_cls1}")
    draw_brace(ax1, (thres_cls1, len(list(concept_dict.keys()))), 0, 'up', f"{sum_neg_cls1}")
    draw_brace(ax1, (0, len(list(concept_dict.keys()))), ymin, 'down', f"{sum_pos_cls1+sum_neg_cls1}")
    # also give other axis to determine shared ymin/ymax ?
    
    
    simxw_pos_cls2 = [x for x in top_simxw_cls2 if x >= 0]
    simxw_neg_cls2 = [x for x in top_simxw_cls2 if x < 0]
    sum_pos_cls2 = sum(simxw_pos_cls2)
    sum_neg_cls2 = sum(simxw_neg_cls2)
    thres_cls2 = len(simxw_pos_cls2)
    draw_brace(ax2, (0, thres_cls2), 0, 'down', f"{sum_pos_cls2}")
    draw_brace(ax2, (thres_cls2, len(list(concept_dict.keys()))), 0, 'up', f"{sum_neg_cls2}")
    draw_brace(ax2, (0, len(list(concept_dict.keys()))), ymin, 'down', f"{sum_pos_cls2+sum_neg_cls2}")

    return fig



@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)##
def simfreq_plot(c1, c2, inp):
    '''Shows a bar plot of weights for two chosen classes of n prototypes with highest similarity to some input segments.
    Also shows the corresponding images.'''
    n = 10
    
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    fig = Figure(figsize=(24,12))
    FigureCanvas(fig)
    ax = fig.add_subplot(2,1,1)
    
    bar_width = 0.35
    index = np.arange(n)
#     index = np.array([x-bar_width for x in index_middle])
    
    cls_all_sim = [(k, v[0], v[1]) for k, v in ip.sim_dict.items()] # for a given class, get the weights of top 5 sim scores
    
    cls_all_sim.sort(key = lambda y: y[2], reverse=True)
    cls1_w = [w_dict[cls1][c[0]] for c in cls_all_sim[:n]]
    cls2_w = [w_dict[cls2][c[0]] for c in cls_all_sim[:n]]

    
    rects1 = ax.bar(index, cls1_w, bar_width, color='b', label=c1)
    rects2 = ax.bar(index+bar_width, cls2_w, bar_width, color='g', label=c2)
    
    
    ax.set(xlabel='Prototypes with highest similarity scores to segments', 
           ylabel='Weight of prototype for given class', 
           title=f'Weights of prototypes with highest similarity scores to input segments'
          )
    ax.legend()
    
    # ax min and max: min = 0-bar_width, max = (n-1)+2*bar_width
    ax.set_xlim([0-bar_width, (n-1)+2*bar_width])
    ax.set_ylim(min([-0.5, min(cls1_w)]),max([0.5, max(cls1_w)]))
    
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.set_xticks([])
    
    top_concepts = [x[0] for x in cls_all_sim[:n]]
    prototype_paths = ['/'.join(concept_dict[c].split('/')[-3:]) for c in top_concepts]
    segment_paths = [os.path.join(ip.segment_dir, x[1].split('/')[-1]) for x in cls_all_sim[:n]]
    prototypes = [Image.open(p) for p in prototype_paths]
    segments = [Image.open(s) for s in segment_paths]

    for i in range(n):
        ax2 = fig.add_subplot(4, n, i+1+(2*n), frame_on=False)
        ax2.imshow(prototypes[i])
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax3 = fig.add_subplot(4, n, i+1+(3*n), frame_on=False)
        ax3.imshow(segments[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        if i == 0:
            ax2.set_ylabel('prototype approx.', size='large')
            ax3.set_ylabel('closest input segments', size='large')
        ax2.set_xlabel(f"sim = {round(cls_all_sim[i][2],4)}")
    return fig

# TODO :
'''Make prototype images clickable to see a few more concept images --> on cluster, also output top5-10 concept images'''



@pn.depends(cls1.param.value, cls2.param.value, inp.param.value)
def wfreq_plot(c1, c2, inp):
    '''Shows a plot for each of two chosen classes. Each plot shows the n most important prototypes for that class, along 
    with their similarity scores to some input segments. The corresponding prototype and segment images are shown as well.'''
    
    ip = Input_Data(inp)
    
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    n = 10
    
    fig = Figure(figsize=(24,16))
    FigureCanvas(fig)
    
    # plot 1
    ax1 = fig.add_subplot(2,2,1)
#     ax1.right = 0
    
    bar_width = 0.7
    index = np.arange(n)
    
    
    w_cls1 = [(k, v) for k, v in w_dict[cls1].items()]
    w_cls1.sort(key = lambda y: y[1], reverse=True)
    top_c_by_w = [c[0] for c in w_cls1[:n]]
    cls1_w = [ip.sim_dict[c][1] for c in top_c_by_w] # similarity scores for concepts with highest weight
    
    rects1 = ax1.bar(index, cls1_w, bar_width, color='b', label=c1)
    
    ax1.set(xlabel=f'Prototypes with highest weights for class {c1}', 
           ylabel='Similarity between the prototype and the input segment closest to it', 
           title=f'Similarity scores of prototypes with\nhighest weights for class {c1}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_xlim([0-bar_width, (n-1)+bar_width])
    ax1.set_ylim(0, 1)
    
    # plot 2
    ax2 = fig.add_subplot(2,2,2, sharey=ax1)
#     ax2.wspace = 0
    
#     index = np.arange(n)
#     bar_width = 0.7
    
    w_cls2 = [(k, v) for k, v in w_dict[cls2].items()]
    w_cls2.sort(key = lambda y: y[1], reverse=True)
    top_c_by_w2 = [c[0] for c in w_cls2[:n]]
    cls2_w = [ip.sim_dict[c][1] for c in top_c_by_w2] # similarity scores for concepts with highest weight
        
    rects2 = ax2.bar(index, cls2_w, bar_width, color='g', label=c2)
    
    ax2.set(xlabel=f'Prototypes with highest weights for class {c2}', 
#            ylabel='Weight of prototype for given class', 
           title=f'Similarity scores of prototypes with\nhighest weights for class {c2}'
          )
    ax2.legend()
    
    ax2.xaxis.set_tick_params(labelbottom=False)
    ax2.set_xticks([])
    ax2.set_xlim([0-bar_width, (n-1)+bar_width])
    ax2.set_ylim(0, 1)
    
    # images plot 1
    prototypes = []
    segments = []
    n1 = n
    n2 = n
    ppp = [concept_dict[c].split('/')[-1] for c in top_c_by_w]
    prototype_paths = [os.path.join(concept_dir, p) for p in ppp] # chosen class, max w, prototype path
    prototypes = [Image.open(p) for p in prototype_paths]
    
    sss = [ip.sim_dict[c][0].split('/')[-1] for c in top_c_by_w]
    segment_paths = [os.path.join(ip.segment_dir, s) for s in sss]
    segments = [Image.open(s) for s in segment_paths]


    for i in range(n1):
        ax3 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1, frame_on=False)
        ax3.imshow(prototypes[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        ax4 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1, frame_on=False)
        ax4.imshow(segments[i])
        ax4.set_xticks([])
        ax4.set_yticks([])
        if i == 0:
            ax3.set_ylabel('prototypes', size='large')
            ax4.set_ylabel('closest input\nsegments', size='large')
        ax3.set_xlabel(f"w = {round(w_cls1[i][1], 3)}")
            
    # images plot 2
    ppp2 = [concept_dict[c].split('/')[-1] for c in top_c_by_w2]
    prototype_paths2 = [os.path.join(concept_dir, p) for p in ppp2] # chosen class, max w, prototype path
    prototypes2 = [Image.open(p) for p in prototype_paths2]
    
    sss2 = [ip.sim_dict[c][0].split('/')[-1] for c in top_c_by_w2]
    segment_paths2 = [os.path.join(ip.segment_dir, s) for s in sss2]
    segments2 = [Image.open(s) for s in segment_paths2]


    for i in range(n2):
        ax5 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1+(n1+int(n/4)), frame_on=False)
        ax5.imshow(prototypes2[i])
        ax5.set_xticks([])
        ax5.set_yticks([])
        ax6 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*5+1+(n1+int(n/4)), frame_on=False)
        ax6.imshow(segments2[i])
        ax6.set_xticks([])
        ax6.set_yticks([])
        if i == 0:
            ax5.set_ylabel('prototypes', size='large')
            ax6.set_ylabel('closest input\nsegments', size='large')
        ax5.set_xlabel(f"w = {round(w_cls2[i][1], 3)}")


    return fig

# TODO :
'''Make prototype images clickable to see a few more concept images --> on cluster, also output top5-10 concept images'''


'''Some more visualisation to come...'''
# add barchart of weights for closest segments for ALL classes (maybe have actual class stand out in colour or size)
"""


expl_simfreq_plot = '### Prototypes most present in input image \n Below, the bar chart shows the weights for the prototypes that have the highest similarity score with its closest input image segment. The images shown for the prototypes are not the exact prototypes, but rather the training data segment closest to the actual prototype in latent space (which cannot be shown as a sensical image).'
expl_wfreq_plot = '### Most important prototypes per class \n Below, the bar charts show the similarity scores between the prototype and its closest input image segment, for the prototypes with the highest weight for the given class. (The similarity scores are based on the distance between the segment and the cluster center of a concept with respect to the cluster size. The prototypes shown are the training data segments that are closest to the cluster center.)'

'''Creates the page layout for the created plots.'''
app = pn.Column(
    pn.pane.Markdown('# Local visualisations', margin=(0,0,0,460), align='start'),
                                                       
    pn.pane.Markdown('## Choose input image', margin=(0,0,0,460), align='start'),
    pn.WidgetBox(inp, margin=(0,0,0,460), align='start'),
    pn.Row(class_overview, og_img, margin=(0,0,100,400), align='start'),
                                                       
    pn.pane.Markdown('# Why class x? Why not class y?', margin=(0,0,0,460), align='start'), 
    pn.pane.Markdown('## Choose classes to compare', margin=(0,0,0,460), align='start'), 
    pn.Row(cls1, cls2, margin=(0,0,50,460), align='start'), 
    
    pn.pane.Markdown('### Most positively contributing prototypes/segments', margin=(0,0,0,460), align='start'), 
    simxw_plot_combined,
    simxw_plot_mirror,
#     simxw_distr,
                                                       
#     pn.pane.Markdown(expl_simfreq_plot, width=600, margin=(0,0,0,460)), 
#     simfreq_plot,
                                                       
#     pn.pane.Markdown(expl_wfreq_plot, width=600, margin=(0,0,0,460)), 
#     wfreq_plot,
                                                       
    width=1600,
    background='white'
)

# app
'''Makes the entire panel usable outside of Jupyter notebook.'''
# app.servable();
server = app.show(threaded=True)

Launching server at http://localhost:62223


In [5]:
server.stop()

<b>Global visualisations</b>

In [1]:
import numpy as np
import json
import matplotlib.pyplot as plt
import os
from PIL import Image
from configparser import ConfigParser, ExtendedInterpolation

from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvas

import panel as pn


'''Set folders to be used. --> Can be turned into something more interactive'''
# Later add: classification_phase.py code to upload a file and create the results that can then be shown here.
# input_dir = 'medoid2'
input_dir = 'eval_labelmodel_1'
# curr_dir = os.getcwd()
# input_dir = os.path.join(curr_dir, input_dir)


'''Global variables'''
input_file = os.path.join(input_dir, 'classes.ini')
config = ConfigParser(interpolation=ExtendedInterpolation())
config.read(input_file)
class_files = config['classification']
img_paths = {i.split('/')[-1]: i.split('/')[-2] for i in class_files.get('img_paths').split('\n') if i}
print(img_paths)

start_val = list(img_paths.keys())[0]

w_dict_path = os.path.join(input_dir,'weights.txt')
class_dict_path = os.path.join(input_dir,'imagenet_classcode_to_index.json')
img_dir = os.path.join(input_dir, 'images')
concept_dir = os.path.join(input_dir, 'concepts')
concept_dict_path = os.path.join(input_dir, 'concept_files.txt')

with open(w_dict_path, 'r') as f:
    w_dict = json.load(f)
with open(class_dict_path, 'r') as f:
    class_dict = json.load(f)
with open(concept_dict_path, 'r') as f:
    concept_dict = json.load(f)

class_dict_clean = {k: v[1] for k, v in class_dict.items()}
inverted_dict = dict(map(reversed, class_dict_clean.items()))

class_list = [key for key in w_dict]
class_names = [class_dict_clean[x] for x in class_list]

eval_dir = os.path.join(input_dir, 'eval_results')
plot_eval = os.path.exists(eval_dir)


def draw_brace(ax, xspan, ylocs=None, dirr='up', text=''):
    '''Draws an annotated brace on the axes.'''
    xmin, xmax = xspan
    xspan = xmax - xmin
    ax_xmin, ax_xmax = ax.get_xlim()
    xax_span = ax_xmax - ax_xmin
    ymin, ymax = ax.get_ylim()
    yspan = ymax - ymin
    resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
    beta = 300./xax_span # the higher this is, the smaller the radius

    x = np.linspace(xmin, xmax, resolution)
    x_half = x[:resolution//2+1]
    y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
                    + 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
    y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))

    if ylocs is not None:
        if dirr == 'up':
            yloc = ylocs
            y = yloc + (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] + .01*yspan
            valign = 'bottom'
        elif dirr == 'down':
            yloc = ylocs # find out height of bracket
            y = yloc - (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] - .01*yspan
            valign = 'top'
    else:
        if dirr == 'up':
            yloc = ymax
            y = yloc + (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] + .01*yspan
            valign = 'bottom'
        elif dirr == 'down':
            yloc = ymin
            y = yloc - (.05*y - .01)*yspan
            texty = y[int(len(y)/2)] - .01*yspan
            valign = 'top'
     
    
    ax.plot(x, y, color='black', lw=1)

    ax.text((xmax+xmin)/2., texty, text, ha='center', va=valign, fontsize=16)




'''Accuracy per class'''
def class_acc():
    if plot_eval:
        eval_file = os.path.join(eval_dir,'model_accuracy.txt')
        with open(eval_file, 'r') as f:
            eval_dict = json.load(f)
        model_accuracy = eval_dict['accuracy']
        
        pred_matrix = eval_dict['pred_matrix']
        conf_matrix = [pm.copy() for pm in pred_matrix]
        for i in range(len(conf_matrix)):
            conf_matrix[i][i] = 0
        
        acc_dict = {}
        for i, class_preds in enumerate(pred_matrix):
            total = sum(class_preds)
            correct = class_preds[i]
            class_acc = correct/total
            
            most_conf = np.argmax(np.array(conf_matrix[i]))
            conf_class = class_names[most_conf]
            n_conf = max(conf_matrix[i])
            
            temp_dict = {'accuracy': class_acc, 'most_confused': (conf_class, n_conf)}
            acc_dict[class_names[i]] = temp_dict
            
        index = np.arange(len(class_list))

        accuracies = [temp_dict['accuracy'] for k,temp_dict in acc_dict.items()]

        fig = Figure(figsize=(6,6))
        FigureCanvas(fig)
        ax = fig.add_subplot()
        ax.barh(index, accuracies, align='center')
        ax.set_yticks(index)
        ax.set_yticklabels(class_names, ha='right')
        ax.set(xlabel='prediction accuracy', title=f'prediction accuracy per class')
        
        labels = [f"{temp_dict['most_confused'][0]}: {int(temp_dict['most_confused'][1]*100/50)}%" for k,temp_dict in acc_dict.items()]

        ax2 = ax.twinx()
        ax_lim = ax.get_ylim()
        ax2_lim = (ax_lim[0], ax_lim[1]) # Aligning the limits of both y axes
        ax2.set_ylim(ax2_lim)
        ax2.set_yticks(index)
        ax2.set_yticklabels(labels)
        
        ax.set(ylabel='class')
        ax2.set_ylabel('most misclassified as (and how often)', rotation=270, labelpad=15)
        
        fig.set_tight_layout(True)
        

        return fig

    else:
        return None



'''Create interactive widgets to select the to-be-compared classes.'''
cls1 = pn.widgets.Select(name='class 1', options=class_names)
cls2 = pn.widgets.Select(name='class 2', options=class_names)


'''Global visualisations'''
@pn.depends(cls1.param.value, cls2.param.value)
def distribution_plot(c2, c1):
    '''Distribution plot of the weights of concepts for the 2 given classes.'''
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    fig = Figure(figsize=(24,12))
    FigureCanvas(fig)
    
    # plot 1    
    w_c1 = [(k, v) for k, v in w_dict[cls1].items()]
    w_c1.sort(key = lambda y: y[1], reverse=True)
    w_c1_w = [x[1] for x in w_c1]
    w_c1_c = [x[0] for x in w_c1]    
    
    ax1 = fig.add_subplot(1,2,1)
    
    index = np.arange(len(w_c1))
    bar_width = 0.5
    
    rects1 = ax1.bar(index, w_c1_w, bar_width, color='g', label=c1)
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    
    ax1.set(xlabel=f'Concepts', 
           ylabel='Weight', 
           title=f'All weights for class {c1}'
          )
    ax1.legend()
    
    # plot 2
    w_c2 = [(k, v) for k, v in w_dict[cls2].items()]
    w_c2.sort(key = lambda y: y[1], reverse=True)
    w_c2_w = [x[1] for x in w_c2]
    w_c2_c = [x[0] for x in w_c2]
    
    ax2 = fig.add_subplot(1,2,2, sharey=ax1)
    
    index2 = np.arange(len(w_c2))
    bar_width2 = bar_width
    
    rects2 = ax2.bar(index2, w_c2_w, bar_width2, color='b', label=c2)
    
    ax2.xaxis.set_tick_params(labelbottom=False)
    ax2.set_xticks([])
    
    ax2.set(xlabel=f'Concepts', 
           ylabel='Weight', 
           title=f'All weights for class {c2}'
          )
    
    ax1.set_ylim(2*min([min(w_c1_w), min(w_c2_w)]),1.1*max([max(w_c1_w), max(w_c2_w)]))
    ax2.set_ylim(2*min([min(w_c1_w), min(w_c2_w)]),1.1*max([max(w_c1_w), max(w_c2_w)]))
    
    
    ax2.legend()
    
    
    ymin1, ymax1 = ax1.get_ylim()
    ymin2, ymax2 = ax2.get_ylim()
    ymin = min([ymin1, ymin2])
    ymax = max([ymax1, ymax2])
    ymin = ymin/2
    
    w_pos_cls1 = [x for x in w_c1_w if x >= 0]
    w_neg_cls1 = [x for x in w_c1_w if x < 0]
    sum_pos_cls1 = sum(w_pos_cls1)
    sum_neg_cls1 = sum(w_neg_cls1)
    sum_tot_cls1 = sum_pos_cls1 + sum_neg_cls1
    
    sum_pos_cls1_sh = round(sum_pos_cls1*100)/100
    sum_neg_cls1_sh = round(sum_neg_cls1*100)/100
    sum_tot_cls1_sh = round(sum_tot_cls1*100)/100
    
    thres_cls1 = len(w_pos_cls1)
    draw_brace(ax1, (0, thres_cls1), 0, 'down', f"{sum_pos_cls1_sh}")
    draw_brace(ax1, (thres_cls1, len(list(concept_dict.keys()))), 0, 'up', f"{sum_neg_cls1_sh}")
    draw_brace(ax1, (0, len(list(concept_dict.keys()))), ymin, 'down', f"{sum_tot_cls1_sh}")
    
    
    w_pos_cls2 = [x for x in w_c2_w if x >= 0]
    w_neg_cls2 = [x for x in w_c2_w if x < 0]
    sum_pos_cls2 = sum(w_pos_cls2)
    sum_neg_cls2 = sum(w_neg_cls2)
    sum_tot_cls2 = sum_pos_cls2 + sum_neg_cls2
    
    sum_pos_cls2_sh = round(sum_pos_cls2*100)/100
    sum_neg_cls2_sh = round(sum_neg_cls2*100)/100
    sum_tot_cls2_sh = round(sum_tot_cls2*100)/100
    
    thres_cls2 = len(w_pos_cls2)
    draw_brace(ax2, (0, thres_cls2), 0, 'down', f"{sum_pos_cls2_sh}")
    draw_brace(ax2, (thres_cls2, len(list(concept_dict.keys()))), 0, 'up', f"{sum_neg_cls2_sh}")
    draw_brace(ax2, (0, len(list(concept_dict.keys()))), ymin, 'down', f"{sum_tot_cls2_sh}")

    
    
    
    
    return fig

'''Top prototypes per class'''
@pn.depends(cls1.param.value, cls2.param.value)##
def w_plot_mirror(c1, c2):
    '''Shows a bar plot for the n highest weights for two chosen classes.
    Also shows the corresponding images.'''
        
    cls1 = inverted_dict[c1]
    cls2 = inverted_dict[c2]
    
    n = 10
    
    fig = Figure(figsize=(24,16))
    FigureCanvas(fig)
        
    bar_width = 0.7
    index = np.arange(n)
    index1 = [i+1 for i in index]

    # plot 1
    ax1 = fig.add_subplot(2,1,1)

    w_cls1 = [(c, w_dict[cls1][c]) for c in concept_dict.keys()]
    w_cls1.sort(key=lambda y: y[1], reverse=True)
    
    top_w1_cls1 = [tup[1] for tup in w_cls1[:n]]
    
    rects1b = ax1.bar(index1, top_w1_cls1, bar_width, color='b', edgecolor='black', label=f'weight class {c1}')
    
    
    w_cls2 = [(c, w_dict[cls2][c]) for c in concept_dict.keys()]
    w_cls2.sort(key=lambda y: y[1], reverse=False)
    
    top_w2_cls2 = [tup[1] for tup in w_cls2[-n:]]
    
    index2 = [-i for i in index1]
    index2.sort()
    rects2b = ax1.bar(index2, top_w2_cls2, bar_width, color='g', edgecolor='black', label=f'weight class {c2}')
    
    ax1.set(xlabel=f'Prototypes with highest weights for classes {c1} and {c2}', 
           ylabel='Weight for given prototype and the input image segment closest to it', 
           title=f'Most important prototypes for the given classes {c1} and {c2}'
          )
    ax1.legend()
    
    ax1.xaxis.set_tick_params(labelbottom=False)
    ax1.set_xticks([])
    ax1.set_xlim([-(n+bar_width), (n)+bar_width])    
    
    y_min = min([-0.5, min(top_w1_cls1), min(top_w2_cls2)])
    y_max = max([0.5, max(top_w1_cls1), max(top_w2_cls2)])
    
    ax1.set_ylim(y_min-0.1*y_min, 1.1*y_max)
    
    
    # images plot 1
    prototypes = []
    n1 = n
    n2 = n
    ppp = [concept_dict[tup[0]].split('/')[-1] for tup in w_cls2[-n:]]
    prototype_paths = [os.path.join(concept_dir, p) for p in ppp] # chosen class, max w, prototype path
    prototypes = [Image.open(p) for p in prototype_paths]

    for i in range(n1):
        ax3 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1, frame_on=False)
        ax3.imshow(prototypes[i])
        ax3.set_xticks([])
        ax3.set_yticks([])
        if i == 0:
            ax3.set_ylabel('prototypes', size='large')
        ax3.set_xlabel(f"w = {round(w_cls2[-n:][i][1], 3)}")
    
    
    
    
    # images plot 2
    prototypes2 = []
    ppp2 = [concept_dict[tup[0]].split('/')[-1] for tup in w_cls1[:n]]
    prototype_paths2 = [os.path.join(concept_dir, p) for p in ppp2] # chosen class, max w, prototype path
    prototypes2 = [Image.open(p) for p in prototype_paths2]
    
    for i in range(n2):
        ax5 = fig.add_subplot(8, n1+n2+int(n/4), i+(n1+n2+int(n/4))*4+1+(n1+int(n/4)), frame_on=False)
        ax5.imshow(prototypes2[i])
        ax5.set_xticks([])
        ax5.set_yticks([])
        if i == 0:
            ax5.set_ylabel('prototypes', size='large')
        ax5.set_xlabel(f"w = {round(w_cls1[i][1], 3)}")
    
    
    return fig




'''Creates the page layout for the created plots.'''
app = pn.Column(
    class_acc,
    pn.WidgetBox(cls1, cls2),
    distribution_plot,
    w_plot_mirror,
    width=1600,
    background='white'
)

# app
'''Makes the entire panel usable outside of Jupyter notebook.'''
# app.servable();
server = app.show(threaded=True)

{'n01641577_4373.JPEG': 'n01641577', 'n03633091_3191.JPEG': 'n03633091', 'ILSVRC2012_val_00003722.JPEG': 'n01693334', 'ILSVRC2012_val_00026626.JPEG': 'n01693334', 'ILSVRC2012_val_00004688.JPEG': 'n11879895', 'ILSVRC2012_val_00000203.JPEG': 'n02106382', 'ILSVRC2012_val_00000717.JPEG': 'n03785016', 'ILSVRC2012_val_00007332.JPEG': 'n11879895', 'ILSVRC2012_val_00003989.JPEG': 'n04044716', 'ILSVRC2012_val_00001444.JPEG': 'n09246464', 'ILSVRC2012_val_00000773.JPEG': 'n07873807', 'ILSVRC2012_val_00007029.JPEG': 'n07742313', 'ILSVRC2012_val_00000494.JPEG': 'n02777292', 'ILSVRC2012_val_00003803.JPEG': 'n02777292', 'ILSVRC2012_val_00004895.JPEG': 'n02777292', 'ILSVRC2012_val_00005105.JPEG': 'n02777292'}
Launching server at http://localhost:58092


