In [None]:
import os
import json
from torch.utils.data import DataLoader
import torch
from kneed import KneeLocator
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import torch_pruning as pruning
from model import BasicBlock
import numpy as np

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def evaluteTop1(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        # correct += torch.eq(pred, y).sum().item()
    print(correct / total)
    return correct / total

In [None]:
batch_size = 512
num_epoches = 20


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(root='././data', train=True, transform=transform_train, download=True)
test_dataset = datasets.CIFAR10(root='././data', train=False, transform=transform_test, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import math

teacher_model = torch.load("Resnet34_cifar10_original.pth")

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach() 
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')
        
def student_kd_main(model):
    epochs = 10
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(),1e-4)
    best_acc = 0
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        acc = evaluteTop1(model,test_loader)
        if acc > best_acc:
            best_acc = acc
            print(best_acc)
            torch.save(model, "ResNet34_3_10_A.pth")


In [None]:
def eval_pruning_strategy(model, dataloader_train):
   # Adaptive-BN
   model.train()
   max_iter = 100
   with torch.no_grad():
      for iter_in_epoch, (data1,label1) in enumerate(dataloader_train):
            data1, label1 = data1.cuda(), label1.cuda()
            model.forward(data1)
            if iter_in_epoch > max_iter:
                break

def prune_conv(model,conv, pruned_prob):
    weight = conv.weight.detach().cpu().numpy()
    out_channels = weight.shape[0]
    L1_norm = np.sum(weight, axis=(1, 2, 3))
    num_pruned = int(out_channels * pruned_prob)
    prune_index = np.argsort(L1_norm)[:num_pruned].tolist()  # remove filters with small L1-Norm
    DG = pruning.DependencyGraph().build_dependency(model, torch.randn(1, 3, 32, 32))
    plan = DG.get_pruning_plan(conv, pruning.prune_conv, prune_index)
    plan.exec()

def prune_BasicBlock(model,n):
    acc_rate = []
    cut_rate = []
    for prune_rate in range(20):
        prune_rate = prune_rate/20
        cut_rate.append(prune_rate)
        model_temp = torch.load("ResNet34_3_10_A.pth", map_location=device)
        temp_n=0
        for m in model_temp.modules():
            if isinstance(m, BasicBlock) :
                if temp_n == n:
                    temp_block = m
                    break
                temp_n=temp_n+1
        prune_conv(model_temp,temp_block.conv1, prune_rate)
        model_temp = model_temp.cuda()
        eval_pruning_strategy(model_temp, train_loader)
        acc_rate.append(evaluteTop1(model_temp, test_loader))
    if acc_rate[0]-acc_rate[19]<=0.05:
        real_prune_rate= cut_rate[19]
        real_acc_rate = acc_rate[19]
    else:
        kneedle_con_dec = KneeLocator(cut_rate,
                                      acc_rate,
                                      curve='concave',
                                      direction='decreasing',
                                      online=True)

        real_prune_rate=kneedle_con_dec.knee
        real_acc_rate = kneedle_con_dec.knee_y
    print("Pruning rate is: ",real_prune_rate)
    plt.plot(cut_rate, acc_rate, 'g--')
    plt.scatter(x=real_prune_rate, y=real_acc_rate, c='b', s=200, marker='^', alpha=1)
    plt.title('concave+decreasing')
    plt.show()
    
    model = torch.load("ResNet34_3_10_A.pth", map_location=device)
    temp_n=0
    for m in model.modules():
        if isinstance(m, BasicBlock) :
            if temp_n == n:
                temp_block = m
                break
            temp_n=temp_n+1
    prune_conv(model,temp_block.conv1, real_prune_rate) 
    model=model.cuda()
    eval_pruning_strategy(model, train_loader)
    evaluteTop1(model, test_loader)
    
    student_kd_main(model)   
    student_simple_model = torch.load("ResNet34_3_10_A.pth", map_location=device)
    
    acc_rate = []
    cut_rate = []
    for prune_rate in range(20):
        prune_rate = prune_rate/20
        cut_rate.append(prune_rate)
        model_temp = torch.load("ResNet34_3_10_A.pth", map_location=device)
        temp_n=0
        for m in model_temp.modules():
            if isinstance(m, BasicBlock) :
                if temp_n == n:
                    temp_block = m
                    break
                temp_n=temp_n+1
        prune_conv(model_temp,temp_block.conv2, prune_rate)
        model_temp = model_temp.cuda()
        eval_pruning_strategy(model_temp, train_loader)
        acc_rate.append(evaluteTop1(model_temp, test_loader))
        
    if acc_rate[0]-acc_rate[19]<=0.05:
        real_prune_rate= cut_rate[19]
        real_acc_rate = acc_rate[19]
    else:
        kneedle_con_dec = KneeLocator(cut_rate,
                                      acc_rate,
                                      curve='concave',
                                      direction='decreasing',
                                      online=True)

        real_prune_rate=kneedle_con_dec.knee
        real_acc_rate = kneedle_con_dec.knee_y
    print("Pruning rate is: ",real_prune_rate)
    plt.plot(cut_rate, acc_rate, 'g--')
    plt.scatter(x=real_prune_rate, y=real_acc_rate, c='b', s=200, marker='^', alpha=1)
    plt.title('concave+decreasing')
    plt.show()
    
    model = torch.load("ResNet34_3_10_A.pth", map_location=device)
    temp_n=0
    for m in model.modules():
        if isinstance(m, BasicBlock) :
            if temp_n == n:
                temp_block = m
                break
            temp_n=temp_n+1
    print("Pruning rate is: ",real_prune_rate)
    prune_conv(model,temp_block.conv2, real_prune_rate) 
    model=model.cuda()
    eval_pruning_strategy(model, train_loader)
    evaluteTop1(model, test_loader)
    
    student_kd_main(model)
    student_simple_model = torch.load("ResNet34_3_10_A.pth", map_location=device)
    print(evaluteTop1(student_simple_model, test_loader)*100)
    
    
    
        

In [None]:
# load model weights
weights_path = "Resnet34_cifar10_original.pth"
#     assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model = torch.load(weights_path, map_location=device)
torch.save(model, "ResNet34_3_10_A.pth")

n=7
for m in model.modules():
    if isinstance(m,BasicBlock):
        prune_BasicBlock(model,n)
        n = n + 1
        if n > 15:
            break