<a href="https://colab.research.google.com/github/MarcoParola/medical_images_classification/blob/main/utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Utils**
This notebook contains some functions used in other notebooks

In [None]:
import numpy as np
import os
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import f1_score

In [None]:
def load_data(dir):
  imagesTrain = np.load(os.path.join(dir + 'train_tensor.npy'))
  labelsTrain = np.load(os.path.join(dir + 'train_labels.npy'))
  imagesTestPublic = np.load(os.path.join(dir + 'public_test_tensor.npy'))
  labelsTestPublic = np.load(os.path.join(dir + 'public_test_labels.npy'))
  imagesTestPrivate = np.load(os.path.join(dir + 'private_test_tensor.npy'))

  return imagesTrain, labelsTrain, imagesTestPublic, labelsTestPublic, imagesTestPrivate

In [None]:
def scaleData(image):
  scaledImage = image / (pow(2,16)-1)
  return scaledImage

## Utility function that plots the confusion matrix

In [None]:
def plot_confusionMatrix(model_, testSet_, testLabels_, classes):
    pred = model_.predict_classes(testSet_)
    cm = metrics.confusion_matrix(pred, testLabels_)
    plt.imshow(cm, interpolation='nearest', cmap='OrRd')
    plt.title('Confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    print(tick_marks)
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    cm = np.round( cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] ,2)
    print("Normalized confusion matrix")
    thresh = 0.6
    
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

## Utility function that plots roc curves

In [None]:
def plotRocCurves(models, test, labels):
  for i in range(len(models)):
    probs = models[i].predict_proba(test)
    preds = probs[:,1]
    fpr, tpr, threshold = metrics.roc_curve(labels, preds)
    roc_auc = metrics.auc(fpr, tpr)

    # method I: plt
    import matplotlib.pyplot as plt
    plt.title('Roc curves')
    plt.plot(fpr, tpr, 'b', color=(np.random.rand(),np.random.rand(),np.random.rand()), label = 'model' + str(i+1) + '= %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
  plt.show()

In [None]:
def plotRocCurves(models, test, labels, names):
  for i in range(len(models)):
    probs = models[i].predict_proba(test)
    preds = probs[:,1]
    fpr, tpr, threshold = metrics.roc_curve(labels, preds)
    roc_auc = metrics.auc(fpr, tpr)

    # method I: plt
    import matplotlib.pyplot as plt
    plt.title('Roc curves')
    plt.plot(fpr, tpr, 'b', color=(np.random.rand(),np.random.rand(),np.random.rand()), label = names[i] + '= %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
  plt.show()

In [None]:
def plot_accurancy_loss(hist):
  acc_1 = hist.history['accuracy']
  val_acc_1 = hist.history['val_accuracy']
  loss_1 = hist.history['loss']
  val_loss_1 = hist.history['val_loss']

  plt.ylim(0,1)
  
  epochs = range(len(acc_1))

  plt.plot(epochs, acc_1, 'bo', label='Training acc')
  plt.plot(epochs, val_acc_1, 'b', label='Validation acc')
  plt.title('Training and validation accuracy')
  plt.legend()

  plt.figure()
  plt.ylim(0,1)
  plt.plot(epochs, loss_1, 'bo', label='Training loss')
  plt.plot(epochs, val_loss_1, 'b', label='Validation loss')
  plt.title('Training and validation loss')
  plt.legend()

  plt.show()