In [33]:
%load_ext autoreload
%autoreload 2
import os
import sys
module_path = os.path.abspath(os.path.join('../'))
sys.path.insert(0, module_path)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from causalpruner import best_device, get_causal_pruner
from tests.models import get_model
from tests.datasets import get_dataset

In [35]:
root = '../data/'
train, test = get_dataset('cifar10', root)

batch_size = 4096
num_workers = 0

trainloader = DataLoader(
    train, batch_size=batch_size, shuffle=True, pin_memory=True,
    num_workers=num_workers)
testloader = DataLoader(
    test, batch_size=batch_size, shuffle=False, pin_memory=True,
    num_workers=num_workers)

Files already downloaded and verified
Data already transformed and saved
Files already downloaded and verified
Data already transformed and saved


# Online CausalPruner

In [6]:
# from torch.profiler import profile, record_function, ProfilerActivity

In [11]:
momentum = 0
prune_threshold = 1e-5
l1_regularization_coeff = 1e-5
num_epochs_batched = 16
causal_weights_num_epochs = 10
device = best_device()

model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

num_pre_prune_epochs = 1
num_prune_epochs = 1
num_post_prune_epochs = 1

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        num_epochs_batched=num_epochs_batched,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        device=device)


def run():
    pbar = tqdm(total=num_pre_prune_epochs + num_prune_epochs +
                num_post_prune_epochs)
    for _ in range(num_pre_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()
    cp1.compute_masks()
    for _ in range(num_post_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()

    pbar.close()

run()
# with profile(activities=[
#         ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
#     with record_function('cp'):
#         run()

100%|██████████| 3/3 [00:06<00:00,  2.01s/it]


In [12]:
for (name, param) in model1.named_buffers():
    name = name.rstrip('.weight_mask')
    non_zero = torch.count_nonzero(param)
    total = torch.count_nonzero(torch.ones_like(param))
    pruned = total - non_zero
    frac = 100 * pruned / total
    print(f'Name: {name}; Total: {
          total}; non-zero: {non_zero}; pruned: {pruned}; percent: {frac:.4f}%')

Name: conv1; Total: 1500; non-zero: 1498; pruned: 2; percent: 0.1333%
Name: conv2; Total: 25000; non-zero: 24708; pruned: 292; percent: 1.1680%
Name: fc1; Total: 625000; non-zero: 611260; pruned: 13740; percent: 2.1984%
Name: fc2; Total: 5000; non-zero: 4969; pruned: 31; percent: 0.6200%


In [13]:
model1.eval()
model2.eval()
total = 0
correct1 = 0
correct2 = 0

with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        total += labels.size(0)
        outputs1 = model1(inputs)
        outputs2 = model2(inputs)
        _, predicted1 = torch.max(outputs1.data, 1)
        _, predicted2 = torch.max(outputs2.data, 1)
        correct1 += (predicted1 == labels).sum().item()
        correct2 += (predicted2 == labels).sum().item()

pruned_accuracy = correct1 / total
non_pruned_accuracy = correct2 / total
print(f'Total: {total}; Pruned: {
      pruned_accuracy}; Non-pruned: {non_pruned_accuracy}')

Total: 10000; Pruned: 0.1792; Non-pruned: 0.1808


In [28]:
# print(prof.key_averages().table())

In [29]:
# prof.export_chrome_trace('trace.json')

# Checkpoint CausalPruner

In [36]:
from torch.profiler import profile, record_function, ProfilerActivity

In [56]:
momentum = 0.9
prune_threshold = 1e-5
l1_regularization_coeff = 5e-3
checkpoint_dir = '../checkpoints'
causal_weights_batch_size = 512
causal_weights_num_epochs = 256
device = best_device()

model1 = get_model('lenet', 'cifar10').to(device=device)
optimizer1 = optim.SGD(model1.parameters(), lr=1e-3, momentum=momentum)
model2 = get_model('lenet', 'cifar10').to(device=device)
optimizer2 = optim.SGD(model2.parameters(), lr=1e-3, momentum=momentum)

num_pre_prune_epochs = 10
num_prune_epochs = 20
num_post_prune_epochs = 20

cp1 = get_causal_pruner(model1,
                        optimizer1,
                        checkpoint_dir=checkpoint_dir,
                        prune_threshold=prune_threshold,
                        l1_regularization_coeff=l1_regularization_coeff,
                        causal_weights_batch_size=causal_weights_batch_size,
                        causal_weights_num_epochs=causal_weights_num_epochs,
                        start_clean=True,
                        device=device)


def run():
    pbar = tqdm(total=num_pre_prune_epochs + num_prune_epochs +
                num_post_prune_epochs)
    for _ in range(num_pre_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()
    for _ in range(num_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            cp1.provide_loss(loss1)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()
    cp1.compute_masks()
    for _ in range(num_post_prune_epochs):
        pbar.update(1)
        model1.train()
        model2.train()
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()
            outputs1 = model1(inputs)
            outputs2 = model2(inputs)
            loss1 = F.cross_entropy(outputs1, labels)
            loss2 = F.cross_entropy(outputs2, labels)
            loss1.backward()
            loss2.backward()
            optimizer1.step()
            optimizer2.step()

    pbar.close()

run()
# with profile(activities=[
#         ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
#     with record_function('cp'):
#         run()

100%|██████████| 50/50 [02:12<00:00,  2.65s/it]


In [58]:
for (name, param) in model1.named_buffers():
    name = name.rstrip('.weight_mask')
    non_zero = torch.count_nonzero(param)
    total = torch.count_nonzero(torch.ones_like(param))
    pruned = total - non_zero
    frac = 100 * pruned / total
    print(f'Name: {name}; Total: {
          total}; non-zero: {non_zero}; pruned: {pruned}; percent: {frac:.4f}%')

Name: conv1; Total: 1500; non-zero: 1396; pruned: 104; percent: 6.9333%
Name: conv2; Total: 25000; non-zero: 23348; pruned: 1652; percent: 6.6080%
Name: fc1; Total: 625000; non-zero: 571803; pruned: 53197; percent: 8.5115%
Name: fc2; Total: 5000; non-zero: 4662; pruned: 338; percent: 6.7600%


In [59]:
model1.eval()
model2.eval()
total = 0
correct1 = 0
correct2 = 0

with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        total += labels.size(0)
        outputs1 = model1(inputs)
        outputs2 = model2(inputs)
        _, predicted1 = torch.max(outputs1.data, 1)
        _, predicted2 = torch.max(outputs2.data, 1)
        correct1 += (predicted1 == labels).sum().item()
        correct2 += (predicted2 == labels).sum().item()

pruned_accuracy = correct1 / total
non_pruned_accuracy = correct2 / total
print(f'Total: {total}; Pruned: {
      pruned_accuracy}; Non-pruned: {non_pruned_accuracy}')

Total: 10000; Pruned: 0.4574; Non-pruned: 0.4661


In [47]:
# print(prof.key_averages().table())

In [46]:
# prof.export_chrome_trace('trace.json')