In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import common_utils as utils
import wandb
from collections import defaultdict
import torch_pruning as tp
import gc
import matplotlib.pyplot as plt
import pickle
print(utils.device)

In [None]:
# Make sure you change these if you choose a different checkpoint!!
checkpoint_path = "base_model/checkpoint_6.pth"

In [None]:
transform = transforms.Compose(
    [
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])


trainset = torchvision.datasets.FER2013(root='./', split="train",
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.FER2013(root='./', split="test",
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=8)

In [None]:
def compute_sparsity(model):
    for name, param in model.named_parameters():
        print(
            "Sparsity in {0}: {1:.2f}%".format(name,
                100. * float(torch.sum(param == 0))
                / float(param.nelement())
            ), "\nParam Element Count: {0}\n".format(param.nelement())
        )

In [None]:
print("CUDA Reserved Memory: ", torch.cuda.memory_reserved())
layer_sensitivity = defaultdict(list)
example_inputs = torch.randn(1, 1, 96, 96).to(utils.device)
imp = tp.importance.GroupMagnitudeImportance(p=2) 

def get_prunable_modules(model):
    prunable_modules = []
    for name, module in model.named_children():
        if (isinstance(module, nn.MaxPool2d) or (isinstance(module, nn.Linear) and module.out_features == 7)):
            continue
        else:
            prunable_modules.append((name, module))
    return prunable_modules
pruned_model = utils.BaseModel()
num_prunable_params = len(get_prunable_modules(pruned_model))

for param_ind in range(num_prunable_params):
    for i in range(10, 100, 10):
        print("CUDA Reserved Memory: ", torch.cuda.memory_reserved())
        sparsity = i / 100
        pruned_model = utils.BaseModel()
        # Checking training graphs this is around test/loss starts climbing (though accuracy still improves)
        # Make sure you change these if you choose a different checkpoint!!!!!!!
        pruned_model.load_state_dict(torch.load(checkpoint_path))
        pruned_model.to(utils.device)
        module_name, module = get_prunable_modules(pruned_model)[param_ind]
        print("Pruning {0} with sparsity {1}".format(module_name, sparsity))
        pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
            pruned_model,
            example_inputs,
            importance=imp,
            pruning_ratio=0.0,
            pruning_ratio_dict = {module: sparsity}, # customized pruning ratios for layers or blocks
            # ignored_layers=ignored_layers,
            round_to=1,
        )
        base_macs, base_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
        # tp.utils.print_tool.before_pruning(pruned_model) # or print(model)
        pruner.step()
        # tp.utils.print_tool.after_pruning(pruned_model) # or print(model), this util will show the difference before and after pruning
        macs, nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
        print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

        loss, acc = utils.test(pruned_model, testloader)
        layer_sensitivity[module_name].append((loss, acc))
        print("Post Pruning Loss: {0} Accuracy: {1}\n".format(loss, acc))
        del pruned_model
        gc.collect()
        torch.cuda.empty_cache()
        
with open("sensitivity.pkl", "wb") as file:
    pickle.dump(layer_sensitivity, file)

In [None]:
with open("sensitivity.pkl", "rb") as file:
    layer_sensitivity = pickle.load(file)
fig, axs = plt.subplots(1, 2, figsize=(30, 12))


for (key, values) in layer_sensitivity.items():
    losses = [e[0] for e in values]
    acc = [e[1] for e in values]

    axs[0].plot(list(range(10, 100, 10)), losses, label = key)
    axs[1].plot(list(range(10, 100, 10)), acc, label=key)

axs[0].set_title("Loss Vs Pruning Ratio")
axs[0].set_ylabel("Loss")
axs[0].set_xlabel("Pruning Ratio")

axs[1].set_title("Accuracy Vs Pruning Ratio")
axs[1].set_ylabel("Accuracy")
axs[1].set_xlabel("Pruning Ratio")

axs[0].legend()
axs[1].legend()
plt.tight_layout()
fig.savefig("images/sensitivity_plot.png")

In [None]:
accuracy_threshold = 50
layer_to_pruning_threshold = {}

for (key, values) in layer_sensitivity.items():
    layer_to_pruning_threshold[key] = 0
    losses = [e[0] for e in values]
    acc = [e[1] for e in values]
    for i in range(len(acc)):
        if acc[i] > accuracy_threshold:
            layer_to_pruning_threshold[key] = i / 10
print(layer_to_pruning_threshold)


In [None]:
pruned_model = utils.BaseModel()
pruned_model.load_state_dict(torch.load("base_model/checkpoint_6.pth"))
pruned_model.to(utils.device)
example_inputs = torch.randn(1, 1, 96, 96).to(utils.device)


# 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score
imp = tp.importance.GroupMagnitudeImportance(p=2) 

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in pruned_model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 7:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruning_ratio_dict = {}
for (key, value) in layer_to_pruning_threshold.items():
    pruning_ratio_dict[eval("pruned_model." + key)] = value

pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
    pruned_model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.0, 
    pruning_ratio_dict=pruning_ratio_dict,
    ignored_layers=ignored_layers,
    round_to=1,
)

# 3. Prune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
tp.utils.print_tool.before_pruning(pruned_model) # or print(model)
pruner.step()
tp.utils.print_tool.after_pruning(pruned_model) # or print(model), this util will show the difference before and after pruning
macs, nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

In [None]:
loss, acc = utils.test(pruned_model, testloader)
print("Post Pruning Finetuning Loss: {0} Accuracy: {1}\n".format(loss, acc))

In [None]:
run = wandb.init(project="hpml-final", name="Pruned Model Finetuning")
params = utils.TrainingParams()
params.lr = 0.00001
params.checkpoint = True
params.dir_name = "pruned_model"
params.save_state_dict = False
utils.train(run, pruned_model, params, trainloader, testloader, 10)
run.finish()