Use KL Divergence loss on Knowledge Distillation Task. You can use any teacher and student model (prefer small models). You need to show that it works, and update README.md with proper logs


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

try:
    from torchsummary import summary
except ModuleNotFoundError:
    !pip install torchsummary
    from torchsummary import summary

from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision

import os
import time
import math

In [4]:
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                      #  transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                                       transforms.RandomAffine(degrees=10, shear = 10),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

train = datasets.CIFAR10(root = './data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

# Do we have CUDA drivers for us?
cuda = torch.cuda.is_available()
print ("Cuda Available?", cuda)

dataloader_args = dict(shuffle=True, batch_size=2048, num_workers=2, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# Dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train, **dataloader_args)
test_loader = torch.utils.data.DataLoader(dataset=test, **dataloader_args)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified
Cuda Available? True


In [5]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 64, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=64)
        self.conv12 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=64)
        self.conv13 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=64)
        self.conv14 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=64)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv21 = nn.Conv2d(64, 128, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=128)
        self.conv22 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=128)
        self.conv23 = nn.Conv2d(128,128, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=128)
        self.conv24 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=128)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(128, 128, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [4]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 24, 24]           2,320
            Conv2d-5           [-1, 16, 20, 20]           2,320
            Conv2d-6           [-1, 16, 16, 16]           2,320
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
        MaxPool2d-13           [-1, 16, 16, 16]               0
           Conv2d-14           [-1, 16,

Loss=1.6608870029449463 Batch_id=24 Accuracy=25.05: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 0, Avg Training Loss: 2.0001, Avg Time Taken = 442.69ms






Test set: Average loss: 2.3458, Accuracy: 1373/10000 (13.73%)

Epoch 2/100


Loss=1.409379005432129 Batch_id=24 Accuracy=42.47: 100%|██████████| 25/25 [00:20<00:00,  1.24it/s]

 --> EPOCH: 1, Avg Training Loss: 1.5328, Avg Time Taken = 434.34ms






Test set: Average loss: 1.5188, Accuracy: 4367/10000 (43.67%)

Epoch 3/100


Loss=1.2821158170700073 Batch_id=24 Accuracy=50.34: 100%|██████████| 25/25 [00:19<00:00,  1.31it/s]

 --> EPOCH: 2, Avg Training Loss: 1.3499, Avg Time Taken = 442.18ms






Test set: Average loss: 1.3797, Accuracy: 4987/10000 (49.87%)

Epoch 4/100


Loss=1.1426337957382202 Batch_id=24 Accuracy=56.14: 100%|██████████| 25/25 [00:19<00:00,  1.31it/s]

 --> EPOCH: 3, Avg Training Loss: 1.2080, Avg Time Taken = 456.69ms






Test set: Average loss: 1.1863, Accuracy: 5738/10000 (57.38%)

Epoch 5/100


Loss=1.1013754606246948 Batch_id=24 Accuracy=60.83: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 4, Avg Training Loss: 1.0925, Avg Time Taken = 467.66ms






Test set: Average loss: 1.1440, Accuracy: 6050/10000 (60.50%)

Epoch 6/100


Loss=1.0109670162200928 Batch_id=24 Accuracy=63.88: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

 --> EPOCH: 5, Avg Training Loss: 1.0116, Avg Time Taken = 481.71ms






Test set: Average loss: 1.0084, Accuracy: 6420/10000 (64.20%)

Epoch 7/100


Loss=0.8873741626739502 Batch_id=24 Accuracy=66.47: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 6, Avg Training Loss: 0.9331, Avg Time Taken = 469.92ms






Test set: Average loss: 1.0001, Accuracy: 6456/10000 (64.56%)

Epoch 8/100


Loss=0.8859184384346008 Batch_id=24 Accuracy=68.99: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 7, Avg Training Loss: 0.8799, Avg Time Taken = 465.46ms






Test set: Average loss: 0.9434, Accuracy: 6702/10000 (67.02%)

Epoch 9/100


Loss=0.8045757412910461 Batch_id=24 Accuracy=71.00: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 8, Avg Training Loss: 0.8151, Avg Time Taken = 470.15ms






Test set: Average loss: 0.8705, Accuracy: 7017/10000 (70.17%)

Epoch 10/100


Loss=0.7854903340339661 Batch_id=24 Accuracy=73.25: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 9, Avg Training Loss: 0.7588, Avg Time Taken = 471.01ms






Test set: Average loss: 0.8487, Accuracy: 7085/10000 (70.85%)

Epoch 11/100


Loss=0.6698622703552246 Batch_id=24 Accuracy=74.93: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

 --> EPOCH: 10, Avg Training Loss: 0.7134, Avg Time Taken = 468.97ms






Test set: Average loss: 0.8550, Accuracy: 7061/10000 (70.61%)

Epoch 12/100


Loss=0.6664386987686157 Batch_id=24 Accuracy=76.31: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 11, Avg Training Loss: 0.6699, Avg Time Taken = 470.06ms






Test set: Average loss: 0.8167, Accuracy: 7261/10000 (72.61%)

Epoch 13/100


Loss=0.5815029740333557 Batch_id=24 Accuracy=77.97: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]

 --> EPOCH: 12, Avg Training Loss: 0.6254, Avg Time Taken = 472.03ms






Test set: Average loss: 0.7687, Accuracy: 7369/10000 (73.69%)

Epoch 14/100


Loss=0.6217029690742493 Batch_id=24 Accuracy=78.91: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 13, Avg Training Loss: 0.5996, Avg Time Taken = 469.57ms






Test set: Average loss: 0.7773, Accuracy: 7391/10000 (73.91%)

Epoch 15/100


Loss=0.5859531760215759 Batch_id=24 Accuracy=80.35: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

 --> EPOCH: 14, Avg Training Loss: 0.5619, Avg Time Taken = 466.26ms






Test set: Average loss: 0.7854, Accuracy: 7337/10000 (73.37%)

Epoch 16/100


Loss=0.5812035799026489 Batch_id=24 Accuracy=81.09: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

 --> EPOCH: 15, Avg Training Loss: 0.5406, Avg Time Taken = 475.19ms






Test set: Average loss: 0.7305, Accuracy: 7539/10000 (75.39%)

Epoch 17/100


Loss=0.5477516055107117 Batch_id=24 Accuracy=82.29: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 16, Avg Training Loss: 0.5071, Avg Time Taken = 469.09ms






Test set: Average loss: 0.7564, Accuracy: 7483/10000 (74.83%)

Epoch 18/100


Loss=0.4621361196041107 Batch_id=24 Accuracy=83.12: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]

 --> EPOCH: 17, Avg Training Loss: 0.4839, Avg Time Taken = 472.00ms






Test set: Average loss: 0.7170, Accuracy: 7680/10000 (76.80%)

Epoch 19/100


Loss=0.4255615174770355 Batch_id=24 Accuracy=84.40: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 18, Avg Training Loss: 0.4501, Avg Time Taken = 470.78ms






Test set: Average loss: 0.7371, Accuracy: 7624/10000 (76.24%)

Epoch 20/100


Loss=0.4585602879524231 Batch_id=24 Accuracy=85.07: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 19, Avg Training Loss: 0.4309, Avg Time Taken = 469.88ms






Test set: Average loss: 0.7230, Accuracy: 7672/10000 (76.72%)

---------- prev = 0.45011252999305723 current = 0.43086560487747194 ---------
Epoch 21/100


Loss=0.3736545741558075 Batch_id=24 Accuracy=85.75: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 20, Avg Training Loss: 0.4051, Avg Time Taken = 469.38ms






Test set: Average loss: 0.7149, Accuracy: 7655/10000 (76.55%)

Epoch 22/100


Loss=0.412638783454895 Batch_id=24 Accuracy=86.63: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 21, Avg Training Loss: 0.3821, Avg Time Taken = 471.27ms






Test set: Average loss: 0.7607, Accuracy: 7632/10000 (76.32%)

Epoch 23/100


Loss=0.4209372401237488 Batch_id=24 Accuracy=87.34: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

 --> EPOCH: 22, Avg Training Loss: 0.3651, Avg Time Taken = 471.82ms






Test set: Average loss: 0.7522, Accuracy: 7672/10000 (76.72%)

---------- prev = 0.38208768963813783 current = 0.3650553143024445 ---------
Epoch 24/100


Loss=0.3752506971359253 Batch_id=24 Accuracy=87.81: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]

 --> EPOCH: 23, Avg Training Loss: 0.3507, Avg Time Taken = 473.14ms






Test set: Average loss: 0.6966, Accuracy: 7817/10000 (78.17%)

---------- prev = 0.3650553143024445 current = 0.3507401490211487 ---------
Epoch 25/100


Loss=0.3359489142894745 Batch_id=24 Accuracy=88.23: 100%|██████████| 25/25 [00:18<00:00,  1.32it/s]

 --> EPOCH: 24, Avg Training Loss: 0.3369, Avg Time Taken = 467.33ms






Test set: Average loss: 0.7221, Accuracy: 7759/10000 (77.59%)

---------- prev = 0.3507401490211487 current = 0.3369123363494873 ---------
Epoch 26/100


Loss=0.34969058632850647 Batch_id=24 Accuracy=89.07: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 25, Avg Training Loss: 0.3139, Avg Time Taken = 469.92ms






Test set: Average loss: 0.7374, Accuracy: 7726/10000 (77.26%)

Epoch 27/100


Loss=0.2797302305698395 Batch_id=24 Accuracy=90.01: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 26, Avg Training Loss: 0.2925, Avg Time Taken = 468.78ms






Test set: Average loss: 0.6983, Accuracy: 7830/10000 (78.30%)

Epoch 28/100


Loss=0.2557670474052429 Batch_id=24 Accuracy=90.24: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 27, Avg Training Loss: 0.2786, Avg Time Taken = 469.72ms






Test set: Average loss: 0.7558, Accuracy: 7702/10000 (77.02%)

---------- prev = 0.2925028991699219 current = 0.27855880856513976 ---------
Epoch 29/100


Loss=0.2288149744272232 Batch_id=24 Accuracy=91.24: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

 --> EPOCH: 28, Avg Training Loss: 0.2550, Avg Time Taken = 469.45ms






Test set: Average loss: 0.7281, Accuracy: 7783/10000 (77.83%)

Epoch 30/100


Loss=0.22104182839393616 Batch_id=24 Accuracy=91.36: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

 --> EPOCH: 29, Avg Training Loss: 0.2484, Avg Time Taken = 468.10ms






Test set: Average loss: 0.7941, Accuracy: 7708/10000 (77.08%)

---------- prev = 0.25495045244693754 current = 0.24840199291706086 ---------
Epoch 31/100


Loss=0.2958418130874634 Batch_id=24 Accuracy=91.63: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 30, Avg Training Loss: 0.2414, Avg Time Taken = 471.31ms






Test set: Average loss: 0.7602, Accuracy: 7795/10000 (77.95%)

---------- prev = 0.24840199291706086 current = 0.2414081174135208 ---------
Epoch 32/100


Loss=0.2600346803665161 Batch_id=24 Accuracy=92.19: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 31, Avg Training Loss: 0.2265, Avg Time Taken = 470.95ms






Test set: Average loss: 0.7412, Accuracy: 7839/10000 (78.39%)

---------- prev = 0.2414081174135208 current = 0.2265439236164093 ---------
Epoch 33/100


Loss=0.213126078248024 Batch_id=24 Accuracy=92.50: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 32, Avg Training Loss: 0.2174, Avg Time Taken = 468.34ms






Test set: Average loss: 0.7855, Accuracy: 7799/10000 (77.99%)

---------- prev = 0.2265439236164093 current = 0.2173517221212387 ---------
Epoch 34/100


Loss=0.20625358819961548 Batch_id=24 Accuracy=92.55: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

 --> EPOCH: 33, Avg Training Loss: 0.2151, Avg Time Taken = 468.82ms






Test set: Average loss: 0.7590, Accuracy: 7844/10000 (78.44%)

---------- prev = 0.2173517221212387 current = 0.2151263326406479 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth
Early stopping triggered!


In [None]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = TeacherModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_heavy_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_heavy_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

In [5]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.6850, Accuracy: 7915/10000 (79.15%)



0.6849886352539063

In [7]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 32, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=32)
        self.conv12 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=32)
        self.conv13 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=32)
        self.conv14 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=32)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(32, 32, 1, bias=False)

        self.conv21 = nn.Conv2d(32, 64, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=64)
        self.conv22 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=64)
        self.conv23 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=64)
        self.conv24 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=64)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, groups=64, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [13]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

In [None]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = StudentModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_small_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

In [9]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.7419, Accuracy: 7566/10000 (75.66%)



0.7419342895507812

# Distillation Process

In [15]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device = {device}')

# Prepare teacher model
teacher = TeacherModel().to(device)
teacher.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth', weights_only=True))
teacher.eval()

# Prepare student model
student = StudentModel().to(device)
student.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth', weights_only=True))

print("=============================================")
print("Student model accuracy before training ")
student.eval()
test(student, device, test_loader)
print("=============================================")

# Loss functions
hard_loss = nn.CrossEntropyLoss() #Hard label loss
soft_loss = nn.KLDivLoss(reduction="batchmean")  # Distillation loss

# Temperature and alpha
T = 5.0  # Temperature
alpha = 0.5  # Weight for distillation loss

# Optimizer
optimizer = optim.Adam(student.parameters(), lr=0.001)

# Training loop
EPOCHS = 100
for epoch in range(EPOCHS):

    correct = 0
    processed = 0
    epoch_loss = 0
    pbar = tqdm(train_loader)
    student.train()
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # Teacher predictions (soft labels)
        with torch.no_grad():
            teacher_logits = teacher(data) / T

        teacher_probs = torch.softmax(teacher_logits, dim=1)

        # Student predictions
        student_logits = student(data)
        student_probs = torch.log_softmax(student_logits / T, dim=1)

        # Compute losses
        loss_soft = soft_loss(student_probs, teacher_probs) * (T ** 2)  # Scale by T^2
        loss_hard = hard_loss(student_logits, target)
        loss = alpha * loss_hard + (1 - alpha) * loss_soft

        epoch_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

        pred = student_logits.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_small_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

    student.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = student(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



Using device = cuda
Student model accuracy before training 

Test set: Average loss: 0.7419, Accuracy: 7566/10000 (75.66%)



Loss=1.3155590295791626 Batch_id=24 Accuracy=75.65: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]



Test set: Average loss: 0.8766, Accuracy: 7533/10000 (75.33%)



Loss=1.1270556449890137 Batch_id=24 Accuracy=80.29: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]



Test set: Average loss: 0.7717, Accuracy: 7695/10000 (76.95%)



Loss=1.1114906072616577 Batch_id=24 Accuracy=81.71: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]



Test set: Average loss: 0.7617, Accuracy: 7803/10000 (78.03%)



Loss=1.0582565069198608 Batch_id=24 Accuracy=81.95: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]



Test set: Average loss: 0.7608, Accuracy: 7832/10000 (78.32%)



Loss=1.0600461959838867 Batch_id=24 Accuracy=82.41: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]



Test set: Average loss: 0.7835, Accuracy: 7711/10000 (77.11%)



Loss=1.0280518531799316 Batch_id=24 Accuracy=82.47: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]



Test set: Average loss: 0.7539, Accuracy: 7823/10000 (78.23%)



Loss=1.0234251022338867 Batch_id=24 Accuracy=82.91: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]



Test set: Average loss: 0.7279, Accuracy: 7889/10000 (78.89%)



Loss=0.9528346061706543 Batch_id=24 Accuracy=83.53: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]



Test set: Average loss: 0.7846, Accuracy: 7772/10000 (77.72%)



Loss=1.001296043395996 Batch_id=24 Accuracy=83.85: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]



Test set: Average loss: 0.7395, Accuracy: 7865/10000 (78.65%)



Loss=0.9289317727088928 Batch_id=24 Accuracy=84.13: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]



Test set: Average loss: 0.7789, Accuracy: 7763/10000 (77.63%)



Loss=0.9482051134109497 Batch_id=24 Accuracy=84.53: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]



Test set: Average loss: 0.7370, Accuracy: 7930/10000 (79.30%)



Loss=0.8285729289054871 Batch_id=24 Accuracy=84.97: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]



Test set: Average loss: 0.7871, Accuracy: 7793/10000 (77.93%)



Loss=0.8407131433486938 Batch_id=24 Accuracy=85.06: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]



Test set: Average loss: 0.7056, Accuracy: 7932/10000 (79.32%)



Loss=0.8000921607017517 Batch_id=24 Accuracy=85.44: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]



Test set: Average loss: 0.7124, Accuracy: 7936/10000 (79.36%)



Loss=0.818144679069519 Batch_id=24 Accuracy=85.60: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]



Test set: Average loss: 0.7284, Accuracy: 7913/10000 (79.13%)



Loss=0.8630034327507019 Batch_id=24 Accuracy=86.15: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]



Test set: Average loss: 0.7138, Accuracy: 7928/10000 (79.28%)



Loss=0.8151281476020813 Batch_id=5 Accuracy=85.91:  24%|██▍       | 6/25 [00:05<00:16,  1.12it/s]


KeyboardInterrupt: 