In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
module_path = os.path.abspath(os.path.join('../'))
sys.path.insert(0, module_path)

In [2]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from causalpruner import best_device, get_causal_pruner, get_sgd_pruner
from tests.models import get_model
from tests.datasets import get_dataset
from tests.pruner.mag_pruner import MagPruner

In [20]:
root = '../data/'
train, test = get_dataset('cifar10', root)

batch_size = 4096
num_workers = 2

trainloader = DataLoader(
    train, batch_size=batch_size, shuffle=True, pin_memory=True,
    num_workers=num_workers, persistent_workers=True)
testloader = DataLoader(
    test, batch_size=batch_size, shuffle=False, pin_memory=True,
    num_workers=num_workers, persistent_workers=True)

Files already downloaded and verified
Data already transformed and saved
Files already downloaded and verified
Data already transformed and saved


In [4]:
def get_pruned_stats(model):
    for (name, param) in model.named_buffers():
        name = name.rstrip('.weight_mask')
        non_zero = torch.count_nonzero(param)
        total = torch.count_nonzero(torch.ones_like(param))
        pruned = total - non_zero
        frac = 100 * pruned / total
        print(f'Name: {name}; Total: {
            total}; non-zero: {non_zero}; pruned: {pruned}; percent: {frac:.4f}%')

In [5]:
def eval_model(model, device=best_device()):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            total += labels.size(0)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy: {accuracy}; Correct: {correct}; Total: {total}')

# SGD CausalPruner without momentum

In [7]:
momentum = 0
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 100
total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [8]:
model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=False,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)


pbar = tqdm(total=total_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()
cp1.start_pruning()
for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

100%|██████████| 210/210 [05:17<00:00,  1.51s/it]


In [9]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 866; pruned: 634; percent: 42.2667%
Name: conv2; Total: 25000; non-zero: 15209; pruned: 9791; percent: 39.1640%
Name: fc1; Total: 625000; non-zero: 390858; pruned: 234142; percent: 37.4627%
Name: fc2; Total: 5000; non-zero: 2948; pruned: 2052; percent: 41.0400%


In [10]:
eval_model(model1)

Accuracy: 0.3952; Correct: 3952; Total: 10000


In [22]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

  1%|          | 2/210 [00:01<02:09,  1.60it/s]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x76fc560a6700><function _MultiProcessingDataLoaderIter.__del__ at 0x76fc560a6700>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/aditya/miniconda3/envs/cpn/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
  File "/home/aditya/miniconda3/envs/cpn/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/home/aditya/miniconda3/envs/cpn/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
  File "/home/aditya/miniconda3/envs/cpn/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
        if w.is_alive():if w.is_alive():

              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/aditya/miniconda3/envs/cpn/li

In [23]:
eval_model(model2)

Accuracy: 0.4311; Correct: 4311; Total: 10000


In [30]:
mp1 = MagPruner(model2, amount=0.4)

In [31]:
mp1.compute_masks()

In [32]:
get_pruned_stats(model2)

Name: conv1; Total: 1500; non-zero: 839; pruned: 661; percent: 44.0667%
Name: conv2; Total: 25000; non-zero: 10348; pruned: 14652; percent: 58.6080%
Name: fc1; Total: 625000; non-zero: 378021; pruned: 246979; percent: 39.5166%
Name: fc2; Total: 5000; non-zero: 4692; pruned: 308; percent: 6.1600%


In [33]:
eval_model(model2)

Accuracy: 0.3198; Correct: 3198; Total: 10000


In [39]:
import torch.nn.utils.prune as prune
from causalpruner import Pruner

model = get_model('lenet', 'cifar10')

parameters_to_prune = []
for name, module in model.named_children():
    if hasattr(module, 'weight'):
        parameters_to_prune.append((module, 'weight'))

In [40]:
prune.global_unstructured(
    parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.4)

In [41]:
get_pruned_stats(model)

Name: conv1; Total: 1500; non-zero: 891; pruned: 609; percent: 40.6000%
Name: conv2; Total: 25000; non-zero: 10404; pruned: 14596; percent: 58.3840%
Name: fc1; Total: 625000; non-zero: 377911; pruned: 247089; percent: 39.5342%
Name: fc2; Total: 5000; non-zero: 4694; pruned: 306; percent: 6.1200%


# SGD CausalPruner with momentum

In [12]:
momentum = 0.9
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 500
total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [13]:
model1 = get_model('lenet', 'cifar10').to(device=device)

pbar = tqdm(total=total_epochs)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=True,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)
cp1.start_pruning()

optim_dir = os.path.join(checkpoint_dir, 'optim.pth')
torch.save(optimizer1.state_dict(), optim_dir)

for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
    optimizer1.load_state_dict(torch.load(optim_dir))

for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

 78%|███████▊  | 474/610 [11:28<02:58,  1.31s/it]

KeyboardInterrupt: 

In [None]:
get_pruned_stats(model1)

In [None]:
eval_model(model1)

In [None]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

In [None]:
eval_model(model2)

# SGD Pruner without momentum used with momentum

In [None]:
momentum = 0.9
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 500

total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [None]:
model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=False,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)
cp1.start_pruning()

optim_dir = os.path.join(checkpoint_dir, 'optim.pth')
torch.save(optimizer1.state_dict(), optim_dir)

for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
    optimizer1.load_state_dict(torch.load(optim_dir))

for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

In [None]:
get_pruned_stats(model1)

In [None]:
eval_model(model1)

In [None]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

In [None]:
eval_model(model2)