In [76]:
import torch, torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn, optim

In [None]:
# Download MNIST manually using 'wget' then uncompress the file
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

In [77]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [82]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(28*28, 512)
        self.output = nn.Linear(512, 10)
        
        self.sigmoid = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x):
        x = self.hidden(x)
        x = self.sigmoid(x)
        x = self.output(x)
        x = self.softmax(x)
        
        return x

In [83]:
model = Net()

In [84]:
model

Net(
  (hidden): Linear(in_features=784, out_features=512, bias=True)
  (output): Linear(in_features=512, out_features=10, bias=True)
  (sigmoid): ReLU()
  (softmax): LogSoftmax(dim=1)
)

In [85]:
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(5):
    running_loss = 0
    for images, labels in trainloader:
        images = images.view(images.shape[0], -1)
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    else:
        print('The running loss is: {}'.format(running_loss/len(trainloader)))
        

The running loss is: 1.1886519565701739
The running loss is: 0.47946929210411715
The running loss is: 0.3845183381647952
The running loss is: 0.344596299837266
The running loss is: 0.31974164313916714


In [86]:
images, labels = next(iter(trainloader))

In [None]:
img = images[0].view(1, -1)

In [88]:
with torch.no_grad():
    logprobs = model(img)

In [89]:
logprobs

tensor([[-11.5440, -14.8617, -10.5375, -10.3530,  -5.2734,  -8.8099, -11.1371,
          -5.0983,  -4.9283,  -0.0189]])

In [93]:
probs = torch.exp(logprobs)
probs

tensor([[9.6945e-06, 3.5127e-07, 2.6523e-05, 3.1898e-05, 5.1259e-03, 1.4925e-04,
         1.4563e-05, 6.1071e-03, 7.2391e-03, 9.8130e-01]])

In [94]:
torch.argmax(probs)

tensor(9)