In [1]:
%load_ext autoreload
%autoreload 2

import nntools
from nntools.utils import Config
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt

import sys
sys.path.append('../')

from experiment import OCTClassification


In [2]:
config_path = '../configs/config.yaml'
config = Config(config_path)
config['Manager']['run'] = 'Wide ResNet-101-2'
experiment = OCTClassification(config)



In [3]:
test_dataset = experiment.test_dataset

In [4]:
confMat_resNet = np.load('/home/clement/Documents/Clement/runs/mlruns/1/d90aa317c79140a587451d6e30c848b0/artifacts/confMat.npy')
confMat_vit = np.load('/home/clement/Documents/Clement/runs/mlruns/1/8cde2d9e96fb475681128bce91f90c07/artifacts/confMat.npy')
confMat_OpticNet = np.load('/home/clement/Documents/Clement/runs/mlruns/1/e3c337f9fe6647cf9d61e81109fd6644/artifacts/test_confMat.npy')

confs = {'WResNet-101(2)':confMat_resNet, 'ViT_L_32_384':confMat_vit, 'OpticNet71':confMat_OpticNet}


In [5]:
from bokeh.plotting import figure
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.io import output_notebook
from bokeh.plotting import figure, show, output_file
from collections import OrderedDict
from bokeh.layouts import layout, widgetbox, column
from bokeh.models import CustomJS, ColumnDataSource, Slider

In [6]:
def disp(confMat, labels) :
    TOOLS = "hover,save,pan"
    the_color = '#00cc66'
    
    def get_list(cm):
        predicted = []
        actual = []
        count = []
        color = []
        alpha = []
        ratios = []
        N = len(labels)
        confMatNormalized = cm/cm.sum(0)
        for coli in range(N):
            for rowi in range(N):
                predicted.append(labels[coli])
                actual.append(labels[rowi])
                count.append(cm[coli,rowi])
                ratio = confMatNormalized[coli,rowi]
                ratios.append(ratio)
                a = min(ratio+0.01, 1)
                alpha.append(a)
                color.append(the_color)
        return predicted, actual, count, color, alpha, ratios
    
    source_data = {}
    if isinstance(confMat, dict):
        for i, (k, v) in enumerate(confMat.items()):
            predicted, actual, count, color, alpha,ratios = get_list(v)
            source_data[str(i)+'count'] = count
            source_data[str(i)+'alphas'] = alpha
            source_data[str(i)+'ratios'] = ratios

            if i == 0:
                source_data['predicted'] = predicted
                source_data['groundtruth'] = actual
                source_data['count'] = count
                source_data['colors'] = color
                source_data['alphas'] = alpha
                source_data['ratios'] = ratios

    else:
        predicted, actual, count, color, alpha, ratios = get_list(confMat)
        source_data['predicted'] = predicted
        source_data['groundtruth'] = actual
        source_data['count'] = count
        source_data['colors'] = color
        source_data['alphas'] = alpha
        source_data['ratios'] = ratios
        
    source = ColumnDataSource(data=source_data)
    p = figure(title='Confusion Matrix',
         x_axis_location="above", tools="hover,save",
         y_range=labels[::-1], x_range=labels)
    
    p.plot_width = 600
    p.plot_height = p.plot_width
    rectwidth = 0.9
    
    
    p.rect('predicted', 'groundtruth', rectwidth, rectwidth, source=source,
          color='colors', alpha='alphas',line_width=1)
    
    p.text(x='predicted', y='groundtruth', text='count', source=source, text_align='center',
           text_baseline='middle')

    p.axis.major_label_text_font_size = "12pt"
    p.axis.major_label_standoff = 1
    p.xgrid.visible = False
    p.ygrid.visible = False
    p.xaxis.axis_label = 'Predicted'
    p.yaxis.axis_label = 'Groundtruth'
    
    
    hover = p.select(dict(type=HoverTool))
    hover.tooltips = OrderedDict([
        ('predicted', '@predicted'),
        ('groundtruth', '@groundtruth'),
        ('ratio', '@ratios'),

    ])
    if isinstance(confMat, dict):
        from bokeh.models.widgets import RadioButtonGroup
        
        callback = CustomJS(args=dict(source=source), 
                            code=
                            """
                            var data = source.data;
                            var f = cb_obj.active
                            var count = data['count']
                            var alphas = data['alphas']
                            var ratios = data['ratios']

                            for (var i = 0; i < count.length; i++) {
                                count[i] = data[f.toString()+'count'][i]
                                alphas[i] = data[f.toString()+'alphas'][i]
                                ratios[i] = data[f.toString()+'ratios'][i]

                            }
                            source.change.emit();
                            """
                           )
        radio_button_group = RadioButtonGroup(labels=list(confMat.keys()), active=0, callback=callback)
        radio_button_group.js_on_change('active', callback)
        
        p = column(radio_button_group, p)
    return p
labels = test_dataset.map_class.keys()
p = disp(confs, list(labels))
output_notebook()
show(p)

In [None]:
from bokeh.embed import file_html
from bokeh.resources import CDN
html = file_html(p, CDN, "ConfMat")

In [None]:
with open("ConfMat.html", "w") as text_file:
    text_file.write(html)