In [13]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [14]:
mnist = torchvision.datasets.MNIST(
    root = 'data/',
    download = True,
    train = True,
    transform = torchvision.transforms.ToTensor()
)

In [15]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            nn.Flatten(1),
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(128, 10)
        )
    
    def forward(self, x, y):
        return -nn.functional.cross_entropy(self.f(x), y, reduction='none')

    def classify(self, x):
        q = self.f(torch.as_tensor(x)[None, ...])[0]
        return nn.functional.softmax(q, dim=0)

In [16]:
model = ImageClassifier()

EPOCHS = 5
BATCH_NUM = 1000

mnist_batched = torch.utils.data.DataLoader(mnist, batch_size=BATCH_NUM)
model.train(mode=True)
optimizer = optim.Adam(model.parameters())

for epoch in tqdm(range(EPOCHS)):
    for batch_num, (imgs, lbls) in enumerate(mnist_batched):
        optimizer.zero_grad()
        
        loglik = model(imgs, lbls)
        e = -torch.mean(loglik)
        
        e.backward()
        optimizer.step()


model.train(mode=False)

100%|██████████| 60/60 [01:30<00:00,  1.50s/it]
100%|██████████| 60/60 [01:12<00:00,  1.21s/it]
100%|██████████| 60/60 [00:47<00:00,  1.26it/s]
100%|██████████| 60/60 [01:09<00:00,  1.16s/it]
100%|██████████| 60/60 [01:22<00:00,  1.37s/it]
100%|██████████| 5/5 [06:02<00:00, 72.55s/it]


ImageClassifier(
  (f): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout2d(p=0.25, inplace=False)
    (5): Flatten(start_dim=1, end_dim=-1)
    (6): Linear(in_features=9216, out_features=128, bias=True)
    (7): ReLU()
    (8): Dropout2d(p=0.5, inplace=False)
    (9): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [17]:
img, _ = mnist[110]
model.classify(img)



tensor([1.3325e-09, 5.6040e-09, 8.2563e-08, 1.1641e-06, 1.0658e-04, 3.7527e-08,
        1.3281e-10, 1.8109e-05, 2.3385e-07, 9.9987e-01],
       grad_fn=<SoftmaxBackward0>)

In [18]:
mnist[110]

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 