In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
module_path = os.path.abspath(os.path.join('../'))
sys.path.insert(0, module_path)

In [3]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from causalpruner import best_device, get_causal_pruner, get_sgd_pruner
from tests.models import get_model
from tests.datasets import get_dataset

In [4]:
root = '../data/'
train, test = get_dataset('cifar10', root)

batch_size = 4096
num_workers = 2

trainloader = DataLoader(
    train, batch_size=batch_size, shuffle=True, pin_memory=True,
    num_workers=num_workers)
testloader = DataLoader(
    test, batch_size=batch_size, shuffle=False, pin_memory=True,
    num_workers=num_workers)

Files already downloaded and verified
Data already transformed and saved
Files already downloaded and verified
Data already transformed and saved


In [5]:
def get_pruned_stats(model):
    for (name, param) in model.named_buffers():
        name = name.rstrip('.weight_mask')
        non_zero = torch.count_nonzero(param)
        total = torch.count_nonzero(torch.ones_like(param))
        pruned = total - non_zero
        frac = 100 * pruned / total
        print(f'Name: {name}; Total: {
            total}; non-zero: {non_zero}; pruned: {pruned}; percent: {frac:.4f}%')

In [6]:
def eval_model(model, device=best_device()):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            total += labels.size(0)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy: {accuracy}; Correct: {correct}; Total: {total}')

# SGD CausalPruner without momentum

In [9]:
momentum = 0
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 100
total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [11]:
model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=False,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)


pbar = tqdm(total=total_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()
cp1.start_pruning()
for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

 50%|█████     | 4/8 [00:49<00:49, 12.35s/it]
100%|██████████| 8/8 [00:13<00:00,  1.64s/it]


In [12]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 1298; pruned: 202; percent: 13.4667%
Name: conv2; Total: 25000; non-zero: 21847; pruned: 3153; percent: 12.6120%
Name: fc1; Total: 625000; non-zero: 533455; pruned: 91545; percent: 14.6472%
Name: fc2; Total: 5000; non-zero: 4374; pruned: 626; percent: 12.5200%


In [9]:
eval_model(model1)

Accuracy: 0.5014; Correct: 5014; Total: 10000


In [10]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

100%|██████████| 610/610 [13:44<00:00,  1.35s/it]


In [11]:
eval_model(model2)

Accuracy: 0.5372; Correct: 5372; Total: 10000


# SGD CausalPruner with momentum

In [13]:
momentum = 0.9
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 100
total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [14]:
model1 = get_model('lenet', 'cifar10').to(device=device)

pbar = tqdm(total=total_epochs)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=True,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)
cp1.start_pruning()

optim_dir = os.path.join(checkpoint_dir, 'optim.pth')
torch.save(optimizer1.state_dict(), optim_dir)

for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
    optimizer1.load_state_dict(torch.load(optim_dir))

for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

100%|██████████| 8/8 [00:14<00:00,  1.83s/it]


In [15]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 1497; pruned: 3; percent: 0.2000%
Name: conv2; Total: 25000; non-zero: 24988; pruned: 12; percent: 0.0480%
Name: fc1; Total: 625000; non-zero: 624347; pruned: 653; percent: 0.1045%
Name: fc2; Total: 5000; non-zero: 4999; pruned: 1; percent: 0.0200%


In [16]:
eval_model(model1)

Accuracy: 0.2885; Correct: 2885; Total: 10000


In [39]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

100%|██████████| 610/610 [12:48<00:00,  1.26s/it]


In [40]:
eval_model(model2)

Accuracy: 0.6369; Correct: 6369; Total: 10000


# SGD Pruner without momentum used with momentum

In [47]:
momentum = 0.9
pruner_lr = 1e-3
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

num_pre_prune_epochs = 10
num_prune_iterations = 10
num_prune_epochs = 10
num_post_prune_epochs = 500

total_epochs = num_pre_prune_epochs + num_prune_iterations * num_prune_epochs + num_post_prune_epochs

In [48]:
model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(num_pre_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir,
                        momentum=False,
                        pruner_lr=pruner_lr,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)
cp1.start_pruning()

optim_dir = os.path.join(checkpoint_dir, 'optim.pth')
torch.save(optimizer1.state_dict(), optim_dir)

for _ in range(num_prune_iterations):
    cp1.start_iteration()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            outputs1 = model1(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss1.backward()
            optimizer1.step()
    cp1.compute_masks()
    cp1.reset_weights()
    optimizer1.load_state_dict(torch.load(optim_dir))

for _ in range(num_post_prune_epochs):
    pbar.update(1)
    model1.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer1.zero_grad()
        outputs1 = model1(inputs)
        loss1 = F.cross_entropy(outputs1, labels)
        loss1.backward()
        optimizer1.step()

pbar.close()

  3%|▎         | 20/610 [00:25<13:10,  1.34s/it]

FileNotFoundError: [Errno 2] No such file or directory: '../checkpoints/optim.pth'

In [None]:
get_pruned_stats(model1)

Name: conv1; Total: 1500; non-zero: 873; pruned: 627; percent: 41.8000%
Name: conv2; Total: 25000; non-zero: 15118; pruned: 9882; percent: 39.5280%
Name: fc1; Total: 625000; non-zero: 390927; pruned: 234073; percent: 37.4517%
Name: fc2; Total: 5000; non-zero: 2957; pruned: 2043; percent: 40.8600%


In [None]:
eval_model(model1)

Accuracy: 0.6437; Correct: 6437; Total: 10000


In [None]:
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

pbar = tqdm(total=total_epochs)
for _ in range(total_epochs):
    pbar.update(1)
    model2.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer2.zero_grad()
        outputs2 = model2(inputs)
        loss2 = F.cross_entropy(outputs2, labels)
        loss2.backward()
        optimizer2.step()

pbar.close()

100%|██████████| 610/610 [12:58<00:00,  1.28s/it]


In [None]:
eval_model(model2)

Accuracy: 0.6396; Correct: 6396; Total: 10000
