In [None]:
import torch
from torch import nn
from torch.optim import SGD
from datasets import load_dataset
from torch.utils.data import DataLoader
from minai.datasets import inplace
import torchvision.transforms.functional as TF
import torch.nn.functional as F

In [None]:
mnist_ds = load_dataset('mnist')
fashion_ds = load_dataset('fashion_mnist')

In [None]:
x, y = 'image', 'label'

In [None]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [None]:
bs = 1024
mnist_dsd = mnist_ds.with_transform(transformi)
fashion_dsd = fashion_ds.with_transform(transformi)

In [None]:
class DataLoaders:
    def __init__(self, *dls): self.train, self.valid = dls[:2]
        
    @classmethod
    def from_dd(cls, dd, batch_size, num_workers=4, as_tuple=True):
        return cls(*[DataLoader(ds, batch_size, collate_fn=collate_dict(ds)) for ds in dd.values()])

In [None]:
mnist_dls = DataLoaders.from_dd(mnist_dsd, bs)
mnist_dt = mnist_dls.train
xb, yb = next(iter(mnist_dt))
xb.shape, yb[:10]

(torch.Size([1024, 784]), tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4]))

In [None]:
def train_v1(dl, epochs):
    model = nn.Sequential(nn.Linear(784, 50), nn.ReLU(), nn.Linear(50, 10))
    opt = SGD(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        losses = []
        ns = []
        for xb, yb in dl:
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            losses.append(loss * len(xb))
            ns.append(len(xb))

            loss.backward()
            opt.step()
            opt.zero_grad()
        print(sum(losses)/ sum(ns))

In [None]:
train_v1(mnist_dt, 1)

tensor(2.31, grad_fn=<DivBackward0>)


In [None]:
def train_v2(dl, epochs):
    model = nn.Sequential(nn.Linear(784, 50), nn.ReLU(), nn.Linear(50, 10))
    opt = SGD(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs):
        losses = []
        ns = []
        for xb, yb in dl:
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            losses.append(loss * len(xb))
            ns.append(len(xb))

            opt.zero_grad()
            loss.backward()
            opt.step()
            
        print(sum(losses)/ sum(ns))

In [None]:
train_v2(mnist_dt, 1)

tensor(2.31, grad_fn=<DivBackward0>)
