## Neural Network Batch Processing - Pass Image Batch To PyTorch CNN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [2]:
print(torch.__version__)
print(torchvision.__version__)

1.10.0+cu111
0.11.1+cu111


In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root = './content/FashionMNIST',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ]),
    download=True
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./content/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./content/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./content/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./content/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./content/FashionMNIST/FashionMNIST/raw



In [4]:
class Network(nn.Module):

    def __init__(self):

        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):

        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)

        return t

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f3ec4b79410>

In [17]:
network = Network()

In [7]:
data_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size = 10
)

In [8]:
batch = next(iter(data_loader))

In [9]:
images, labels = batch

In [10]:
images.shape # Batch_size, Channels, H, W

torch.Size([10, 1, 28, 28])

In [18]:
preds = network(images)

In [19]:
preds.shape

torch.Size([10, 10])

In [20]:
preds

tensor([[ 0.0466,  0.0893, -0.0263,  0.1129, -0.0328, -0.0680,  0.0532,  0.0002,  0.0970,  0.0846],
        [ 0.0424,  0.0846, -0.0308,  0.1102, -0.0334, -0.0621,  0.0569, -0.0019,  0.0930,  0.0903],
        [ 0.0281,  0.0903, -0.0246,  0.1124, -0.0293, -0.0660,  0.0536,  0.0007,  0.0880,  0.0848],
        [ 0.0288,  0.0898, -0.0272,  0.1067, -0.0339, -0.0694,  0.0560,  0.0031,  0.0913,  0.0883],
        [ 0.0360,  0.0805, -0.0293,  0.1117, -0.0396, -0.0654,  0.0518,  0.0018,  0.0914,  0.0794],
        [ 0.0336,  0.0867, -0.0199,  0.1105, -0.0316, -0.0651,  0.0587,  0.0050,  0.0892,  0.0891],
        [ 0.0334,  0.0964, -0.0147,  0.1124, -0.0291, -0.0674,  0.0511, -0.0026,  0.0887,  0.0862],
        [ 0.0373,  0.0843, -0.0204,  0.1088, -0.0331, -0.0612,  0.0600,  0.0093,  0.0846,  0.0863],
        [ 0.0281,  0.0909, -0.0207,  0.1059, -0.0302, -0.0731,  0.0483,  0.0119,  0.0934,  0.0894],
        [ 0.0344,  0.0936, -0.0208,  0.1116, -0.0310, -0.0737,  0.0465,  0.0090,  0.0967,  0.0863]])

In [21]:
preds.argmax(dim=1)

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

In [22]:
labels

tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

In [23]:
# element-wise equals operation
preds.argmax(dim=1).eq(labels)

tensor([False, False, False,  True, False, False, False, False, False, False])

In [24]:
preds.argmax(dim=1).eq(labels).sum()

tensor(1)

In [27]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [28]:
get_num_correct(preds, labels)

1