In [2]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [41]:
trainset = datasets.MNIST(root = './data', download = True, train = True, transform = transform)

In [11]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True)

torch.Size([64, 1, 28, 28])

In [14]:
model = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Linear(64, 10),
                      nn.LogSoftmax(dim = 1))  #define the model

criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [16]:
epoch = 10
for count in range(epoch):
    running_loss = 0
    for images, label in trainloader:
        images = images.view(images.shape[0], -1)
        
        logps = model.forward(images)
        loss = criterion(logps, label)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        running_loss += loss.item()
        
    else:
        print(f"Training loss: {running_loss / len(trainloader)}")

Training loss: 1.0386095260823969
Training loss: 0.3809133465213181
Training loss: 0.32221768392937017
Training loss: 0.2906380556920952
Training loss: 0.2662835619025139
Training loss: 0.24450486555282494
Training loss: 0.2244886434329217
Training loss: 0.2068652132911278
Training loss: 0.1909526328860061
Training loss: 0.1771284985278588


In [43]:
testset = datasets.MNIST(root = './data', download = True, train = False, transform = transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True)

In [44]:
imgs, labl = next(iter(testloader))

for index, ground_truth in zip(range(64), labl):
    print(f"Index: {index}, Label: {ground_truth}")

Index: 0, Label: 1
Index: 1, Label: 0
Index: 2, Label: 1
Index: 3, Label: 6
Index: 4, Label: 0
Index: 5, Label: 9
Index: 6, Label: 7
Index: 7, Label: 5
Index: 8, Label: 3
Index: 9, Label: 2
Index: 10, Label: 6
Index: 11, Label: 4
Index: 12, Label: 1
Index: 13, Label: 6
Index: 14, Label: 3
Index: 15, Label: 5
Index: 16, Label: 1
Index: 17, Label: 3
Index: 18, Label: 4
Index: 19, Label: 6
Index: 20, Label: 5
Index: 21, Label: 7
Index: 22, Label: 0
Index: 23, Label: 9
Index: 24, Label: 1
Index: 25, Label: 8
Index: 26, Label: 8
Index: 27, Label: 3
Index: 28, Label: 3
Index: 29, Label: 9
Index: 30, Label: 0
Index: 31, Label: 6
Index: 32, Label: 9
Index: 33, Label: 0
Index: 34, Label: 3
Index: 35, Label: 0
Index: 36, Label: 8
Index: 37, Label: 7
Index: 38, Label: 3
Index: 39, Label: 0
Index: 40, Label: 7
Index: 41, Label: 6
Index: 42, Label: 0
Index: 43, Label: 7
Index: 44, Label: 6
Index: 45, Label: 0
Index: 46, Label: 3
Index: 47, Label: 1
Index: 48, Label: 5
Index: 49, Label: 7
Index: 50,

In [45]:
index = 15
with torch.no_grad():
    print("The image is actually ", labl[index].item())
    prediction = model.forward(imgs[index].view(1, -1))
    
    ps = F.softmax(prediction, dim = 1)
    
    maximum = torch.max(ps)
    print(f"Model prediction: {(ps == maximum).nonzero()[0][1].item()}, probability: {maximum}")

The image is actually  5
Model prediction: 5, probability: 0.9955824017524719
