In [None]:
import matplotlib.pyplot as plt
import numpy as np
from data_utils import MIT_split_dataset, CustomTransform
from torch.utils.data import DataLoader
from model_luis import Model
import torch
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, average_precision_score, confusion_matrix, ConfusionMatrixDisplay, PrecisionRecallDisplay
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from collections import Counter

In [1]:
def accuracy(predictions, labels):
    """
    Calculates the accuracy of a set of predictions compared to the actual labels.

    Parameters:
        predictions: numpy array containing the predicted values.
        labels: numpy array containing the actual labels.

    Returns:
        A float representing the accuracy.
    """
    return sum(predictions == labels) / len(labels)

def precision(predictions, labels, class_label):
    """
    Calculates precision for a specific class in a classification task.

    Parameters:
        predictions: numpy array containing the predicted class labels.
        labels: numpy array containing the actual class labels.
        class_label: the specific class for which precision is calculated.

    Returns:
        Precision value for the specified class.
    """
    tp = np.sum((predictions == class_label) & (labels == class_label))
    fp = np.sum((predictions == class_label) & (labels != class_label))
    return tp / (tp + fp) if (tp + fp) > 0 else 0

def recall(predictions, labels, class_label):
    """
    Calculates recall for a specific class in a classification task.

    Parameters:
        predictions: numpy array containing the predicted class labels.
        labels: numpy array containing the actual class labels.
        class_label: the specific class for which recall is calculated.

    Returns:
        Recall value for the specified class.
    """
    tp = np.sum((predictions == class_label) & (labels == class_label))
    fn = np.sum((predictions != class_label) & (labels == class_label))
    return tp / (tp + fn) if (tp + fn) > 0 else 0

def average_precision(predictions, labels):
    """
    Calculates the average precision across all classes in a classification task.

    Parameters:
        predictions: numpy array containing the predicted class labels.
        labels: numpy array containing the actual class labels.

    Returns:
        The average precision across all classes.
    """
    classes = np.unique(labels)
    return np.mean([precision(predictions, labels, c) for c in classes])

def average_recall(predictions, labels):
    """
    Calculates the average recall across all classes in a classification task.

    Parameters:
        predictions: numpy array containing the predicted class labels.
        labels: numpy array containing the actual class labels.

    Returns:
        The average recall across all classes.
    """
    classes = np.unique(labels)
    return np.mean([recall(predictions, labels, c) for c in classes])

def average_f1(predictions, labels):
    """
    Calculates the average F1 score across all classes in a classification task.

    Parameters:
        predictions: numpy array containing the predicted class labels.
        labels: numpy array containing the actual class labels.

    Returns:
        The average F1 score across all classes.
    """
    
    return 2 * average_precision(predictions, labels) * average_recall(predictions, labels) / (average_precision(predictions, labels) + average_recall(predictions, labels))

In [None]:
config = {
    'IMG_WIDTH': 256,
    'IMG_HEIGHT': 256,
    'TEST_DATASET_DIR': 'data/MIT_split/test',
    'batch_size': 32
}

transform_test = CustomTransform(config, mode='test')
dataset_test = MIT_split_dataset(config['TEST_DATASET_DIR'], transform=transform_test)
dataloader_test = DataLoader(dataset_test, batch_size=config['batch_size'], shuffle=False)

model = Model(num_classes=8)
model.load_state_dict(torch.load('pretrained/best_model.pth'))
model = model.to('cuda')

In [None]:
model.eval()
with torch.no_grad():
    labels_pred = []
    labels_true = []
    predictions = []
    for inputs, labels in dataloader_test:
        outputs = F.softmax(model(inputs.cuda()), dim=1)
        _, preds = torch.max(outputs, 1)
        labels_pred.extend(preds.cpu().numpy())
        labels_true.extend(labels.numpy())
        predictions.extend(outputs.cpu().numpy())

labels_pred = np.array(labels_pred)
labels_true = np.array(labels_true)
predictions = np.array(predictions)

In [None]:
# plot confusion matrix
cm = confusion_matrix(labels_true, labels_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['coast', 'forest', 'highway', 'inside_city', 'mountain', 'Opencountry', 'street', 'tallbuilding'])
_, ax = plt.subplots(figsize=(6, 6))
disp.plot(ax=ax, cmap=plt.cm.Blues, xticks_rotation='vertical', colorbar=False)

# Adding text for metrics
textstr = '\n'.join((
    f'Accuracy: {accuracy(labels_pred, labels_true):.3f}',
    f'Avg. Precision: {average_precision(labels_pred, labels_true):.3f}',
    f'Avg. Recall: {average_recall(labels_pred, labels_true):.3f}',
    f'Avg. F1-score: {average_f1(labels_pred, labels_true):.3f}'))

# These are matplotlib.patch.Patch properties
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

# Place a text box in upper left in axes coords
ax.text(1.03, 0.55, textstr, transform=ax.transAxes, fontsize=12,
        verticalalignment='top', bbox=props)

plt.show()

In [None]:
# Convert the labels to binary format for each class
n_classes = 8  # Number of classes
labels_binary = label_binarize(labels_true, classes=[i for i in range(n_classes)])

# Calculate precision and recall for each class
precision_dict = dict()
recall_dict = dict()
average_precision_dict = dict()
for i in range(n_classes):
    precision_dict[i], recall_dict[i], _ = precision_recall_curve(labels_binary[:, i], predictions[:, i])
    average_precision_dict[i] = average_precision_score(labels_binary[:, i], predictions[:, i])

# Compute micro-average precision and recall
precision_micro, recall_micro, _ = precision_recall_curve(labels_binary.ravel(), predictions.ravel())
average_precision_micro = average_precision_score(labels_binary, predictions, average="micro")

classes = ['"coast"', '"forest"', '"highway"', '"inside_city"', '"mountain"', '"opencountry"', '"street"', '"tallbuilding"']

# Plot PR curves for each class
_, ax = plt.subplots(figsize=(6, 6))
colors = ["tab:red", "tab:green", "tab:blue", "tab:brown", "tab:orange", "tab:purple", "tab:pink", "brown"]
for i, color in zip(range(n_classes), colors):
    plt.plot(recall_dict[i], precision_dict[i], color=color, alpha=0.35,
             label=f'PR curve of class {classes[i]} (area = {average_precision_dict[i]:0.2f})')

# Plot micro-averaged PR curve
display = PrecisionRecallDisplay(
    recall=recall_micro,
    precision=precision_micro,
    average_precision=average_precision_micro,
    prevalence_pos_label=Counter(labels_binary.ravel())[1] / labels_binary.size,
)
display.plot(ax=ax, name="Micro-average", plot_chance_level=True, color='black')

# Plot iso F1 curves
f1_scores = np.linspace(0.2, 0.8, num=4)
for f1_score in f1_scores:
    x = np.linspace(0.01, 1)
    y = f1_score * x / (2 * x - f1_score)
    l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
    plt.annotate(f'f1={f1_score:0.1f}', xy=(0.9, y[45] + 0.02))

# Customizing the plot
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Micro-averaged Precision-Recall curve')
plt.legend(loc="best")

plt.show()