# Import

In [1]:
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.ao.quantization as quant
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx as prepare_qat_fx_

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader

import time
import numpy as np
import pandas as pd
import random

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
SAVE_PATH="/content/drive/MyDrive/DemoQuantDistillation"

In [5]:
RANDOM_SEED = 0

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

if torch.cuda.is_available():
  torch.cuda.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False # Disable the non-deterministic algorithms

# Dataset

In [6]:
transforms_train = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transforms_test = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_train)
test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_test)

100%|██████████| 170M/170M [00:15<00:00, 11.2MB/s]


In [7]:
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)

# Models

In [8]:
class DeepNN(nn.Module):
  def __init__(self, num_classes=10):
    super().__init__()
    self.features = nn.Sequential(
      nn.Conv2d(3, 128, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(128, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(64, 64, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.Conv2d(64, 32, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )

    self.classifier = nn.Sequential(
      nn.Linear(2048, 512),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(512, num_classes)
    )

  def get_features_map(self, x):
    x = self.features(x)

    return x

  def get_emb(self, x):
    features = self.get_features_map(x)
    x = torch.flatten(features, 1)

    return features, x

  def classify(self, emb):
    x = self.classifier(emb)

    return x

  def forward(self, x):
    _, x = self.get_emb(x)
    x = self.classify(x)

    return x

In [9]:
class LightNN(nn.Module):
  def __init__(self, num_classes=10):
    super().__init__()

    self.features = nn.Sequential(
      nn.Conv2d(3, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
      nn.Conv2d(16, 16, kernel_size=3, padding=1),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2, stride=2),
    )

    self.classifier = nn.Sequential(
      nn.Linear(1024, 256),
      nn.ReLU(),
      nn.Dropout(0.1),
      nn.Linear(256, num_classes)
    )

  def get_features_map(self, x):
    x = self.features(x)

    return x

  def get_emb(self, x):
    features = self.get_features_map(x)
    x = torch.flatten(features, 1)

    return features, x

  def classify(self, emb):
    x = self.classifier(emb)

    return x

  def forward(self, x):
    _, x = self.get_emb(x)
    x = self.classify(x)

    return x

# Utils

## Quantisation Utils

In [10]:
def fuse_model(model):
	"""
	Fuse conv+relu and linear+relu in simple Sequential modules.
	In-place on the given model.
	"""
	fused_model = copy.deepcopy(model).cpu().eval()

	modules_to_fuse = []

	# features is a Sequential
	try:
		features = model.features
		for i in range(len(features) - 1):
			if isinstance(features[i], nn.Conv2d) and isinstance(features[i+1], nn.ReLU):
				modules_to_fuse.append([f"features.{i}", f"features.{i+1}"])
	except Exception:
		pass

	# classifier sequential (linear + relu)
	try:
		classifier = model.classifier
		for i in range(len(classifier) - 1):
			if isinstance(classifier[i], nn.Linear) and isinstance(classifier[i+1], nn.ReLU):
				modules_to_fuse.append([f"classifier.{i}", f"classifier.{i+1}"])
	except Exception:
		pass

	if modules_to_fuse:
		quant.fuse_modules(fused_model, modules_to_fuse, inplace=True)

	return fused_model

In [11]:
def prepare_example_input_from_loader(dl, device):
	batch = next(iter(dl))[0]
	example_input = batch[:1].cpu()

	return (example_input,)

In [12]:
def prepare_post_static_quantize_fx(float_model, calib_dl, input_size=(1, 3, 32, 32)):
	quant_model = copy.deepcopy(float_model).cpu().eval()
	fuse_model(quant_model)

	qconfig = quant.get_default_qconfig("fbgemm")
	qconfig_dict = {"": qconfig}

	example_inputs = torch.rand(size=input_size).cpu()
	prepared = prepare_fx(quant_model, qconfig_dict, example_inputs=example_inputs)

	# calibration: run a batch through prepared model
	with torch.no_grad():
		for inputs, _ in calib_dl:
			prepared(inputs.cpu())
			break

	return prepared

In [13]:
def prepare_qat_fx(model, input_size=(1, 3, 32, 32)):
	qconfig = quant.get_default_qat_qconfig('fbgemm')
	qconfig_dict = {"": qconfig}

	example_inputs = torch.rand(size=input_size).cpu()
	prepared_qat = prepare_qat_fx_(model, qconfig_dict, example_inputs=example_inputs)

	return prepared_qat

## Base Training and Test Utils

In [14]:
def train(model, train_loader, epochs, learning_rate, device, use_qat=False):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

  model.to(device)

  if use_qat:
    model = prepare_model_for_qat(model)

  model.train()

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(inputs)

      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [15]:
def test(model, test_loader, device):
  model.to(device)
  model.eval()

  correct = 0
  total = 0

  with torch.no_grad():
    for inputs, labels in test_loader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs = model(inputs)
      _, predicted = torch.max(outputs.data, 1)

      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f"Test Accuracy: {accuracy:.2f}%")
  return accuracy

## Knowledge Distilation Training

In [16]:
def train_knowledge_distillation(teacher, student, train_dl, epochs, lr, T, soft_target_loss_weight, ce_loss_weight, device):
  ce_loss = nn.CrossEntropyLoss()
  optimizer = optim.AdamW(student.parameters(), lr=lr)

  teacher.to(device)
  student.to(device)

  teacher.eval()
  student.train()

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_dl:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      with torch.no_grad():
        teacher_logits = teacher(inputs)

      student_logits = student(inputs)

      soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
      soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

      soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

      label_loss = ce_loss(student_logits, labels)

      loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")

In [17]:
def train_cosine_loss(teacher, student, train_dl, epochs, lr, hidden_rep_loss_weight, ce_loss_weight, device):
  ce_loss = nn.CrossEntropyLoss()
  cosine_loss = nn.CosineEmbeddingLoss()
  optimizer = optim.AdamW(student.parameters(), lr=lr)

  teacher.to(device)
  student.to(device)

  teacher.eval()
  student.train()

  with torch.no_grad():
    _, teacher_dummy_emb = teacher.get_emb(torch.randn(1, 3, 32, 32).to(device))
    _, student_dummy_emb = student.get_emb(torch.randn(1, 3, 32, 32).to(device))

  teacher_emb_size = teacher_dummy_emb.size(1)
  student_emb_size = student_dummy_emb.size(1)

  if teacher_emb_size % student_emb_size != 0:
    raise ValueError("Teacher embedding size must be a multiple of the student embedding size for dynamic pooling.")

  pooling_kernel_size = teacher_emb_size // student_emb_size

  print(f"Teacher embedding size: {teacher_emb_size}")
  print(f"Student embedding size: {student_emb_size}")
  print(f"Pooling kernel size: {pooling_kernel_size}")

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_dl:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      with torch.no_grad():
        _, teacher_emb = teacher.get_emb(inputs)

      _, student_emb = student.get_emb(inputs)
      student_logits = student.classify(student_emb)

      teacher_emb_pooled = F.avg_pool1d(
        teacher_emb.unsqueeze(1),
        kernel_size=pooling_kernel_size
      ).squeeze(1)

      hidden_rep_loss = cosine_loss(student_emb, teacher_emb_pooled, target=torch.ones(inputs.size(0)).to(device))
      label_loss = ce_loss(student_logits, labels)

      loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")

In [18]:
def train_cosine_loss_proj(teacher, student, train_dl, epochs, lr, hidden_rep_loss_weight, ce_loss_weight, device):
  ce_loss = nn.CrossEntropyLoss()
  cosine_loss = nn.CosineEmbeddingLoss()
  optimizer = optim.AdamW(student.parameters(), lr=lr)

  teacher.to(device)
  student.to(device)

  teacher.eval()
  student.train()

  with torch.no_grad():
    _, teacher_dummy_emb = teacher.get_emb(torch.randn(1, 3, 32, 32).to(device))
    _, student_dummy_emb = student.get_emb(torch.randn(1, 3, 32, 32).to(device))

  teacher_emb_size = teacher_dummy_emb.size(1)
  student_emb_size = student_dummy_emb.size(1)

  if teacher_emb_size != student_emb_size:
    proj_layer = nn.Linear(teacher_emb_size, student_emb_size).to(device)
  else:
    proj_layer = nn.Identity().to(device)

  proj_layer.train()

  print(f"Teacher embedding size: {teacher_emb_size}")
  print(f"Student embedding size: {student_emb_size}")

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_dl:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      with torch.no_grad():
        _, teacher_emb = teacher.get_emb(inputs)

      _, student_emb = student.get_emb(inputs)
      student_logits = student.classify(student_emb)

      teacher_emb_projed = proj_layer(teacher_emb)

      hidden_rep_loss = cosine_loss(student_emb, teacher_emb_projed, target=torch.ones(inputs.size(0)).to(device))
      label_loss = ce_loss(student_logits, labels)

      loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")

In [19]:
def train_mse_loss(teacher, student, train_dl, epochs, lr, hidden_rep_loss_weight, ce_loss_weight, device):
  ce_loss = nn.CrossEntropyLoss()
  mse_loss = nn.MSELoss()
  optimizer = optim.AdamW(student.parameters(), lr=lr)

  teacher.to(device)
  student.to(device)

  teacher.eval()
  student.train()

  proj_layer = nn.Conv2d(16, 32, kernel_size=3, padding=1).to(device)
  proj_layer.train()

  for epoch in range(epochs):
    running_loss = 0.0
    for inputs, labels in train_dl:
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      with torch.no_grad():
        teacher_features_map, _ = teacher.get_emb(inputs)

      student_features_map, student_emb = student.get_emb(inputs)
      student_logits = student.classify(student_emb)

      student_features_map_projed = proj_layer(student_features_map)

      hidden_rep_loss = mse_loss(student_features_map_projed, teacher_features_map)
      label_loss = ce_loss(student_logits, labels)

      loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_dl)}")

# Training

## Teacher Model

In [20]:
nn_deep = DeepNN(num_classes=10).to(device)
train(nn_deep, train_dl, epochs=10, learning_rate=0.001, device=device)
test_accuracy_deep = test(nn_deep, test_dl, device)

Epoch 1/10, Loss: 1.576228148492096
Epoch 2/10, Loss: 1.164011489247422
Epoch 3/10, Loss: 0.9765752381680871
Epoch 4/10, Loss: 0.871798596418727
Epoch 5/10, Loss: 0.7950609061114319
Epoch 6/10, Loss: 0.7505644590348539
Epoch 7/10, Loss: 0.7044291852990074
Epoch 8/10, Loss: 0.6760034995615635
Epoch 9/10, Loss: 0.6492420825202142
Epoch 10/10, Loss: 0.624915267195543
Test Accuracy: 79.25%


In [21]:
torch.save(nn_deep.state_dict(), f"{SAVE_PATH}/nn_deep.pth")

## Base Student Model

In [22]:
nn_light = LightNN(num_classes=10).to(device)

In [23]:
new_nn_light = LightNN(num_classes=10).to(device)

In [24]:
print("Norm of 1st layer of nn_light:", torch.norm(nn_light.features[0].weight).item())
print("Norm of 1st layer of new_nn_light:", torch.norm(new_nn_light.features[0].weight).item())

Norm of 1st layer of nn_light: 2.3297154903411865
Norm of 1st layer of new_nn_light: 2.323422908782959


In [25]:
train(nn_light, train_dl, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_ce = test(nn_light, test_dl, device)

Epoch 1/10, Loss: 1.6446998732169267
Epoch 2/10, Loss: 1.3801570719160388
Epoch 3/10, Loss: 1.2711946558769402
Epoch 4/10, Loss: 1.1947882338558011
Epoch 5/10, Loss: 1.130501159469185
Epoch 6/10, Loss: 1.0893920783496573
Epoch 7/10, Loss: 1.0567099930685195
Epoch 8/10, Loss: 1.0298952349006671
Epoch 9/10, Loss: 0.9990920134822426
Epoch 10/10, Loss: 0.9827327857846799
Test Accuracy: 69.66%


In [26]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy: {test_accuracy_light_ce:.2f}%")

Teacher accuracy: 79.25%
Student accuracy: 69.66%


In [29]:
torch.save(nn_light.state_dict(), f"{SAVE_PATH}/nn_light.pth")

## Distillated Student Model

In [27]:
train_knowledge_distillation(nn_deep, new_nn_light, train_dl, epochs=10, lr=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd = test(new_nn_light, test_dl, device)

Epoch 1/10, Loss: 1.9301316713738015
Epoch 2/10, Loss: 1.5408039059480438
Epoch 3/10, Loss: 1.3700120830170028
Epoch 4/10, Loss: 1.2478109688100303
Epoch 5/10, Loss: 1.167782247981147
Epoch 6/10, Loss: 1.1021317158208783
Epoch 7/10, Loss: 1.0567719520205427
Epoch 8/10, Loss: 1.011583151719759
Epoch 9/10, Loss: 0.9823337624140103
Epoch 10/10, Loss: 0.9556450740150784
Test Accuracy: 70.86%


In [30]:
torch.save(new_nn_light.state_dict(), f"{SAVE_PATH}/new_nn_light.pth")

In [28]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")

Teacher accuracy: 79.25%
Student accuracy without teacher: 69.66%
Student accuracy with CE + KD: 70.86%


## COS Pool Distllated Student Model

In [27]:
cos_pool_nn_light = LightNN(num_classes=10).to(device)

In [38]:
train_cosine_loss(nn_deep, cos_pool_nn_light, train_dl, epochs=10, lr=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd_cos_pool = test(cos_pool_nn_light, test_dl, device)

Teacher embedding size: 2048
Student embedding size: 1024
Pooling kernel size: 2
Epoch 1/10, Loss: 1.43088225513468
Epoch 2/10, Loss: 1.2199425557080437
Epoch 3/10, Loss: 1.1315733409293778
Epoch 4/10, Loss: 1.0706597118426466
Epoch 5/10, Loss: 1.0230807174197243
Epoch 6/10, Loss: 0.9930553835676149
Epoch 7/10, Loss: 0.9646160815987745
Epoch 8/10, Loss: 0.9489797294292304
Epoch 9/10, Loss: 0.9283358602572584
Epoch 10/10, Loss: 0.9102163734033589
Test Accuracy: 68.45%


## COS Proj Distllated Student Model

In [39]:
cos_proj_nn_light = LightNN(num_classes=10).to(device)

In [40]:
train_cosine_loss_proj(nn_deep, cos_proj_nn_light, train_dl, 10, 0.001, 0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd_cos_proj = test(cos_proj_nn_light, test_dl, device)

Teacher embedding size: 2048
Student embedding size: 1024
Epoch 1/10, Loss: 1.4890597607473584
Epoch 2/10, Loss: 1.2910036577288146
Epoch 3/10, Loss: 1.1974697675546417
Epoch 4/10, Loss: 1.1266944438905058
Epoch 5/10, Loss: 1.079360692397408
Epoch 6/10, Loss: 1.0356844101110687
Epoch 7/10, Loss: 1.0136581846820119
Epoch 8/10, Loss: 0.9868874568158709
Epoch 9/10, Loss: 0.9689526736278973
Epoch 10/10, Loss: 0.9526536088160542
Test Accuracy: 70.55%


In [41]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD COS Pool: {test_accuracy_light_ce_and_kd_cos_pool:.2f}%")
print(f"Student accuracy with CE + KD COS Proj: {test_accuracy_light_ce_and_kd_cos_proj:.2f}%")

Teacher accuracy: 79.25%
Student accuracy without teacher: 69.66%
Student accuracy with CE + KD COS Pool: 68.45%
Student accuracy with CE + KD COS Proj: 70.55%


## MSE Distllated Student Model

In [42]:
mse_nn_light = LightNN(num_classes=10).to(device)

In [50]:
train_mse_loss(nn_deep, mse_nn_light, train_dl, 10, 0.001, 0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_ce_and_kd_mse = test(mse_nn_light, test_dl, device)

Epoch 1/10, Loss: 1.2881577461576827
Epoch 2/10, Loss: 1.071609301792691
Epoch 3/10, Loss: 0.988097681413831
Epoch 4/10, Loss: 0.937765251949925
Epoch 5/10, Loss: 0.9004530242032103
Epoch 6/10, Loss: 0.869272967281244
Epoch 7/10, Loss: 0.8422944977155427
Epoch 8/10, Loss: 0.8251956388773516
Epoch 9/10, Loss: 0.8051633097021781
Epoch 10/10, Loss: 0.788013038580375
Test Accuracy: 69.77%


In [51]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD MSE: {test_accuracy_light_ce_and_kd_mse:.2f}%")

Teacher accuracy: 79.25%
Student accuracy without teacher: 69.66%
Student accuracy with CE + KD MSE: 69.77%


## QAT Base Student Model

In [31]:
nn_light_qat = LightNN(num_classes=10).to(device)
nn_light_qat = prepare_qat_fx(nn_light_qat)

  prepared = prepare(


In [32]:
new_nn_light_qat = LightNN(num_classes=10).to(device)
new_nn_light_qat = prepare_qat_fx(new_nn_light_qat)

  prepared = prepare(


In [56]:
cos_pool_nn_light_qat = LightNN(num_classes=10).to(device)
cos_pool_nn_light_qat = prepare_qat_fx(cos_pool_nn_light_qat)

  prepared = prepare(


In [57]:
cos_proj_nn_light_qat = LightNN(num_classes=10).to(device)
cos_proj_nn_light_qat = prepare_qat_fx(cos_proj_nn_light_qat)

  prepared = prepare(


In [58]:
mse_nn_light_qat = LightNN(num_classes=10).to(device)
mse_nn_light_qat = prepare_qat_fx(mse_nn_light_qat)

  prepared = prepare(


In [33]:
train(nn_light_qat, train_dl, epochs=10, learning_rate=0.001, device=device)
test_accuracy_light_qat_ce = test(nn_light_qat, test_dl, device)

Epoch 1/10, Loss: 1.705483193592647
Epoch 2/10, Loss: 1.3798409557098623
Epoch 3/10, Loss: 1.245113938665756
Epoch 4/10, Loss: 1.157856786647416
Epoch 5/10, Loss: 1.0918839682093666
Epoch 6/10, Loss: 1.044563437514293
Epoch 7/10, Loss: 1.0187929763513452
Epoch 8/10, Loss: 0.9892782256426409
Epoch 9/10, Loss: 0.9624277299932201
Epoch 10/10, Loss: 0.9512704305941492
Test Accuracy: 69.55%


In [35]:
torch.save(nn_light_qat.state_dict(), f"{SAVE_PATH}/nn_light_qat.pth")

## QAT Distillated Student Model

In [34]:
train_knowledge_distillation(nn_deep, new_nn_light_qat, train_dl, epochs=10, lr=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_qat_ce_and_kd = test(new_nn_light_qat, test_dl, device)

Epoch 1/10, Loss: 2.016908584958147
Epoch 2/10, Loss: 1.5323560283616986
Epoch 3/10, Loss: 1.3478092696050854
Epoch 4/10, Loss: 1.2342640627985415
Epoch 5/10, Loss: 1.1556020444616333
Epoch 6/10, Loss: 1.1010504811621078
Epoch 7/10, Loss: 1.05998781331055
Epoch 8/10, Loss: 1.022975729097186
Epoch 9/10, Loss: 0.9903042057286138
Epoch 10/10, Loss: 0.9632899747480212
Test Accuracy: 69.16%


In [36]:
torch.save(new_nn_light_qat.state_dict(), f"{SAVE_PATH}/new_nn_light_qat.pth")

## QAT COS Pool Distllated Student Model

In [61]:
train_cosine_loss(nn_deep, cos_pool_nn_light_qat, train_dl, 10, 0.001, 0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_qat_ce_and_kd_cos_pool = test(cos_pool_nn_light_qat, test_dl, device)

AttributeError: 'GraphModule' object has no attribute 'get_emb'

## QAT COS Proj Distllated Student Model

In [None]:
train_cosine_loss_proj(nn_deep, cos_proj_nn_light_qat, train_dl, 10, 0.001, 0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_qat_ce_and_kd_cos_proj = test(cos_proj_nn_light_qat, test_dl, device)

In [None]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_qat_ce:.2f}%")
print(f"Student accuracy with CE + KD COS Pool: {test_accuracy_light_qat_ce_and_kd_cos_pool:.2f}%")
print(f"Student accuracy with CE + KD COS Proj: {test_accuracy_light_qat_ce_and_kd_cos_proj:.2f}%")

## QAT MSE Distllated Student Model

In [None]:
train_mse_loss(nn_deep, mse_nn_light_qat, train_dl, 10, 0.001, 0.25, ce_loss_weight=0.75, device=device)
test_accuracy_light_qat_ce_and_kd_mse = test(mse_nn_light_qat, test_dl, device)

In [None]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_qat_ce:.2f}%")
print(f"Student accuracy with CE + KD MSE: {test_accuracy_light_qat_ce_and_kd_mse:.2f}%")

# Testing

In [None]:
nn_light_quantised = prepare_post_static_quantize_fx(nn_light, test_dl)
nn_light_quantised = convert_fx(nn_light_quantised)

new_nn_light_quantised = prepare_post_static_quantize_fx(new_nn_light, test_dl)
new_nn_light_quantised = convert_fx(new_nn_light_quantised)

test_accuracy_light_quantised_ce = test(nn_light_quantised, test_dl, device)
test_accuracy_light_quantised_ce_and_kd = test(new_nn_light_quantised, test_dl, device)

  prepared = prepare(
  prepared = prepare(


In [None]:
nn_light_qat_quantised = convert_fx(nn_light_qat)
new_nn_light_qat_quantised = convert_fx(new_nn_light_qat)

test_accuracy_light_qat_quantised_ce = test(nn_light_qat_quantised, test_dl, device)
test_accuracy_light_qat_quantised_ce_and_kd = test(new_nn_light_qat_quantised, test_dl, device)

In [None]:
print(f"Teacher accuracy: {test_accuracy_deep:.2f}%")
print(f"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%")
print(f"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%")
print(f"Student QAT accuracy without teacher: {test_accuracy_light_qat_ce:.2f}%")
print(f"Student QAT accuracy with CE + KD: {test_accuracy_light_qat_ce_and_kd:.2f}%")

In [None]:
print(f"Student Quantised accuracy without teacher: {test_accuracy_light_quantised_ce:.2f}%")
print(f"Student Quantised accuracy with CE + KD: {test_accuracy_light_quantised_ce_and_kd:.2f}%")
print(f"Student Quantised QAT accuracy without teacher: {test_accuracy_light_qat_quantised_ce:.2f}%")
print(f"Student Quantised QAT accuracy with CE + KD: {test_accuracy_light_qat_quantised_ce_and_kd:.2f}%")