In [None]:
import wandb
wandb.login()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision import models
from torch.nn.utils import prune
import numpy as np
import copy

In [None]:
# MODEL-M1
class Model1(nn.Module):
    def __init__(self):
        super(Model1, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=(5,5))
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5,5))
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)            

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# Resnet18 - Uncomment the below code to initialize renset18 as M1 model
# M1 = models.resnet18()
# num_ftrs = M1.fc.in_features
# M1.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# M1.fc = nn.Linear(num_ftrs, 10)

In [None]:
def accuracy(model, loader):
    model.eval()
    total_acc = []
    with torch.no_grad():
        for image, label in loader:
            image = image.to(device)
            label = label.to(device)
            outputs = model(image)
            
            _, preds = torch.max(outputs, 1)
            acc = torch.tensor(torch.sum(preds==label).item() / len(preds))
            total_acc.append(acc)
            
        avg_acc = torch.stack(total_acc).mean()
        return (avg_acc*100).item()

In [None]:
def prune_model_global_unstructured(model, amount):
    model_copy = copy.deepcopy(model)
    module_tups = []
    for module in model_copy.modules():
         if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            module_tups.append((module, 'weight'))
            module_tups.append((module, 'bias'))

    prune.global_unstructured(
        parameters=module_tups, pruning_method=prune.L1Unstructured,
        amount=amount,
    )
    
    weight_mask = list(model_copy.named_buffers())
    
    for module, name in module_tups:
        prune.remove(module, name) 
        
    return model_copy, weight_mask

In [None]:
# CIFAR10 dataset

# Transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# Dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform = transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform = transform)

# Train-val split
torch.manual_seed(39)
val_size = int(0.1*len(train_dataset))
train_size = len(train_dataset) - val_size
train, val = torch.utils.data.random_split(train_dataset, [train_size, val_size])

In [None]:
# Dataloader
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
batch_size = 8

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=4)

In [None]:
# Load Model1 from Task_1
M1 = M1.to(device)
M1.load_state_dict(torch.load('save_model/Task_1/M1_weights.pth', map_location=device))

In [None]:
wandb.init(
    project='VCL_Tasks',
    group = 'Task_3',
)

for amount in np.arange(1, 10, 1)/10:
    M2, weight_mask = prune_model_global_unstructured(M1, amount)
    
    train_accuracy = accuracy(M2, train_loader)
    val_accuracy = accuracy(M2, val_loader)
    test_accuracy = accuracy(M2,  test_loader)
    
    print(f'Train Accuracy:{train_accuracy}, Validation Accuracy:{val_accuracy}, Test Accuracy:{test_accuracy}, Sparsity:{amount}')
    wandb.log({'Task_3/Train Accuracy':train_accuracy, 'Task_3/Validation Accuracy':val_accuracy, 'Task_3/Test Accuracy':test_accuracy, 'Task_3/Sparsity':amount})
    
    checkpoint = {
    f'M2_sparsity_{amount*100}_weights.pth':M2.state_dict(),
    'weight_mask':weight_mask,
     }
    torch.save(checkpoint, f'save_model/Task_3/M2_sparsity_{amount*100}.tar')
    
wandb.finish()