In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [12]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = F.relu(self.fc1(t.reshape(-1, 12 * 4 * 4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        
        return t

In [13]:
network = Network()

In [14]:
sample = next(iter(train_set))

In [15]:
image, label = sample
image.shape

torch.Size([1, 28, 28])

In [16]:
image.unsqueeze(0)

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039,
           0.0000, 0.0000, 0.0510, 0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.00

In [17]:
pred = network(image.unsqueeze(0))

In [18]:
pred.shape

torch.Size([1, 10])

In [19]:
pred

tensor([[ 0.0774, -0.0052, -0.1186, -0.1491, -0.1061,  0.1333,  0.0381, -0.0800, -0.0595,  0.0454]],
       grad_fn=<AddmmBackward0>)

In [20]:
label

9

In [21]:
pred.argmax(dim=1)

tensor([5])

In [22]:
F.softmax(pred, dim=1)

tensor([[0.1101, 0.1013, 0.0905, 0.0878, 0.0916, 0.1164, 0.1058, 0.0940, 0.0960, 0.1066]], grad_fn=<SoftmaxBackward0>)

In [23]:
F.softmax(pred, dim=1).sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [24]:
net1 = Network()
net1(image.unsqueeze(0))

tensor([[-0.1251,  0.0840, -0.0346, -0.0015,  0.1562,  0.0805, -0.0721,  0.1277,  0.0127, -0.0326]],
       grad_fn=<AddmmBackward0>)

In [25]:
net2 = Network()
net2(image.unsqueeze(0))

tensor([[ 0.0890,  0.0103,  0.0796, -0.0636,  0.0459, -0.0116,  0.0303, -0.0892, -0.0541, -0.0795]],
       grad_fn=<AddmmBackward0>)

In [27]:
data_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size = 10
)

In [28]:
batch = next(iter(data_loader))

In [29]:
images, labels = batch

In [30]:
images.shape

torch.Size([10, 1, 28, 28])

In [31]:
labels.shape

torch.Size([10])

In [32]:
preds = network(images)

In [33]:
preds.shape

torch.Size([10, 10])

In [35]:
preds

tensor([[ 0.0774, -0.0052, -0.1186, -0.1491, -0.1061,  0.1333,  0.0381, -0.0800, -0.0595,  0.0454],
        [ 0.0685,  0.0044, -0.1230, -0.1615, -0.1150,  0.1444,  0.0424, -0.0839, -0.0580,  0.0382],
        [ 0.0392, -0.0002, -0.1146, -0.1429, -0.0951,  0.1395,  0.0257, -0.0536, -0.0517,  0.0437],
        [ 0.0490,  0.0017, -0.1175, -0.1473, -0.1016,  0.1411,  0.0304, -0.0601, -0.0517,  0.0420],
        [ 0.0655,  0.0103, -0.1218, -0.1481, -0.1069,  0.1422,  0.0300, -0.0767, -0.0550,  0.0501],
        [ 0.0684,  0.0013, -0.1203, -0.1530, -0.1097,  0.1458,  0.0426, -0.0770, -0.0542,  0.0380],
        [ 0.0599,  0.0006, -0.1162, -0.1465, -0.1007,  0.1469,  0.0300, -0.0628, -0.0529,  0.0312],
        [ 0.0876,  0.0189, -0.1268, -0.1546, -0.1176,  0.1484,  0.0441, -0.0970, -0.0511,  0.0375],
        [ 0.0445, -0.0108, -0.1142, -0.1454, -0.0989,  0.1298,  0.0324, -0.0537, -0.0512,  0.0374],
        [ 0.0478, -0.0211, -0.1216, -0.1487, -0.0968,  0.1206,  0.0460, -0.0554, -0.0598,  0.0382]],

In [36]:
preds.argmax(dim=1)

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])

In [37]:
labels

tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

In [38]:
preds.argmax(dim=1).eq(labels)

tensor([False, False, False, False, False, False, False, False,  True,  True])

In [39]:
preds.argmax(dim=1).eq(labels).sum()

tensor(2)

In [40]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [41]:
get_num_correct(preds, labels)

2