In [2]:
import torch
from torchvision import datasets, transforms

In [3]:
# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    #We are normalizing the input tensor with a mean of 0.5 and a standard deviation of 0.5. This scales the pixel values to be between -1 and 1, which can help with training stability.
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

print(trainset[0])

# Download and load the test data
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

(tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 32*7*7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()

In [5]:
import torch.optim as optim
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}"), 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 200 mini-batches
            print(f"[Epoch {epoch+1}, batch {i+1:5d}] loss: {running_loss/200:.3f}")
            running_loss = 0.0

print('Finished Training')


Epoch 1:  22%|██▏       | 203/938 [00:06<00:21, 33.91it/s]

[Epoch 1, batch   200] loss: 1.006


Epoch 1:  43%|████▎     | 407/938 [00:12<00:14, 35.41it/s]

[Epoch 1, batch   400] loss: 0.171


Epoch 1:  64%|██████▍   | 603/938 [00:17<00:09, 35.49it/s]

[Epoch 1, batch   600] loss: 0.108


Epoch 1:  86%|████████▌ | 807/938 [00:23<00:03, 36.44it/s]

[Epoch 1, batch   800] loss: 0.103


Epoch 1: 100%|██████████| 938/938 [00:27<00:00, 34.01it/s]
Epoch 2:  22%|██▏       | 204/938 [00:05<00:20, 35.77it/s]

[Epoch 2, batch   200] loss: 0.070


Epoch 2:  43%|████▎     | 404/938 [00:11<00:14, 36.41it/s]

[Epoch 2, batch   400] loss: 0.063


Epoch 2:  64%|██████▍   | 604/938 [00:17<00:09, 35.72it/s]

[Epoch 2, batch   600] loss: 0.059


Epoch 2:  86%|████████▌ | 804/938 [00:22<00:04, 32.43it/s]

[Epoch 2, batch   800] loss: 0.053


Epoch 2: 100%|██████████| 938/938 [00:27<00:00, 34.71it/s]
Epoch 3:  22%|██▏       | 206/938 [00:05<00:20, 36.30it/s]

[Epoch 3, batch   200] loss: 0.043


Epoch 3:  43%|████▎     | 406/938 [00:11<00:14, 36.55it/s]

[Epoch 3, batch   400] loss: 0.042


Epoch 3:  65%|██████▍   | 606/938 [00:17<00:09, 35.20it/s]

[Epoch 3, batch   600] loss: 0.043


Epoch 3:  86%|████████▌ | 806/938 [00:22<00:03, 35.28it/s]

[Epoch 3, batch   800] loss: 0.042


Epoch 3: 100%|██████████| 938/938 [00:26<00:00, 35.20it/s]
Epoch 4:  22%|██▏       | 204/938 [00:05<00:20, 36.09it/s]

[Epoch 4, batch   200] loss: 0.030


Epoch 4:  43%|████▎     | 404/938 [00:11<00:15, 33.63it/s]

[Epoch 4, batch   400] loss: 0.034


Epoch 4:  64%|██████▍   | 604/938 [00:18<00:12, 27.18it/s]

[Epoch 4, batch   600] loss: 0.035


Epoch 4:  86%|████████▌ | 803/938 [00:25<00:04, 28.69it/s]

[Epoch 4, batch   800] loss: 0.029


Epoch 4: 100%|██████████| 938/938 [00:30<00:00, 30.29it/s]
Epoch 5:  22%|██▏       | 206/938 [00:07<00:24, 30.48it/s]

[Epoch 5, batch   200] loss: 0.025


Epoch 5:  43%|████▎     | 403/938 [00:14<00:18, 28.44it/s]

[Epoch 5, batch   400] loss: 0.027


Epoch 5:  64%|██████▍   | 605/938 [00:21<00:10, 30.52it/s]

[Epoch 5, batch   600] loss: 0.021


Epoch 5:  86%|████████▌ | 805/938 [00:27<00:04, 31.27it/s]

[Epoch 5, batch   800] loss: 0.021


Epoch 5: 100%|██████████| 938/938 [00:31<00:00, 29.35it/s]
Epoch 6:  22%|██▏       | 204/938 [00:07<00:35, 20.70it/s]

[Epoch 6, batch   200] loss: 0.019


Epoch 6:  43%|████▎     | 405/938 [00:14<00:18, 28.48it/s]

[Epoch 6, batch   400] loss: 0.021


Epoch 6:  64%|██████▍   | 605/938 [00:21<00:11, 29.21it/s]

[Epoch 6, batch   600] loss: 0.018


Epoch 6:  86%|████████▌ | 805/938 [00:28<00:04, 30.24it/s]

[Epoch 6, batch   800] loss: 0.022


Epoch 6: 100%|██████████| 938/938 [00:32<00:00, 28.60it/s]
Epoch 7:  22%|██▏       | 206/938 [00:07<00:23, 30.57it/s]

[Epoch 7, batch   200] loss: 0.015


Epoch 7:  43%|████▎     | 403/938 [00:13<00:17, 29.86it/s]

[Epoch 7, batch   400] loss: 0.017


Epoch 7:  64%|██████▍   | 605/938 [00:20<00:10, 30.40it/s]

[Epoch 7, batch   600] loss: 0.014


Epoch 7:  86%|████████▌ | 803/938 [00:26<00:04, 30.31it/s]

[Epoch 7, batch   800] loss: 0.016


Epoch 7: 100%|██████████| 938/938 [00:31<00:00, 29.71it/s]
Epoch 8:  22%|██▏       | 203/938 [00:06<00:24, 30.40it/s]

[Epoch 8, batch   200] loss: 0.012


Epoch 8:  43%|████▎     | 405/938 [00:13<00:17, 30.74it/s]

[Epoch 8, batch   400] loss: 0.013


Epoch 8:  64%|██████▍   | 605/938 [00:20<00:10, 30.29it/s]

[Epoch 8, batch   600] loss: 0.015


Epoch 8:  86%|████████▌ | 804/938 [00:26<00:04, 30.38it/s]

[Epoch 8, batch   800] loss: 0.016


Epoch 8: 100%|██████████| 938/938 [00:31<00:00, 29.79it/s]
Epoch 9:  22%|██▏       | 202/938 [00:06<00:23, 30.92it/s]

[Epoch 9, batch   200] loss: 0.011


Epoch 9:  43%|████▎     | 403/938 [00:13<00:17, 29.93it/s]

[Epoch 9, batch   400] loss: 0.011


Epoch 9:  65%|██████▍   | 606/938 [00:20<00:10, 30.47it/s]

[Epoch 9, batch   600] loss: 0.013


Epoch 9:  86%|████████▌ | 805/938 [00:27<00:04, 30.05it/s]

[Epoch 9, batch   800] loss: 0.013


Epoch 9: 100%|██████████| 938/938 [00:31<00:00, 29.43it/s]
Epoch 10:  22%|██▏       | 203/938 [00:06<00:24, 30.18it/s]

[Epoch 10, batch   200] loss: 0.008


Epoch 10:  43%|████▎     | 405/938 [00:13<00:18, 28.80it/s]

[Epoch 10, batch   400] loss: 0.010


Epoch 10:  64%|██████▍   | 604/938 [00:20<00:11, 29.12it/s]

[Epoch 10, batch   600] loss: 0.012


Epoch 10:  86%|████████▌ | 806/938 [00:27<00:04, 30.17it/s]

[Epoch 10, batch   800] loss: 0.007


Epoch 10: 100%|██████████| 938/938 [00:32<00:00, 29.21it/s]

Finished Training





In [6]:
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test set: {100 * correct / total:.2f}%")


Accuracy on the test set: 99.02%
