## Pruning TinyYolov2

In [3]:
import numpy as np
import torch
from torch import nn
from torch.nn.utils import prune
import torch.nn.functional as F

from models.q_tinyyolov2 import QTinyYoloV2
from models.tinyyolov2 import TinyYoloV2

from utils.ap import ap
from utils.dataloader import VOCDataLoaderPerson
from utils.yolo import test

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

loader = VOCDataLoaderPerson(train=True, batch_size=128, shuffle=True)
loader_test = VOCDataLoaderPerson(train=False, batch_size=1)

In [4]:
def check_sparsity(net):
    sd = net.state_dict()

    sums = 0
    nelem_count = 0
    for key, value in sd.items():
        if "conv" in key and "weight" in key:
            sum_i = torch.sum(value == 0)
            nelem_i = value.nelement()
            print(
                "Sparsity in {:s}: {:.2f}%".format(
                    key,
                    100. * float(sum_i)
                    / float(nelem_i)
                )
            )
            sums += sum_i
            nelem_count += nelem_i


    print(
        "Global sparsity: {:.2f}%".format(
            100. * float(sums)
            / float(nelem_count)
        )
    )

#### Unstructured L1 pruning

In [5]:
ratios = np.arange(start=0.1, stop=0.3, step=0.1)
final_net = None
final_ratio = 0.1
max_ap = 0
for i, ratio in enumerate(ratios):
    net = TinyYoloV2(num_classes=1)
    sd = torch.load("models/configs/voc_finetuned.pt")
    net.load_state_dict(sd)
    print(f"-------------Prune {ratio*100}% of connections-------------")
    for name, module in net.named_modules():
        # prune ratio of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=ratio)

    print(dict(net.named_buffers()).keys())  # to verify that all masks exist

    # make pruning permanent
    for name, module in net.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.remove(module, "weight")
    print(dict(net.named_buffers()).keys()) # to verify the final modules

    check_sparsity(net)

    test_precision, test_recall = test(net, loader_test)
    avg_precision = ap(test_precision, test_recall)
    print("AP:", avg_precision)
    # save best net so far
    if avg_precision > max_ap:
        print("Saved net!")
        max_ap = avg_precision
        final_ratio = ratio
        final_net = net

if final_net:
    check_sparsity(final_net)



-------------Prune 10.0% of connections-------------
dict_keys(['anchors', 'conv1.weight_mask', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'conv2.weight_mask', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'conv3.weight_mask', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'conv4.weight_mask', 'bn4.running_mean', 'bn4.running_var', 'bn4.num_batches_tracked', 'conv5.weight_mask', 'bn5.running_mean', 'bn5.running_var', 'bn5.num_batches_tracked', 'conv6.weight_mask', 'bn6.running_mean', 'bn6.running_var', 'bn6.num_batches_tracked', 'conv7.weight_mask', 'bn7.running_mean', 'bn7.running_var', 'bn7.num_batches_tracked', 'conv8.weight_mask', 'bn8.running_mean', 'bn8.running_var', 'bn8.num_batches_tracked', 'conv9.weight_mask'])
dict_keys(['anchors', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batch

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
 21%|██        | 465/2232 [00:54<03:28,  8.47it/s]


KeyboardInterrupt: 

#### Random 