# Milking the final drops of Accuracy

So... you've trained a model.  

Now what? You could squeeze a few more epochs, but that'll only lead to small, negligible gains.  

Or you could really scrape the bottom of the barrel and try the techniques in this notebook...? We'll try to go through some methods (e.g. SWA, KD) to get better results from our model. We use the model from the Speedrunning Model Training notebook, which itself is undertrained. Thus, this notebook will come in very handy.

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18
import pytorch_lightning as pl
import torchmetrics
from torchvision.transforms import v2

In [2]:
img_transforms = v2.Compose([
    v2.ToImage(),
])

train_ds = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=img_transforms)
test_ds = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=img_transforms)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=4096, shuffle=False, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

Our model is just a ResNet18 trained on CIFAR-10.

In [8]:
model = resnet18()
model.fc = nn.Linear(512, 10)
model.load_state_dict(torch.load("model.pth", weights_only=True))

<All keys matched successfully>

## 1. Stochastic Weight Averaging (SWA)
SWA is a method where you average the weights of a model over several training iterations. This way, you get a more generalized model, which should hopefully perform better.  

Thankfully, PyTorch provides some help on how to do so! But since we're lazy people we can just use `pytorch_lightning`'s in-built feature. 

In [49]:
class LitSWAResNet(pl.LightningModule):
    def __init__(self, resnet_cls, lr):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.lr = lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

        self.img_scale = v2.ToDtype(torch.float32, scale=True)
        self.img_norm = v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
        self.img_aug = v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10)

    def training_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(self.img_aug(X)))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(X))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.lr)
        return optimizer

model = LitSWAResNet(resnet18, 1e-5)
model.resnet.load_state_dict(torch.load("model.pth", weights_only=True))

<All keys matched successfully>

After playing around a while I found that decreasing the batch size + SWA gives us a much larger boost. This is because decreasing the batch size increases the noise of the gradients, letting us explore more area, while SWA lets us average these to get good results.

In [50]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)

In [51]:
swa_callback = pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-3, swa_epoch_start=5, annealing_epochs=5)
trainer = pl.Trainer(precision="bf16-mixed", max_epochs=15, check_val_every_n_epoch=1, log_every_n_steps=5, callbacks=[swa_callback])
trainer.fit(model, train_loader, test_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | resnet    | ResNet             | 11.2 M | train
1 | accuracy  | MulticlassAccuracy | 0      | train
2 | img_scale | ToDtype            | 0      | train
3 | img_norm  | Normalize          | 0      | train
4 | img_aug   | AutoAugment        | 0      | train
---------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
72        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f36143f7ac0>
Traceback (most recent call last):
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f36143f7ac0>
Traceback (most recent call last):
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f36143f7ac0>
Traceback (most recent call last):
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/venv/main/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=16` reached.


In [52]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.6302424669265747, 'val_acc': 0.792900025844574}]

That's +1.2%! Not bad.

## 2: Self-Distillation
Self-distillation is a special type of knowledge distillation, except instead of distilling the knowledge into a smaller model, you distill into a model of the same size. I've really wanted to see this for a long time now, so let's try it out!

In [54]:
teacher_model = model.resnet
for param in teacher_model.parameters():
    param.requires_grad = False

In [81]:
class LSelfDistillStudent(pl.LightningModule):
    def __init__(self, resnet_cls, teacher_model, temperature, max_lr, min_lr, kd_balance=0.5):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.max_lr = max_lr
        self.min_lr = min_lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

        self.img_scale = v2.ToDtype(torch.float32, scale=True)
        self.img_norm = v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
        self.img_aug = v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10)

        self.teacher = teacher_model
        self.temperature = temperature
        self.kd_balance = kd_balance

    def training_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(self.img_aug(X)))

        with torch.inference_mode():
            teacher_logits = self.teacher(X)
        
        student_logits = self.resnet(X)

        # We soften everything by dividing by temperature so that the targets arent so sharp, and easier to learn
        # Also log_softmax for kl_div
        soft_teacher_targets = nn.functional.log_softmax(teacher_logits / self.temperature, dim=-1)
        soft_student_probs = nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
        
        labels_loss = nn.functional.cross_entropy(student_logits, y)
        self.log("train_labels_loss", labels_loss)
        teacher_loss = nn.functional.kl_div(soft_student_probs, soft_teacher_targets, log_target=True) * (self.temperature**2) # scaled according to original distillation paper
        self.log("train_teacher_loss", teacher_loss)
        loss = self.kd_balance*labels_loss + (1-self.kd_balance)*teacher_loss
        
        self.log("train_loss", loss)
        acc = self.accuracy(student_logits, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(X))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.min_lr)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.max_lr, epochs=100, steps_per_epoch=len(train_loader), div_factor=self.max_lr/self.min_lr, pct_start=0.1)
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

Let's get the batch size back to 4096 for training...

In [82]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=4096, shuffle=True, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=4096, shuffle=False, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)

In [83]:
model = LSelfDistillStudent(resnet18, teacher_model, 2.0, 2e-3, 1e-5)
trainer = pl.Trainer(precision="bf16-mixed", max_epochs=100, check_val_every_n_epoch=25, log_every_n_steps=5) 
trainer.fit(model, train_loader, test_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | resnet    | ResNet             | 11.2 M | train
1 | accuracy  | MulticlassAccuracy | 0      | train
2 | img_scale | ToDtype            | 0      | train
3 | img_norm  | Normalize          | 0      | train
4 | img_aug   | AutoAugment        | 0      | train
5 | teacher   | ResNet             | 11.2 M | train
---------------------------------------------------------
11.2 M    Trainable params
11.2 M    Non-trainable params
22.4 M    Total params
89.453    Total estimated model params size (MB)
140       Modules in train

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [84]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.7320585250854492, 'val_acc': 0.7882999777793884}]

That's +0.8% from the original result! We can probably also throw SWA at this...

## 2.5: SWA + Distilled Student

In [85]:
torch.save(model.resnet.state_dict(), "model_student.pth")

In [86]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=3, pin_memory=True, prefetch_factor=2, persistent_workers=True)

Load the no-scheduler model, as per above:

In [89]:
class LitSWAResNet(pl.LightningModule):
    def __init__(self, resnet_cls, lr):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.lr = lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

        self.img_scale = v2.ToDtype(torch.float32, scale=True)
        self.img_norm = v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
        self.img_aug = v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10)

    def training_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(self.img_aug(X)))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(X))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.lr)
        return optimizer

model = LitSWAResNet(resnet18, 1e-5)
model.resnet.load_state_dict(torch.load("model_student.pth", weights_only=True))

<All keys matched successfully>

In [90]:
swa_callback = pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-3, swa_epoch_start=5, annealing_epochs=5)
trainer = pl.Trainer(precision="bf16-mixed", max_epochs=15, check_val_every_n_epoch=1, log_every_n_steps=5, callbacks=[swa_callback])
trainer.fit(model, train_loader, test_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | resnet    | ResNet             | 11.2 M | train
1 | accuracy  | MulticlassAccuracy | 0      | train
2 | img_scale | ToDtype            | 0      | train
3 | img_norm  | Normalize          | 0      | train
4 | img_aug   | AutoAugment        | 0      | train
---------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
72        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=16` reached.


In [91]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.6451417803764343, 'val_acc': 0.795799970626831}]

79.5%!!!!!!!  

That's +0.2% more than the previous one!!!!!  

But can we do better?  

There's one last thing we can try that involves no training at all. (No we are not doing ensembling although you could definitely do that for better performance!)

## 3. Test-Time Augmentation
By augmenting the images during test time and averaging the outputs, we can see if there is an increase in accuracy. Sounds counter-intuitive, but when you hve more changes to look at the same image, maybe you'll do better?

In [92]:
torch.save(model.resnet.state_dict(), "model_student_swa.pth")

In [157]:
class LitSWAResNet(pl.LightningModule):
    def __init__(self, resnet_cls, lr):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.lr = lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

        self.img_scale = v2.ToDtype(torch.float32, scale=True)
        self.img_norm = v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
        self.img_aug = v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10)

        self.test_augs = [
            None,
            v2.RandomHorizontalFlip(p=1.0)
        ]

    def training_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(self.img_aug(X)))
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = self.img_norm(self.img_scale(X))
        
        preds = []
        for aug in self.test_augs:
            if aug is not None:
                y_pred = self.resnet(aug(X))
            else:
                y_pred = self.resnet(X)
            preds.append(y_pred)
        y_pred = torch.mean(torch.stack(preds, dim=0), dim=0)
        
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.lr)
        return optimizer

model = LitSWAResNet(resnet18, 1e-5)
model.resnet.load_state_dict(torch.load("model_student_swa.pth", weights_only=True))

<All keys matched successfully>

In [158]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.5610480308532715, 'val_acc': 0.8126999735832214}]

Bam! Instant +1.7% improvement. One thing to note is what augmentations you use though - initially I used multiple (vertical flip, rotate 90 degrees) but I realized they were performing poorly because the model was not trained on that.  

To be fair the entire model was trained with very weak augmentation. It could defintely perform better with those.