# Rewrite of `convexnn_pytorch_stepsize_fig.py` with Lottery
Borrows from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py

In [1]:
import copy
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random
import time
from tqdm.auto import tqdm, trange

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.init as init

from helperfunctions import *

In [2]:
from importlib import reload
import helperfunctions
reload(helperfunctions)
from helperfunctions import *

# Parameters and Args
I'm not using argparse in a notebook, it's gross. 

In [3]:
P = dict()
P['seed'] = 42        # Well we can tell who read Hitchhiker's Guide to the Galaxy lol
P['device'] = 'cuda'  # Or 'cpu'
P['verbose'] = True
P['P'] = 4096         # Number of hyperplane arrangements and number of neurons
P['num_neurons'] = P['P']
P["num_classes"] = 10
P["dim_in"] = 3*32*32
P['batch_size'] = 1000
P['beta'] = 1e-3      # Regularization parameter (in loss)
P['dir'] = os.path.abspath('')
P["print_freq"] = 5
P['device'] = 'cuda'

# Nonconvex (Regular) Args:
P['ncvx_solver'] = 'sgd'       # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['ncvx_schedule'] = 0         # learning rate schedule (0: Nothing, 1: ReduceLROnPlateau, 2: ExponentialLR)
P['ncvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['ncvx_num_epochs'] = 100
P["ncvx_learning_rate"] = 1e-3
P["ncvx_train_len"] = 50000
P["ncvx_test_len"] = 10000

P["ncvx_prune_epochs"] = 100 # 25
P["ncvx_prune_rounds"] = 5   # 5
P["ncvx_prune_perc"] = 80    # 20

# Convex Args:
P['cvx_solver'] = 'sgd'   # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['cvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['cvx_num_epochs'] = 100
P['cvx_learning_rate'] = 5e-7
P['cvx_rho'] = 1e-2
P['cvx_test_len'] = 10000

P["cvx_prune_epochs"] = 100 # 25
P["cvx_prune_rounds"] = 5   # 5
P["cvx_prune_perc"] = 80    # 20

# Set seed
random.seed(a=P['seed'])
np.random.seed(seed=P['seed'])
torch.manual_seed(seed=P['seed'])

<torch._C.Generator at 0x16b57e4c5f0>

# Load Data
Downloads CIFAR10 if not already downloaded.  

In [4]:
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])

train_dataset = datasets.CIFAR10(P['dir'], train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

test_dataset = datasets.CIFAR10(P['dir'], train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

# Extract the data via a dummy loader (dumps entire dataset at once)
dummy_loader= torch.utils.data.DataLoader(train_dataset, batch_size=50000, shuffle=False, pin_memory=True, sampler=None)
for A, y in dummy_loader:
    pass
Apatch=A.detach().clone() # Detaches from graph

A = A.view(A.shape[0], -1)
n,dim_in=A.size()

P["cvx_n"] = n

print("Apatch (Detached A) Shape:",Apatch.shape)
print("A shape:", A.shape)

Files already downloaded and verified
Files already downloaded and verified
Apatch (Detached A) Shape: torch.Size([50000, 3, 32, 32])
A shape: torch.Size([50000, 3072])


# Standard Non-Convex Network
Consists of typical 2-layer network definition, training and test loss, as well as training loop. 

In [5]:
class FCNetwork(nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(FCNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(input_dim, num_neurons, bias=False), nn.ReLU())
        self.layer2 = nn.Linear(num_neurons, num_classes, bias=False)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        out = self.layer2(self.layer1(x))
        return out
    
def save_model(model,path):
    torch.save(model.state_dict(),path)
    
def load_fc_model(path,P):
    model = FCNetwork(P["num_neurons"],P["num_classes"],P["dim_in"])
    model.load_state_dict(torch.load(path))
    return model

def loss_func_primal(yhat, y, model, beta):
    loss = 0.5 * torch.norm(yhat - y)**2
    # l2 norm on first layer weights, l1 squared norm on second layer
    for layer, p in enumerate(model.parameters()):
        if layer == 0:
            loss += beta/2 * torch.norm(p)**2
        else:
            loss += beta/2 * sum([torch.norm(p[:, j], 1)**2 for j in range(p.shape[1])])
    return loss

def validation_primal(model, testloader, beta, device):
    test_loss = 0
    test_correct = 0
    for ix, (_x, _y) in enumerate(testloader):
        _x = Variable(_x).float().to(device)
        _y = Variable(_y).float().to(device)
        #output = model.forward(_x) # Does this do anything?
        yhat = model(_x).float()
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, beta)
        test_loss += loss.item()
        test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()
    return test_loss, test_correct.item()

def ncvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-6
    device = P["device"]
    for ix, (_x, _y) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device) # shape 1000,3,32,32
        _y = Variable(_y).to(device) # shape 1000
        yhat = model(_x).float()
        
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, P["beta"])/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        d_out["times"].append(time.time())
    return ix

def ncvx_train(model, ds, ds_test, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["ncvx_prune_epochs"] if prune else P["ncvx_num_epochs"]
    rounds = P["ncvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
    scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
    
    d_out = {"losses":[], "accs":[], "losses_test":[],"accs_test":[], "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["ncvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
            scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
            
            print("\nPruning Round [{:>2}/{:}]".format(p,rounds))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = ncvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at = validation_primal(model, ds_test, P["beta"], device)
            d_out["losses_test"] += [lt/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], loss: {:.3f} acc: {:.3f}, TEST loss: {:.3f} test acc: {:.3f}".format(
                       i,num_epochs,d_out["losses"][-1],d_out["accs"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))

            if P["ncvx_schedule"] > 0:
                scheduler.step(d_out["losses"][-1])
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Nonconvex (Regular) Training

In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=P["batch_size"], shuffle=True, pin_memory=True, sampler=None)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)
print_params(P,True,True)

Parameter          : Value
seed               : 42
device             : cuda
verbose            : True
P                  : 4096
num_neurons        : 4096
num_classes        : 10
dim_in             : 3072
batch_size         : 1000
beta               : 0.001
dir                : C:\Users\trevo\Documents\Repos\spring22\convex_nn
print_freq         : 5
ncvx_solver        : sgd
ncvx_schedule      : 0
ncvx_LBFGS_param   : (10, 4)
ncvx_num_epochs    : 100
ncvx_learning_rate : 0.001
ncvx_train_len     : 50000
ncvx_test_len      : 10000
ncvx_prune_epochs  : 100
ncvx_prune_rounds  : 5
ncvx_prune_perc    : 80


In [7]:
ncvx_save_loc = "models/ncvx_nn{:}_solver{:}_l1e-3".format(P['num_neurons'],P['cvx_solver'])

model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
model.apply(weight_init)
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,ncvx_save_loc+"_INITIAL.pth")

# Initial training,
results_ncvx = ncvx_train(model, train_loader, test_loader, P, prune=False)
results_ncvx.to_csv(ncvx_save_loc+"_EPOCHS{:}_Results.csv".format(P["ncvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,ncvx_save_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.785 acc: 0.220, TEST loss: 0.779 test acc: 0.237
Epoch [ 5/100], loss: 0.496 acc: 0.398, TEST loss: 0.549 test acc: 0.346
Epoch [10/100], loss: 0.441 acc: 0.460, TEST loss: 0.502 test acc: 0.376
Epoch [15/100], loss: 0.412 acc: 0.483, TEST loss: 0.475 test acc: 0.405
Epoch [20/100], loss: 0.369 acc: 0.571, TEST loss: 0.459 test acc: 0.414
Epoch [25/100], loss: 0.355 acc: 0.563, TEST loss: 0.447 test acc: 0.424
Epoch [30/100], loss: 0.340 acc: 0.633, TEST loss: 0.438 test acc: 0.430
Epoch [35/100], loss: 0.321 acc: 0.644, TEST loss: 0.432 test acc: 0.439
Epoch [40/100], loss: 0.317 acc: 0.657, TEST loss: 0.425 test acc: 0.446
Epoch [45/100], loss: 0.307 acc: 0.685, TEST loss: 0.422 test acc: 0.451
Epoch [50/100], loss: 0.290 acc: 0.694, TEST loss: 0.418 test acc: 0.455
Epoch [55/100], loss: 0.287 acc: 0.720, TEST loss: 0.415 test acc: 0.459
Epoch [60/100], loss: 0.284 acc: 0.725, TEST loss: 0.413 test acc: 0.461
Epoch [65/100], loss: 0.280 acc: 0.732, TEST loss: 

### Pruning

In [7]:
ncvx_save_loc = "models/ncvx_5at80_solver{:}_l1e-3".format(P['cvx_solver'])
ncvx_load_loc = "models/ncvx_nn{:}_solver{:}_l1e-3".format(P['num_neurons'],P['cvx_solver'])
# Load from file just in case
model_init = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(ncvx_load_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(ncvx_load_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
# Use init_state_dict=initial_state_dict_post for iterative, initial_state_dict for lottery - iterative is more fair comp
prune_results_ncvx = ncvx_train(model, train_loader, test_loader, P, 
                                prune=True, re_init=False, init_state_dict=initial_state_dict_post, mask=mask)
# TODO - add pruning percentage to name here
prune_results_ncvx.to_csv(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_POST_Results.csv".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))


Pruning Round [ 0/5]
layer1.0.weight      | nonzeros = 2516583 / 12582912 ( 20.00%) | total_pruned = 10066329 | shape = (4096, 3072)
layer2.weight        | nonzeros =    8192 /   40960 ( 20.00%) | total_pruned =   32768 | shape = (10, 4096)
alive: 2524775, pruned : 10099097, total: 12623872, Compression rate :       5.00x  ( 80.00% pruned)


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.471 acc: 0.371, TEST loss: 0.518 test acc: 0.292
Epoch [ 5/100], loss: 0.421 acc: 0.429, TEST loss: 0.453 test acc: 0.365
Epoch [10/100], loss: 0.396 acc: 0.448, TEST loss: 0.437 test acc: 0.388
Epoch [15/100], loss: 0.386 acc: 0.511, TEST loss: 0.429 test acc: 0.399
Epoch [20/100], loss: 0.378 acc: 0.516, TEST loss: 0.423 test acc: 0.406
Epoch [25/100], loss: 0.378 acc: 0.505, TEST loss: 0.418 test acc: 0.414
Epoch [30/100], loss: 0.374 acc: 0.527, TEST loss: 0.415 test acc: 0.421
Epoch [35/100], loss: 0.362 acc: 0.536, TEST loss: 0.412 test acc: 0.426
Epoch [40/100], loss: 0.367 acc: 0.530, TEST loss: 0.410 test acc: 0.432
Epoch [45/100], loss: 0.361 acc: 0.524, TEST loss: 0.408 test acc: 0.435
Epoch [50/100], loss: 0.350 acc: 0.558, TEST loss: 0.405 test acc: 0.439
Epoch [55/100], loss: 0.353 acc: 0.546, TEST loss: 0.404 test acc: 0.440
Epoch [60/100], loss: 0.353 acc: 0.548, TEST loss: 0.403 test acc: 0.442
Epoch [65/100], loss: 0.352 acc: 0.558, TEST loss: 

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.443 acc: 0.240, TEST loss: 0.452 test acc: 0.226
Epoch [ 5/100], loss: 0.415 acc: 0.328, TEST loss: 0.418 test acc: 0.325
Epoch [10/100], loss: 0.404 acc: 0.359, TEST loss: 0.410 test acc: 0.347
Epoch [15/100], loss: 0.401 acc: 0.368, TEST loss: 0.405 test acc: 0.361
Epoch [20/100], loss: 0.400 acc: 0.382, TEST loss: 0.402 test acc: 0.370
Epoch [25/100], loss: 0.392 acc: 0.385, TEST loss: 0.399 test acc: 0.380
Epoch [30/100], loss: 0.396 acc: 0.374, TEST loss: 0.397 test acc: 0.386
Epoch [35/100], loss: 0.392 acc: 0.406, TEST loss: 0.396 test acc: 0.389
Epoch [40/100], loss: 0.392 acc: 0.387, TEST loss: 0.394 test acc: 0.393
Epoch [45/100], loss: 0.394 acc: 0.388, TEST loss: 0.393 test acc: 0.398
Epoch [50/100], loss: 0.389 acc: 0.397, TEST loss: 0.392 test acc: 0.399
Epoch [55/100], loss: 0.391 acc: 0.399, TEST loss: 0.391 test acc: 0.401
Epoch [60/100], loss: 0.391 acc: 0.402, TEST loss: 0.390 test acc: 0.403
Epoch [65/100], loss: 0.380 acc: 0.439, TEST loss: 

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.445 acc: 0.284, TEST loss: 0.445 test acc: 0.282
Epoch [ 5/100], loss: 0.426 acc: 0.313, TEST loss: 0.426 test acc: 0.312
Epoch [10/100], loss: 0.419 acc: 0.312, TEST loss: 0.418 test acc: 0.324
Epoch [15/100], loss: 0.416 acc: 0.334, TEST loss: 0.414 test acc: 0.335
Epoch [20/100], loss: 0.406 acc: 0.370, TEST loss: 0.411 test acc: 0.343
Epoch [25/100], loss: 0.407 acc: 0.338, TEST loss: 0.409 test acc: 0.345
Epoch [30/100], loss: 0.409 acc: 0.339, TEST loss: 0.407 test acc: 0.348
Epoch [35/100], loss: 0.402 acc: 0.367, TEST loss: 0.405 test acc: 0.351
Epoch [40/100], loss: 0.407 acc: 0.335, TEST loss: 0.404 test acc: 0.354
Epoch [45/100], loss: 0.407 acc: 0.341, TEST loss: 0.403 test acc: 0.356
Epoch [50/100], loss: 0.397 acc: 0.366, TEST loss: 0.402 test acc: 0.357
Epoch [55/100], loss: 0.398 acc: 0.375, TEST loss: 0.402 test acc: 0.358
Epoch [60/100], loss: 0.402 acc: 0.343, TEST loss: 0.401 test acc: 0.360
Epoch [65/100], loss: 0.406 acc: 0.364, TEST loss: 

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.483 acc: 0.252, TEST loss: 0.483 test acc: 0.226
Epoch [ 5/100], loss: 0.473 acc: 0.263, TEST loss: 0.473 test acc: 0.240
Epoch [10/100], loss: 0.461 acc: 0.256, TEST loss: 0.462 test acc: 0.230
Epoch [15/100], loss: 0.453 acc: 0.243, TEST loss: 0.454 test acc: 0.228
Epoch [20/100], loss: 0.444 acc: 0.261, TEST loss: 0.448 test acc: 0.234
Epoch [25/100], loss: 0.442 acc: 0.245, TEST loss: 0.444 test acc: 0.241
Epoch [30/100], loss: 0.437 acc: 0.281, TEST loss: 0.441 test acc: 0.249
Epoch [35/100], loss: 0.435 acc: 0.271, TEST loss: 0.438 test acc: 0.255
Epoch [40/100], loss: 0.436 acc: 0.269, TEST loss: 0.436 test acc: 0.263
Epoch [45/100], loss: 0.433 acc: 0.280, TEST loss: 0.435 test acc: 0.269
Epoch [50/100], loss: 0.433 acc: 0.274, TEST loss: 0.433 test acc: 0.274
Epoch [55/100], loss: 0.429 acc: 0.275, TEST loss: 0.432 test acc: 0.277
Epoch [60/100], loss: 0.430 acc: 0.264, TEST loss: 0.431 test acc: 0.279
Epoch [65/100], loss: 0.429 acc: 0.308, TEST loss: 

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.497 acc: 0.191, TEST loss: 0.496 test acc: 0.202
Epoch [ 5/100], loss: 0.493 acc: 0.211, TEST loss: 0.493 test acc: 0.204
Epoch [10/100], loss: 0.489 acc: 0.183, TEST loss: 0.488 test acc: 0.202
Epoch [15/100], loss: 0.483 acc: 0.206, TEST loss: 0.482 test acc: 0.197
Epoch [20/100], loss: 0.474 acc: 0.204, TEST loss: 0.476 test acc: 0.194
Epoch [25/100], loss: 0.473 acc: 0.195, TEST loss: 0.472 test acc: 0.199
Epoch [30/100], loss: 0.470 acc: 0.199, TEST loss: 0.469 test acc: 0.202
Epoch [35/100], loss: 0.470 acc: 0.199, TEST loss: 0.467 test acc: 0.206
Epoch [40/100], loss: 0.464 acc: 0.213, TEST loss: 0.466 test acc: 0.210
Epoch [45/100], loss: 0.466 acc: 0.201, TEST loss: 0.464 test acc: 0.214
Epoch [50/100], loss: 0.461 acc: 0.232, TEST loss: 0.463 test acc: 0.217
Epoch [55/100], loss: 0.463 acc: 0.217, TEST loss: 0.462 test acc: 0.218
Epoch [60/100], loss: 0.463 acc: 0.217, TEST loss: 0.462 test acc: 0.221
Epoch [65/100], loss: 0.462 acc: 0.239, TEST loss: 

In [8]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_POST.pth".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))

with open(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_POST_MASK.pkl".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Convex Network


In [9]:
class custom_cvx_layer(torch.nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(custom_cvx_layer, self).__init__()
        
        # (num_neurons) P x (input_dim) d x (num_classes) C
        self.weight_v = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)
        self.weight_w = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)

    def forward(self, x, sign_patterns):
        sign_patterns = sign_patterns.unsqueeze(2)
        x = x.view(x.shape[0], -1) # n x d
        
        Xv_w = torch.matmul(x, self.weight_v - self.weight_w) # P x N x C
        
        # for some reason, the permutation is necessary. not sure why
        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C
        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C
        
        return y_pred
    
def get_nonconvex_cost(y, model, _x, beta, device):
    _x = _x.view(_x.shape[0], -1)
    Xv = torch.matmul(_x, model.weight_v)
    Xw = torch.matmul(_x, model.weight_w)
    Xv_relu = torch.max(Xv, torch.Tensor([0]).to(device))
    Xw_relu = torch.max(Xw, torch.Tensor([0]).to(device))
    
    prediction_w_relu = torch.sum(Xv_relu - Xw_relu, dim=0, keepdim=False)
    prediction_cost = 0.5 * torch.norm(prediction_w_relu - y)**2
    regularization_cost = beta * (torch.sum(torch.norm(model.weight_v, dim=1)**2) + torch.sum(torch.norm(model.weight_w, p=1, dim=1)**2))
    return prediction_cost + regularization_cost

def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device):
    _x = _x.view(_x.shape[0], -1)
    # term 1
    loss = 0.5 * torch.norm(yhat - y)**2
    # term 2
    loss = loss + beta * torch.sum(torch.norm(model.weight_v, dim=1))
    loss = loss + beta * torch.sum(torch.norm(model.weight_w, dim=1))
    # term 3
    sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1
    
    Xv = torch.matmul(_x, torch.sum(model.weight_v, dim=2, keepdim=True)) # N x d times P x d x 1 -> P x N x 1
    DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1
    relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_v)
    
    Xw = torch.matmul(_x, torch.sum(model.weight_w, dim=2, keepdim=True))
    DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))
    relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_w)
    return loss

def validation_cvxproblem(model, testloader, u_vectors, beta, rho, device):
    test_loss = 0
    test_correct = 0
    test_noncvx_cost = 0

    with torch.no_grad():
        for ix, (_x, _y) in enumerate(testloader):
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            _x = _x.view(_x.shape[0], -1)
            _z = (torch.matmul(_x, torch.from_numpy(u_vectors).float().to(device)) >= 0)

            output = model.forward(_x, _z)
            yhat = model(_x, _z).float()

            loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x, _z, beta, rho, device)

            test_loss += loss.item()
            test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()

            test_noncvx_cost += get_nonconvex_cost(one_hot(_y).to(device), model, _x, beta, device)

    return test_loss, test_correct.item(), test_noncvx_cost.item()

In [10]:
def cvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-12
    device = P["device"]
    for ix, (_x, _y, _z) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device)
        _y = Variable(_y).to(device)
        _z = Variable(_z).to(device)
        yhat = model(_x, _z).float()
        
        loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x,_z, P["beta"], P["cvx_rho"], device)/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        
        noncvx_loss = get_nonconvex_cost(one_hot(_y).to(device), model, _x, P["beta"], device)/len(_y)
        d_out["noncvx_losses"].append(noncvx_loss.item())
        d_out["times"].append(time.time())
    return ix


def cvx_train(model, ds, ds_test, u_vectors, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["cvx_prune_epochs"] if prune else P["cvx_num_epochs"]
    rounds = P["cvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
    
    d_out = {"losses":[], "accs":[], "noncvx_losses": [], "losses_test":[],"accs_test":[], "noncvx_losses_test":[], 
             "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["cvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
            
            print("Pruning Round [{:>2}/{:}]".format(p,rounds))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = cvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at,nlt = validation_cvxproblem(model, ds_test, u_vectors, P["beta"], P["cvx_rho"], device)
            d_out["losses_test"] += [lt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["cvx_test_len"]]*(train_iters + 1)
            d_out["noncvx_losses_test"] += [nlt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}, TEST noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}".format(
                       i,num_epochs,d_out["noncvx_losses"][-1],d_out["losses"][-1],d_out["accs"][-1],
                       d_out["noncvx_losses_test"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))
            scheduler.step(d_out["losses"][-1])
            
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Convex Training

In [11]:
def generate_conv_sign_patterns(A2, P, verbose=False): 
    # generate convolutional sign patterns
    n, c, p1, p2 = A2.shape
    A = A2.reshape(n,int(c*p1*p2))
    fsize=9*c
    d=c*p1*p2;
    fs=int(np.sqrt(9))
    unique_sign_pattern_list = []  
    u_vector_list = []             

    for i in range(P): 
        # obtain a sign pattern
        ind1=np.random.randint(0,p1-fs+1)
        ind2=np.random.randint(0,p2-fs+1)
        u1p= np.zeros((c,p1,p2))
        u1p[:,ind1:ind1+fs,ind2:ind2+fs]=np.random.normal(0, 1, (fsize,1)).reshape(c,fs,fs)
        u1=u1p.reshape(d,1)
        sampled_sign_pattern = (np.matmul(A, u1) >= 0)[:,0]
        unique_sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(u1)

    if verbose:
        print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list)))
    return len(unique_sign_pattern_list),unique_sign_pattern_list, u_vector_list

def generate_sign_patterns(A, P, verbose=False):
    # generate sign patterns
    n, d = A.shape
    sign_pattern_list = []  # sign patterns
    u_vector_list = []             # random vectors used to generate the sign paterns
    umat = np.random.normal(0, 1, (d,P))
    sampled_sign_pattern_mat = (np.matmul(A, umat) >= 0)
    for i in range(P):
        sampled_sign_pattern = sampled_sign_pattern_mat[:,i]
        sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(umat[:,i])
    if verbose:
        print("Number of sign patterns generated: " + str(len(sign_pattern_list)))
    return len(sign_pattern_list),sign_pattern_list, u_vector_list

# Generate sign patterns for convex network
num_neurons,sign_pattern_list, u_vector_list = generate_sign_patterns(A, P["P"], P["verbose"])
sign_patterns = np.array([sign_pattern_list[i].int().data.numpy() for i in range(num_neurons)])
u_vectors = np.asarray(u_vector_list).reshape((num_neurons, A.shape[1])).T

ds_train = PrepareData3D(X=A, y=y, z=sign_patterns.T)
ds_train = DataLoader(ds_train, batch_size=P["batch_size"], shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)

print_params(P,True,False,True)

Number of sign patterns generated: 4096
Parameter         : Value
seed              : 42
device            : cuda
verbose           : True
P                 : 4096
num_neurons       : 4096
num_classes       : 10
dim_in            : 3072
batch_size        : 1000
beta              : 0.001
dir               : C:\Users\trevo\Documents\Repos\spring22\convex_nn
print_freq        : 5
cvx_solver        : sgd
cvx_LBFGS_param   : (10, 4)
cvx_num_epochs    : 100
cvx_learning_rate : 5e-07
cvx_rho           : 0.01
cvx_test_len      : 10000
cvx_prune_epochs  : 100
cvx_prune_rounds  : 5
cvx_prune_perc    : 80
cvx_n             : 50000


In [9]:
cvx_load_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
cvx_save_loc = "models/cvx_5at80_solver{:}_lr5e-7".format(P['cvx_solver'])

model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,cvx_load_loc+"_INITIAL.pth")

# Initial training,
results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, prune=False)
results_cvx.to_csv(cvx_load_loc+"_EPOCHS{:}_Results.csv".format(P["cvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,cvx_save_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.408 loss: 0.366 acc: 0.477, TEST noncvx_loss: 0.411 loss: 0.366 acc: 0.467
Epoch [ 5/100], noncvx_loss: 0.431 loss: 0.296 acc: 0.636, TEST noncvx_loss: 0.436 loss: 0.342 acc: 0.519
Epoch [10/100], noncvx_loss: 0.436 loss: 0.258 acc: 0.765, TEST noncvx_loss: 0.443 loss: 0.338 acc: 0.529
Epoch [15/100], noncvx_loss: 0.435 loss: 0.232 acc: 0.835, TEST noncvx_loss: 0.445 loss: 0.335 acc: 0.539
Epoch [20/100], noncvx_loss: 0.438 loss: 0.215 acc: 0.867, TEST noncvx_loss: 0.450 loss: 0.338 acc: 0.544
Epoch [25/100], noncvx_loss: 0.440 loss: 0.200 acc: 0.886, TEST noncvx_loss: 0.451 loss: 0.337 acc: 0.543
Epoch [30/100], noncvx_loss: 0.439 loss: 0.184 acc: 0.934, TEST noncvx_loss: 0.452 loss: 0.339 acc: 0.541
Epoch [35/100], noncvx_loss: 0.437 loss: 0.169 acc: 0.938, TEST noncvx_loss: 0.452 loss: 0.340 acc: 0.541
Epoch [40/100], noncvx_loss: 0.440 loss: 0.159 acc: 0.952, TEST noncvx_loss: 0.452 loss: 0.341 acc: 0.541
Epoch [45/100], noncvx_loss: 0.443 loss: 0.154

### Pruning

In [12]:
cvx_load_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
cvx_save_loc = "models/cvx_5at80_solver{:}_lr5e-7".format(P['cvx_solver'])

# Load from file just in case
model_init = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(cvx_load_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(cvx_load_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
prune_results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, 
                              prune=True, re_init=False, init_state_dict=initial_state_dict_post, mask=mask)
prune_results_cvx.to_csv(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_Results.csv".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

Pruning Round [ 0/5]
weight_v             | nonzeros = 25165824 / 125829120 ( 20.00%) | total_pruned = 100663296 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 25165829 / 125829120 ( 20.00%) | total_pruned = 100663291 | shape = (4096, 3072, 10)
alive: 50331653, pruned : 201326587, total: 251658240, Compression rate :       5.00x  ( -5.33% pruned)


  print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.441 loss: 0.162 acc: 0.992, TEST noncvx_loss: 0.452 loss: 0.357 acc: 0.537
Epoch [ 5/100], noncvx_loss: 0.438 loss: 0.146 acc: 0.984, TEST noncvx_loss: 0.449 loss: 0.345 acc: 0.545
Epoch [10/100], noncvx_loss: 0.434 loss: 0.143 acc: 0.985, TEST noncvx_loss: 0.447 loss: 0.343 acc: 0.547
Epoch [15/100], noncvx_loss: 0.438 loss: 0.146 acc: 0.987, TEST noncvx_loss: 0.446 loss: 0.342 acc: 0.545
Epoch [20/100], noncvx_loss: 0.433 loss: 0.142 acc: 0.987, TEST noncvx_loss: 0.445 loss: 0.341 acc: 0.545
Epoch [25/100], noncvx_loss: 0.435 loss: 0.143 acc: 0.987, TEST noncvx_loss: 0.445 loss: 0.341 acc: 0.546
Epoch [30/100], noncvx_loss: 0.434 loss: 0.141 acc: 0.986, TEST noncvx_loss: 0.443 loss: 0.340 acc: 0.546
Epoch [35/100], noncvx_loss: 0.433 loss: 0.138 acc: 0.984, TEST noncvx_loss: 0.443 loss: 0.340 acc: 0.547
Epoch [40/100], noncvx_loss: 0.430 loss: 0.138 acc: 0.982, TEST noncvx_loss: 0.443 loss: 0.340 acc: 0.548
Epoch [45/100], noncvx_loss: 0.430 loss: 0.140

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.445 loss: 0.328 acc: 0.864, TEST noncvx_loss: 0.446 loss: 0.400 acc: 0.517
Epoch [ 5/100], noncvx_loss: 0.433 loss: 0.301 acc: 0.826, TEST noncvx_loss: 0.439 loss: 0.378 acc: 0.518
Epoch [10/100], noncvx_loss: 0.429 loss: 0.297 acc: 0.801, TEST noncvx_loss: 0.435 loss: 0.368 acc: 0.518
Epoch [15/100], noncvx_loss: 0.426 loss: 0.291 acc: 0.798, TEST noncvx_loss: 0.432 loss: 0.364 acc: 0.520
Epoch [20/100], noncvx_loss: 0.423 loss: 0.290 acc: 0.774, TEST noncvx_loss: 0.430 loss: 0.361 acc: 0.521
Epoch [25/100], noncvx_loss: 0.422 loss: 0.289 acc: 0.787, TEST noncvx_loss: 0.429 loss: 0.360 acc: 0.521
Epoch [30/100], noncvx_loss: 0.420 loss: 0.289 acc: 0.791, TEST noncvx_loss: 0.428 loss: 0.359 acc: 0.523
Epoch [35/100], noncvx_loss: 0.421 loss: 0.282 acc: 0.802, TEST noncvx_loss: 0.427 loss: 0.358 acc: 0.524
Epoch [40/100], noncvx_loss: 0.418 loss: 0.284 acc: 0.793, TEST noncvx_loss: 0.426 loss: 0.357 acc: 0.525
Epoch    44: reducing learning rate of group 0

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.449 loss: 0.386 acc: 0.614, TEST noncvx_loss: 0.450 loss: 0.405 acc: 0.463
Epoch [ 5/100], noncvx_loss: 0.451 loss: 0.368 acc: 0.568, TEST noncvx_loss: 0.451 loss: 0.387 acc: 0.467
Epoch [10/100], noncvx_loss: 0.449 loss: 0.359 acc: 0.575, TEST noncvx_loss: 0.450 loss: 0.383 acc: 0.476
Epoch [15/100], noncvx_loss: 0.447 loss: 0.351 acc: 0.617, TEST noncvx_loss: 0.449 loss: 0.381 acc: 0.480
Epoch [20/100], noncvx_loss: 0.442 loss: 0.350 acc: 0.606, TEST noncvx_loss: 0.449 loss: 0.380 acc: 0.484
Epoch [25/100], noncvx_loss: 0.443 loss: 0.349 acc: 0.579, TEST noncvx_loss: 0.448 loss: 0.379 acc: 0.488
Epoch [30/100], noncvx_loss: 0.446 loss: 0.348 acc: 0.602, TEST noncvx_loss: 0.447 loss: 0.378 acc: 0.489
Epoch [35/100], noncvx_loss: 0.440 loss: 0.339 acc: 0.634, TEST noncvx_loss: 0.447 loss: 0.378 acc: 0.492
Epoch [40/100], noncvx_loss: 0.438 loss: 0.343 acc: 0.605, TEST noncvx_loss: 0.446 loss: 0.378 acc: 0.494
Epoch [45/100], noncvx_loss: 0.440 loss: 0.345

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.467 loss: 0.442 acc: 0.458, TEST noncvx_loss: 0.465 loss: 0.445 acc: 0.417
Epoch [ 5/100], noncvx_loss: 0.453 loss: 0.408 acc: 0.473, TEST noncvx_loss: 0.453 loss: 0.413 acc: 0.414
Epoch [10/100], noncvx_loss: 0.453 loss: 0.393 acc: 0.477, TEST noncvx_loss: 0.453 loss: 0.402 acc: 0.426
Epoch [15/100], noncvx_loss: 0.447 loss: 0.380 acc: 0.498, TEST noncvx_loss: 0.454 loss: 0.396 acc: 0.436
Epoch [20/100], noncvx_loss: 0.453 loss: 0.386 acc: 0.468, TEST noncvx_loss: 0.454 loss: 0.392 acc: 0.444
Epoch [25/100], noncvx_loss: 0.452 loss: 0.372 acc: 0.522, TEST noncvx_loss: 0.454 loss: 0.389 acc: 0.449
Epoch [30/100], noncvx_loss: 0.457 loss: 0.376 acc: 0.509, TEST noncvx_loss: 0.455 loss: 0.388 acc: 0.451
Epoch [35/100], noncvx_loss: 0.450 loss: 0.373 acc: 0.480, TEST noncvx_loss: 0.454 loss: 0.386 acc: 0.457
Epoch [40/100], noncvx_loss: 0.454 loss: 0.374 acc: 0.498, TEST noncvx_loss: 0.454 loss: 0.385 acc: 0.459
Epoch [45/100], noncvx_loss: 0.455 loss: 0.371

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.488 loss: 0.478 acc: 0.408, TEST noncvx_loss: 0.488 loss: 0.479 acc: 0.377
Epoch [ 5/100], noncvx_loss: 0.467 loss: 0.448 acc: 0.373, TEST noncvx_loss: 0.466 loss: 0.449 acc: 0.355
Epoch [10/100], noncvx_loss: 0.457 loss: 0.429 acc: 0.379, TEST noncvx_loss: 0.458 loss: 0.434 acc: 0.361
Epoch [15/100], noncvx_loss: 0.455 loss: 0.419 acc: 0.398, TEST noncvx_loss: 0.455 loss: 0.425 acc: 0.369
Epoch [20/100], noncvx_loss: 0.455 loss: 0.416 acc: 0.392, TEST noncvx_loss: 0.454 loss: 0.419 acc: 0.377
Epoch [25/100], noncvx_loss: 0.455 loss: 0.414 acc: 0.404, TEST noncvx_loss: 0.453 loss: 0.415 acc: 0.385
Epoch [30/100], noncvx_loss: 0.456 loss: 0.411 acc: 0.395, TEST noncvx_loss: 0.453 loss: 0.411 acc: 0.389
Epoch [35/100], noncvx_loss: 0.451 loss: 0.403 acc: 0.399, TEST noncvx_loss: 0.453 loss: 0.409 acc: 0.393
Epoch [40/100], noncvx_loss: 0.454 loss: 0.399 acc: 0.430, TEST noncvx_loss: 0.453 loss: 0.407 acc: 0.396
Epoch [45/100], noncvx_loss: 0.454 loss: 0.398

In [14]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}.pth".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

with open(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_MASK.pkl".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Save Parameters 

In [15]:
param_save_loc = 'models/5at80_100EPOCHS_solver{:}.json'.format(P['num_neurons'],P['cvx_solver'])
save_params(P,param_save_loc)