# Evaluating of checkpoints
Notebook is setup specifically for MAnet  

## Setup

In [2]:
from datetime import datetime
import os
from pathlib import Path
import numpy as np
import copy
import cv2
from tqdm import tqdm
import gc
from torch import cuda
import pandas as pd

import torch
import torchvision.transforms as tf
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.metrics.functional import accuracy as acc  


### Get data sources

In [None]:
ValidFolder="./Data/trainData/Arundo4/"  #Used for validation and evaluation after training
vListImages=os.listdir(os.path.join(ValidFolder, "images")) # Create list of validation images

### necessary parameters

In [None]:
width=height=256 # image width and height  ( USED to generate a default empty label)

### declaration of necessary functions

#### Confusion Matrix function for pixel segmentation (for one class)
segmentation_models_pytorch.metrics.functional.accuracy(tp, fp, fn, tn, reduction=None, class_weights=None, zero_division=1.0)
https://smp.readthedocs.io/en/stable/metrics.html#segmentation_models_pytorch.metrics.functional.accuracy 

In [None]:
def bitmapConfMatrix(Output, Target):
    tp, fp, fn, tn = smp.metrics.get_stats(Output, Target, mode='multilabel', threshold=0.5)

    TruePs=int(torch.count_nonzero(tp).cpu().detach().numpy())
    FalsePs=int(torch.count_nonzero(fp).cpu().detach().numpy())
    TrueNs=int(torch.count_nonzero(tn).cpu().detach().numpy())
    FalseNs=int(torch.count_nonzero(fn).cpu().detach().numpy())
    acc=float(smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro"))

    return TruePs,FalsePs,FalseNs,TrueNs, acc         

In [None]:
def getEpochNum(checkpoint):
    detectEpoch=False
    epochNum=""
    for char in checkpoint:
        if char=='-': detectEpoch=False
        if detectEpoch: epochNum=epochNum+char
        if char=='/': detectEpoch=True
    return epochNum       

In [None]:
# Data Transformation

tensorise=tf.ToTensor()

def AdaptMask(Lbl):   #function to adapt mask to Tensor
    Lbl=Lbl.astype(np.float32)
    Lbl=Lbl/10
    Lbl=Lbl.astype(int)
    Lbl=tensorise(Lbl)
    return Lbl
    

transformImg= tf.Compose([tf.ToPILImage(),tf.ToTensor(),tf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) #function to adapt image
# Normalize parameters are suggested by PyTorch documentation

## Setup the model and prepare for Evaluation

classsegmentation_models_pytorch.MAnet(encoder_name='resnet34', encoder_depth=5, encoder_weights='imagenet', decoder_use_batchnorm=True, decoder_channels=(256, 128, 64, 32, 16), decoder_pab_channels=64, in_channels=3, classes=1, activation=None, aux_params=None)

More Models: https://smp.readthedocs.io/en/stable/models.html 
Backbones to choose from:  https://smp.readthedocs.io/en/stable/encoders.html 

In [3]:
model = smp.MAnet(
    encoder_name="efficientnet-b7",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use None or `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,                      # model output channels (number of classes in your dataset, add +1 for background)
    # activation='softmax',  #deprecated for some models.  Last activation is self(x)
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device: ", device)

Net = model # Load net
Net=Net.to(device)

model_naming_title="MaNet-ENb7"

device:  cuda


In [4]:
Net.eval()  

MAnet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          64, 64, kernel_size=(3, 3), stride=[1, 1], groups=64, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          64, 16, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          16, 64, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
       

# Picking the best checkpoints to evaluate

**This requires access to the training log**

## pick the criteria for choosing the epochs

``0: picks the checkpoints with the best Validation Loss (default)``

``1: picks the checkpoints with the best Accuracy score on the validation set``

``2: picks the checkpoints with the best Training Loss``

``anything else: picks the last checkpoints saved regardless of how they perform``

````

Please enter one of the above vales as the criteria in the next cell.

In [None]:
criteria=0

### Select how many checkpoints to compare in the next cell (default 10)

In [None]:
number_of_chkpoints=10

In [None]:
model_naming_title="MaNet-ENb7"
log_path='LOG for '+model_naming_title+'.csv'


if os.path.exists(log_path):
    print("A log file for ",model_naming_title," was found as: ",log_path)
    log_DB=pd.read_csv(log_path, sep=",", index_col=None)
    
    #ToEvaluate=log_DB[log_DB['CheckPoint'] != "not saved"]
    ToEvaluate=log_DB[log_DB.loc[:, ('CheckPoint')] != "not saved"]
    
    if criteria==0:
        print("Picking the ",number_of_chkpoints, " checkpoints with a lowest validation loss" )
        ToEvaluate=ToEvaluate.sort_values(by=['Val-Loss'], ascending=True)# , inplace=True)
    
    if criteria==1:
        print("Picking the ",number_of_chkpoints, " checkpoints with a lowest validation loss" )
        ToEvaluate=ToEvaluate.sort_values(by=['Acc'], ascending=False)

    if criteria==2:
        print("Picking the ",number_of_chkpoints, " checkpoints with a lowest training loss" )
        ToEvaluate=ToEvaluate.sort_values(by=['Train-Loss'], ascending=True)

    if criteria>2:
        print("Picking the last ",number_of_chkpoints, " checkpoints" )
        ToEvaluate=ToEvaluate.tail(number_of_chkpoints)


    
    ToEvaluate=ToEvaluate.head(number_of_chkpoints)
    checkpoints=ToEvaluate['Epoch'].values
    ToEvaluate=ToEvaluate['CheckPoint']
    
    print("")
    print("The following checkpoints will be evaluated:")
    f = open('checkpoints_to_evaluate.txt', 'w')
    for i in range(ToEvaluate.size):
        print("Checkpoint ",checkpoints[i]," at path: ", ToEvaluate.values[i])
        #f.write(str(checkpoints[i]))
        #f.write(',')
        f.write(str(ToEvaluate.values[i]))
        f.write('\n')
    f.close()
    print("")
    os.startfile('checkpoints_to_evaluate.txt')
    print("To add more checkpoints manually, add them to the file: checkpoints_to_evaluate.txt")
else:
    print("A log file for ",model_naming_title," was not found")
    print("set the log_path variable at the top of the cell to the log file you want to use and run this cell again")
    


### Plot training loss graph

In [None]:
import matplotlib.pyplot as plt

log_DB=pd.read_csv(log_path, sep=",")

xAxis=log_DB['Epoch']
yAxis=log_DB['Train-Loss']
yAxis2=log_DB['Valid-Loss']

plt.plot(xAxis,yAxis,xAxis,yAxis2)

plt.title('Loss Chart')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.show()

# Evaluation of the checkpoints


In [None]:
list = open('checkpoints_to_evaluate.txt', 'r')

result_titles=['ImageTile','True-Positives','False-Positives','True-Negatives','True-Negatives','Accuracy']

for checkpoint in list:
    print("Evaluating ", checkpoint)
    epochNum=getEpochNum(checkpoint)
    Net.load_state_dict(torch.load(checkpoint[0:-1]))
    gc.collect()
    cuda.empty_cache()
    Net.eval()  #  sets network to evaluation model
    ValACC=0
    tp=0
    fp=0
    fn=0
    tn=0
    results=pd.DataFrame(columns=result_titles)
    
    with torch.no_grad():  # tells the Net not to perform gradient descent (since we are only evaluating)
        for i in tqdm(range(len(vListImages))):

            idx=i

            Img=cv2.imread(os.path.join(ValidFolder, "images", vListImages[idx]), cv2.IMREAD_COLOR)#[:,:,0:3]
            Img=transformImg(Img)
            image=torch.autograd.Variable(Img, requires_grad=False).to(device).unsqueeze(0) # Load image
        
            Lbl=cv2.imread(os.path.join(ValidFolder, "labels", vListImages[idx]), cv2.COLOR_GRAY2BGR )#[:,:,0:3]
            if type(Lbl)==type(None): Lbl=np.zeros((width, height), dtype=np.int8)
            Lbl=Lbl/10
            Lbl=Lbl.astype(np.int8)
            Target= torch.from_numpy(Lbl).to(device)
        
            Pred=Net(image)
            Output=Pred[0][1]
       
            a,b,c,d,e=bitmapConfMatrix(Output, Target)
            tp=tp+a
            fp=fp+b
            fn=fn+c
            tn=tn+d
            ValACC=ValACC+e
            #print("IMG tile:",idx ," TP:",a ," FP:",b," FN:",c," TN:",d," Acc:", e)
            
            new_result_entry=pd.DataFrame([[vListImages[idx], a,b,c,d,e]], columns=result_titles)
            results=pd.concat([results, new_result_entry])
            gc.collect()
    
    results_path="results for checkpoint "+epochNum+".csv"
    results.to_csv(results_path, sep=",")
        
        
    ValACC=ValACC/len(vListImages)

    #Generate Result
    
    f = open(results_path, 'a')
    f.write('\n')
    f.write('\n')
    f.write(',')

    f.write(',')
    f.write(str(tp))
    f.write(',')

    f.write(str(fp))
    f.write(',')

    f.write(str(fn))
    f.write(',')

    f.write(str(tn))
    f.write(',')

    f.write(str(ValACC))
        
    f.close()   
    #os.startfile(results_path)  #opens the CSV in the os' default application for csv files
    del results
    
    print("TP:",tp,"  FP:",fp)
    print("FN:",fn,"  TN:",tn)
    print("\nTP% :",tp/(tp+fp)  )
    print("TN% :",tn/(fn+tn))
    

    print("\nAccuracy:", ValACC*100)

    del Net
    gc.collect()
    cuda.empty_cache()
    print("FINISHED.... Results Saved As:",results_path)


    