In [5]:
#Device check and load model into device
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [6]:
import pandas as pd
import os
import torch
import time
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import torchvision.models as models
import matplotlib.pyplot as plt
from sklearn.metrics import *


##HYPER-PARAM
batch_size =16
epochs = 120
max_lr = 0.001
grad_clip = 0.01
weight_decay =0.001
opt_func = torch.optim.Adam

##DOWNLOAD dataset
train_data = torchvision.datasets.CIFAR100('./', train=True, download=True)
# Stick all the images together to form a 1600000 X 32 X 3 array
x = np.concatenate([np.asarray(train_data[i][0]) for i in range(len(train_data))])
# calculate the mean and std along the (0, 1) axes
mean = np.mean(x, axis=(0, 1))/255
std = np.std(x, axis=(0, 1))/255
# the the mean and std
mean=mean.tolist()
std=std.tolist()

##TRANSFORM
transform_train = tt.Compose([tt.RandomCrop(32, padding=4,padding_mode='reflect'), 
                         tt.RandomHorizontalFlip(), 
                         tt.ToTensor(), 
                         tt.Normalize(mean,std,inplace=True)])
transform_test = tt.Compose([tt.ToTensor(), tt.Normalize(mean,std)])
##DATASET and DATALOADER
trainset = torchvision.datasets.CIFAR100("./",
                                         train=True,
                                         download=True,
                                         transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size, shuffle=True, num_workers=2,pin_memory=True)

testset = torchvision.datasets.CIFAR100("./",
                                        train=False,
                                        download=True,
                                        transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size*2,pin_memory=True, num_workers=2)
#LOADER
device = get_default_device()
trainloader = DeviceDataLoader(trainloader, device)
testloader = DeviceDataLoader(testloader, device)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
#from torchvision.models.resnet import resnet9
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn
import time

resnet50_pretrain_weight = "teacher_model.pth"
resnet18_pretrain_weight = "student_model.pth"

In [7]:
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))


In [15]:
class ResNet9(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True) 
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128)) 
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True) 
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512)) 
        self.conv5 = conv_block(512, 1028, pool=True) 
        self.res3 = nn.Sequential(conv_block(1028, 1028), conv_block(1028, 1028))  
        
        self.classifier = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten(), # 1028 
                                        nn.Linear(1028, num_classes)) # 1028 -> 100
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        out = self.classifier(out)
        return out

In [11]:
class ResNet9_2(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 8)
        self.conv2 = conv_block(8, 8, pool=True) 
        self.res1 = nn.Sequential(conv_block(8, 8), conv_block(8, 8)) 
        
        self.conv3 = conv_block(8, 16, pool=True)
        self.conv4 = conv_block(16, 32, pool=True) 
        self.res2 = nn.Sequential(conv_block(32, 32), conv_block(32, 32)) 
        self.conv5 = conv_block(32, 128, pool=True) 
        self.res3 = nn.Sequential(conv_block(128,128), conv_block(128, 128))  
        
        self.classifier = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten(), # 1028 
                                        nn.Linear(128, num_classes)) # 1028 -> 100
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.conv5(out)
        out = self.res3(out) + out
        out = self.classifier(out)
        return out

In [None]:
class ResNet9_3(ImageClassificationBase):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv1 = conv_block(in_channels, 8)
        self.conv2 = conv_block(8, 8, pool=True) 
        self.res1 = nn.Sequential(conv_block(8, 8), conv_block(8, 8))  
        self.classifier = nn.Sequential(nn.MaxPool2d(2), # 1028 x 1 x 1
                                        nn.Flatten(), # 1028 
                                        nn.Linear(8, num_classes)) # 1028 -> 100
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.classifier(out)
        return out

In [17]:

'''
img_dir = "/data/cifar10/"


def create_data(img_dir):
    dataset = dst.CIFAR100
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=train_transform,
                train=True,
                download=True),
        batch_size=512, shuffle=True, num_workers=4, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=test_transform,
                train=False,
                download=True),
        batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, test_loader
'''

def load_checkpoint(net, pth_file, exclude_fc=False):
    if exclude_fc:
        model_dict = net.state_dict()
        pretrain_dict = torch.load(pth_file)
        new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
        model_dict.update(new_dict)
        net.load_state_dict(model_dict, strict=False)
    else:
        pretrain_dict = torch.load(pth_file)
        net.load_state_dict(pretrain_dict, strict=False)
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    #res = res.contiguous().view(res.size()[0], -1)  #解决报错
    for k in topk:
        
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class KD_loss(nn.Module):
    def __init__(self, T):
        super(KD_loss, self).__init__()
        self.T = T

    def forward(self, out_s, out_t):
        loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
                        F.softmax(out_t / self.T, dim=1),
                        reduction='batchmean') * self.T * self.T

        return loss


def test(net, test_loader):
    prec1_sum = 0
    prec5_sum = 0
    net.eval()
    for i, (img, target) in enumerate(test_loader, start=1):
        # print(f"batch: {i}")
        img = img.cuda()
        target = target.cuda()

        with torch.no_grad():
            out = net(img)
        prec1, prec5 = accuracy(out, target, topk=(1, 5))
        prec1_sum += prec1
        prec5_sum += prec5
        # print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
    print(f"Acc_student:{prec1_sum / (i + 1)/100}, Acc_teacher: {prec5_sum / (i + 1)/100}")


def train(net_s, net_t, train_loader, test_loader):
    # opt = Adam(filter(lambda p: p.requires_grad,net.parameters()), lr=0.0001)
    opt = Adam(net_s.parameters(), lr=0.0001)
    net_s.train()
    net_t.eval()
    for epoch in range(100):
        for step, batch in enumerate(train_loader):
            opt.zero_grad()
            image, target = batch
            image = image.cuda()
            target = target.cuda()
            out_s, out_t = net_s(image), net_t(image)
            loss_init = CrossEntropyLoss()(out_s, target)
            loss_kd = KD_loss(T=4)(out_s, out_t)
            loss = loss_init + loss_kd
            # prec1, prec5 = accuracy(predict, target, topk=(1, 5))
            # print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
            loss.backward()
            opt.step()
        print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
        test(net_s, test_loader)

    torch.save(net_s.state_dict(), 'resnet9_cifar100_kd.pth')


def main():
    net_t = ResNet9(3, 100)
    net_s = ResNet9_2(3, 100)
    net_t = net_t.cuda()
    net_s = net_s.cuda()
    load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
    load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
    # for name, value in net.named_parameters():
    #     if 'fc' not in name:
    #         value.requires_grad = False

  
    train(net_s, net_t, trainloader, testloader)


if __name__ == "__main__":
    current_time=time.time()
    main()
    time_train = time.time() - current_time
    print('Training time: {:.2f} s'.format(time_train))


epoch:0, loss_init: 1.574031114578247, loss_kd: 5.030104637145996, loss_all:6.604135513305664
Acc_student:0.46089285612106323, Acc_teacher: 0.7384821772575378
epoch:1, loss_init: 1.6289645433425903, loss_kd: 4.643418312072754, loss_all:6.272382736206055
Acc_student:0.45750004053115845, Acc_teacher: 0.7378571629524231
epoch:2, loss_init: 1.6621828079223633, loss_kd: 4.825350284576416, loss_all:6.487533092498779
Acc_student:0.4616071581840515, Acc_teacher: 0.7383928298950195
epoch:3, loss_init: 1.517826795578003, loss_kd: 4.611663818359375, loss_all:6.129490852355957
Acc_student:0.4641071557998657, Acc_teacher: 0.7397322058677673
epoch:4, loss_init: 1.6665574312210083, loss_kd: 4.572325229644775, loss_all:6.238882541656494
Acc_student:0.46223217248916626, Acc_teacher: 0.7382143139839172
epoch:5, loss_init: 1.5219961404800415, loss_kd: 4.449599742889404, loss_all:5.971595764160156
Acc_student:0.4613392949104309, Acc_teacher: 0.7377678751945496
epoch:6, loss_init: 1.5658502578735352, loss_

Acc_student:0.4794642925262451, Acc_teacher: 0.7466964721679688
epoch:52, loss_init: 1.2838106155395508, loss_kd: 3.6270580291748047, loss_all:4.9108686447143555
Acc_student:0.4853571653366089, Acc_teacher: 0.7477678656578064
epoch:53, loss_init: 1.5162208080291748, loss_kd: 4.031732559204102, loss_all:5.5479536056518555
Acc_student:0.4784821569919586, Acc_teacher: 0.7475000023841858
epoch:54, loss_init: 1.4713636636734009, loss_kd: 3.7979698181152344, loss_all:5.269333362579346
Acc_student:0.4843750298023224, Acc_teacher: 0.7495535612106323
epoch:55, loss_init: 1.3179320096969604, loss_kd: 3.624600410461426, loss_all:4.942532539367676
Acc_student:0.47830358147621155, Acc_teacher: 0.7461607456207275
epoch:56, loss_init: 1.2564806938171387, loss_kd: 3.6080408096313477, loss_all:4.864521503448486
Acc_student:0.4794642925262451, Acc_teacher: 0.7446428537368774
epoch:57, loss_init: 1.2513408660888672, loss_kd: 3.8414080142974854, loss_all:5.092748641967773
Acc_student:0.4806250333786011, A