In [1]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

### **Building a Confusion Matrix** ###

In [None]:
@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    for batch in loader:
        images, labels = batch
        
        preds = model(images)
        all_preds = torch.cat(
            (all_preds, preds),
            dim=0
        )
    return all_preds

In [None]:
prediction_loader = torch.utils.data.DataLoader(train_set, batch_size=10000)
train_preds = get_all_preds(network, train_loader)

In [None]:
stacked = torch.stack(
    (
        train_set.targets,
        train_preds.argmax(dim=1)
    ),
    dim=1
)

In [None]:
conf_mt = torch.zeros(10,10, dtype=torch.int32)

for p in stacked:
    true_lab, pred_lab = p.tolist()
    conf_mt[true_lab, pred_lab] = conf_mt[true_lab, pred_lab] + 1

##### **Alternative way to create a Confusion Matrix** #####

In [None]:
# from sklearn.metrics import confusion_matrix

# conf_mt = confusion_matrix(train_set.targets, train_preds.argmax(dim=1))

### **Plotting a Confusion Matrix** ###

In [None]:
def plot_confusion_matrix(cm, classes, normalize=False, 
                          title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
                          plt.text(j, i, format(cm[i, j], fmt), 
                          horizontalalignment="center",
                          color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
plt.figure(figsize=(10,10))
plot_confusion_matrix(conf_mt, train_set.classes)