In [None]:
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda

transform = transforms.Compose([
    transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
trainloader.dataset.data.shape, testloader.dataset.data.shape
trainloader.dataset.train_data[0]

In [None]:
import matplotlib.pyplot as plt

numpy_img = trainloader.dataset.train_data[0].numpy()
numpy_img.shape
plt.imshow(numpy_img)

In [None]:
for data in trainloader:
    print(data)
    break

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


class SimpleConvnet(nn.Module):
    def __init__(self):
        super(SimpeConvnet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1  = nn.Linear(4 * 4 * 16, 120)
        self.fc2  = nn.Linear(120, 84)
        self.fc3  = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.shape)
        x = x.view(-1, 4 * 4 * 16)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = SimpleConvnet()
batch = []

batch = next(iter(trainloader))
net.forward(torch.FloatTensor(batch[0]))

summary(net.cuda(), (1, 28, 28))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

net = SimpleConvnet().to(device)

loss_fn = nn.CrossEntropyLoss()

learning_rate = 1e-3

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

losses = []


fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(1, 1, 1)

for epoch in tqdm(range(2)):
    running_loss = 0.0

    for i, batch in enumerate(tqdm(trainloader)):
        x_batch, y_batch = batch

        optimizer.zero_grad()

        y_pred = net(x_batch.to(device))

        loss = loss_fn(y_pred, y_batch.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i+1, running_loss / 2000))
            losses.append(running_loss)
            running_loss = 0.0
    
    ax.clear()
    ax.plot(np.arange(len(losses)), losses)
    plt.show()

print('Finished Training')


In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
classes = tuple(str(i) for i in range(10))

with torch.no_grad():
    for batch in tqdm(testloader):
        images, labels = batch
        outputs = net(images.to(device))
        _, predicted = torch.max(outputs, 1)
        
        c = predicted.cpu().detach() == labels

        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

100%|██████████| 2500/2500 [00:06<00:00, 373.44it/s]Accuracy of     0 : 97 %
Accuracy of     1 : 99 %
Accuracy of     2 : 99 %
Accuracy of     3 : 98 %
Accuracy of     4 : 99 %
Accuracy of     5 : 99 %
Accuracy of     6 : 98 %
Accuracy of     7 : 97 %
Accuracy of     8 : 98 %
Accuracy of     9 : 95 %