<a href="https://colab.research.google.com/github/cv-ape/mnist-knowledge_distillation-pytorch-demo/blob/main/mnist_kd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import os

In [2]:
from google.colab import drive
drive.mount('/content/drive')
prj="knowledge_distill"
path = "/content/drive/My Drive/Colab Notebooks/"+prj

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
class config:
  batch_size=128 #input batch size for training
  test_batch_size=1000 #input batch size for testing
  epochs=5 #number of epochs to train
  lr=0.01 #learning rate
  momentum=0.9 #SGD momentum
  no_cuda=False #disables CUDA training
  seed=1 #random seed 
  log_interval=10#how many batches to wait before logging training status
  device='cuda' if (not no_cuda and torch.cuda.is_available()) else 'cpu'
cfg=config()
device=cfg.device
torch.manual_seed(cfg.seed)
if device=='cuda' :
    torch.cuda.manual_seed(cfg.seed)

dataset loader

In [4]:
kwargs = {'num_workers': 1, 'pin_memory': True} 
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data_mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=cfg.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data_mnist', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=cfg.test_batch_size, shuffle=True, **kwargs)

教师网络

In [5]:
class teacherNet(nn.Module):
    def __init__(self):
        super(teacherNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.fc3(x)
        return x

In [6]:
def train_Teacher(epoch, model):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        #data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % cfg.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [7]:
def train_evaluate(model):
    model.eval()
    train_loss = 0
    correct = 0
    for data, target in train_loader:
        with torch.no_grad():
          data, target = data.to(device), target.to(device)
          output = model(data)
        train_loss += F.cross_entropy(output, target).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))


def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        with torch.no_grad():
          data, target = data.to(device), target.to(device)
          output = model(data)
        test_loss += F.cross_entropy(output, target).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [8]:
path+"/teacher_MLP.pth.tar"

'/content/drive/My Drive/Colab Notebooks/knowledge_distill/teacher_MLP.pth.tar'

训练教师网络

In [10]:
teacher_model = teacherNet()
teacher_model.to(device)

optimizer = optim.SGD(teacher_model.parameters(), lr=cfg.lr, momentum=cfg.momentum,
                      weight_decay=5e-4)

if os.path.exists(path+"/teacher_MLP.pth.tar"):
  teacher_model.load_state_dict(torch.load(path+'/teacher_MLP.pth.tar',map_location=device))
  print("existed weight is loaded!")

for epoch in range(1, cfg.epochs + 1):
    train_Teacher(epoch, teacher_model)
    train_evaluate(teacher_model)
    test(teacher_model)
#torch.save(teacher_model.state_dict(), 'teacher_MLP.pth.tar')

existed weight is loaded!

Train set: Average loss: 38.1411, Accuracy: 58604/60000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9719/10000 (97%)


Train set: Average loss: 35.0719, Accuracy: 58667/60000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9717/10000 (97%)


Train set: Average loss: 32.1289, Accuracy: 58815/60000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9742/10000 (97%)


Train set: Average loss: 30.8555, Accuracy: 58841/60000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9736/10000 (97%)


Train set: Average loss: 29.4283, Accuracy: 58887/60000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9751/10000 (98%)



In [11]:
torch.save(teacher_model.state_dict(), path+'/teacher_MLP.pth.tar')

In [12]:
class studentNet(nn.Module):
    def __init__(self):
        super(studentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 800)
        self.fc2 = nn.Linear(800, 800)
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [13]:
def distill_unlabeled(outputs,labels, teacher_scores):
    alpha=0.8
    T=2
    return nn.KLDivLoss()(F.log_softmax(outputs/T), F.softmax(teacher_scores/T)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

In [14]:
def train_student(epoch, model, loss_fn):
    model.train()
    model.to(device)
    teacher_model.eval()
    teacher_model.to(device)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data).detach()
        loss = loss_fn(output,target, teacher_output)
        loss.backward()
        optimizer.step()
        if batch_idx % cfg.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

训练学生网络

In [15]:
teacher_model = teacherNet()
teacher_model.load_state_dict(torch.load('teacher_MLP.pth.tar',map_location=device))

student_model = studentNet()
student_model.to(device)

optimizer =optim.SGD(student_model.parameters(), lr=cfg.lr, momentum=cfg.momentum)

In [16]:
for epoch in range(1, cfg.epochs + 1):
    train_student(epoch, student_model, loss_fn=distill_unlabeled)
    train_evaluate(student_model)
    test(student_model)

torch.save(student_model.state_dict(), 'distill_unlabeled.pth.tar')
# the_model = Net()
# the_model.load_state_dict(torch.load('student.pth.tar'))

# test(the_model)
# for data, target in test_loader:
#     data, target = Variable(data, volatile=True), Variable(target)
#     teacher_out = the_model(data)
# print(teacher_out)
#print("--- %s seconds ---" % (time.time() - start_time))

  after removing the cwd from sys.path.
  after removing the cwd from sys.path.
  "reduction: 'mean' divides the total loss by both the batch size and the support size."



Train set: Average loss: 130.6064, Accuracy: 55023/60000 (92%)


Test set: Average loss: 0.0003, Accuracy: 9228/10000 (92%)


Train set: Average loss: 90.6212, Accuracy: 56577/60000 (94%)


Test set: Average loss: 0.0002, Accuracy: 9425/10000 (94%)


Train set: Average loss: 67.5664, Accuracy: 57390/60000 (96%)


Test set: Average loss: 0.0001, Accuracy: 9546/10000 (95%)


Train set: Average loss: 55.6160, Accuracy: 57847/60000 (96%)


Test set: Average loss: 0.0001, Accuracy: 9612/10000 (96%)


Train set: Average loss: 47.1102, Accuracy: 58172/60000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9654/10000 (97%)

