In [2]:
from torch import nn
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
from torchvision import datasets
import torch 
from torch.utils.data import DataLoader
import tqdm 
import os
from torchmetrics import Accuracy
import json 

In [3]:
def customized_resnet50_for_cifar10(num_class):
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        backbone = list(model.children())[:-1]
        return nn.Sequential(
            *backbone, 
            nn.Flatten(),
            nn.Linear(2048, num_class, bias=True),
            nn.Softmax(dim=1)
        )
    
def load_dataset(batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

    return trainloader, testloader

In [4]:
def train_resnet50_on_cifar10(
            model: nn.Module, 
            accuracy: Accuracy, 
            train_loader: DataLoader, 
            val_loader: DataLoader, 
            epochs: int, 
            lr: float=1e-4, 
            device: torch.device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 
            artifact_path: str="./artifacts"
    ):
        learning_history = {"train_acc": [], "test_acc": []}
        start_epoch = 1
        best_val_acc = 0
        os.makedirs(artifact_path, exist_ok=True)
        model.to(device)
        accuracy.to(device)

        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        checkpoint_path = os.path.join(artifact_path, "checkpoints.pt")
        if os.path.isfile(checkpoint_path):
            checkpoints = torch.load(checkpoint_path)
            model.load_state_dict(checkpoints["model"])
            optimizer.load_state_dict(checkpoints["optimizer"])
            start_epoch = checkpoints["start_epoch"]
            best_val_acc = checkpoints["best_val_acc"]
            with open(os.path.join(artifact_path, "learning_history.json")) as history:
                 learning_history = json.loads(history)
                 
        loss_fn = nn.CrossEntropyLoss()

        for epoch in range(start_epoch, epochs + 1):
              
            model.train()
            accuracy.reset()
            train_acc = 0

            pbar = tqdm.tqdm(train_loader, total=len(train_loader), desc=f"Training - {epoch} / {epochs}")
            for x, y in pbar:
                
                model.zero_grad(set_to_none=True)

                x, y = x.to(device), y.to(device)
                y_hat = model(x)

                loss = loss_fn(y_hat, y)
                loss.backward()
                optimizer.step()

                accuracy.update(y_hat, y)
                acc = accuracy.compute().item()
                train_acc += acc / len(train_loader)

                pbar.set_postfix(loss=loss.item(), accuracy=acc)
            
            learning_history["train_acc"].append(train_acc)

            test_acc = evaluate(model, val_loader, accuracy, device)
            learning_history["test_acc"].append(test_acc)

            if test_acc > best_val_acc:
                 best_val_acc = test_acc
                 torch.save(model.state_dict(), os.path.join(artifact_path, "best.pt"))
            
            with open(os.path.join(artifact_path, "learning_history.json"), "w") as history:
                 json.dump(learning_history, history, indent=4)
                 
            torch.save({
                 "model": model.state_dict(), 
                 "optimizer": optimizer.state_dict(), 
                 "start_epoch": start_epoch, 
                 "best_val_acc": best_val_acc
            }, os.path.join(artifact_path, "checkpoints.pt")
            )
            
@torch.no_grad()
def evaluate(model, val_loader, metric, device):
    model.eval()
    metric.reset()
    model.to(device)
    metric.to(device)

    pbar = tqdm.tqdm(val_loader, desc="Test: ", total=len(val_loader))

    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        y_hat = model(x)
        metric.update(y_hat, y)
    
    acc = metric.compute().item()
    print(f"----------------------{metric.compute().item()}--------------------------")
    pbar.close()
    return acc


In [7]:
batch_size = 128
epochs = 20
lr = 1e-4
model = customized_resnet50_for_cifar10(num_class=10)
accuracy = Accuracy(task="multiclass", num_classes=10)

train_loader, val_loader = load_dataset(batch_size=batch_size)
train_resnet50_on_cifar10(model=model, 
                          accuracy=accuracy,
                          train_loader=train_loader, 
                          val_loader=val_loader, 
                          epochs=epochs)

Files already downloaded and verified
Files already downloaded and verified


Training - 1 / 20: 100%|██████████| 391/391 [01:40<00:00,  3.87it/s, accuracy=0.528, loss=1.81] 
Test: 100%|██████████| 79/79 [00:19<00:00,  4.13it/s]


----------------------0.7580999732017517--------------------------


Training - 2 / 20: 100%|██████████| 391/391 [01:03<00:00,  6.15it/s, accuracy=0.753, loss=1.73]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.87it/s]


----------------------0.7921000123023987--------------------------


Training - 3 / 20: 100%|██████████| 391/391 [01:02<00:00,  6.21it/s, accuracy=0.785, loss=1.69]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


----------------------0.8136000037193298--------------------------


Training - 4 / 20: 100%|██████████| 391/391 [01:04<00:00,  6.08it/s, accuracy=0.806, loss=1.62]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


----------------------0.8278999924659729--------------------------


Training - 5 / 20: 100%|██████████| 391/391 [01:04<00:00,  6.04it/s, accuracy=0.819, loss=1.68]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


----------------------0.8360000252723694--------------------------


Training - 6 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.90it/s, accuracy=0.83, loss=1.63] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.80it/s]


----------------------0.8406999707221985--------------------------


Training - 7 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.91it/s, accuracy=0.836, loss=1.69]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.83it/s]


----------------------0.8442000150680542--------------------------


Training - 8 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.81it/s, accuracy=0.846, loss=1.61]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.76it/s]


----------------------0.855400025844574--------------------------


Training - 9 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.84it/s, accuracy=0.852, loss=1.61]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


----------------------0.8532000184059143--------------------------


Training - 10 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.91it/s, accuracy=0.857, loss=1.57]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


----------------------0.8560000061988831--------------------------


Training - 11 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.92it/s, accuracy=0.864, loss=1.63]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.87it/s]


----------------------0.8604999780654907--------------------------


Training - 12 / 20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s, accuracy=0.868, loss=1.62]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.94it/s]


----------------------0.8616999983787537--------------------------


Training - 13 / 20: 100%|██████████| 391/391 [01:04<00:00,  6.05it/s, accuracy=0.873, loss=1.52]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.95it/s]


----------------------0.8671000003814697--------------------------


Training - 14 / 20: 100%|██████████| 391/391 [01:03<00:00,  6.14it/s, accuracy=0.877, loss=1.61]
Test: 100%|██████████| 79/79 [00:19<00:00,  3.95it/s]


----------------------0.8719000220298767--------------------------


Training - 15 / 20: 100%|██████████| 391/391 [01:03<00:00,  6.12it/s, accuracy=0.88, loss=1.56] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


----------------------0.8690999746322632--------------------------


Training - 16 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s, accuracy=0.881, loss=1.54]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


----------------------0.8783000111579895--------------------------


Training - 17 / 20: 100%|██████████| 391/391 [01:05<00:00,  5.96it/s, accuracy=0.886, loss=1.61]
Test: 100%|██████████| 79/79 [00:31<00:00,  2.47it/s]


----------------------0.8726999759674072--------------------------


Training - 18 / 20: 100%|██████████| 391/391 [01:09<00:00,  5.60it/s, accuracy=0.89, loss=1.54] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


----------------------0.8773999810218811--------------------------


Training - 19 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.89it/s, accuracy=0.891, loss=1.53]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8784000277519226--------------------------


Training - 20 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.71it/s, accuracy=0.895, loss=1.59]
Test: 100%|██████████| 79/79 [00:21<00:00,  3.72it/s]


----------------------0.8812000155448914--------------------------


In [5]:
from NISPPruner import NISPPruner

accuracy = Accuracy(task="multiclass", num_classes=10)
train_loader, val_loader = load_dataset(batch_size=128)

model = customized_resnet50_for_cifar10(num_class=10)
model.load_state_dict(torch.load("./artifacts/best.pt", weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pruner = NISPPruner(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    accuracy=accuracy, 
    device=device
)
pruner.compute_importance_scores()
pruner.apply_pruning(prune_ratio=0.75)
pruner.fine_tune(learning_rate=1e-4)


Files already downloaded and verified
Files already downloaded and verified


Neural Importance Score Computation: 100%|██████████| 391/391 [00:59<00:00,  6.60it/s]


Total layers with computed importance scores: 55
Pruning applied to layer 0.weight with threshold: 0.0007268773042596877
Pruning applied to layer 4.0.conv1.weight with threshold: 0.0004790328675881028
Pruning applied to layer 4.0.conv2.weight with threshold: 5.119311754242517e-05
Pruning applied to layer 4.0.conv3.weight with threshold: 0.00020138919353485107
Pruning applied to layer 4.0.downsample.0.weight with threshold: 0.00046223984099924564
Pruning applied to layer 4.1.conv1.weight with threshold: 8.674434502609074e-05
Pruning applied to layer 4.1.conv2.weight with threshold: 7.297771662706509e-05
Pruning applied to layer 4.1.conv3.weight with threshold: 0.00020080836839042604
Pruning applied to layer 4.2.conv1.weight with threshold: 9.889837383525446e-05
Pruning applied to layer 4.2.conv2.weight with threshold: 9.610374399926513e-05
Pruning applied to layer 4.2.conv3.weight with threshold: 0.00018216246098745614
Pruning applied to layer 5.0.conv1.weight with threshold: 0.00017496

Fine tuning: 1 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.77it/s, accuracy=0.816, loss=1.6] 
Test: 100%|██████████| 79/79 [00:22<00:00,  3.51it/s]


----------------------0.8384000062942505--------------------------


Fine tuning: 2 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.71it/s, accuracy=0.809, loss=1.6] 
Test: 100%|██████████| 79/79 [00:21<00:00,  3.69it/s]


----------------------0.8166999816894531--------------------------


Fine tuning: 3 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, accuracy=0.816, loss=1.63]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


----------------------0.826200008392334--------------------------


Fine tuning: 4 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.87it/s, accuracy=0.828, loss=1.62]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.86it/s]


----------------------0.8295999765396118--------------------------


Fine tuning: 5 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.87it/s, accuracy=0.831, loss=1.6] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8367999792098999--------------------------


Fine tuning: 6 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s, accuracy=0.835, loss=1.71]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.84it/s]


----------------------0.820900022983551--------------------------


Fine tuning: 7 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.86it/s, accuracy=0.832, loss=1.66]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8324999809265137--------------------------


Fine tuning: 8 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.80it/s, accuracy=0.838, loss=1.59]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.89it/s]


----------------------0.8328999876976013--------------------------


Fine tuning: 9 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.76it/s, accuracy=0.838, loss=1.58]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.85it/s]


----------------------0.8346999883651733--------------------------


Fine tuning: 10 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.91it/s, accuracy=0.839, loss=1.67]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8263000249862671--------------------------


Fine tuning: 11 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.80it/s, accuracy=0.841, loss=1.63]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8287000060081482--------------------------


Fine tuning: 12 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s, accuracy=0.843, loss=1.63]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.86it/s]


----------------------0.849399983882904--------------------------


Fine tuning: 13 / 20: 100%|██████████| 391/391 [01:06<00:00,  5.86it/s, accuracy=0.854, loss=1.62]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.88it/s]


----------------------0.8442000150680542--------------------------


Fine tuning: 14 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.83it/s, accuracy=0.852, loss=1.61]
Test: 100%|██████████| 79/79 [00:21<00:00,  3.72it/s]


----------------------0.8492000102996826--------------------------


Fine tuning: 15 / 20: 100%|██████████| 391/391 [01:09<00:00,  5.60it/s, accuracy=0.855, loss=1.6] 
Test: 100%|██████████| 79/79 [00:21<00:00,  3.67it/s]


----------------------0.8476999998092651--------------------------


Fine tuning: 16 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s, accuracy=0.857, loss=1.6] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.89it/s]


----------------------0.8457000255584717--------------------------


Fine tuning: 17 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s, accuracy=0.856, loss=1.61]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.77it/s]


----------------------0.8353999853134155--------------------------


Fine tuning: 18 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.70it/s, accuracy=0.853, loss=1.65]
Test: 100%|██████████| 79/79 [00:20<00:00,  3.78it/s]


----------------------0.8517000079154968--------------------------


Fine tuning: 19 / 20: 100%|██████████| 391/391 [01:08<00:00,  5.68it/s, accuracy=0.861, loss=1.61]
Test: 100%|██████████| 79/79 [00:21<00:00,  3.74it/s]


----------------------0.8508999943733215--------------------------


Fine tuning: 20 / 20: 100%|██████████| 391/391 [01:07<00:00,  5.75it/s, accuracy=0.86, loss=1.55] 
Test: 100%|██████████| 79/79 [00:20<00:00,  3.79it/s]


----------------------0.8440999984741211--------------------------
Pruning reparameterization removed from 0
Pruning reparameterization removed from 4.0.conv1
Pruning reparameterization removed from 4.0.conv2
Pruning reparameterization removed from 4.0.conv3
Pruning reparameterization removed from 4.0.downsample.0
Pruning reparameterization removed from 4.1.conv1
Pruning reparameterization removed from 4.1.conv2
Pruning reparameterization removed from 4.1.conv3
Pruning reparameterization removed from 4.2.conv1
Pruning reparameterization removed from 4.2.conv2
Pruning reparameterization removed from 4.2.conv3
Pruning reparameterization removed from 5.0.conv1
Pruning reparameterization removed from 5.0.conv2
Pruning reparameterization removed from 5.0.conv3
Pruning reparameterization removed from 5.0.downsample.0
Pruning reparameterization removed from 5.1.conv1
Pruning reparameterization removed from 5.1.conv2
Pruning reparameterization removed from 5.1.conv3
Pruning reparameterization 

In [6]:
def compute_pruning_percentage(original_model: nn.Module, pruned_model: nn.Module) -> float:
    """
    Computes the pruning percentage by comparing the number of non-zero 
    parameters in the original model and the pruned model.

    Parameters:
    original_model (nn.Module): The original model before pruning.
    pruned_model (nn.Module): The model after pruning.

    Returns:
    float: The percentage of weights that have been pruned.
    """
    original_non_zero = 0
    original_total = 0
    pruned_non_zero = 0
    pruned_total = 0

    for orig_param, pruned_param in zip(original_model.parameters(), pruned_model.parameters()):
        original_non_zero += orig_param.nonzero().size(0)
        original_total += orig_param.numel()
        pruned_non_zero += pruned_param.nonzero().size(0)
        pruned_total += pruned_param.numel()

    if original_total != pruned_total:
        raise ValueError("The original and pruned models do not have the same total number of parameters.")

    pruning_percentage = 100.0 * (original_non_zero - pruned_non_zero) / original_non_zero
    return pruning_percentage

In [9]:
pruned_model = customized_resnet50_for_cifar10(num_class=10)
orig_model = customized_resnet50_for_cifar10(num_class=10)
orig_model.load_state_dict(torch.load("./artifacts/best.pt", weights_only=True))
pruned_model.load_state_dict(torch.load("./pruned_artifacts/pruned_model.pt", weights_only=True))

print(compute_pruning_percentage(orig_model, pruned_model))

77.66045398006726
