In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [37]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)

In [38]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        t = F.max_pool2d(F.relu(self.conv1(t)), kernel_size=2, stride=2)

        t = F.max_pool2d(F.relu(self.conv2(t)), kernel_size=2, stride=2)

        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)

        return t

network = Network()

In [39]:
def num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [40]:
network = Network()
batch = next(iter(train_loader))
optimiser = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(1):

    total_loss = 0
    total_correct = 0

    for batch in train_loader: #looping each image in the batch
        images, labels = batch

        preds = network(images)
        loss = F.cross_entropy(preds, labels)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        total_loss += loss.item()
        total_correct += num_correct(preds, labels)

    print('epoch:', epoch , 'loss:', total_loss, 'total correct:', total_correct)

epoch: 0 loss: 329.2417930960655 total correct: 47723


In [41]:
@torch.no_grad()
def get_preds(model, loader):
    all_preds = torch.tensor([])

    for batch in loader:
        images, labels = batch
        preds = network(images)

        all_preds = torch.cat((all_preds, preds), dim=0)
    return all_preds

In [42]:
pred_loader = torch.utils.data.DataLoader(train_set, batch_size=5000)
train_preds = get_preds(network, pred_loader)

we have initliased a function to recieve all the prediction tensor values for the entire training set

this has been done in order to graph the predictions for each label, here we can identify areas in which the model is mistaking items with each other

the @torch.no_grad() decorator has been added in order to prevent pytorch for tracking the gradient of the values within the prediction tensor, as it is not necessary for this task

In [43]:
train_preds.shape

torch.Size([60000, 10])

In [45]:
preds_correct = num_correct(train_preds, train_set.targets)

print('total correct:', preds_correct)
print('total accuracy:', preds_correct / len(train_set))

total correct: 49898
total accuracy: 0.8316333333333333


train_set.targets represents all the labels for the whole training dataset

In [50]:
stacked = torch.stack((train_set.targets, train_preds.argmax(dim=1)), dim=1)

this stacked tensor is a pairing of the dataset labels and the predicted labels, allowing us to perform an operation finding incorrect pairing 

In [55]:
confusion_matrix = torch.zeros(10, 10, dtype=torch.int32)

In [56]:
confusion_matrix

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)

In [64]:
for n in stacked:
    tl, pl = n.tolist()
    confusion_matrix[tl, pl] = confusion_matrix[tl, pl] +1

In [65]:
confusion_matrix

tensor([[26659,     5,    41,  2431,    40,   178,   374,  1327,    96,   410],
        [   24,  5628,     7,   291,    16,     2,    22,     0,    10,     0],
        [   85,     3, 15437,    61,  2081,   310,   795,     0,   110,     6],
        [ 1319,     9,    21, 15098,   590,    53,    86,   511,    10,   303],
        [   11,     6,   324,    88, 10969,   139,   424,     0,    39,     0],
        [   85,     0,   424,    72,   402, 16379,     0,   923,    27,   160],
        [ 1406,     4,   489,   146,   804,     1,  2933,     0,   216,     1],
        [ 1443,     0,     1,   950,     0,   222,     0,  8790,    10,  1534],
        [   17,     1,     9,    21,    41,    12,    36,    25,  5837,     1],
        [  268,     0,     8,   239,     0,   106,     1,   751,     9, 10849]], dtype=torch.int32)

this confusion matrix shows the prediction tensors number of estimates for each label, moving down the diagonal, there are higher values as these are the correctly predicted images