In [126]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [127]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=5)

In [128]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 12*4*4)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)

        return t

In [129]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x22008dafa88>

the pytorch gradient calculation is being turned off as it is not required until the training begins, no loss is being generated yet

In [130]:
network = Network()

In [131]:
sample = next(iter(train_set))

In [132]:
image, label = sample

In [133]:
image.shape

torch.Size([1, 28, 28])

a single colour channel image which is 28x28

the network requires an input of a batch, therefore we can generate a batch with only one image by reshaping,

the batch tensor is as follows: (batchsize, in_channels, height, width)

In [134]:
pred = network(image.unsqueeze(0))

In [135]:
pred

tensor([[ 0.0651,  0.1349,  0.0129,  0.1204, -0.0349,  0.1324, -0.0161, -0.1004,  0.0798, -0.1418]])

In [136]:
pred.shape

torch.Size([1, 10])

In [137]:
F.softmax(pred, dim=1)

tensor([[0.1036, 0.1111, 0.0984, 0.1095, 0.0938, 0.1108, 0.0955, 0.0878, 0.1052, 0.0843]])

the shape of the prediction tensor shows that there is one image, with 10 different predictions

using the softmax function , we can see these are probabilities

In [138]:
pred.argmax(dim=1)

tensor([1])

In [139]:
label

9

the prediction shows that the highest label prediction is an 8, the real label value was 9, which is wrong

In [140]:
batch = next(iter(train_loader))

In [141]:
images, labels = batch

In [142]:
images.shape

torch.Size([5, 1, 28, 28])

In [143]:
preds = network(images)

In [144]:
preds.shape

torch.Size([5, 10])

In [145]:
preds

tensor([[ 0.0651,  0.1349,  0.0129,  0.1204, -0.0349,  0.1324, -0.0161, -0.1004,  0.0798, -0.1418],
        [ 0.0694,  0.1266,  0.0142,  0.1192, -0.0407,  0.1243, -0.0198, -0.0995,  0.0915, -0.1450],
        [ 0.0703,  0.1248,  0.0010,  0.1166, -0.0334,  0.1217, -0.0031, -0.0905,  0.0907, -0.1392],
        [ 0.0689,  0.1294,  0.0066,  0.1223, -0.0398,  0.1256, -0.0101, -0.0944,  0.0901, -0.1413],
        [ 0.0716,  0.1301,  0.0125,  0.1206, -0.0427,  0.1326, -0.0181, -0.0996,  0.0884, -0.1479]])

In [146]:
F.softmax(preds, dim=1)

tensor([[0.1036, 0.1111, 0.0984, 0.1095, 0.0938, 0.1108, 0.0955, 0.0878, 0.1052, 0.0843],
        [0.1042, 0.1103, 0.0986, 0.1095, 0.0933, 0.1101, 0.0953, 0.0880, 0.1065, 0.0841],
        [0.1041, 0.1100, 0.0972, 0.1091, 0.0939, 0.1096, 0.0968, 0.0887, 0.1063, 0.0845],
        [0.1040, 0.1105, 0.0977, 0.1097, 0.0933, 0.1100, 0.0961, 0.0883, 0.1062, 0.0843],
        [0.1043, 0.1106, 0.0984, 0.1096, 0.0931, 0.1109, 0.0954, 0.0879, 0.1061, 0.0838]])

In [147]:
preds.argmax(dim=1)

tensor([1, 1, 1, 1, 5])

In [148]:
labels

tensor([9, 0, 0, 3, 0])

In [149]:
preds.argmax(dim=1).eq(labels)

tensor([False, False, False, False, False])

no matches found in the prediction tensor and labels tensor

In [159]:
def error(images, labels):
    return preds.argmax(dim=1).eq(labels)

In [152]:
error(images, labels)

0