Use KL Divergence loss on Knowledge Distillation Task. You can use any teacher and student model (prefer small models). You need to show that it works, and update README.md with proper logs


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

try:
    from torchsummary import summary
except ModuleNotFoundError:
    !pip install torchsummary
    from torchsummary import summary

from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision

import os
import time
import math

In [2]:
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                      #  transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                                       transforms.RandomAffine(degrees=10, shear = 10),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

train = datasets.CIFAR10(root = './data', train=True, download=True, transform=train_transforms)
test = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

# Do we have CUDA drivers for us?
cuda = torch.cuda.is_available()
print ("Cuda Available?", cuda)

dataloader_args = dict(shuffle=True, batch_size=2048, num_workers=2, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

# Dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train, **dataloader_args)
test_loader = torch.utils.data.DataLoader(dataset=test, **dataloader_args)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 29.7MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Cuda Available? True


In [5]:
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 64, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=64)
        self.conv12 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=64)
        self.conv13 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=64)
        self.conv14 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=64)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv21 = nn.Conv2d(64, 128, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=128)
        self.conv22 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=128)
        self.conv23 = nn.Conv2d(128,128, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=128)
        self.conv24 = nn.Conv2d(128, 128, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=128)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(128, 128, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [6]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

In [7]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = TeacherModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_heavy_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_heavy_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

Device Using = cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 24, 24]           2,320
            Conv2d-5           [-1, 16, 20, 20]           2,320
            Conv2d-6           [-1, 16, 16, 16]           2,320
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
        MaxPool2d-13           [-1, 16, 16, 16]               0
           Conv2d-1

Loss=1.7569657564163208 Batch_id=24 Accuracy=24.23: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 0, Avg Training Loss: 2.0317, Avg Time Taken = 428.80ms






Test set: Average loss: 2.3616, Accuracy: 1000/10000 (10.00%)

Epoch 2/100


Loss=1.489882469177246 Batch_id=24 Accuracy=42.15: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 1, Avg Training Loss: 1.5687, Avg Time Taken = 417.93ms






Test set: Average loss: 1.6134, Accuracy: 4367/10000 (43.67%)

Epoch 3/100


Loss=1.2746210098266602 Batch_id=24 Accuracy=50.72: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 2, Avg Training Loss: 1.3426, Avg Time Taken = 421.68ms






Test set: Average loss: 1.2827, Accuracy: 5303/10000 (53.03%)

Epoch 4/100


Loss=1.1759039163589478 Batch_id=24 Accuracy=56.56: 100%|██████████| 25/25 [00:18<00:00,  1.39it/s]

 --> EPOCH: 3, Avg Training Loss: 1.2008, Avg Time Taken = 427.61ms






Test set: Average loss: 1.1646, Accuracy: 5778/10000 (57.78%)

Epoch 5/100


Loss=1.0140554904937744 Batch_id=24 Accuracy=61.13: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 4, Avg Training Loss: 1.0761, Avg Time Taken = 430.93ms






Test set: Average loss: 1.1328, Accuracy: 6031/10000 (60.31%)

Epoch 6/100


Loss=0.9696936011314392 Batch_id=24 Accuracy=64.40: 100%|██████████| 25/25 [00:19<00:00,  1.31it/s]

 --> EPOCH: 5, Avg Training Loss: 0.9969, Avg Time Taken = 439.18ms






Test set: Average loss: 1.0748, Accuracy: 6167/10000 (61.67%)

Epoch 7/100


Loss=0.9940269589424133 Batch_id=24 Accuracy=67.41: 100%|██████████| 25/25 [00:17<00:00,  1.41it/s]

 --> EPOCH: 6, Avg Training Loss: 0.9157, Avg Time Taken = 436.29ms






Test set: Average loss: 0.9721, Accuracy: 6614/10000 (66.14%)

Epoch 8/100


Loss=0.8761070370674133 Batch_id=24 Accuracy=70.04: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 7, Avg Training Loss: 0.8490, Avg Time Taken = 447.69ms






Test set: Average loss: 0.9209, Accuracy: 6768/10000 (67.68%)

Epoch 9/100


Loss=0.7703757286071777 Batch_id=24 Accuracy=72.27: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]

 --> EPOCH: 8, Avg Training Loss: 0.7875, Avg Time Taken = 446.57ms






Test set: Average loss: 0.8902, Accuracy: 6950/10000 (69.50%)

Epoch 10/100


Loss=0.7251912355422974 Batch_id=24 Accuracy=74.38: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 9, Avg Training Loss: 0.7319, Avg Time Taken = 454.60ms






Test set: Average loss: 0.8283, Accuracy: 7105/10000 (71.05%)

Epoch 11/100


Loss=0.6894527673721313 Batch_id=24 Accuracy=75.46: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

 --> EPOCH: 10, Avg Training Loss: 0.6942, Avg Time Taken = 457.98ms






Test set: Average loss: 0.8576, Accuracy: 7105/10000 (71.05%)

Epoch 12/100


Loss=0.6242570877075195 Batch_id=24 Accuracy=77.32: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 11, Avg Training Loss: 0.6459, Avg Time Taken = 457.65ms






Test set: Average loss: 0.8157, Accuracy: 7270/10000 (72.70%)

Epoch 13/100


Loss=0.6227222084999084 Batch_id=24 Accuracy=78.55: 100%|██████████| 25/25 [00:19<00:00,  1.28it/s]

 --> EPOCH: 12, Avg Training Loss: 0.6134, Avg Time Taken = 465.67ms






Test set: Average loss: 0.7652, Accuracy: 7420/10000 (74.20%)

Epoch 14/100


Loss=0.5683180093765259 Batch_id=24 Accuracy=79.65: 100%|██████████| 25/25 [00:18<00:00,  1.39it/s]

 --> EPOCH: 13, Avg Training Loss: 0.5806, Avg Time Taken = 469.83ms






Test set: Average loss: 0.7777, Accuracy: 7338/10000 (73.38%)

Epoch 15/100


Loss=0.5646330714225769 Batch_id=24 Accuracy=80.92: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 14, Avg Training Loss: 0.5502, Avg Time Taken = 466.73ms






Test set: Average loss: 0.7263, Accuracy: 7560/10000 (75.60%)

Epoch 16/100


Loss=0.532381534576416 Batch_id=24 Accuracy=82.18: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 15, Avg Training Loss: 0.5161, Avg Time Taken = 463.41ms






Test set: Average loss: 0.7111, Accuracy: 7583/10000 (75.83%)

Epoch 17/100


Loss=0.5037235021591187 Batch_id=24 Accuracy=82.92: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 16, Avg Training Loss: 0.4924, Avg Time Taken = 467.92ms






Test set: Average loss: 0.6844, Accuracy: 7706/10000 (77.06%)

Epoch 18/100


Loss=0.5057856440544128 Batch_id=24 Accuracy=83.41: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 17, Avg Training Loss: 0.4694, Avg Time Taken = 467.64ms






Test set: Average loss: 0.7211, Accuracy: 7617/10000 (76.17%)

Epoch 19/100


Loss=0.47667601704597473 Batch_id=24 Accuracy=84.72: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 18, Avg Training Loss: 0.4428, Avg Time Taken = 465.76ms






Test set: Average loss: 0.6841, Accuracy: 7704/10000 (77.04%)

Epoch 20/100


Loss=0.43096408247947693 Batch_id=24 Accuracy=85.32: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]

 --> EPOCH: 19, Avg Training Loss: 0.4209, Avg Time Taken = 465.79ms






Test set: Average loss: 0.6644, Accuracy: 7787/10000 (77.87%)

Epoch 21/100


Loss=0.41643834114074707 Batch_id=24 Accuracy=85.98: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

 --> EPOCH: 20, Avg Training Loss: 0.4022, Avg Time Taken = 467.56ms






Test set: Average loss: 0.7021, Accuracy: 7748/10000 (77.48%)

---------- prev = 0.4209310495853424 current = 0.4022049951553345 ---------
Epoch 22/100


Loss=0.43559035658836365 Batch_id=24 Accuracy=87.19: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

 --> EPOCH: 21, Avg Training Loss: 0.3750, Avg Time Taken = 466.96ms






Test set: Average loss: 0.7172, Accuracy: 7676/10000 (76.76%)

Epoch 23/100


Loss=0.380073606967926 Batch_id=24 Accuracy=87.66: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 22, Avg Training Loss: 0.3560, Avg Time Taken = 466.01ms






Test set: Average loss: 0.7064, Accuracy: 7730/10000 (77.30%)

---------- prev = 0.3750325775146484 current = 0.3559500801563263 ---------
Epoch 24/100


Loss=0.3013826310634613 Batch_id=24 Accuracy=88.52: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]

 --> EPOCH: 23, Avg Training Loss: 0.3357, Avg Time Taken = 468.97ms






Test set: Average loss: 0.7457, Accuracy: 7755/10000 (77.55%)

Epoch 25/100


Loss=0.3540545701980591 Batch_id=24 Accuracy=89.29: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

 --> EPOCH: 24, Avg Training Loss: 0.3128, Avg Time Taken = 466.98ms






Test set: Average loss: 0.7052, Accuracy: 7764/10000 (77.64%)

Epoch 26/100


Loss=0.34189483523368835 Batch_id=24 Accuracy=89.44: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 25, Avg Training Loss: 0.3084, Avg Time Taken = 468.49ms






Test set: Average loss: 0.7343, Accuracy: 7752/10000 (77.52%)

---------- prev = 0.3127818727493286 current = 0.3083853161334991 ---------
Epoch 27/100


Loss=0.3049859404563904 Batch_id=24 Accuracy=90.14: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

 --> EPOCH: 26, Avg Training Loss: 0.2893, Avg Time Taken = 466.61ms






Test set: Average loss: 0.6879, Accuracy: 7872/10000 (78.72%)

---------- prev = 0.3083853161334991 current = 0.28933569669723513 ---------
Epoch 28/100


Loss=0.26368948817253113 Batch_id=24 Accuracy=90.70: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 27, Avg Training Loss: 0.2702, Avg Time Taken = 464.59ms






Test set: Average loss: 0.7678, Accuracy: 7724/10000 (77.24%)

---------- prev = 0.28933569669723513 current = 0.27017143309116365 ---------
Epoch 29/100


Loss=0.3211618959903717 Batch_id=24 Accuracy=91.05: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]

 --> EPOCH: 28, Avg Training Loss: 0.2634, Avg Time Taken = 467.57ms






Test set: Average loss: 0.7115, Accuracy: 7846/10000 (78.46%)

---------- prev = 0.27017143309116365 current = 0.26343166947364804 ---------
Epoch 30/100


Loss=0.3437854051589966 Batch_id=24 Accuracy=91.23: 100%|██████████| 25/25 [00:18<00:00,  1.32it/s]

 --> EPOCH: 29, Avg Training Loss: 0.2569, Avg Time Taken = 465.19ms






Test set: Average loss: 0.7036, Accuracy: 7891/10000 (78.91%)

---------- prev = 0.26343166947364804 current = 0.2568626445531845 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_heavy_acc_91.pth
Early stopping triggered!


In [8]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_91.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.7036, Accuracy: 7891/10000 (78.91%)



0.7035769897460937

In [9]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv01 = nn.Conv2d(3, 16, 3, bias=False, padding=1)
        self.batch01 = nn.BatchNorm2d(num_features=16)

        # ---- Lets take a skip connection
        self.skip_conv1 = nn.Conv2d(16, 16, 3, padding=0, dilation=2)

        self.conv02 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch02 = nn.BatchNorm2d(num_features=16)
        self.conv03 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch03 = nn.BatchNorm2d(num_features=16)
        self.conv04 = nn.Conv2d(16, 16, 3, bias=False,padding=1)
        self.batch04 = nn.BatchNorm2d(num_features=16)
        self.pool01 = nn.MaxPool2d(2, 2)                                #O=16
        self.conv05 = nn.Conv2d(16, 16, 1, bias=False)

        self.conv11 = nn.Conv2d(16, 32, 3, bias=False, padding=1)
        self.batch11 = nn.BatchNorm2d(num_features=32)
        self.conv12 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch12 = nn.BatchNorm2d(num_features=32)
        self.conv13 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch13 = nn.BatchNorm2d(num_features=32)
        self.conv14 = nn.Conv2d(32, 32, 3, bias=False, padding=1)
        self.batch14 = nn.BatchNorm2d(num_features=32)
        self.pool11 = nn.MaxPool2d(2, 2)                                #O=8
        self.conv15 = nn.Conv2d(32, 32, 1, bias=False)

        self.conv21 = nn.Conv2d(32, 64, 3, bias=False, padding=1)
        self.batch21 = nn.BatchNorm2d(num_features=64)
        self.conv22 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch22 = nn.BatchNorm2d(num_features=64)
        self.conv23 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch23 = nn.BatchNorm2d(num_features=64)
        self.conv24 = nn.Conv2d(64, 64, 3, bias=False, padding=1)
        self.batch24 = nn.BatchNorm2d(num_features=64)
        self.pool21 = nn.MaxPool2d(2, 2)                                #O=4
        self.conv25 = nn.Conv2d(64, 64, 1, bias=False)

        self.conv31 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, groups=64, bias = False, padding = 1)
        self.convPV1= nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, bias = False, padding = 0)
        self.batch31 = nn.BatchNorm2d(num_features=128)
        self.conv32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, groups=128, bias = False, padding = 1)
        self.convPV2= nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, bias = False, padding = 0)
        self.batch32 = nn.BatchNorm2d(num_features=256)


        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.convx3 = nn.Conv2d(256, 10, 1, bias=False, padding=0)

    def forward(self, x):
        x = self.batch01(F.relu(self.conv01(x)))

        # ---- Lets take a skip connection
        skip_channels = self.skip_conv1(self.skip_conv1(self.skip_conv1(self.skip_conv1(x))))

        x = self.batch02(F.relu(self.conv02(x)))
        x = self.batch03(F.relu(self.conv03(x)))
        x = self.batch04(F.relu(self.conv04(x)))
        x = self.pool01(x)
        x = self.conv05(x)
        # ----------------------------------------------------------

        # ---- Lets add the skip connection here
        x = skip_channels + x

        x = self.batch11(F.relu(self.conv11(x)))
        x = self.batch12(F.relu(self.conv12(x)))
        x = self.batch13(F.relu(self.conv13(x)))
        x = self.batch14(F.relu(self.conv14(x)))
        x = self.pool11(x)
        x = self.conv15(x)
        # ----------------------------------------------------------

        x = self.batch21(F.relu(self.conv21(x)))
        x = self.batch22(F.relu(self.conv22(x)))
        x = self.batch23(F.relu(self.conv23(x)))
        x = self.batch24(F.relu(self.conv24(x)))
        x = self.pool21(x)
        x = self.conv25(x)
        # ----------------------------------------------------------

        x = self.batch31(F.relu(self.convPV1(F.relu(self.conv31(x)))))
        x = self.batch32(F.relu(self.convPV2(F.relu(self.conv32(x)))))


        x = self.avg_pool(x)
        x = self.convx3(x)
        x = x.view(-1, 10)                           # Don't want 10x1x1..
        return F.log_softmax(x, dim=1)  # Added dim=1 parameter)

In [10]:
from tqdm import tqdm

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met



def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)

    correct = 0
    processed = 0
    epoch_loss = 0
    time_taken.clear()

    for batch_idx, (data, target) in enumerate(pbar):
        t0 = time.time()

        data, target = data.to(device), target.to(device)

        # Don't want history of gradients
        optimizer.zero_grad()

        y_predict = model(data)

        # Calculate loss
        loss = F.nll_loss(y_predict, target)
        epoch_loss += loss.item()

        # Backpropagate error
        loss.backward()

        # Take an optimizer step
        optimizer.step()

        torch.cuda.synchronize()
        t1 = time.time()

        time_taken.append((t1 - t0))

        pred = y_predict.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
        train_acc.append(100 * correct / processed)

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    return avg_train_loss


def test(model, device, test_loader):
    model.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    test_acc.append(100. * correct / len(test_loader.dataset))
    return test_loss

In [11]:
# Initialize model, optimizer, and early stopping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (f'Device Using = {device}')
model = StudentModel().to(device)
summary(model, input_size=(3, 32, 32))
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
early_stopping = EarlyStopping(tolerance=5, min_delta=0.02)

EPOCHS = 100
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    avg_train_loss = train(model, device, train_loader, optimizer, epoch)
    print(f" --> EPOCH: {epoch}, Avg Training Loss: {avg_train_loss:.4f}, Avg Time Taken = {(sum(time_taken) / len(time_taken)) * 1000:.2f}ms")
    val_loss = test(model, device, test_loader)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_small_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

Device Using = cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 24, 24]           2,320
            Conv2d-5           [-1, 16, 20, 20]           2,320
            Conv2d-6           [-1, 16, 16, 16]           2,320
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
        MaxPool2d-13           [-1, 16, 16, 16]               0
           Conv2d-1

Loss=1.7643420696258545 Batch_id=24 Accuracy=22.57: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 0, Avg Training Loss: 2.0624, Avg Time Taken = 357.53ms






Test set: Average loss: 2.3772, Accuracy: 1000/10000 (10.00%)

Epoch 2/100


Loss=1.5391430854797363 Batch_id=24 Accuracy=37.13: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 1, Avg Training Loss: 1.6691, Avg Time Taken = 359.02ms






Test set: Average loss: 1.6755, Accuracy: 3922/10000 (39.22%)

Epoch 3/100


Loss=1.4422804117202759 Batch_id=24 Accuracy=45.17: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 2, Avg Training Loss: 1.4835, Avg Time Taken = 355.18ms






Test set: Average loss: 1.4565, Accuracy: 4640/10000 (46.40%)

Epoch 4/100


Loss=1.2985095977783203 Batch_id=24 Accuracy=50.64: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 3, Avg Training Loss: 1.3461, Avg Time Taken = 361.12ms






Test set: Average loss: 1.2908, Accuracy: 5257/10000 (52.57%)

Epoch 5/100


Loss=1.238024115562439 Batch_id=24 Accuracy=54.64: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

 --> EPOCH: 4, Avg Training Loss: 1.2474, Avg Time Taken = 362.41ms






Test set: Average loss: 1.2401, Accuracy: 5520/10000 (55.20%)

Epoch 6/100


Loss=1.0856209993362427 Batch_id=24 Accuracy=58.01: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

 --> EPOCH: 5, Avg Training Loss: 1.1535, Avg Time Taken = 360.57ms






Test set: Average loss: 1.1840, Accuracy: 5674/10000 (56.74%)

Epoch 7/100


Loss=1.034713625907898 Batch_id=24 Accuracy=61.23: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

 --> EPOCH: 6, Avg Training Loss: 1.0800, Avg Time Taken = 357.37ms






Test set: Average loss: 1.1479, Accuracy: 5969/10000 (59.69%)

Epoch 8/100


Loss=0.9146243929862976 Batch_id=24 Accuracy=63.05: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

 --> EPOCH: 7, Avg Training Loss: 1.0241, Avg Time Taken = 362.02ms






Test set: Average loss: 1.0619, Accuracy: 6239/10000 (62.39%)

Epoch 9/100


Loss=0.9481847286224365 Batch_id=24 Accuracy=65.20: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 8, Avg Training Loss: 0.9718, Avg Time Taken = 362.40ms






Test set: Average loss: 1.0279, Accuracy: 6376/10000 (63.76%)

Epoch 10/100


Loss=0.9405885338783264 Batch_id=24 Accuracy=66.67: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 9, Avg Training Loss: 0.9324, Avg Time Taken = 360.19ms






Test set: Average loss: 0.9473, Accuracy: 6653/10000 (66.53%)

Epoch 11/100


Loss=0.9161137342453003 Batch_id=24 Accuracy=68.36: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 10, Avg Training Loss: 0.8885, Avg Time Taken = 361.11ms






Test set: Average loss: 0.9932, Accuracy: 6495/10000 (64.95%)

Epoch 12/100


Loss=0.8603850603103638 Batch_id=24 Accuracy=69.25: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

 --> EPOCH: 11, Avg Training Loss: 0.8613, Avg Time Taken = 361.51ms






Test set: Average loss: 0.9441, Accuracy: 6640/10000 (66.40%)

Epoch 13/100


Loss=0.8587924838066101 Batch_id=24 Accuracy=70.73: 100%|██████████| 25/25 [00:17<00:00,  1.39it/s]

 --> EPOCH: 12, Avg Training Loss: 0.8264, Avg Time Taken = 362.89ms






Test set: Average loss: 0.9430, Accuracy: 6648/10000 (66.48%)

Epoch 14/100


Loss=0.874880313873291 Batch_id=24 Accuracy=71.69: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 13, Avg Training Loss: 0.7988, Avg Time Taken = 361.62ms






Test set: Average loss: 0.8697, Accuracy: 6930/10000 (69.30%)

Epoch 15/100


Loss=0.8040109276771545 Batch_id=24 Accuracy=72.89: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 14, Avg Training Loss: 0.7657, Avg Time Taken = 359.15ms






Test set: Average loss: 0.8769, Accuracy: 6948/10000 (69.48%)

Epoch 16/100


Loss=0.7048029899597168 Batch_id=24 Accuracy=73.86: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 15, Avg Training Loss: 0.7370, Avg Time Taken = 362.69ms






Test set: Average loss: 0.8686, Accuracy: 6936/10000 (69.36%)

Epoch 17/100


Loss=0.6863178014755249 Batch_id=24 Accuracy=74.64: 100%|██████████| 25/25 [00:17<00:00,  1.45it/s]

 --> EPOCH: 16, Avg Training Loss: 0.7148, Avg Time Taken = 360.20ms






Test set: Average loss: 0.8452, Accuracy: 7060/10000 (70.60%)

Epoch 18/100


Loss=0.674196720123291 Batch_id=24 Accuracy=75.55: 100%|██████████| 25/25 [00:17<00:00,  1.43it/s]

 --> EPOCH: 17, Avg Training Loss: 0.6898, Avg Time Taken = 358.37ms






Test set: Average loss: 0.8681, Accuracy: 7003/10000 (70.03%)

Epoch 19/100


Loss=0.6903563737869263 Batch_id=24 Accuracy=76.16: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 18, Avg Training Loss: 0.6730, Avg Time Taken = 361.41ms






Test set: Average loss: 0.8020, Accuracy: 7226/10000 (72.26%)

---------- prev = 0.6898008227348328 current = 0.6729649662971496 ---------
Epoch 20/100


Loss=0.6400915384292603 Batch_id=24 Accuracy=77.00: 100%|██████████| 25/25 [00:17<00:00,  1.46it/s]

 --> EPOCH: 19, Avg Training Loss: 0.6529, Avg Time Taken = 359.87ms






Test set: Average loss: 0.7977, Accuracy: 7281/10000 (72.81%)

Epoch 21/100


Loss=0.6483575105667114 Batch_id=24 Accuracy=77.58: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 20, Avg Training Loss: 0.6334, Avg Time Taken = 359.12ms






Test set: Average loss: 0.7772, Accuracy: 7288/10000 (72.88%)

---------- prev = 0.6528990292549133 current = 0.6333779048919678 ---------
Epoch 22/100


Loss=0.6078910231590271 Batch_id=24 Accuracy=78.23: 100%|██████████| 25/25 [00:17<00:00,  1.40it/s]

 --> EPOCH: 21, Avg Training Loss: 0.6116, Avg Time Taken = 360.48ms






Test set: Average loss: 0.7825, Accuracy: 7298/10000 (72.98%)

Epoch 23/100


Loss=0.5993676781654358 Batch_id=24 Accuracy=79.22: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 22, Avg Training Loss: 0.5943, Avg Time Taken = 361.84ms






Test set: Average loss: 0.7507, Accuracy: 7443/10000 (74.43%)

---------- prev = 0.6115591692924499 current = 0.59430020570755 ---------
Epoch 24/100


Loss=0.5777539014816284 Batch_id=24 Accuracy=79.86: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

 --> EPOCH: 23, Avg Training Loss: 0.5747, Avg Time Taken = 364.55ms






Test set: Average loss: 0.7822, Accuracy: 7324/10000 (73.24%)

---------- prev = 0.59430020570755 current = 0.5747307419776917 ---------
Epoch 25/100


Loss=0.5954622030258179 Batch_id=24 Accuracy=80.24: 100%|██████████| 25/25 [00:17<00:00,  1.42it/s]

 --> EPOCH: 24, Avg Training Loss: 0.5650, Avg Time Taken = 359.70ms






Test set: Average loss: 0.7520, Accuracy: 7468/10000 (74.68%)

---------- prev = 0.5747307419776917 current = 0.5649992895126342 ---------
Epoch 26/100


Loss=0.5759264230728149 Batch_id=24 Accuracy=80.44: 100%|██████████| 25/25 [00:17<00:00,  1.44it/s]

 --> EPOCH: 25, Avg Training Loss: 0.5586, Avg Time Taken = 360.99ms






Test set: Average loss: 0.7487, Accuracy: 7449/10000 (74.49%)

---------- prev = 0.5649992895126342 current = 0.5586209177970887 ---------
Epoch 27/100


Loss=0.6129828691482544 Batch_id=24 Accuracy=80.87: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]


 --> EPOCH: 26, Avg Training Loss: 0.5443, Avg Time Taken = 361.21ms

Test set: Average loss: 0.7521, Accuracy: 7460/10000 (74.60%)

---------- prev = 0.5586209177970887 current = 0.5442671251296997 ---------
Model saved at: /content/drive/MyDrive/EPAi_V5/model_small_acc_80.pth
Early stopping triggered!


In [12]:
model.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_80.pth', weights_only=True))
test(model, device, test_loader)


Test set: Average loss: 0.7521, Accuracy: 7460/10000 (74.60%)



0.7520527465820313

# Distillation Process

In [13]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0.01):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.prev_loss = None  # Initialize as None
        self.counter = 0

    def __call__(self, train_loss):
        if self.prev_loss is None:  # First iteration
            self.prev_loss = train_loss
            return False  # Continue training

        if (abs(train_loss - self.prev_loss)) < self.min_delta:
            print(f'---------- prev = {self.prev_loss} current = {train_loss} ---------')
            self.counter += 1
        else:
            self.counter = 0  # Reset counter if loss improves

        self.prev_loss = train_loss

        return self.counter >= self.tolerance  # Return True if stopping criteria met

train_losses = []
test_losses = []
train_acc = []
test_acc = []
time_taken = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device = {device}')

# Prepare teacher model
teacher = TeacherModel().to(device)
teacher.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_heavy_acc_92.pth', weights_only=True))
teacher.eval()

# Prepare student model
student = StudentModel().to(device)
student.load_state_dict(torch.load('/content/drive/MyDrive/EPAi_V5/model_small_acc_81.pth', weights_only=True))

print("=============================================")
print("Student model accuracy before training ")
student.eval()
test(student, device, test_loader)
print("=============================================")

# Loss functions
hard_loss = nn.CrossEntropyLoss() #Hard label loss
soft_loss = nn.KLDivLoss(reduction="batchmean")  # Distillation loss

# Temperature and alpha
T = 5.0  # Temperature
alpha = 0.5  # Weight for distillation loss

# Optimizer
optimizer = optim.Adam(student.parameters(), lr=0.001)

# Training loop
EPOCHS = 100
for epoch in range(EPOCHS):

    correct = 0
    processed = 0
    epoch_loss = 0
    pbar = tqdm(train_loader)
    student.train()
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # Teacher predictions (soft labels)
        with torch.no_grad():
            teacher_logits = teacher(data) / T

        teacher_probs = torch.softmax(teacher_logits, dim=1)

        # Student predictions
        student_logits = student(data)
        student_probs = torch.log_softmax(student_logits / T, dim=1)

        # Compute losses
        loss_soft = soft_loss(student_probs, teacher_probs) * (T ** 2)  # Scale by T^2
        loss_hard = hard_loss(student_logits, target)
        loss = alpha * loss_hard + (1 - alpha) * loss_soft

        epoch_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

        pred = student_logits.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item()} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    # Check for early stopping
    if early_stopping(avg_train_loss):
        try:
            # Ensure the directory exists
            save_dir = '/content/drive/MyDrive/EPAi_V5'
            os.makedirs(save_dir, exist_ok=True)
            PATH = os.path.join(save_dir, f'model_distil_acc_{int(train_acc[-1]):d}.pth')
        except OSError:
            # Fallback to current directory if Drive is unavailable
            PATH = f'./model_small_acc_{int(train_acc[-1]):d}.pth'

        # Save the model weights
        torch.save(model.state_dict(), PATH)
        print(f"Model saved at: {PATH}")
        print("Early stopping triggered!")
        break

    student.eval()

    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = student(data)

            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



Using device = cuda
Student model accuracy before training 

Test set: Average loss: 0.7419, Accuracy: 7566/10000 (75.66%)



Loss=1.3082377910614014 Batch_id=24 Accuracy=76.71: 100%|██████████| 25/25 [00:18<00:00,  1.32it/s]



Test set: Average loss: 0.9304, Accuracy: 7478/10000 (74.78%)



Loss=1.2178696393966675 Batch_id=24 Accuracy=80.33: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]



Test set: Average loss: 0.7959, Accuracy: 7728/10000 (77.28%)



Loss=1.0896053314208984 Batch_id=24 Accuracy=81.28: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]



Test set: Average loss: 0.7732, Accuracy: 7714/10000 (77.14%)



Loss=1.0880683660507202 Batch_id=24 Accuracy=82.09: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]



Test set: Average loss: 0.7589, Accuracy: 7764/10000 (77.64%)



Loss=1.0922045707702637 Batch_id=24 Accuracy=82.32: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]



Test set: Average loss: 0.8024, Accuracy: 7725/10000 (77.25%)



Loss=1.0072500705718994 Batch_id=24 Accuracy=82.76: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]



Test set: Average loss: 0.7968, Accuracy: 7740/10000 (77.40%)



Loss=1.0438783168792725 Batch_id=24 Accuracy=83.13: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]



Test set: Average loss: 0.7820, Accuracy: 7802/10000 (78.02%)



Loss=0.9212185740470886 Batch_id=24 Accuracy=83.24: 100%|██████████| 25/25 [00:18<00:00,  1.33it/s]

---------- prev = 0.9830327153205871 current = 0.9655375790596008 ---------






Test set: Average loss: 0.7392, Accuracy: 7823/10000 (78.23%)



Loss=1.0202991962432861 Batch_id=24 Accuracy=83.73: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

---------- prev = 0.9655375790596008 current = 0.9466933465003967 ---------






Test set: Average loss: 0.7113, Accuracy: 7915/10000 (79.15%)



Loss=0.9273560047149658 Batch_id=24 Accuracy=84.27: 100%|██████████| 25/25 [00:18<00:00,  1.35it/s]



Test set: Average loss: 0.7360, Accuracy: 7865/10000 (78.65%)



Loss=0.9515639543533325 Batch_id=24 Accuracy=84.56: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]



Test set: Average loss: 0.7050, Accuracy: 7949/10000 (79.49%)



Loss=0.8512232899665833 Batch_id=24 Accuracy=85.07: 100%|██████████| 25/25 [00:18<00:00,  1.34it/s]



Test set: Average loss: 0.7860, Accuracy: 7752/10000 (77.52%)



Loss=0.8640667200088501 Batch_id=24 Accuracy=85.16: 100%|██████████| 25/25 [00:19<00:00,  1.29it/s]

---------- prev = 0.8758802199363709 current = 0.8721289110183715 ---------






Test set: Average loss: 0.7365, Accuracy: 7894/10000 (78.94%)



Loss=0.8685557246208191 Batch_id=24 Accuracy=85.54: 100%|██████████| 25/25 [00:18<00:00,  1.36it/s]

---------- prev = 0.8721289110183715 current = 0.86076979637146 ---------






Test set: Average loss: 0.7165, Accuracy: 7938/10000 (79.38%)



Loss=0.8674250841140747 Batch_id=24 Accuracy=85.67: 100%|██████████| 25/25 [00:19<00:00,  1.27it/s]

---------- prev = 0.86076979637146 current = 0.8424052166938781 ---------






Test set: Average loss: 0.7170, Accuracy: 7957/10000 (79.57%)



Loss=0.9037202000617981 Batch_id=24 Accuracy=85.83: 100%|██████████| 25/25 [00:18<00:00,  1.38it/s]

---------- prev = 0.8424052166938781 current = 0.8303352212905883 ---------






Test set: Average loss: 0.6974, Accuracy: 7936/10000 (79.36%)



Loss=0.8141599297523499 Batch_id=24 Accuracy=86.05: 100%|██████████| 25/25 [00:18<00:00,  1.37it/s]

---------- prev = 0.8303352212905883 current = 0.8178793621063233 ---------





IndexError: list index out of range