In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
module_path = os.path.abspath(os.path.join('../'))
sys.path.insert(0, module_path)

In [2]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from causalpruner import best_device, get_causal_pruner
from tests.models import get_model
from tests.datasets import get_dataset

In [10]:
root = '../data/'
train, test = get_dataset('cifar10', root)

batch_size = 4096
num_workers = 2

trainloader = DataLoader(
    train, batch_size=batch_size, shuffle=True, pin_memory=True,
    num_workers=num_workers)
testloader = DataLoader(
    test, batch_size=batch_size, shuffle=False, pin_memory=True,
    num_workers=num_workers)

Files already downloaded and verified
Data already transformed and saved
Files already downloaded and verified
Data already transformed and saved


In [4]:
def get_pruned_stats(model):
    for (name, param) in model.named_buffers():
        name = name.rstrip('.weight_mask')
        non_zero = torch.count_nonzero(param)
        total = torch.count_nonzero(torch.ones_like(param))
        pruned = total - non_zero
        frac = 100 * pruned / total
        print(f'Name: {name}; Total: {
            total}; non-zero: {non_zero}; pruned: {pruned}; percent: {frac:.4f}%')

In [5]:
def eval_model(model, device=best_device()):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            total += labels.size(0)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy: {accuracy}; Correct: {correct}; Total: {total}')

# SGD CausalPruner without momentum

In [11]:
momentum = 0
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 500

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir=checkpoint_dir,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)


pbar = tqdm(total=num_pre_prune_epochs
            + num_prune_iterations * num_prune_epochs +
            num_post_prune_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()
cp1.start_pruning()
for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

100%|██████████| 610/610 [14:23<00:00,  1.42s/it]


In [13]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 869; pruned: 631; percent: 42.0667%
Name: conv2; Total: 25000; non-zero: 15155; pruned: 9845; percent: 39.3800%
Name: fc1; Total: 625000; non-zero: 391333; pruned: 233667; percent: 37.3867%
Name: fc2; Total: 5000; non-zero: 2955; pruned: 2045; percent: 40.9000%


In [18]:
eval_model(model1)

Accuracy: 0.5009; Correct: 5009; Total: 10000


In [19]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

num_total_epochs = num_pre_prune_epochs + num_post_prune_epochs

pbar = tqdm(total=num_total_epochs)
for _ in range(num_total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

100%|██████████| 510/510 [11:22<00:00,  1.34s/it]


In [20]:
eval_model(model2)

Accuracy: 0.5169; Correct: 5169; Total: 10000


# SGD CausalPruner with momentum

In [21]:
momentum = 0.9
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 500

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir=checkpoint_dir,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)


pbar = tqdm(total=num_pre_prune_epochs
            + num_prune_iterations * num_prune_epochs +
            num_post_prune_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()
cp1.start_pruning()
for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

100%|██████████| 610/610 [13:58<00:00,  1.38s/it]


In [22]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 1499; pruned: 1; percent: 0.0667%
Name: conv2; Total: 25000; non-zero: 24925; pruned: 75; percent: 0.3000%
Name: fc1; Total: 625000; non-zero: 622417; pruned: 2583; percent: 0.4133%
Name: fc2; Total: 5000; non-zero: 4985; pruned: 15; percent: 0.3000%


In [23]:
eval_model(model1)

Accuracy: 0.6322; Correct: 6322; Total: 10000


In [24]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

num_total_epochs = num_pre_prune_epochs + num_post_prune_epochs

pbar = tqdm(total=num_total_epochs)
for _ in range(num_total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

100%|██████████| 510/510 [10:34<00:00,  1.24s/it]


In [25]:
eval_model(model2)

Accuracy: 0.6311; Correct: 6311; Total: 10000
