<a href="https://colab.research.google.com/github/panda1230/pytorch-cnn-visualizations/blob/master/EnergyConstrainedCompression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import argparse
import datetime
import numpy as np
import os
import math
import time
import torch
import random
import sys
import copy
from models import get_net_model
from proj_utils import fill_model_weights, layers_stat, model_sparsity, filtered_parameters, \
    l0proj, round_model_weights, clamp_model_weights
from sa_energy_model import build_energy_info, energy_eval2, energy_eval2_relax, energy_proj2, \
    reset_Xenergy_cache
from utils import get_data_loaders, joint_loss, eval_loss_acc1_acc5, model_snapshot

In [0]:
net='lenet-5'
dataset = 'mnist-32'
batch_size= 64
val_batch_size=64
num_workers =8 
epochs =30
lr_m = 0.001, #'learning rate'
xlr =1e-4 #'learning rate for input mask'
l2wd =1e-4 #'l2 weight decay'
xl2wd=1e-5 #l2 weight decay (for input mask)'
momentum =0.9 #momentum
proj_int=10 #'how many batches for each projection'
nodp = 1 #action='store_true', help='turn off dropout'
input_mask=1 # action='store_true', help='enable input mask'
randinit =1 #action='store_true', help='use random init'
eval =1 #action='store_true', help='evaluate testset in the begining')
#parser.add_argument('--pretrain', default=None, help='file to load pretrained model')
#parser.add_argument('--eval', action='store_true', help='evaluate testset in the begining')
#parser.add_argument('--seed', type=int, default=117, help='random seed')
log_interval =100 # help='how many batches to wait before logging training status')
test_interval = 1 # help='how many epochs to wait before another test')
save_interval=10# help='how many epochs to wait before save a model')
logdir = './sample_data/' #help='folder to save to the log'
distill=0.5 # help='distill loss weight')
budget=0.2 #help='energy budget (relative)')
exp_bdecay = 1 #action='store_true', help='exponential budget decay')
mgpu =1 # action='store_true', help='enable using multiple gpus')
skip1 = 1 #action='store_true', help='skip the first W update')
cuda = torch.cuda.is_available()

In [0]:
# set up random seeds
torch.manual_seed(117)
if cuda:
  torch.cuda.manual_seed(117)
  np.random.seed(117)
  random.seed(117)

In [0]:
# get training and validation data loaders
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.Resize(32),
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))])

trainset = torchvision.datasets.MNIST(root='./sample_data', train=True,
                                        download=True, transform=transform)
tr_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.MNIST(root='./sample_data', train=False,
                                       download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
                                         shuffle=False, num_workers=num_workers)

train_loader4eval = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers)

In [0]:
# get network model
model, teacher_model = get_net_model(net=net, pretrained_dataset=dataset, dropout=(not nodp),
                                         pretrained=not randinit, input_mask=input_mask)

In [6]:
model

MyLeNet5(
  (features): Sequential(
    (0): SparseConv2d(32, 32, 1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): SparseConv2d(14, 14, 6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [7]:
teacher_model

MyLeNet5(
  (features): Sequential(
    (0): FixHWConv2d(32, 32, 1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): FixHWConv2d(14, 14, 6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [8]:
 # for energy estimate
print('================model energy summary================')
energy_info = build_energy_info(model)
energy_estimator = lambda m: sum(energy_eval2(m, energy_info, verbose=False).values())
energy_estimator_relaxed = lambda m: sum(energy_eval2_relax(m, energy_info, verbose=False).values())

reset_Xenergy_cache(energy_info)
cur_energy = sum(energy_eval2(model, energy_info, verbose=True).values())
cur_energy_relaxed = energy_estimator_relaxed(model)

dense_model = fill_model_weights(copy.deepcopy(model), 1.0)
budget_ub = energy_estimator_relaxed(dense_model)
zero_model = fill_model_weights(copy.deepcopy(model), 0.0)
budget_lb = energy_estimator_relaxed(zero_model)

del zero_model, dense_model
budget = max(budget, budget_lb / budget_ub)

proj_func = lambda m, budget, grad=False, in_place=True: energy_proj2(m, energy_info, budget, grad=grad,
                                                                         in_place=in_place, param_name='weight')
print('energy on dense DNN:{:.4e}, on zero DNN:{:.4e}, normalized_lb={:.4e}'.format(budget_ub, budget_lb,
                                                                                        budget_lb / budget_ub))
print('energy on current DNN:{:.4e}, normalized={:.4e}'.format(cur_energy, cur_energy / budget_ub))
print('====================================================')
print('current energy {:.4e}, relaxed: {:.4e}'.format(cur_energy, cur_energy_relaxed))
netl2wd = l2wd

Layer: features.0, input shape: (1, 32, 32), output shape: (6, 28, 28), weight shape: torch.Size([6, 1, 5, 5])
Layer: features.0, W_energy=2.07e+05, C_energy=3.53e+05, X_energy=1.38e+06
Layer: features.3, input shape: (6, 14, 14), output shape: (16, 10, 10), weight shape: torch.Size([16, 6, 5, 5])
Layer: features.3, W_energy=8.50e+05, C_energy=7.20e+05, X_energy=9.75e+05
Layer: classifier.0, W_energy=9.94e+06, C_energy=1.44e+05, X_energy=1.74e+05
Layer: classifier.2, W_energy=2.09e+06, C_energy=3.02e+04, X_energy=5.52e+04
Layer: classifier.4, W_energy=1.74e+05, C_energy=2.52e+03, X_energy=2.01e+04
energy on dense DNN:1.7108e+07, on zero DNN:2.6049e+06, normalized_lb=1.5227e-01
energy on current DNN:1.7108e+07, normalized=1.0000e+00
current energy 1.7108e+07, relaxed: 1.7108e+07


In [0]:
if cuda:
  if distill > 0.0:
    teacher_model.cuda()
    model.cuda()
loss_func = lambda m, x, y: joint_loss(model=m, data=x, target=y, teacher_model=teacher_model, distill=distill)

In [10]:
if eval or dataset != 'imagenet':
  val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, cuda)
  print('**Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(val_loss, val_acc1,
                                                                                              val_acc5))
# also evaluate training data
  tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, cuda)
  print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1, tr_acc5))
else:
    val_acc1 = 0.0
    print('For imagenet, skip the first validation evaluation.')

**Validation loss:1.2966e+02, top-1 accuracy:0.05770, top-5 accuracy:0.49600
###Training loss:1.2851e+02, top-1 accuracy:0.06018, top-5 accuracy:0.49382


In [0]:
old_file = None

energy_step = math.ceil(max(0.0, cur_energy - budget * budget_ub) / ((len(tr_loader) * epochs) / proj_int))

energy_decay_factor = min(1.0, (budget * budget_ub) / cur_energy) ** \
                          (1.0 / ((len(tr_loader) * epochs) / proj_int))

optimizer = torch.optim.SGD(filtered_parameters(model, param_name='input_mask', inverse=True), lr=0.1, momentum=0.9, weight_decay=netl2wd)
if input_mask:
  Xoptimizer = torch.optim.Adam(filtered_parameters(model, param_name='input_mask', inverse=False), lr=0.01, weight_decay=xl2wd)

In [0]:
cur_budget = cur_energy_relaxed
lr = lr_m
xlr = xlr
cur_sparsity = model_sparsity(model)

best_acc_pruned = None
Xbudget = 0.9
iter_idx = 0

W_proj_time = 0.0
W_proj_time_cnt = 1e-15

In [0]:
while True:
        # update W
        if not (skip1 and iter_idx == 0):
            t_begin = time.time()
            log_tic = t_begin
            for epoch in range(epochs):
                for batch_idx, (data, target) in enumerate(tr_loader):
                    model.train()
                    if cuda:
                        data, target = data.cuda(), target.cuda()

                    loss = loss_func(model, data, target)
                    # update network weights
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    if proj_int == 1 or (batch_idx > 0 and batch_idx % args.proj_int == 0) or batch_idx == len(tr_loader) - 1:
                        temp_tic = time.time()
                        proj_func(model, cur_budget)
                        W_proj_time += time.time() - temp_tic
                        W_proj_time_cnt += 1
                        if epoch == epochs - 1 and batch_idx >= len(tr_loader) - 1 - proj_int:
                            cur_budget = budget * budget_ub
                        else:
                            if exp_bdecay:
                                cur_budget = max(cur_budget * energy_decay_factor, budget * budget_ub)
                            else:
                                cur_budget = max(cur_budget - energy_step, budget * budget_ub)
                    print(batch_idx)

                    if batch_idx % log_interval == 0:
                        print('======================================================')
                        print('+-------------- epoch {}, batch {}/{} ----------------+'.format(epoch, batch_idx,
                                                                                               len(tr_loader)))
                        log_toc = time.time()
                        print(
                            'primal update: net loss={:.4e}, lr={:.4e}, current normalized budget: {:.4e}, time_elapsed={:.3f}s, averaged projection_time {}'.format(
                                loss.item(), optimizer.param_groups[0]['lr'], cur_budget / budget_ub, log_toc - log_tic, W_proj_time / W_proj_time_cnt))
                        log_tic = time.time()
                        if batch_idx % proj_int == 0:
                            cur_sparsity = model_sparsity(model)
                        print('sparsity:{}'.format(cur_sparsity))
                        print(layers_stat(model, param_names='weight', param_filter=lambda p: p.dim() > 1))
                        print('+-----------------------------------------------------+')

                cur_energy = energy_estimator(model)
                cur_energy_relaxed = energy_estimator_relaxed(model)
                cur_sparsity = model_sparsity(model)
                if epoch % test_interval == 0:
                    val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, cuda)

                    # also evaluate training data
                    tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, cuda)
                    print('###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1,
                                                                                                         tr_acc5))

                    print(
                        '***Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}, current normalized energy:{:.4e}, {:.4e}(relaxed), sparsity: {:.4e}'.format(
                            val_loss, val_acc1,
                            val_acc5, cur_energy / budget_ub, cur_energy_relaxed / budget_ub, cur_sparsity))
                    # save current model
                    model_snapshot(model, os.path.join(logdir, 'primal_model_latest.pkl'))

                if save_interval > 0 and epoch % save_interval == 0:
                    model_snapshot(model, os.path.join(logdir, 'Wprimal_model_epoch{}_{}.pkl'.format(iter_idx, epoch)))

                elapse_time = time.time() - t_begin
                speed_epoch = elapse_time / (1 + epoch)
                eta = speed_epoch * (epochs - epoch)
                print("Updating Weights, Elapsed {:.2f}s, ets {:.2f}s".format(elapse_time, eta))

        if not input_mask:
            print("Complete weights training.")
            break
        else:
            print("Continue to train input mask.")

        if best_acc_pruned is not None and val_acc1 <= best_acc_pruned:
            print("Pruned accuracy does not improve, stop here!")
            break
        best_acc_pruned = val_acc1

        # update X
        t_begin = time.time()
        log_tic = t_begin
        for epoch in range(epochs):
            for batch_idx, (data, target) in enumerate(tr_loader):
                model.train()
                Xoptimizer.param_groups[0]['lr'] = xlr
                if cuda:
                    data, target = data.cuda(), target.cuda()

                loss = loss_func(model, data, target)
                # update network weights
                Xoptimizer.zero_grad()
                loss.backward()
                Xoptimizer.step()
                clamp_model_weights(model, min=0.0, max=1.0, param_name='input_mask')

                if (batch_idx > 0 and batch_idx % proj_int == 0) or batch_idx == len(tr_loader) - 1:
                    l0proj(model, Xbudget, param_name='input_mask')

                if batch_idx % log_interval == 0:
                    print('======================================================')
                    print('+-------------- epoch {}, batch {}/{} ----------------+'.format(epoch, batch_idx,
                                                                                           len(tr_loader)))
                    log_toc = time.time()
                    print('primal update: net loss={:.4e}, xlr={:.4e}, time_elapsed={:.3f}s'.format(
                            loss.item(), Xoptimizer.param_groups[0]['lr'], log_toc - log_tic))
                    log_tic = time.time()
                    if batch_idx % proj_int == 0:
                        cur_sparsity = model_sparsity(model, param_name='input_mask')
                    print('sparsity:{}'.format(cur_sparsity))
                    print(layers_stat(model, param_names='input_mask'))
                    print('+-----------------------------------------------------+')

            cur_energy = energy_estimator(model)
            cur_energy_relaxed = energy_estimator_relaxed(model)
            cur_sparsity = model_sparsity(model, param_name='input_mask')
            if epoch % test_interval == 0:

                val_loss, val_acc1, val_acc5 = eval_loss_acc1_acc5(model, val_loader, loss_func, cuda)

                # also evaluate training data
                tr_loss, tr_acc1, tr_acc5 = eval_loss_acc1_acc5(model, train_loader4eval, loss_func, cuda)
                print(
                    '###Training loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}'.format(tr_loss, tr_acc1,
                                                                                                   tr_acc5))

                print(
                    '***Validation loss:{:.4e}, top-1 accuracy:{:.5f}, top-5 accuracy:{:.5f}, current normalized energy:{:.4e}, {:.4e}(relaxed), sparsity: {:.4e}'.format(
                        val_loss, val_acc1,
                        val_acc5, cur_energy / budget_ub, cur_energy_relaxed / budget_ub, cur_sparsity))
                # save current model
                model_snapshot(model, os.path.join(logdir, 'primal_model_latest.pkl'))

            if save_interval > 0 and epoch % save_interval == 0:
                model_snapshot(model, os.path.join(logdir, 'Xprimal_model_epoch{}_{}.pkl'.format(iter_idx, epoch)))

            elapse_time = time.time() - t_begin
            speed_epoch = elapse_time / (1 + epoch)
            eta = speed_epoch * (epochs - epoch)
            print("Updating input mask, Elapsed {:.2f}s, ets {:.2f}s".format(elapse_time, eta))

        round_model_weights(model, param_name='input_mask')
        # refresh X_energy_cache
        reset_Xenergy_cache(energy_info)
        cur_energy = energy_estimator(model)
        cur_energy_relaxed = energy_estimator_relaxed(model)

        iter_idx += 1
        Xbudget -= 0.1

Continue to train input mask.
+-------------- epoch 0, batch 0/938 ----------------+
primal update: net loss=1.2599e+02, xlr=1.0000e-04, time_elapsed=0.390s
sparsity:1.0
########### layer stat ###########
          features.0abs(W): min=9.9990e-01, mean=9.9995e-01, max=1.0000e+00, nnz=1.0000
          features.3abs(W): min=9.9990e-01, mean=9.9995e-01, max=1.0000e+00, nnz=1.0000
########### layer stat ###########
+-----------------------------------------------------+
+-------------- epoch 0, batch 100/938 ----------------+
primal update: net loss=1.2957e+02, xlr=1.0000e-04, time_elapsed=1.254s
sparsity:0.9
########### layer stat ###########
          features.0abs(W): min=0.0000e+00, mean=9.3005e-01, max=1.0000e+00, nnz=0.9316
          features.3abs(W): min=0.0000e+00, mean=8.7073e-01, max=1.0000e+00, nnz=0.8724
########### layer stat ###########
+-----------------------------------------------------+
+-------------- epoch 0, batch 200/938 ----------------+
primal update: net loss=1.2