In [None]:
# General imports.
import os
import time
import wandb
import random
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from pathlib import Path

# Metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score

In [None]:
# Class names.
classes = 'classes'
# Number of classes to classify.
classesLen = 'classesLen'

config = {
    classes : [
        'Apple - Apple scab',
        'Apple - Black rot',
        'Apple - Cedar apple rust',
        'Apple - Healthy'
        'Background without leaves',
        'Blueberry - Healthy',
        'Cherry - healthy',
        'Cherry - Powdery mildew',
        'Corn - Cercospora',
        'Corn - Common rust',
        'Corn - Healthy',
        'Corn - Northern Leaf Blight',
        'Grape - Black rot',
        'Grape - Esca ',
        'Grape - Healthy',
        'Grape - Leaf blight',
        'Orange - Haunglongbing',
        'Peach - Bacterial spot',
        'Peach - healthy',
        'Pepper bell - Bacterial spot',
        'Pepper bell - healthy',
        'Potato - Early blight',
        'Potato - Healthy',
        'Potato - Late blight',
        'Raspberry - healthy',
        'Soybean - Healthy',
        'Squash - Powdery mildew',
        'Strawberry - Healthy',
        'Strawberry - Leaf scorch',
        'Tomato - Bacterial spot',
        'Tomato - Early blight',
        'Tomato - Healthy',
        'Tomato - Late blight',
        'Tomato - Leaf Mold',
        'Tomato - Septoria leaf spot',
        'Tomato - Spider mites',
        'Tomato - Target Spot',
        'Tomato - Mosaic virus',
        'Tomato - Yellow Leaf Curl Virus'
    ],
    classesLen : 39
}

In [None]:
# About metrics.
# Metric dictionary keys.
# For preprocessing.
_loss        = 'Loss'
_groundtruth = 'Groundtruth'
_logits      = 'Logits'
# For postprocessing.
_probabilities = 'Probabilities'
_predictions   = 'Predictions'
_accuracyClass = 'Accuracy class'
_accuracy      = 'Accuracy'
_recall        = 'Recall'
_precision     = 'Precision'
_f1            = 'F1'
_auc           = 'AUC'

_metricsPrint = [_accuracy, _recall, _precision, _f1, _auc]

# Get a clean dictionary for the metrics.
def getMetricsDict():
    return {
        _loss          : torch.tensor(0.),
        _groundtruth   : torch.tensor([]),
        _logits        : torch.tensor([])
    }

# Function used to update the dictionary of resulting metrics.
def updateRunningMetrics(logits, groundtruth, loss, batchAmount, metricsResults):
    # Accumulate the loss.
    metricsResults[_loss] += loss.cpu() / batchAmount
    # Accumulate the groundtruth and the logits.
    metricsResults[_groundtruth] = torch.cat((metricsResults[_groundtruth], groundtruth.cpu())) 
    metricsResults[_logits] = torch.cat((metricsResults[_logits], logits.cpu()))

# Function used to process the dictionary of resulting metrics (make final calculations).
def processRunningMetrics(metricsResults):
    # Detach the other values in the dictionary.
    metricsResults[_loss] = metricsResults[_loss].detach()
    metricsResults[_groundtruth] = metricsResults[_groundtruth].detach()
    metricsResults[_logits] = metricsResults[_logits].detach()
    # Save in the dictionary the probabilities and the predictions.
    metricsResults[_probabilities] = softmax(metricsResults[_logits]).detach()
    metricsResults[_predictions] = torch.argmax(metricsResults[_probabilities], axis=1).detach()

    # Get Groundtruth (as numpy).
    groundtruth = metricsResults[_groundtruth].detach().numpy()
    # Get probabilities (as numpy).
    probs = metricsResults[_probabilities].detach().numpy()
    # Get predictions (as numpy).
    preds = metricsResults[_predictions].detach().numpy()

    # Calculate accuracy by class.
    confusionMatrix = confusion_matrix(groundtruth, preds)
    confusionMatrix = confusionMatrix.astype('float') / confusionMatrix.sum(axis=1)[:, np.newaxis]
    metricsResults[_accuracyClass] = torch.tensor(confusionMatrix.diagonal())

    # Calculate accuracy.
    metricsResults[_accuracy] = torch.tensor(accuracy_score(groundtruth, preds))
    # Calculate recall.
    metricsResults[_recall] = torch.tensor(recall_score(groundtruth, preds, average='macro'))
    # Calculate precision.
    metricsResults[_precision] = torch.tensor(precision_score(groundtruth, preds, average='macro'))
    # Calculate F1.
    metricsResults[_f1] = torch.tensor(f1_score(groundtruth, preds, average='macro'))
    # Calculate AUC.
    metricsResults[_auc] = torch.tensor(roc_auc_score(groundtruth, probs, multi_class='ovr'))

# Pretty print the metrics dictionaries.
def printMetricsDict(metricsResults):
    # All metrics to print
    metricPrints = []

    # Format the loss.
    lossPrint = 'Loss: {:.4f}'.format(metricsResults[_loss])
    metricPrints.append(lossPrint)

    # Format the the remaining metrics.
    baseMetricString = '{}: {:1.4f}'
    for metric in _metricsPrint:
        metricPrints.append(baseMetricString.format(metric, metricsResults[metric]))

    print(', '.join(metricPrints))

# This functions process an metrics result dictionary for wandb. Is necessary to indicte
#   the metrics origin, training or testing.
def processMetricsWandb(metricsResults, training=False):
    # Get the prefix to log on wandb, the keys must be different.
    resultsType = 'training' if training else 'testing'

    # Key name for the confusion matrix
    _confusionMatrix = 'Confusion matrix'

    # All the wandb keys are based in the original metrics results keys.
    lossKey = '{} ({})'.format(_loss, resultsType)
    metricsKeys = ['{} ({})'.format(_metric, resultsType) for _metric in _metricsPrint]
    accuracyClassKeys = ['{} accuracy ({})'.format(_class, resultsType) for _class in config[classes]]
    confusionMatrixKey = '{} ({})'.format(_confusionMatrix, resultsType)

    # Get the confusion matrix
    confusionMatrix = wandb.plot.confusion_matrix(y_true=metricsResults[_groundtruth].tolist(),
        preds=metricsResults[_predictions].tolist(), class_names=config[classes], title=confusionMatrixKey)

    # Make the dictionary for wandb and store the values.
    wandbDict = {
        lossKey            : metricsResults[_loss].item(),
        confusionMatrixKey : confusionMatrix
    }
    for i in range(len(metricsKeys)):
        wandbDict[metricsKeys[i]] = metricsResults[_metricsPrint[i]].item()

    for i in range(config[classesLen]):
        wandbDict[accuracyClassKeys[i]] = metricsResults[_accuracyClass][i].item()

    # Return, to log later.
    return wandbDict

# Get the metrics dictionaries for wandb and log them.
def logMetricsWandb(trainMetricsResults, testMetricsResults):
    # Get both dictionaries for wandb.
    wandbTrainDict = processMetricsWandb(trainMetricsResults, training=True)
    wandbTestDict  = processMetricsWandb(testMetricsResults, training=False)

    # Merge the dictionaries.
    wandbDict = {**wandbTrainDict, **wandbTestDict}

    # Log on wandb
    wandb.log(wandbDict)

# Function used to save the model and the metrics.
def saveEpochData(trainMetricsResults, testMetricsResults, model, optimizer, epoch, rootPath):
    # Create a dir for the current epoch.
    runDir = os.path.join(os.getcwd(), rootPath, str(epoch))
    Path(runDir).mkdir(parents=True, exist_ok=True)

    # Path
    savePath = os.path.join(runDir, 'model.pth')

    # Make dict for torch.save
    saveDict = {
        _model     : model.state_dict(),
        _optimizer : optimizer.state_dict(),
        _epoch     : epoch
    }

    # Save both metrics, for train and test.
    metricsResults = {
        _metricsTrain : trainMetricsResults,
        _metricsTest  : testMetricsResults
    }

    # Merge the save dict with the metricsResults dict.
    saveDict = {**metricsResults, **saveDict}

    # Save
    torch.save(saveDict, savePath)

In [None]:
# Training**
# Antes de cada epoch, obtengo un nuevo diccionario**
# Es uno por training y uno por testing owo**
trainMetricsResults = getMetricsDict()

# Esto es el for del batch, lo q se hace en cada epoch
for __groundtruth, __logits in zip(groundtruthTest, logitsTest):

    # Dentro del epoch llamo a esta, le paso logits, grountruth, loss, len del epoch y el dictionario.
    updateRunningMetrics(__logits, __groundtruth, lossTest, splitsAmount, trainMetricsResults)

# Cuando termina el epoch llamo a esta para terminar de armar el dictionary
processRunningMetrics(trainMetricsResults)

# Testing**
testMetricsResults = getMetricsDict()

for __groundtruth, __logits in zip(groundtruthTest, logitsTest):

    updateRunningMetrics(__logits, __groundtruth, lossTest, splitsAmount, testMetricsResults)

processRunningMetrics(testMetricsResults)

# Cuando tengo los dos diccionarios de resultados logeo en wandb**
logMetricsWandb(trainMetricsResults, testMetricsResults)

# Tambien puedo imprimirlos**
epoch = 1
print('**', '[', 'Epoch ', epoch, ']', '*' * 48, sep='')
print('\tTraining results:', end=' ')
printMetricsDict(trainMetricsResults)
print('\t Testing results:', end=' ')
printMetricsDict(testMetricsResults)

# Tambien se puede guardar si quieren salvar el modelo y las metricas hasta el momento
rootPath = 'runs/corrida1'
saveEpochData(trainMetricsResults, testMetricsResults, model, optimizer, epoch, rootPath)