In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import tqdm
import copy
from tinyyolov2NoBN import TinyYoloV2NoBN
from pruned_tinyyolov2NoBN import PrunedTinyYoloV2NoBN 
from typing import Dict, List

from utils.ap import precision_recall_levels, ap, display_roc
from utils.yolo import nms, filter_boxes
from utils.loss import YoloLoss

from utils.dataloader import VOCDataLoaderPerson

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [6]:
loader = VOCDataLoaderPerson(train=True, batch_size=32, shuffle=True)
loader_test = VOCDataLoaderPerson(train=False, batch_size=1, shuffle=True)

In [7]:
def l1_structured_pruning(state_dict: Dict, prune_ratio: float) -> Dict:
    state_dict = copy.deepcopy(state_dict)
    
    for name, param in state_dict.items():
        # Only prune conv layers, excluding conv1 and conv9
        if 'conv' in name and 'weight' in name and 'conv1' not in name and 'conv9' not in name:
            weight = state_dict[name]
            out_channels = weight.shape[0]
            num_channels_to_prune = int(out_channels*prune_ratio)
            if num_channels_to_prune < 1:
                continue
            l1 = torch.sum(torch.abs(weight), (1, 2, 3))
            zero_out_channels = torch.argsort(l1)
            
            for i in range(num_channels_to_prune):
                weight[zero_out_channels[i], :, :, :] = 0
            
            state_dict[name] = weight
            
            bias_key = name.replace('weight', 'bias')
            bias = state_dict[bias_key]
            for i in range(num_channels_to_prune):
                bias[zero_out_channels[i]] = 0
            
            state_dict[bias_key] = bias
            
    return state_dict

In [8]:
def densify_state_dict(state_dict: Dict) -> Dict:
    state_dict = copy.deepcopy(state_dict)
    
    mask = None
    
    for layer in range(1, 10):
        l = "conv" + str(layer) + "."
        w = l + "weight"
        b = l + "bias"
        
        weights = state_dict[w]
        biases = state_dict[b]
        if mask is not None:
            weights = weights[:,mask,:,:]
            
        mask = torch.nonzero(weights.sum(dim=[1,2,3]), as_tuple=True)[0].tolist()
        
        weights = weights[mask,:,:,:]
        biases = biases[mask]
        
        state_dict[w] = weights
        state_dict[b] = biases
    
    #weights = state_dict["fc.weight"]
    
    #fcmask = []
    #for channel in mask:
    #    for i in range(4):
    #        fcmask.append(channel * 4 + i)
    
    #state_dict["fc.weight"] = weights[:,fcmask]
    
    return state_dict

In [5]:
net = PrunedTinyYoloV2NoBN(num_classes=1)
state_dict = torch.load("fusedyolov2_0717.pt", map_location=device)
net.load_state_dict(state_dict, strict=False)
net.to(device)

frozedLayers = []
lr = 0.001
weight_decay = 0.005

criterion = YoloLoss(anchors=net.anchors)


for key, param in net.named_parameters():
    if any(x in key for x in frozedLayers):
        param.requires_grad = False

optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=lr, weight_decay=weight_decay)

In [None]:
NUM_TEST_SAMPLES = 30
NUM_EPOCHS = 20
test_AP = []
num_iterations = 20
ratio = 0.1

for iteration in range(num_iterations):
    print(f"-----Pruning iteration {iteration+1}/{num_iterations}-----")
    state_dict = l1_structured_pruning(net.cpu().state_dict(), ratio)
    state_dict = densify_state_dict(state_dict)
    net.load_state_dict(state_dict)
    print(f"Pruning done.")
    
    for epoch in range(NUM_EPOCHS):
        if epoch >= 0:
            net.train()
            net.to(device)
            print("Training started.")
            for idx, (input, target) in tqdm.tqdm(enumerate(loader), total=len(loader)):
                input, target = input.to(device), target.to(device)
                optimizer.zero_grad()
                #Yolo head is implemented in the loss for training, therefore yolo=False
                output = net(input, yolo=False)
                loss, _ = criterion(output, target)
                loss.backward()
                optimizer.step()
            
        test_precision = []
        test_recall = []
        net.eval()
        print("Validation started.")
        with torch.no_grad():
            for idx, (input, target) in tqdm.tqdm(enumerate(loader_test), total=NUM_TEST_SAMPLES):
                input, target = input.to(device), target.to(device)
                output = net(input, yolo=True)
        
                #The right threshold values can be adjusted for the target application
                output = filter_boxes(output, 0.0)
                output = nms(output, 0.5)
        
                precision, recall = precision_recall_levels(target[0], output[0])
                test_precision.append(precision)
                test_recall.append(recall)
                if idx == NUM_TEST_SAMPLES:
                    break
                
        #Calculation of average precision with collected samples
        test_AP.append(ap(test_precision, test_recall))
        print('average precision', test_AP)

        #plot ROC
        display_roc(test_precision, test_recall)
    

-----Pruning iteration 1/20-----
Pruning done.
Training started.


  0%|          | 0/67 [00:00<?, ?it/s]

In [None]:
torch.save(state_dict, 'pruned_tinyyolov2NoBN.pt')