# MIL Evaluation

In [None]:

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.nn.functional import softmax
import numpy as np
import torchvision.transforms as transforms
import time
import os
import argparse
import sys

sys.path.append("../..")
from datasets import Cell_Sampling_new
import backbones.cell as cell_model
from Utils import adjust_learning_rate, progress_bar, Logger, mkdir_p,dataset_setting,balance_val_split
from Utils import utils
from sklearn.metrics import top_k_accuracy_score,f1_score,roc_auc_score

# Specify the random selection script ID
random_script_id =3
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


# Data
print('==> Preparing data..')


transform_test = transforms.Compose([
    transforms.ToTensor(),
])

'''self-supverised learning features'''

data_path = "./image_masks/"
label_id_path_file = './train_32class_batch_idx.txt'


random_train_val = './experiment_setup/batch_separated/random{}.txt'.format(random_script_id)



# train_set = Cell_Sampling_new.Cell_Features_sampling(data_path,label_id_path_file,train=True,
#                                   transform=transform_test,random_train_val=random_train_val)

val_set  = Cell_Sampling_new.Cell_Features_sampling(data_path,label_id_path_file,train=False,
                         transform=transform_test,random_train_val=random_train_val)


In [None]:
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    weights = [item[2] for item in batch]

    return [data,target,weights]

In [None]:
#trainloader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=0,collate_fn=my_collate)
valloader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0,collate_fn=my_collate)
train_class_num = val_set.class_number


resume = './experimental_models/batch_separated/random{}/CLANet/last_model.pth'.format(random_script_id)

net = cell_model.GatedAttention(train_class_num)

criterion = nn.CrossEntropyLoss(reduction='none')
net = utils.DataParallel_withLoss(net,criterion,device_ids=[0],output_device=0)

cudnn.benchmark = True
net = net.module

if resume:
    # Load checkpoint.
    if os.path.isfile(resume):

        print('==> Resuming from checkpoint..')
        # change here
        checkpoint = torch.load(resume)
        net.load_state_dict(checkpoint['net'])
        start_epoch = checkpoint['epoch']
        
    else:
        print("=> no checkpoint found at '{}'".format(resume))
        
        
net = net.model

In [None]:
net.eval()

test_loss = 0
correct = 0
total = 0

#print(net.module)

batch_predict_results,class_labels,net_features,attention_weights =[],[],[],[]


scores, labels,A_weights = [], [],[]

with torch.no_grad():
    for batch_idx, (inputs, targets,_) in enumerate(valloader):
        inputs, targets = inputs[0].cuda(), targets[0].cuda()

        prob,M,A = net(inputs)
    

        scores.append(softmax(prob))
        labels.append(targets)
        A_weights.append(A.cpu().numpy().squeeze())
        net_features.append(M)
        
scores = np.array(torch.cat(scores, dim=0).cpu().numpy())
labels = np.array(torch.cat(labels, dim=0).cpu().numpy())
net_features = np.array(torch.cat(net_features, dim=0).cpu().numpy())
#A_weights = np.array(torch.cat(A_weights, dim=0).cpu().numpy())

In [None]:

class_num = train_class_num
seq_top1_acc,seq_top3_acc,seq_top5_acc = top_k_accuracy_score(labels,scores,k=1,labels=range(class_num))\
                                        ,top_k_accuracy_score(labels,scores,k=3,labels=range(class_num))\
                                        ,top_k_accuracy_score(labels,scores,k=5,labels=range(class_num))

seq_f1 = f1_score(labels,np.argmax(scores,axis=1),average='macro')
seq_auc = roc_auc_score(labels,scores,multi_class='ovr')

print(seq_top1_acc,seq_f1)

batch_token = val_set.batch_token
b_scores, b_labels = [],[]
for batch_id in np.unique(batch_token):
    b_ids = np.where(batch_token == batch_id)[0]

    b_score, b_label = np.mean(scores[b_ids],axis=0), np.argmax(np.bincount(labels[b_ids]))
    
    b_scores.append(b_score)
    b_labels.append(b_label)

b_top1_acc, b_top3_acc, b_top5_acc = top_k_accuracy_score(b_labels, b_scores, k=1, labels=range(class_num)) \
                                    , top_k_accuracy_score(b_labels, b_scores, k=3, labels=range(class_num)) \
                                    , top_k_accuracy_score(b_labels, b_scores, k=5, labels=range(class_num))

b_f1 = f1_score(b_labels, np.argmax(b_scores, axis=1), average='macro')
b_auc = roc_auc_score(b_labels, b_scores, multi_class='ovr')

print(b_top1_acc,b_f1)

# Show Confusion Matrix

In [None]:
label_id_path_file = '/projects/img/cellbank/Cell_Batch_Classification/model/classification/train_32class_batch_idx.txt'

with open(label_id_path_file, 'r') as f:
    lines = f.readlines()
    lines = [line.rstrip() for line in lines]

In [None]:
'''Confusion matrix for sequence prediction'''
from sklearn.metrics import confusion_matrix
import math
matrix = confusion_matrix(labels, np.argmax(scores,axis=1),labels=np.arange(0,32,1))

cm = matrix.diagonal()/matrix.sum(axis=1)

for i in range(cm.shape[0]):
    if math.isnan(cm[i]) == False:
        print(cm[i])
        #print(lines[i],cm[i])
    

In [None]:
'''Confusion matrix for batch prediction'''
from sklearn.metrics import confusion_matrix
import math
matrix = confusion_matrix(b_labels, np.argmax(b_scores,axis=1),labels=np.arange(0,32,1))

cm = matrix.diagonal()/matrix.sum(axis=1)

for i in range(cm.shape[0]):
    if math.isnan(cm[i]) == False:
        print(cm[i])
        #print(lines[i],cm[i])
    

In [None]:
# -*- coding: utf-8 -*-
"""
plot a pretty confusion matrix with seaborn
Created on Mon Jun 25 14:17:37 2018
@author: Wagner Cipriano - wagnerbhbr - gmail - CEFETMG / MMC
REFerences:
  https://www.mathworks.com/help/nnet/ref/plotconfusion.html
  https://stackoverflow.com/questions/28200786/how-to-plot-scikit-learn-classification-report
  https://stackoverflow.com/questions/5821125/how-to-plot-confusion-matrix-with-string-axis-rather-than-integer-in-python
  https://www.programcreek.com/python/example/96197/seaborn.heatmap
  https://stackoverflow.com/questions/19233771/sklearn-plot-confusion-matrix-with-labels/31720054
  http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
"""

import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sn
from matplotlib.collections import QuadMesh


def get_new_fig(fn, figsize=[9, 9]):
    """Init graphics"""
    fig1 = plt.figure(fn, figsize)
    ax1 = fig1.gca()  # Get Current Axis
    ax1.cla()  # clear existing plot
    return fig1, ax1


def configcell_text_and_colors(
    array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0
):
    """
    config cell text and colors
    and return text elements to add and to dell
    @TODO: use fmt
    """
    text_add = []
    text_del = []
    cell_val = array_df[lin][col]
    tot_all = array_df[-1][-1]
    per = (float(cell_val) / tot_all) * 100
    curr_column = array_df[:, col]
    ccl = len(curr_column)

    # last line  and/or last column
    if (col == (ccl - 1)) or (lin == (ccl - 1)):
        # tots and percents
        if cell_val != 0:
            if (col == ccl - 1) and (lin == ccl - 1):
                tot_rig = 0
                for i in range(array_df.shape[0] - 1):
                    tot_rig += array_df[i][i]
                per_ok = (float(tot_rig) / cell_val) * 100
            elif col == ccl - 1:
                tot_rig = array_df[lin][lin]
                per_ok = (float(tot_rig) / cell_val) * 100
            elif lin == ccl - 1:
                tot_rig = array_df[col][col]
                per_ok = (float(tot_rig) / cell_val) * 100
            per_err = 100 - per_ok
        else:
            per_ok = per_err = 0

        per_ok_s = ["%.2f%%" % (per_ok), "100%"][per_ok == 100]

        # text to DEL
        text_del.append(oText)

        # text to ADD
        font_prop = fm.FontProperties(weight="bold", size=fz)
        text_kwargs = dict(
            color="w",
            ha="center",
            va="center",
            gid="sum",
            fontproperties=font_prop,
        )
        lis_txt = ["%d" % (cell_val), per_ok_s, "%.2f%%" % (per_err)]
        lis_kwa = [text_kwargs]
        dic = text_kwargs.copy()
        dic["color"] = "g"
        lis_kwa.append(dic)
        dic = text_kwargs.copy()
        dic["color"] = "r"
        lis_kwa.append(dic)
        lis_pos = [
            (oText._x, oText._y - 0.3),
            (oText._x, oText._y),
            (oText._x, oText._y + 0.3),
        ]
        for i in range(len(lis_txt)):
            newText = dict(
                x=lis_pos[i][0],
                y=lis_pos[i][1],
                text=lis_txt[i],
                kw=lis_kwa[i],
            )
            text_add.append(newText)

        # set background color for sum cells (last line and last column)
        carr = [0.27, 0.30, 0.27, 1.0]
        if (col == ccl - 1) and (lin == ccl - 1):
            carr = [0.17, 0.20, 0.17, 1.0]
        facecolors[posi] = carr

    else:
        if per > 0:
            txt = "%s\n%.2f%%" % (cell_val, per)
        else:
            if show_null_values == 0:
                txt = ""
            elif show_null_values == 1:
                txt = "0"
            else:
                txt = "0\n0.0%"
        oText.set_text(txt)

        # main diagonal
        if col == lin:
            # set color of the textin the diagonal to white
            oText.set_color("w")
            # set background color in the diagonal to blue
            facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
        else:
            oText.set_color("r")

    return text_add, text_del


def insert_totals(df_cm):
    """insert total column and line (the last ones)"""
    sum_col = []
    for c in df_cm.columns:
        sum_col.append(df_cm[c].sum())
    sum_lin = []
    for item_line in df_cm.iterrows():
        sum_lin.append(item_line[1].sum())
    df_cm["sum_lin"] = sum_lin
    sum_col.append(np.sum(sum_lin))
    df_cm.loc["sum_col"] = sum_col


def pp_matrix(
    df_cm,
    annot=True,
    cmap="Oranges",
    fmt=".2f",
    fz=11,
    lw=0.5,
    cbar=False,
    figsize=[8, 8],
    show_null_values=0,
    pred_val_axis="y",
):
    """
    print conf matrix with default layout (like matlab)
    params:
      df_cm          dataframe (pandas) without totals
      annot          print text in each cell
      cmap           Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
      fz             fontsize
      lw             linewidth
      pred_val_axis  where to show the prediction values (x or y axis)
                      'col' or 'x': show predicted values in columns (x axis) instead lines
                      'lin' or 'y': show predicted values in lines   (y axis)
    """
    if pred_val_axis in ("col", "x"):
        xlbl = "Predicted"
        ylbl = "Actual"
    else:
        xlbl = "Actual"
        ylbl = "Predicted"
        df_cm = df_cm.T

    # create "Total" column
    insert_totals(df_cm)

    # this is for print allways in the same window
    fig, ax1 = get_new_fig("Conf matrix default", figsize)

    ax = sn.heatmap(
        df_cm,
        annot=annot,
        annot_kws={"size": fz},
        linewidths=lw,
        ax=ax1,
        cbar=cbar,
        cmap=cmap,
        linecolor="w",
        fmt=fmt,
    )

    # set ticklabels rotation
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=10)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=25, fontsize=10)

    # Turn off all the ticks
    for t in ax.xaxis.get_major_ticks():
        t.tick1line.set_visible(False)
        t.tick2line.set_visible(False)
    for t in ax.yaxis.get_major_ticks():
        t.tick1line.set_visible(False)
        t.tick2line.set_visible(False)

    # face colors list
    quadmesh = ax.findobj(QuadMesh)[0]
    facecolors = quadmesh.get_facecolors()

    # iter in text elements
    array_df = np.array(df_cm.to_records(index=False).tolist())
    text_add = []
    text_del = []
    posi = -1  # from left to right, bottom to top.
    for t in ax.collections[0].axes.texts:  # ax.texts:
        pos = np.array(t.get_position()) - [0.5, 0.5]
        lin = int(pos[1])
        col = int(pos[0])
        posi += 1

        # set text
        txt_res = configcell_text_and_colors(
            array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values
        )

        text_add.extend(txt_res[0])
        text_del.extend(txt_res[1])

    # remove the old ones
    for item in text_del:
        item.remove()
    # append the new ones
    for item in text_add:
        ax.text(item["x"], item["y"], item["text"], **item["kw"])

    # titles and legends
    
    ax.set_xlabel(xlbl)
    ax.set_ylabel(ylbl)
    plt.tight_layout()  # set layout slim
    
    plt.show()


def pp_matrix_from_data(
    y_test,
    predictions,
    columns=None,
    annot=True,
    cmap="Oranges",
    fmt=".2f",
    fz=11,
    lw=0.5,
    cbar=False,
    figsize=[8, 8],
    show_null_values=0,
    pred_val_axis="lin",
):
    """
    plot confusion matrix function with y_test (actual values) and predictions (predic),
    whitout a confusion matrix yet
    """
    from pandas import DataFrame
    from sklearn.metrics import confusion_matrix

    # data
    if not columns:
        from string import ascii_uppercase

        columns = [
            "class %s" % (i)
            for i in list(ascii_uppercase)[0 : len(np.unique(y_test))]
        ]

    confm = confusion_matrix(y_test, predictions)
    df_cm = DataFrame(confm, index=columns, columns=columns)
    pp_matrix(
        df_cm,
        fz=fz,
        cmap=cmap,
        figsize=figsize,
        show_null_values=show_null_values,
        pred_val_axis=pred_val_axis,
    )


In [None]:
import numpy as np

pp_matrix_from_data(labels, np.argmax(scores,axis=1),columns=lines,figsize=[25,25])
