In [None]:
import torch
import torch.nn as nn
import torch_pruning as pruning
import torch.nn.functional as F
import math
from kneed import KneeLocator
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import numpy as np
import random as rd
import matplotlib.pyplot as plt

In [None]:
batch_size = 512
num_epoches = 10

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            # 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 2
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 4
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 5
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 6
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 7
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 8
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 9
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 10
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 11
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 12
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 13
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1),
        )
        self.classifier = nn.Sequential(
            # 14
            nn.Linear(512, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            # 15
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            # 16
            nn.Linear(4096, num_classes),
        )
        # self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        #        print(out.shape)
        out = out.view(out.size(0), -1)
        #        print(out.shape)
        out = self.classifier(out)
        #        print(out.shape)
        return out

In [None]:
def evaluteTop1(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        # correct += torch.eq(pred, y).sum().item()
    print(correct / total)
    return correct / total

#Adaptive BN
def eval_pruning_strategy(model, dataloader_train):
   # Adaptive-BN
   model.train()
   max_iter = 100
   with torch.no_grad():
      for iter_in_epoch, (data1,label1) in enumerate(dataloader_train):
            data1, label1 = data1.cuda(), label1.cuda()
            model.forward(data1)
            if iter_in_epoch > max_iter:
                break

#Return Conv_Rank
def calculate_conv_weight_idx(model,layer_number):
    prunable_module_type = (nn.Conv2d)
    prunable_modules = [m for m in model.modules() if isinstance(m, prunable_module_type)]
    layer_to_calculate = prunable_modules[layer_number]
    conv_weight = layer_to_calculate.weight
    total_cov_weight = []
    for i in range(conv_weight.__len__()):
        filter_now = conv_weight[i]
        filter_view = torch.reshape(input=filter_now, shape=(1, -1))
        total_weight = sum(sum(abs(filter_view)))
        total_weight = total_weight.cpu()
        total_weight_list = total_weight.detach().numpy().tolist()
        total_cov_weight.append(total_weight_list)
    sorted_nums = sorted(enumerate(total_cov_weight), key=lambda x: x[1])
    idx = [i[0] for i in sorted_nums]
    nums = [i[1] for i in sorted_nums]
    return idx

In [None]:
#knowledge_distillation
teacher_model = torch.load("VGGNet16_cifar10_original.pth")

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()  
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

def student_kd_main():
    model = torch.load("model_pruned.pth")
    epochs = 5
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #optimizer
    optimizer = torch.optim.Adam(model.parameters(),1e-4)
    student_history = []
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, trainloader, optimizer, epoch)
        acc=evaluteTop1(model,testloader)
        student_history.append(acc)
    torch.save(model, "model_pruned.pth")
    return model, student_history

# Renturn layer number
def chooselayer(model,layer):
    prunable_module_type = (nn.Conv2d)
    prunable_modules = [m for m in model.modules() if isinstance(m, prunable_module_type)]
    layer_to_calculate = prunable_modules[layer]
    return layer_to_calculate

In [None]:
# Save new model  
weights_path = "VGGNet16_cifar10_original.pth"
model = torch.load(weights_path, map_location=device)
torch.save(model, "model_pruned.pth")

for layer in range(13):
    rate_conv = []
    acc_cov = []
    idx_now=calculate_conv_weight_idx(model,layer)
    for temp in range(20):
        print("Layer is:",layer,", epoch is :",temp)        
        DG = pruning.DependencyGraph()
        DG.build_dependency(model_temp, example_inputs=torch.randn(1,3,32,32))
        weights_path = "model_pruned.pth"
        model_temp = torch.load(weights_path, map_location=device)
        layer_to_calculate = chooselayer(model_temp,layer)
        ch = pruning.utils.count_prunable_channels(layer_to_calculate)
        prune_acc=temp*5/100
        idx_number=int(prune_acc*ch)
        weight_idx = idx_now[0:idx_number]
        pruning_plan = DG.get_pruning_plan(layer_to_calculate, pruning.prune_conv, idxs=weight_idx )
        pruning_plan.exec()
        model_temp = model_temp.cuda()
        #Adaptive BN
        eval_pruning_strategy(model_temp, trainloader)
        model_temp = model_temp.cuda()
        temp_rate = temp / 20
        rate_conv.append(temp_rate*100)
        acc_cov.append(evaluteTop1(model_temp, testloader)*100)

    #choose pruning rate 
    if acc_cov[0]-acc_cov[19]<=0.05:
        cut_rate=rate_conv[19]/100
    else:
        kneedle_con_dec = KneeLocator(rate_conv,
                                      acc_cov,
                                      curve='concave',
                                      direction='decreasing',
                                      online=True)

        cut_rate=kneedle_con_dec.knee/100
    plt.plot(rate_conv, acc_cov, 'k--')
    plt.scatter(x=kneedle_con_dec.knee, y=kneedle_con_dec.knee_y, c='b', s=200, marker='^', alpha=1)
    plt.title('concave+decreasing')
    plt.show()

    #true prning
    DG = pruning.DependencyGraph()
    DG.build_dependency(model, example_inputs=torch.randn(1,3,32,32))
    layer_to_calculate = chooselayer(model,layer)
    ch = pruning.utils.count_prunable_channels(layer_to_calculate)
    idx_number=int(cut_rate*ch)
    weight_idx = idx_now[0:idx_number]
    pruning_plan = DG.get_pruning_plan(layer_to_calculate, pruning.prune_conv, idxs=weight_idx )
    pruning_plan.exec()
    model = model.cuda()
    eval_pruning_strategy(model, trainloader)
    model = model.cuda()
    print("Top-1 accuracy after prune is: ",evaluteTop1(model, testloader)*100)

    torch.save(model, "model_pruned.pth")

    student_simple_model, student_simple_history = student_kd_main()

    print("Top-1 accuracy after retraning is:",evaluteTop1(student_simple_model, testloader)*100)


In [None]:
#An 120 epoch overall finetuning is needed