<a href="https://colab.research.google.com/github/alejandrodgb/fastai/blob/main/clean/04_MNIST_full.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if 'google.colab' in sys.modules:
    ! [ -e /content ] && pip install -Uqq fastbook
    !pip install nbdev
    
import fastbook
fastbook.setup_book()

In [2]:
from fastai.vision.all import *
from fastbook import *

# Data

In [3]:
# Download full MNIST dataset
path = untar_data(URLs.MNIST)
path

Path('/Users/adgb/.fastai/data/mnist_png')

# Fastai classifier

In [None]:
# Build data loader
dls = ImageDataLoaders.from_folder(path,train='training' ,valid='testing')

In [None]:
learn = vision_learner(dls,resnet18,pretrained=False,
                       loss_func=F.cross_entropy,metrics=accuracy)

In [None]:
learn.fit_one_cycle(2,0.1)

# Custom trainer

## Data

In [7]:
# Review directories
training_path = (path/'training').ls().sorted()
valid_path = (path/'testing').ls().sorted()

training_path, valid_path

((#10) [Path('/Users/adgb/.fastai/data/mnist_png/training/0'),Path('/Users/adgb/.fastai/data/mnist_png/training/1'),Path('/Users/adgb/.fastai/data/mnist_png/training/2'),Path('/Users/adgb/.fastai/data/mnist_png/training/3'),Path('/Users/adgb/.fastai/data/mnist_png/training/4'),Path('/Users/adgb/.fastai/data/mnist_png/training/5'),Path('/Users/adgb/.fastai/data/mnist_png/training/6'),Path('/Users/adgb/.fastai/data/mnist_png/training/7'),Path('/Users/adgb/.fastai/data/mnist_png/training/8'),Path('/Users/adgb/.fastai/data/mnist_png/training/9')],
 (#10) [Path('/Users/adgb/.fastai/data/mnist_png/testing/0'),Path('/Users/adgb/.fastai/data/mnist_png/testing/1'),Path('/Users/adgb/.fastai/data/mnist_png/testing/2'),Path('/Users/adgb/.fastai/data/mnist_png/testing/3'),Path('/Users/adgb/.fastai/data/mnist_png/testing/4'),Path('/Users/adgb/.fastai/data/mnist_png/testing/5'),Path('/Users/adgb/.fastai/data/mnist_png/testing/6'),Path('/Users/adgb/.fastai/data/mnist_png/testing/7'),Path('/Users/adgb/

In [8]:
# Check training data
av_imgs = [len(i.ls()) for i in training_path]
av_imgs, sum(av_imgs)

([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949], 60000)

In [9]:
# Check testing data
av_imgs = [len(i.ls()) for i in valid_path]
av_imgs, sum(av_imgs)

([980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009], 10000)

In [10]:
# List of lists of image paths
train_nums = [n.ls() for n in training_path]
valid_nums = [n.ls() for n in valid_path]

In [11]:
# Create a list of tensors for each digit
train_nums_lists = [[tensor(Image.open(im)) for im in num] for num in train_nums]
valid_nums_lists = [[tensor(Image.open(im)) for im in num] for num in valid_nums]

In [13]:
# Stack all tensors, convert to float, and normalise
train_nums_tensors = [torch.stack(nums_list).float()/255 for nums_list in train_nums_lists]
valid_nums_tensors = [torch.stack(nums_list).float()/255 for nums_list in valid_nums_lists]

In [14]:
# Check test tensors
av_t = [t.shape for t in train_nums_tensors]
av_t,[i[0] for i in av_t], sum([i[0] for i in av_t])

([torch.Size([5923, 28, 28]),
  torch.Size([6742, 28, 28]),
  torch.Size([5958, 28, 28]),
  torch.Size([6131, 28, 28]),
  torch.Size([5842, 28, 28]),
  torch.Size([5421, 28, 28]),
  torch.Size([5918, 28, 28]),
  torch.Size([6265, 28, 28]),
  torch.Size([5851, 28, 28]),
  torch.Size([5949, 28, 28])],
 [5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949],
 60000)

In [15]:
# Check valid tensors
av_t = [t.shape for t in valid_nums_tensors]
av_t,[i[0] for i in av_t], sum([i[0] for i in av_t])

([torch.Size([980, 28, 28]),
  torch.Size([1135, 28, 28]),
  torch.Size([1032, 28, 28]),
  torch.Size([1010, 28, 28]),
  torch.Size([982, 28, 28]),
  torch.Size([892, 28, 28]),
  torch.Size([958, 28, 28]),
  torch.Size([1028, 28, 28]),
  torch.Size([974, 28, 28]),
  torch.Size([1009, 28, 28])],
 [980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009],
 10000)

In [16]:
# Transform rank-3 to rank-2 tensors
train_x = torch.cat(train_nums_tensors).view(-1,28*28)
valid_x = torch.cat(valid_nums_tensors).view(-1,28*28)

train_x.shape, valid_x.shape

(torch.Size([60000, 784]), torch.Size([10000, 784]))

In [17]:
# Create target labels
train_y = torch.cat([tensor([i]*s).unsqueeze(1) for i,s in enumerate([t.shape[0] for t in train_nums_tensors])])
valid_y = torch.cat([tensor([i]*s).unsqueeze(1) for i,s in enumerate([t.shape[0] for t in valid_nums_tensors])])

train_y.shape, valid_y.shape

(torch.Size([60000, 1]), torch.Size([10000, 1]))

In [18]:
# Create validation and testing PyTorch datasets
train_dset = list(zip(train_x, train_y))
valid_dset = list(zip(valid_x, valid_y))

In [46]:
# Create data loader
train_dl = DataLoader(train_dset, bs=64, shuffle=True)

## Simple net model

In [197]:
# Multiclass loss
def mnist_loss(preds, targs):
    preds = preds.softmax(dim=1)
    return -(torch.log(preds[np.arange(len(preds)),targs.squeeze()])).mean()


In [30]:
# Create optimiser
class BasicOptim:
    def __init__(self,parameters,lr):
        self.params = list(parameters)
        self.lr = lr
        
    def step(self, *args, **kwargs):
        for p in self.params:
            p.data -= self.lr * p.grad.data
            
    def zero_grad(self, *args, **kwargs):
        for p in self.params:
            p.grad = None

In [31]:
def batch_accuracy(preds, y):
    return (preds.argmax(dim=preds.dim()-1)==y.squeeze()).float().mean()

In [32]:
def calc_grads(model, x, y, loss_fn):
    preds = model(x)
    loss = loss_fn(preds,y.squeeze())
    loss.backward()
    return loss.item()

In [33]:
def validate_epoch(model, valid_dl):
    accs = [batch_accuracy(model(xb),yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(),4)

In [56]:
def train_epoch(dl, model, loss_fn, opt):
    batch_loss = []
    for xb, yb in dl:
        batch_loss.append(calc_grads(model, xb, yb, loss_fn))
        opt.step()
        opt.zero_grad()
    print(f'\tAvg batch loss: {tensor(batch_loss).mean()}')

In [193]:
# Train model
def train_model(dl, model, epochs, loss_fn, valid_dl, opt):
    model_start = time.time()
    for e in range(epochs):
        epoch_start = time.time()
        train_epoch(dl, model, loss_fn, opt)
        print(f'Epoch {e}: accuracy = {validate_epoch(model, valid_dl)}, time {time.time()-epoch_start:.2f}s')
    print(f'Training time: {time.time()-model_start:.2f}s')

### Testing

In [99]:
# Simple neural network
simple_net = nn.Sequential(
    nn.Linear(28*28,100),
    nn.ReLU(),
    nn.Linear(100,30),
    nn.ReLU(),
    nn.Linear(30,10)
)
opt = BasicOptim(simple_net.parameters(), 0.1)

In [100]:
for xb, yb in train_dl:
    break

#### Categorical $L_{CE}$

In [175]:
preds = simple_net(xb)
preds.shape

torch.Size([64, 10])

In [176]:
num_classes=len(preds) if preds.dim()==1 else preds.shape[1]
num_classes

10

In [177]:
targs = F.one_hot(yb,num_classes=num_classes).view(-1,num_classes)
targs.shape

torch.Size([64, 10])

In [178]:
s_preds = preds.softmax(dim=1)

In [179]:
targs.shape, s_preds.shape

(torch.Size([64, 10]), torch.Size([64, 10]))

In [180]:
-(targs*torch.log(s_preds)).sum(dim=1).mean()

tensor(2.3130, grad_fn=<NegBackward0>)

#### Sparce $L_{CE}$

In [181]:
s_preds = preds.softmax(dim=1)

In [182]:
-(torch.log(s_preds[np.arange(len(s_preds)),yb.squeeze()])).mean()

tensor(2.3130, grad_fn=<NegBackward0>)

#### PyTorch $L_{CE}$

In [183]:
tcel = nn.CrossEntropyLoss()
tcel(preds,yb.squeeze())

tensor(2.3130, grad_fn=<NllLossBackward0>)

### Model training

In [199]:
# Simple neural network
simple_net = nn.Sequential(
    nn.Linear(28*28,100),
    nn.ReLU(),
    nn.Linear(100,30),
    nn.ReLU(),
    nn.Linear(30,10)
)
opt = BasicOptim(simple_net.parameters(), 0.1)

#### PyTorch $L_{CE}$

In [200]:
train_model(dl=train_dl, model=simple_net, epochs=10, loss_fn=nn.CrossEntropyLoss(), 
            valid_dl=valid_dset, opt=opt)

	Avg batch loss: 0.5137635469436646
Epoch 0: accuracy = 0.922, time 1.70s
	Avg batch loss: 0.19117991626262665
Epoch 1: accuracy = 0.9567, time 1.63s
	Avg batch loss: 0.1321459710597992
Epoch 2: accuracy = 0.9637, time 2.18s
	Avg batch loss: 0.1020367369055748
Epoch 3: accuracy = 0.966, time 1.69s
	Avg batch loss: 0.08401274681091309
Epoch 4: accuracy = 0.9728, time 1.63s
	Avg batch loss: 0.0699848011136055
Epoch 5: accuracy = 0.9747, time 1.64s
	Avg batch loss: 0.06066516414284706
Epoch 6: accuracy = 0.9752, time 1.65s
	Avg batch loss: 0.05132593959569931
Epoch 7: accuracy = 0.9765, time 1.63s
	Avg batch loss: 0.04449652507901192
Epoch 8: accuracy = 0.9727, time 1.65s
	Avg batch loss: 0.03736179322004318
Epoch 9: accuracy = 0.9752, time 1.62s
Training time: 17.04s


#### mnist_loss $L_{CE}$

In [201]:
# Simple neural network
simple_net = nn.Sequential(
    nn.Linear(28*28,100),
    nn.ReLU(),
    nn.Linear(100,30),
    nn.ReLU(),
    nn.Linear(30,10)
)
opt = BasicOptim(simple_net.parameters(), 0.1)

In [202]:
train_model(dl=train_dl, model=simple_net, epochs=10, loss_fn=mnist_loss,
            valid_dl=valid_dset, opt=opt)

	Avg batch loss: 0.5181543231010437
Epoch 0: accuracy = 0.9271, time 1.77s
	Avg batch loss: 0.19730697572231293
Epoch 1: accuracy = 0.9566, time 1.72s
	Avg batch loss: 0.13519515097141266
Epoch 2: accuracy = 0.9632, time 1.72s
	Avg batch loss: 0.10361114144325256
Epoch 3: accuracy = 0.9693, time 1.71s
	Avg batch loss: 0.08366364240646362
Epoch 4: accuracy = 0.9634, time 1.71s
	Avg batch loss: 0.07132456451654434
Epoch 5: accuracy = 0.9727, time 1.74s
	Avg batch loss: 0.06049911677837372
Epoch 6: accuracy = 0.9728, time 1.71s
	Avg batch loss: 0.051981858909130096
Epoch 7: accuracy = 0.975, time 1.81s
	Avg batch loss: 0.04509030282497406
Epoch 8: accuracy = 0.9763, time 1.72s
	Avg batch loss: 0.03924497961997986
Epoch 9: accuracy = 0.9781, time 1.79s
Training time: 17.39s
