# A simple model to illustrate the specific principles and logic of knowledge distillation.



    dataset : MNIST
    data size : 60K trainset & 10K testset
    teacher model : 2 covolutional layers and 2 fully connected layers
    student model: 2 fully connected layers
    input : 28,28,1
    output: 10,
    act_func: Adadelta with 0.001 lr
    loss: cross_entropy + kl_div

    -- default hyperparameter

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms
import torchvision
import torchvision.models as models
import torch.utils.data
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [None]:
def distillation(y,target,teacher_scores,temp,alpha):
  kl_loss = nn.KLDivLoss()(F.log_softmax(y/temp,dim=1), F.softmax(teacher_scores/temp,dim=1)*(temp * temp * 2 * alpha)) + F.cross_entropy(y,target) * (1. - alpha)
  return kl_loss

## Teacher Model

In [None]:
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.3)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output
def train_teacher(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')


def test_teacher(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)
def teacher_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)

    model = TeacherNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())

    teacher_history = []

    for epoch in range(1, epochs + 1):
        train_teacher(model, device, train_loader, optimizer, epoch)
        loss, acc = test_teacher(model, device, test_loader)

        teacher_history.append((loss, acc))

    torch.save(model.state_dict(), "teacher.pt")
    return model, teacher_history

teacher_model, teacher_history = teacher_main()

100%|██████████| 9.91M/9.91M [00:00<00:00, 79.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 35.9MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 45.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.10MB/s]


Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0510, accuracy: 9843/10000 (98%)
Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0412, accuracy: 9868/10000 (99%)
Train epoch 3: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0329, accuracy: 9889/10000 (99%)
Train epoch 4: 14976/60000, [------------->                                     ] 26%

## Student Model

In [None]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output
def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()  
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')


def test_student_kd(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  
            pred = output.argmax(dim=1, keepdim=True)  
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)
def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)

    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())

    student_history = []
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student_kd(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student_kd.pt")
    return model, student_history
student_kd_model, student_kd_history = student_kd_main()

# Fine-tuning Teacher Model

    dataset : CIFAR10
    data size : 50K trainset & 10K testset
    teacher model : resnet50
    student model : resnet18
    input : 32,32,3
    output: 10,
    act_func: Adadelta & SGD with 0.001 lr
    loss: cross_entropy + kl_div

    -- default hyperparameter

## Teacher Model

In [1]:
class TeacherNet_resp(nn.Module):
  def __init__(self):
    super(TeacherNet_resp,self).__init__()
    self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    # self.resnet50.conv1 = nn.Conv2d(3,64,3,1)
    # self.resnet50.maxpool = nn.Identity()
    in_features = self.resnet50.fc.in_features
    self.resnet50.fc = nn.Linear(in_features,10)

  def forward(self,x):
    output = self.resnet50(x)

    return output

def train_teacher_resp(model,device,train_loader,optimizer,epoch):
  model.train()
  trained_samples = 0

  for batch_idx,(data,target) in enumerate(train_loader):
    data,target = data.to(device),target.to(device)
    output = model(data)
    optimizer.zero_grad()
    loss = F.cross_entropy(output,target)
    loss.backward()
    optimizer.step()

    trained_samples += len(data)
    progress = math.ceil((batch_idx) / len(train_loader) * 50)
    print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
            (epoch, trained_samples,len(train_loader.dataset),'-' * progress + ">", progress * 2),end="")

def test_teacher_resp(model,device,test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data,target in test_loader:
      data,target = data.to(device),target.to(device)
      output = model(data)
      test_loss += F.cross_entropy(output,target,reduction="sum").item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_loss /= len(test_loader.dataset)



  print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
  return test_loss, correct / len(test_loader.dataset)


def main_teacher_resp():
  batch_size = 64
  epochs = 5
  torch.manual_seed(0)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  transform_list = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
  ])

  train_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=True,download=True,transform=transform_list)
  train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
  test_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=False,download=True,transform=transform_list)
  test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  model = TeacherNet_resp().to(device)
  #adadelta ==> very very very very slow
  # optimizer = torch.optim.Adadelta(model.parameters())
  lr = 0.001
  optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
  teacher_history = []
  for epoch in range(1,epochs+1):
    train_teacher_resp(model,device,train_loader,optimizer,epoch)
    loss,acc = test_teacher_resp(model,device,test_loader)

  torch.save(model.state_dict(), "teacher.pt")
  return model, teacher_history

teacher_resp_model,teacher_resp_history = main_teacher_resp()

NameError: name 'nn' is not defined

## Student Model

In [None]:
class StudentNet_resp(nn.Module):
  def __init__(self):
    super(StudentNet_resp,self).__init__()
    self.resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    # self.resnet18.conv1 = nn.Conv2d(3,64,3,1)
    # self.resnet18.maxpool = nn.Identity()
    in_features = self.resnet18.fc.in_features
    self.resnet18.fc = nn.Linear(in_features,10)

  def forward(self,x):
    output = self.resnet18(x)
    return output

def train_student(model,device,train_loader,optimizer,epoch):
  model.train()
  trained_samples = 0

  for batch_idx,(data,target) in enumerate(train_loader):
    data,target = data.to(device),target.to(device)

    optimizer.zero_grad()


    output = model(data)
    teacher_output = teacher_resp_model(data)
    teacher_output = teacher_output.detach()


    loss = distillation(output,target,teacher_output,10,0.7)
    loss.backward()
    optimizer.step()

    trained_samples += len(data)
    progress = math.ceil((batch_idx) / len(train_loader) * 50)
    print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
            (epoch, trained_samples,len(train_loader.dataset),'-' * progress + ">", progress * 2),end="")

def test_student(model,device,test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data,target in test_loader:
      data,target = data.to(device),target.to(device)
      output = model(data)
      test_loss += F.cross_entropy(output,target,reduction="sum").item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_loss /= len(test_loader.dataset)
  print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
  return test_loss, correct / len(test_loader.dataset)


def main_student():
  batch_size = 64
  epochs = 5
  torch.manual_seed(0)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  transform_list = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
  ])

  train_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=True,download=True,transform=transform_list)
  train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
  test_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=False,download=True,transform=transform_list)
  test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  model = StudentNet_resp().to(device)
  #adadelta ==> very very very very slow
  # optimizer = torch.optim.Adadelta(model.parameters())
  lr = 0.001
  optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
  student_history = []
  for epoch in range(1,epochs+1):
    train_student(model,device,train_loader,optimizer,epoch)
    loss,acc = test_student(model,device,test_loader)

  torch.save(model.state_dict(), "student.pt")
  return model, student_history

student_model,student_history = main_student()

# Feature-Based Model

    dataset : CIFAR10
    data size : 50K trainset & 10K testset
    teacher model : resnet50
    student model : resnet18
    input : 32,32,3
    output: 10,
    act_func: Adadelta & SGD with 0.001 lr
    loss: cross_entropy + kl_div + mse(avgMaxPooling)

    -- default hyperparameter

## Teacher Model

In [None]:
class TeacherNet_fe(nn.Module):
  def __init__(self):
    super(TeacherNet_fe,self).__init__()
    self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    # self.resnet50.conv1 = nn.Conv2d(3,64,3,1)
    # self.resnet50.maxpool = nn.Identity()
    in_features = self.resnet50.fc.in_features
    self.resnet50.fc = nn.Linear(in_features,10)

  def forward(self,x):
      features = self.resnet50.avgpool(self.resnet50.layer4(self.resnet50.layer3(self.resnet50.layer2(self.resnet50.layer1(self.resnet50.relu(self.resnet50.bn1(self.resnet50.conv1(x))))))))
      feature_vector = torch.flatten(features, 1)
      output = self.resnet50.fc(feature_vector)

      return feature_vector, output
def train_teacher_fe(model,device,train_loader,optimizer,epoch):
  model.train()
  trained_samples = 0

  for batch_idx,(data,target) in enumerate(train_loader):
    data,target = data.to(device),target.to(device)
    _,output = model(data)
    optimizer.zero_grad()
    loss = F.cross_entropy(output,target)
    loss.backward()
    optimizer.step()

    trained_samples += len(data)
    progress = math.ceil((batch_idx) / len(train_loader) * 50)
    print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
            (epoch, trained_samples,len(train_loader.dataset),'-' * progress + ">", progress * 2),end="")

def test_teacher_fe(model,device,test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data,target in test_loader:
      data,target = data.to(device),target.to(device)
      _,output = model(data)
      test_loss += F.cross_entropy(output,target,reduction="sum").item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_loss /= len(test_loader.dataset)



  print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
  return test_loss, correct / len(test_loader.dataset)


def main_teacher_fe():
  batch_size = 64
  epochs = 5
  torch.manual_seed(0)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  transform_list = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
  ])

  train_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=True,download=True,transform=transform_list)
  train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
  test_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=False,download=True,transform=transform_list)
  test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  model = TeacherNet_fe().to(device)
  #adadelta ==> very very very very slow
  # optimizer = torch.optim.Adadelta(model.parameters())
  lr = 0.001
  optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
  teacher_history = []
  for epoch in range(1,epochs+1):
    train_teacher_fe(model,device,train_loader,optimizer,epoch)
    loss,acc = test_teacher_fe(model,device,test_loader)
    teacher_history.append((loss, acc))
  torch.save(model.state_dict(), "teacher.pt")
  return model, teacher_history

teacher_fe_model,teacher_fe_history = main_teacher_fe()

## Student Model

In [None]:
class StudentNet_fe(nn.Module):
  def __init__(self):
    super(StudentNet_fe,self).__init__()
    self.resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    # self.resnet18.conv1 = nn.Conv2d(3,64,3,1)
    # self.resnet18.maxpool = nn.Identity()
    in_features = self.resnet18.fc.in_features
    self.resnet18.fc = nn.Linear(in_features,10)

  def forward(self,x):
      features = self.resnet18.avgpool(self.resnet18.layer4(self.resnet18.layer3(self.resnet18.layer2(self.resnet18.layer1(self.resnet18.relu(self.resnet18.bn1(self.resnet18.conv1(x))))))))
      feature_vector = torch.flatten(features, 1)
      output = self.resnet18.fc(feature_vector)

def train_student(model,device,train_loader,optimizer,epoch):
  model.train()
  trained_samples = 0

  for batch_idx,(data,target) in enumerate(train_loader):
    data,target = data.to(device),target.to(device)

    optimizer.zero_grad()


    student_features,output = model(data)
    teacher_features,teacher_output = teacher_fe_model(data)
    teacher_output = teacher_output.detach()

    mse = nn.MSELoss()(student_features,teacher_features)
    kl = distillation(output,target,teacher_output,10,0.7)
    loss = kl + mse
    loss.backward()
    optimizer.step()

    trained_samples += len(data)
    progress = math.ceil((batch_idx) / len(train_loader) * 50)
    print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
            (epoch, trained_samples,len(train_loader.dataset),'-' * progress + ">", progress * 2),end="")

def test_student(model,device,test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data,target in test_loader:
      data,target = data.to(device),target.to(device)
      _,output = model(data)
      test_loss += F.cross_entropy(output,target,reduction="sum").item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
  test_loss /= len(test_loader.dataset)
  print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
  return test_loss, correct / len(test_loader.dataset)


def main_student():
  batch_size = 64
  epochs = 5
  torch.manual_seed(0)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  transform_list = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2470, 0.2435, 0.2616))
  ])

  train_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=True,download=True,transform=transform_list)
  train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
  test_dataset = torchvision.datasets.CIFAR10("./data/CIFAR10",train=False,download=True,transform=transform_list)
  test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
  model = StudentNet_fe().to(device)
  #adadelta ==> very very very very slow
  # optimizer = torch.optim.Adadelta(model.parameters())
  lr = 0.001
  optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
  student_history = []
  for epoch in range(1,epochs+1):
    train_student(model,device,train_loader,optimizer,epoch)
    loss,acc = test_student(model,device,test_loader)
    student_history.append((loss, acc))

  torch.save(model.state_dict(), "student.pt")
  return model, student_history

student_model,student_history = main_student()