# Milking the final drops of Accuracy

So... you've trained a model.  

Now what? You could squeeze a few more epochs, but that'll only lead to small, negligible gains.  

Or you could really scrape the bottom of the barrel and try the techniques in this notebook...? We'll try to go through some methods (e.g. SWA, KD) to get better results from our model. We use the model from the Speedrunning Model Training notebook, which itself is undertrained. Thus, this notebook will come in very handy.

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18
import pytorch_lightning as pl
import torchmetrics
from torchvision.transforms import v2

In [2]:
train_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((32, 32)),
    v2.AutoAugment(v2.AutoAugmentPolicy.CIFAR10),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
])
test_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize((32, 32)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.48, 0.44, 0.40], [0.22, 0.22, 0.22])
])

train_ds = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=train_transforms)
test_ds = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=test_transforms)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=3, prefetch_factor=1, persistent_workers=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=3, prefetch_factor=1, persistent_workers=True)

In [3]:
torch.set_float32_matmul_precision('high')

Our model is just a ResNet18 trained on CIFAR-10.

In [4]:
model = resnet18()
model.fc = nn.Linear(512, 10)
model.load_state_dict(torch.load("model.pth", weights_only=True))

<All keys matched successfully>

## 1. Stochastic Weight Averaging (SWA)
SWA is a method where you average the weights of a model over several training iterations. This way, you get a more generalized model, which should hopefully perform better.  

Both PyTorch and PyTorch Lightning have in-built methods for performing SWA. I'm a big fan of being lazy, so let's use Lightning. Below is a simple template class that is responsible for handling the training steps and calculate metrics.

In [12]:
class LitSWAResNet(pl.LightningModule):
    def __init__(self, resnet_cls, lr):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.lr = lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.lr)
        return optimizer

model = LitSWAResNet(resnet18, 2e-5) # Our min_lr from the previous notebook - go check it out!
model.resnet.load_state_dict(torch.load("model.pth", weights_only=True))

<All keys matched successfully>

In PyTorch Lightning, the `StochasticWeightAveraging` callback manages SWA. It accepts a `swa_lrs`, or the SWA Learning Rate. It's typically moderately high, such as `1e-3` in this case, as it has to explore the areas around the current local minima.  

We average the weights over 5 epochs.

In [13]:
swa_callback = pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-3, swa_epoch_start=0.0)
trainer = pl.Trainer(precision="bf16-mixed", max_epochs=5, check_val_every_n_epoch=1, log_every_n_steps=5, callbacks=[swa_callback])
trainer.fit(model, train_loader, test_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | resnet   | ResNet             | 11.2 M | train
1 | accuracy | MulticlassAccuracy | 0      | train
--------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=6` reached.


In [14]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.608993411064148, 'val_acc': 0.8374999761581421}]

Alright... a negligible accuracy increase. To be fair, it's all within the margin of error. Besides, SWA is known to give only incremental improvements!

## 2: Self-Distillation
If you've heard of *knowledge distillation* this should be easier to grasp. If you haven't, knowledge distillation (KD) is where a smaller model (called a *student*) is trained to mimic the outputs of a larger, trained model (called a *teacher*). This is because the larger model's outputs will not be discrete and fixed like the target objective, but rather *soft targets* that shift with the input data, and provide subtle biases in the non-target distributions that may guide the student model in training.  

For example, when classifying a truck, the target objective may have a `1` for the truck class, and `0` everywhere else. But a teacher model may put some weight on classes like bus, car, lorry - which helps the model learn more quickly the associations between classes.  

Self-distillation, then is a *special type of knowledge distillation*. Instead of distilling the knowledge into a smaller model, you distill into a model of the **same size**. Papers suggest that this is because providing a teacher model leads to smoother gradients towards better minima - check out https://arxiv.org/pdf/2206.08491 for a rather enjoyable read on this topic!

In [15]:
teacher_model = model.resnet
for param in teacher_model.parameters():
    param.requires_grad = False

We take the class we defined for training the teacher model (from the previous notebook) and modify it slightly!

In [17]:
class LSelfDistillStudent(pl.LightningModule):
    def __init__(self, resnet_cls, teacher_model, temperature, max_lr, min_lr, epochs, wd, kd_balance=0.5):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.epochs = epochs
        self.wd = wd

        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)
        
        self.apply(self._init_weights)
        nn.init.xavier_uniform_(self.resnet.fc.weight)
        nn.init.zeros_(self.resnet.fc.bias)

        #######################################
        # NEW! Add self-distillation parameters
        self.teacher = teacher_model
        self.temperature = temperature
        self.kd_balance = kd_balance
        #######################################
        
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu", mode="fan_out")
            if m.bias is not None:
                nn.init.zeros_(m.bias)

        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def training_step(self, batch, batch_idx):
        X, y = batch
        
        #########################
        # NEW: Distillation logic
        with torch.inference_mode():
            teacher_logits = self.teacher(X)
        
        student_logits = self.resnet(X)

        # We soften everything by dividing by temperature so that the targets arent so sharp, and easier to learn
        # Also log_softmax for KL divergence for numerical stability
        soft_teacher_targets = nn.functional.log_softmax(teacher_logits / self.temperature, dim=-1)
        soft_student_probs = nn.functional.log_softmax(student_logits / self.temperature, dim=-1)
        
        labels_loss = nn.functional.cross_entropy(student_logits, y)
        self.log("train_labels_loss", labels_loss)
        teacher_loss = nn.functional.kl_div(soft_student_probs, soft_teacher_targets, log_target=True) * (self.temperature**2) # scaled according to original distillation paper
        self.log("train_teacher_loss", teacher_loss)
        loss = self.kd_balance*labels_loss + (1-self.kd_balance)*teacher_loss
        #########################
        
        self.log("train_loss", loss)
        acc = self.accuracy(student_logits, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.min_lr, weight_decay=self.wd)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.max_lr, epochs=self.epochs, steps_per_epoch=len(train_loader), div_factor=self.max_lr/self.min_lr)
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

In [18]:
%%time
model = torch.compile(LSelfDistillStudent(resnet18, teacher_model, 2.0, 5e-3, 2e-5, 100, 1e-2))
trainer = pl.Trainer(precision="bf16-mixed", max_epochs=100, check_val_every_n_epoch=5, log_every_n_steps=5) 
trainer.fit(model, train_loader, test_loader)

Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params | Mode 
--------------------------------------------------------
0 | resnet   | ResNet             | 11.2 M | train
1 | accuracy | MulticlassAccuracy | 0      | train
2 | teacher  | ResNet             | 11.2 M | train
--------------------------------------------------------
11.2 M    Trainable params
11.2 M    Non-trainable params
22.4 M    Total params
89.453    Total estimated model params size (MB)
137       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

W0620 07:18:11.961000 17128 torch/_inductor/utils.py:1250] [5/0_1] Not enough SMs to use max_autotune_gemm mode
W0620 07:18:19.905000 17128 torch/_dynamo/convert_frame.py:964] [1/8] torch._dynamo hit config.recompile_limit (8)
W0620 07:18:19.905000 17128 torch/_dynamo/convert_frame.py:964] [1/8]    function: 'log' (/venv/main/lib/python3.12/site-packages/pytorch_lightning/core/module.py:376)
W0620 07:18:19.905000 17128 torch/_dynamo/convert_frame.py:964] [1/8]    last reason: 1/7: name == 'train_acc'                                    
W0620 07:18:19.905000 17128 torch/_dynamo/convert_frame.py:964] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0620 07:18:19.905000 17128 torch/_dynamo/convert_frame.py:964] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


CPU times: user 11min 24s, sys: 39.2 s, total: 12min 3s
Wall time: 11min 45s


In [19]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.5827531218528748, 'val_acc': 0.8440999984741211}]

And bam! +0.7% from the original result!  

We're almost at the end - there's one last thing we can try that involves no training at all. (No we are not doing ensembling although you could definitely do that for better performance!)

## 3. Test-Time Augmentation
By augmenting the images during test time and averaging the outputs, we can see if there is an increase in accuracy. Sounds counter-intuitive, but when the model has more chances to look at the same image from different angles, the model can make a better guess.

In [60]:
class LitSWAResNet(pl.LightningModule):
    def __init__(self, resnet_cls, lr):
        super().__init__()
        self.resnet = resnet_cls()
        self.resnet.fc = nn.Linear(512, 10)
        self.lr = lr
        
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=10)

        ###################################
        # NEW! Add test-time augmentations.
        self.test_augs = [
            None,
            v2.RandomHorizontalFlip(p=1.0)
        ]
        ###################################
        

    def training_step(self, batch, batch_idx):
        X, y = batch
        y_pred = self.resnet(X)
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("train_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("train_acc", acc)

        return loss

    def validation_step(self, batch, batch_idx):
        X, y = batch
        #####################################
        # NEW! Test-time augmentation itself.
        preds = []
        for aug in self.test_augs:
            if aug is not None:
                y_pred = self.resnet(aug(X))
            else:
                y_pred = self.resnet(X)
            preds.append(y_pred)
        y_pred = torch.mean(torch.stack(preds, dim=0), dim=0)
        #####################################
        loss = nn.functional.cross_entropy(y_pred, y)
        self.log("val_loss", loss)
        acc = self.accuracy(y_pred, y)
        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.resnet.parameters(), lr=self.lr)
        return optimizer

model = LitSWAResNet(resnet18, 2e-5)
model.resnet.load_state_dict(torch.load("model_student.pth", weights_only=True))

<All keys matched successfully>

In [61]:
trainer.validate(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.45302221179008484, 'val_acc': 0.8666999936103821}]

Bam! Instant +2.2% improvement. One thing to note is what augmentations you use though - initially I used multiple (vertical flip, rotate 90 degrees) but I realized they were performing poorly because the model was not trained with those augmentations.  

To be fair the entire model was trained with very weak augmentation. It could defintely perform better with those.