In [1]:
import torch
from torch import tensor, nn # no optim! we're making it today!
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from torchvision import transforms

In [2]:
# Load in the MNIST dataset
from datasets import load_dataset

ds = load_dataset("ylecun/mnist")

In [3]:
class MNIST(Dataset):
    def __init__(self, dsd):
        self.dsd = dsd
        
    def __getitem__(self, i):
        return transforms.ToTensor()(self.dsd['image'][i]), self.dsd['label'][i]
    
    def __len__(self):
        return len(self.dsd['image'])

In [4]:
split = 0.8
idx = int(len(ds['train']) * split)
ds_train, ds_valid = MNIST(ds['train'][:idx]), MNIST(ds['train'][idx:])

In [5]:
dl_train = DataLoader(ds_train, batch_size=1000, shuffle=True)
dl_val = DataLoader(ds_valid, batch_size=1000)

In [6]:
class ConvModule(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_batchnorm=True, activation=nn.ReLU):
        super(ConvModule, self).__init__()
        
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
        ]
        
        if use_batchnorm:
            layers.append(nn.BatchNorm2d(out_channels))
        
        if activation is not None:
            layers.append(activation())
        
        self.conv_block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.conv_block(x)

In [7]:
class GlobalAvgPool(nn.Module):
    def forward(self, x): return x.mean((-2, -1))

In [8]:
class MNIST_CNN(nn.Module):
    def __init__(self, in_channels, out_channels, nfs=[64, 128, 256]):
        super().__init__()
        channels = [in_channels] + nfs
        self.net = nn.Sequential(*[ConvModule(channels[i], channels[i+1]) for i in range(len(channels) - 1)])
        self.pool = GlobalAvgPool()
        self.last = nn.Linear(channels[-1], out_channels, bias=False)
        
    def forward(self, x):
        return self.last(self.pool(self.net(x)))

In [9]:
class Adam():
    def __init__(self, params, lr, beta1=0.9, beta2=0.999, eps=1e-8):
        self.params = list(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        for p in self.params:
            p.mom = torch.zeros_like(p.data)
            p.sqr_mom = torch.zeros_like(p.data)\

        self.t = 0
        
    def step(self):
        self.t += 1
        for p in self.params:
            p.mom = self.beta1 * p.mom + (1 - self.beta1) * p.grad
            p.sqr_mom = self.beta2 * p.sqr_mom + (1 - self.beta2) * p.grad ** 2
            mom_corr = p.mom / (1 - self.beta1 ** self.t)
            sqr_mom_corr = p.sqr_mom / (1 - self.beta2 ** self.t)
            
            p.data -= self.lr * mom_corr / (sqr_mom_corr.sqrt() + self.eps)
            
    def zero_grad(self):
        for p in self.params:
            p.grad = None

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MNIST_CNN(1, 10).to(device)
epochs = 10
optim = Adam(model.parameters(), 0.01)

In [11]:
for i in range(epochs):
    for xb, yb in dl_train:
        xb, yb = xb.to(device), yb.to(device)
        optim.zero_grad()
        out = model(xb)
        loss = F.cross_entropy(out, yb)
        loss.backward()
        print(loss)
        optim.step()

tensor(2.3018, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.3378, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.2233, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.1167, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.0866, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.0032, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.9524, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.9092, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.8800, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.8251, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.7785, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.7090, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.6498, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.6574, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.5746, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.6128, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.5108, device='cuda:0', grad_fn=

In [19]:
acc = 0
for xb, yb in dl_val:
    xb, yb = xb.to(device), yb.to(device)
    probs = F.softmax(model(xb), dim=-1)
    preds = probs.max(dim=1).indices
    acc += (preds == yb).to(int).sum()

acc

tensor(11778, device='cuda:0')

In [21]:
len(ds_valid)

12000