# Rewrite of `convexnn_pytorch_stepsize_fig.py` with Lottery
Borrows from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py

In [1]:
import copy
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random
import time
from tqdm.auto import tqdm, trange

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.init as init

from helperfunctions import *

In [2]:
from importlib import reload
import helperfunctions
reload(helperfunctions)
from helperfunctions import *

# Parameters and Args
I'm not using argparse in a notebook, it's gross. 

In [3]:
P = dict()
P['seed'] = 42        # Well we can tell who read Hitchhiker's Guide to the Galaxy lol
P['device'] = 'cuda'  # Or 'cpu'
P['verbose'] = True
P['P'] = 4096         # Number of hyperplane arrangements and number of neurons
P['num_neurons'] = P['P']
P["num_classes"] = 10
P["dim_in"] = 3*32*32
P['batch_size'] = 1000
P['beta'] = 1e-3      # Regularization parameter (in loss)
P['dir'] = os.path.abspath('')
P["print_freq"] = 5
P['device'] = 'cuda'

# Nonconvex (Regular) Args:
P['ncvx_solver'] = 'sgd'       # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['ncvx_schedule'] = 0         # learning rate schedule (0: Nothing, 1: ReduceLROnPlateau, 2: ExponentialLR)
P['ncvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['ncvx_num_epochs'] = 100
P["ncvx_learning_rate"] = 1e-3
P["ncvx_train_len"] = 50000
P["ncvx_test_len"] = 10000

P["ncvx_prune_epochs"] = 100 # 25
P["ncvx_prune_rounds"] = 1   # 5
P["ncvx_prune_perc"] = 80    # 20

# Convex Args:
P['cvx_solver'] = 'sgd'   # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['cvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['cvx_num_epochs'] = 100
P['cvx_learning_rate'] = 5e-7
P['cvx_rho'] = 1e-2
P['cvx_test_len'] = 10000

P["cvx_prune_epochs"] = 100 # 25
P["cvx_prune_rounds"] = 1   # 5
P["cvx_prune_perc"] = 80    # 20

# Set seed
random.seed(a=P['seed'])
np.random.seed(seed=P['seed'])
torch.manual_seed(seed=P['seed'])

<torch._C.Generator at 0x13303fdd110>

# Load Data
Downloads CIFAR10 if not already downloaded.  

In [4]:
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])

train_dataset = datasets.CIFAR10(P['dir'], train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

test_dataset = datasets.CIFAR10(P['dir'], train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

# Extract the data via a dummy loader (dumps entire dataset at once)
dummy_loader= torch.utils.data.DataLoader(train_dataset, batch_size=50000, shuffle=False, pin_memory=True, sampler=None)
for A, y in dummy_loader:
    pass
Apatch=A.detach().clone() # Detaches from graph

A = A.view(A.shape[0], -1)
n,dim_in=A.size()

P["cvx_n"] = n

print("Apatch (Detached A) Shape:",Apatch.shape)
print("A shape:", A.shape)

Files already downloaded and verified
Files already downloaded and verified
Apatch (Detached A) Shape: torch.Size([50000, 3, 32, 32])
A shape: torch.Size([50000, 3072])


# Standard Non-Convex Network
Consists of typical 2-layer network definition, training and test loss, as well as training loop. 

In [5]:
class FCNetwork(nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(FCNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(input_dim, num_neurons, bias=False), nn.ReLU())
        self.layer2 = nn.Linear(num_neurons, num_classes, bias=False)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        out = self.layer2(self.layer1(x))
        return out
    
def save_model(model,path):
    torch.save(model.state_dict(),path)
    
def load_fc_model(path,P):
    model = FCNetwork(P["num_neurons"],P["num_classes"],P["dim_in"])
    model.load_state_dict(torch.load(path))
    return model

def loss_func_primal(yhat, y, model, beta):
    loss = 0.5 * torch.norm(yhat - y)**2
    # l2 norm on first layer weights, l1 squared norm on second layer
    for layer, p in enumerate(model.parameters()):
        if layer == 0:
            loss += beta/2 * torch.norm(p)**2
        else:
            loss += beta/2 * sum([torch.norm(p[:, j], 1)**2 for j in range(p.shape[1])])
    return loss

def validation_primal(model, testloader, beta, device):
    test_loss = 0
    test_correct = 0
    for ix, (_x, _y) in enumerate(testloader):
        _x = Variable(_x).float().to(device)
        _y = Variable(_y).float().to(device)
        #output = model.forward(_x) # Does this do anything?
        yhat = model(_x).float()
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, beta)
        test_loss += loss.item()
        test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()
    return test_loss, test_correct.item()

def ncvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-6
    device = P["device"]
    for ix, (_x, _y) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device) # shape 1000,3,32,32
        _y = Variable(_y).to(device) # shape 1000
        yhat = model(_x).float()
        
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, P["beta"])/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        d_out["times"].append(time.time())
    return ix

def ncvx_train(model, ds, ds_test, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["ncvx_prune_epochs"] if prune else P["ncvx_num_epochs"]
    rounds = P["ncvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
    scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
    
    d_out = {"losses":[], "accs":[], "losses_test":[],"accs_test":[], "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["ncvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
            scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
            
            print("\nPruning Round [{:>2}/{:}]".format(p,rounds))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = ncvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at = validation_primal(model, ds_test, P["beta"], device)
            d_out["losses_test"] += [lt/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], loss: {:.3f} acc: {:.3f}, TEST loss: {:.3f} test acc: {:.3f}".format(
                       i,num_epochs,d_out["losses"][-1],d_out["accs"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))

            if P["ncvx_schedule"] > 0:
                scheduler.step(d_out["losses"][-1])
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Nonconvex (Regular) Training

In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=P["batch_size"], shuffle=True, pin_memory=True, sampler=None)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)
print_params(P,True,True)

Parameter          : Value
seed               : 42
device             : cuda
verbose            : True
P                  : 4096
num_neurons        : 4096
num_classes        : 10
dim_in             : 3072
batch_size         : 1000
beta               : 0.001
dir                : C:\Users\trevo\Documents\repos\spring22\convex_nn
print_freq         : 5
ncvx_solver        : sgd
ncvx_schedule      : 0
ncvx_LBFGS_param   : (10, 4)
ncvx_num_epochs    : 100
ncvx_learning_rate : 0.001
ncvx_train_len     : 50000
ncvx_test_len      : 10000
ncvx_prune_epochs  : 100
ncvx_prune_rounds  : 1
ncvx_prune_perc    : 80


In [7]:
ncvx_save_loc = "models/ncvx_nn{:}_solver{:}_l1e-3".format(P['num_neurons'],P['cvx_solver'])
model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
model.apply(weight_init)
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,ncvx_save_loc+"_INITIAL.pth")

# Initial training,
results_ncvx = ncvx_train(model, train_loader, test_loader, P, prune=False)
results_ncvx.to_csv(ncvx_save_loc+"_EPOCHS{:}_Results.csv".format(P["ncvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,ncvx_save_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.785 acc: 0.220, TEST loss: 0.779 test acc: 0.237
Epoch [ 5/100], loss: 0.496 acc: 0.398, TEST loss: 0.549 test acc: 0.346
Epoch [10/100], loss: 0.441 acc: 0.460, TEST loss: 0.502 test acc: 0.376
Epoch [15/100], loss: 0.412 acc: 0.483, TEST loss: 0.475 test acc: 0.405
Epoch [20/100], loss: 0.369 acc: 0.571, TEST loss: 0.459 test acc: 0.414
Epoch [25/100], loss: 0.355 acc: 0.563, TEST loss: 0.447 test acc: 0.424
Epoch [30/100], loss: 0.340 acc: 0.633, TEST loss: 0.438 test acc: 0.430
Epoch [35/100], loss: 0.321 acc: 0.644, TEST loss: 0.432 test acc: 0.439
Epoch [40/100], loss: 0.317 acc: 0.657, TEST loss: 0.425 test acc: 0.446
Epoch [45/100], loss: 0.307 acc: 0.685, TEST loss: 0.422 test acc: 0.451
Epoch [50/100], loss: 0.290 acc: 0.694, TEST loss: 0.418 test acc: 0.455
Epoch [55/100], loss: 0.287 acc: 0.720, TEST loss: 0.415 test acc: 0.459
Epoch [60/100], loss: 0.284 acc: 0.725, TEST loss: 0.413 test acc: 0.461
Epoch [65/100], loss: 0.280 acc: 0.732, TEST loss: 

### Pruning

In [8]:
# Load from file just in case
model_init = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(ncvx_save_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(ncvx_save_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
prune_results_ncvx = ncvx_train(model, train_loader, test_loader, P, 
                                prune=True, re_init=False, init_state_dict=initial_state_dict, mask=mask)
prune_results_ncvx.to_csv(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_Results.csv".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))


Pruning Round [ 0/1]
layer1.0.weight      | nonzeros = 2516583 / 12582912 ( 20.00%) | total_pruned = 10066329 | shape = (4096, 3072)
layer2.weight        | nonzeros =    8192 /   40960 ( 20.00%) | total_pruned =   32768 | shape = (10, 4096)
alive: 2524775, pruned : 10099097, total: 12623872, Compression rate :       5.00x  ( 80.00% pruned)


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.589 acc: 0.289, TEST loss: 0.614 test acc: 0.251
Epoch [ 5/100], loss: 0.478 acc: 0.379, TEST loss: 0.510 test acc: 0.331
Epoch [10/100], loss: 0.450 acc: 0.424, TEST loss: 0.485 test acc: 0.356
Epoch [15/100], loss: 0.426 acc: 0.456, TEST loss: 0.471 test acc: 0.372
Epoch [20/100], loss: 0.415 acc: 0.441, TEST loss: 0.462 test acc: 0.379
Epoch [25/100], loss: 0.409 acc: 0.484, TEST loss: 0.455 test acc: 0.390
Epoch [30/100], loss: 0.403 acc: 0.483, TEST loss: 0.449 test acc: 0.397
Epoch [35/100], loss: 0.394 acc: 0.498, TEST loss: 0.445 test acc: 0.402
Epoch [40/100], loss: 0.388 acc: 0.493, TEST loss: 0.441 test acc: 0.406
Epoch [45/100], loss: 0.385 acc: 0.529, TEST loss: 0.438 test acc: 0.409
Epoch [50/100], loss: 0.388 acc: 0.488, TEST loss: 0.435 test acc: 0.411
Epoch [55/100], loss: 0.384 acc: 0.509, TEST loss: 0.433 test acc: 0.417
Epoch [60/100], loss: 0.375 acc: 0.527, TEST loss: 0.431 test acc: 0.421
Epoch [65/100], loss: 0.361 acc: 0.552, TEST loss: 

In [9]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}.pth".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))

with open(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_MASK.pkl".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Convex Network


In [6]:
class custom_cvx_layer(torch.nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(custom_cvx_layer, self).__init__()
        
        # (num_neurons) P x (input_dim) d x (num_classes) C
        self.weight_v = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)
        self.weight_w = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)

    def forward(self, x, sign_patterns):
        sign_patterns = sign_patterns.unsqueeze(2)
        x = x.view(x.shape[0], -1) # n x d
        
        Xv_w = torch.matmul(x, self.weight_v - self.weight_w) # P x N x C
        
        # for some reason, the permutation is necessary. not sure why
        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C
        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C
        
        return y_pred
    
def get_nonconvex_cost(y, model, _x, beta, device):
    _x = _x.view(_x.shape[0], -1)
    Xv = torch.matmul(_x, model.weight_v)
    Xw = torch.matmul(_x, model.weight_w)
    Xv_relu = torch.max(Xv, torch.Tensor([0]).to(device))
    Xw_relu = torch.max(Xw, torch.Tensor([0]).to(device))
    
    prediction_w_relu = torch.sum(Xv_relu - Xw_relu, dim=0, keepdim=False)
    prediction_cost = 0.5 * torch.norm(prediction_w_relu - y)**2
    regularization_cost = beta * (torch.sum(torch.norm(model.weight_v, dim=1)**2) + torch.sum(torch.norm(model.weight_w, p=1, dim=1)**2))
    return prediction_cost + regularization_cost

def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device):
    _x = _x.view(_x.shape[0], -1)
    # term 1
    loss = 0.5 * torch.norm(yhat - y)**2
    # term 2
    loss = loss + beta * torch.sum(torch.norm(model.weight_v, dim=1))
    loss = loss + beta * torch.sum(torch.norm(model.weight_w, dim=1))
    # term 3
    sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1
    
    Xv = torch.matmul(_x, torch.sum(model.weight_v, dim=2, keepdim=True)) # N x d times P x d x 1 -> P x N x 1
    DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1
    relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_v)
    
    Xw = torch.matmul(_x, torch.sum(model.weight_w, dim=2, keepdim=True))
    DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))
    relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_w)
    return loss

def validation_cvxproblem(model, testloader, u_vectors, beta, rho, device):
    test_loss = 0
    test_correct = 0
    test_noncvx_cost = 0

    with torch.no_grad():
        for ix, (_x, _y) in enumerate(testloader):
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            _x = _x.view(_x.shape[0], -1)
            _z = (torch.matmul(_x, torch.from_numpy(u_vectors).float().to(device)) >= 0)

            output = model.forward(_x, _z)
            yhat = model(_x, _z).float()

            loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x, _z, beta, rho, device)

            test_loss += loss.item()
            test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()

            test_noncvx_cost += get_nonconvex_cost(one_hot(_y).to(device), model, _x, beta, device)

    return test_loss, test_correct.item(), test_noncvx_cost.item()

In [7]:
def cvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-12
    device = P["device"]
    for ix, (_x, _y, _z) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device)
        _y = Variable(_y).to(device)
        _z = Variable(_z).to(device)
        yhat = model(_x, _z).float()
        
        loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x,_z, P["beta"], P["cvx_rho"], device)/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        
        noncvx_loss = get_nonconvex_cost(one_hot(_y).to(device), model, _x, P["beta"], device)/len(_y)
        d_out["noncvx_losses"].append(noncvx_loss.item())
        d_out["times"].append(time.time())
    return ix


def cvx_train(model, ds, ds_test, u_vectors, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["cvx_prune_epochs"] if prune else P["cvx_num_epochs"]
    rounds = P["cvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
    
    d_out = {"losses":[], "accs":[], "noncvx_losses": [], "losses_test":[],"accs_test":[], "noncvx_losses_test":[], 
             "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["cvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
            
            print("Pruning Round [{:>2}/{:}]".format(p,rounds))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = cvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at,nlt = validation_cvxproblem(model, ds_test, u_vectors, P["beta"], P["cvx_rho"], device)
            d_out["losses_test"] += [lt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["cvx_test_len"]]*(train_iters + 1)
            d_out["noncvx_losses_test"] += [nlt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}, TEST noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}".format(
                       i,num_epochs,d_out["noncvx_losses"][-1],d_out["losses"][-1],d_out["accs"][-1],
                       d_out["noncvx_losses_test"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))
            scheduler.step(d_out["losses"][-1])
            
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Convex Training

In [8]:
def generate_conv_sign_patterns(A2, P, verbose=False): 
    # generate convolutional sign patterns
    n, c, p1, p2 = A2.shape
    A = A2.reshape(n,int(c*p1*p2))
    fsize=9*c
    d=c*p1*p2;
    fs=int(np.sqrt(9))
    unique_sign_pattern_list = []  
    u_vector_list = []             

    for i in range(P): 
        # obtain a sign pattern
        ind1=np.random.randint(0,p1-fs+1)
        ind2=np.random.randint(0,p2-fs+1)
        u1p= np.zeros((c,p1,p2))
        u1p[:,ind1:ind1+fs,ind2:ind2+fs]=np.random.normal(0, 1, (fsize,1)).reshape(c,fs,fs)
        u1=u1p.reshape(d,1)
        sampled_sign_pattern = (np.matmul(A, u1) >= 0)[:,0]
        unique_sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(u1)

    if verbose:
        print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list)))
    return len(unique_sign_pattern_list),unique_sign_pattern_list, u_vector_list

def generate_sign_patterns(A, P, verbose=False):
    # generate sign patterns
    n, d = A.shape
    sign_pattern_list = []  # sign patterns
    u_vector_list = []             # random vectors used to generate the sign paterns
    umat = np.random.normal(0, 1, (d,P))
    sampled_sign_pattern_mat = (np.matmul(A, umat) >= 0)
    for i in range(P):
        sampled_sign_pattern = sampled_sign_pattern_mat[:,i]
        sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(umat[:,i])
    if verbose:
        print("Number of sign patterns generated: " + str(len(sign_pattern_list)))
    return len(sign_pattern_list),sign_pattern_list, u_vector_list

# Generate sign patterns for convex network
num_neurons,sign_pattern_list, u_vector_list = generate_sign_patterns(A, P["P"], P["verbose"])
sign_patterns = np.array([sign_pattern_list[i].int().data.numpy() for i in range(num_neurons)])
u_vectors = np.asarray(u_vector_list).reshape((num_neurons, A.shape[1])).T

ds_train = PrepareData3D(X=A, y=y, z=sign_patterns.T)
ds_train = DataLoader(ds_train, batch_size=P["batch_size"], shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)

print_params(P,True,False,True)

Number of sign patterns generated: 4096
Parameter         : Value
seed              : 42
device            : cuda
verbose           : True
P                 : 4096
num_neurons       : 4096
num_classes       : 10
dim_in            : 3072
batch_size        : 1000
beta              : 0.001
dir               : C:\Users\trevo\Documents\repos\spring22\convex_nn
print_freq        : 5
cvx_solver        : sgd
cvx_LBFGS_param   : (10, 4)
cvx_num_epochs    : 100
cvx_learning_rate : 5e-07
cvx_rho           : 0.01
cvx_test_len      : 10000
cvx_prune_epochs  : 25
cvx_prune_rounds  : 5
cvx_prune_perc    : 20


In [9]:
cvx_save_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,cvx_save_loc+"_INITIAL.pth")

# Initial training,
results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, prune=False)
results_cvx.to_csv(cvx_save_loc+"_EPOCHS{:}_Results.csv".format(P["cvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,cvx_save_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.408 loss: 0.366 acc: 0.477, TEST noncvx_loss: 0.411 loss: 0.366 acc: 0.467
Epoch [ 5/100], noncvx_loss: 0.431 loss: 0.296 acc: 0.636, TEST noncvx_loss: 0.436 loss: 0.342 acc: 0.519
Epoch [10/100], noncvx_loss: 0.436 loss: 0.258 acc: 0.765, TEST noncvx_loss: 0.443 loss: 0.338 acc: 0.529
Epoch [15/100], noncvx_loss: 0.435 loss: 0.232 acc: 0.835, TEST noncvx_loss: 0.445 loss: 0.335 acc: 0.539
Epoch [20/100], noncvx_loss: 0.438 loss: 0.215 acc: 0.867, TEST noncvx_loss: 0.450 loss: 0.338 acc: 0.544
Epoch [25/100], noncvx_loss: 0.440 loss: 0.200 acc: 0.886, TEST noncvx_loss: 0.451 loss: 0.337 acc: 0.543
Epoch [30/100], noncvx_loss: 0.439 loss: 0.184 acc: 0.934, TEST noncvx_loss: 0.452 loss: 0.339 acc: 0.541
Epoch [35/100], noncvx_loss: 0.437 loss: 0.169 acc: 0.938, TEST noncvx_loss: 0.452 loss: 0.340 acc: 0.541
Epoch [40/100], noncvx_loss: 0.440 loss: 0.159 acc: 0.952, TEST noncvx_loss: 0.452 loss: 0.341 acc: 0.541
Epoch [45/100], noncvx_loss: 0.443 loss: 0.154

### Pruning

In [10]:
cvx_save_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
# Load from file just in case
model_init = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(cvx_save_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(cvx_save_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
prune_results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, 
                              prune=True, re_init=False, init_state_dict=initial_state_dict_post, mask=mask)
prune_results_cvx.to_csv(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_Results.csv".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

Pruning Round [ 0/25]
weight_v             | nonzeros = 100663299 / 125829120 ( 80.00%) | total_pruned = 25165821 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 100663298 / 125829120 ( 80.00%) | total_pruned = 25165822 | shape = (4096, 3072, 10)
alive: 201326597, pruned : 50331643, total: 251658240, Compression rate :       1.25x  (  2.93% pruned)


  print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x  ({100 * (total-nonzero) / total:6.2f}% pruned)')


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], noncvx_loss: 0.437 loss: 0.083 acc: 0.993, TEST noncvx_loss: 0.454 loss: 0.356 acc: 0.537
Epoch [ 5/25], noncvx_loss: 0.438 loss: 0.081 acc: 0.995, TEST noncvx_loss: 0.453 loss: 0.356 acc: 0.537
Epoch [10/25], noncvx_loss: 0.439 loss: 0.082 acc: 0.997, TEST noncvx_loss: 0.453 loss: 0.356 acc: 0.538
Epoch [15/25], noncvx_loss: 0.440 loss: 0.080 acc: 0.993, TEST noncvx_loss: 0.453 loss: 0.356 acc: 0.537
Epoch    16: reducing learning rate of group 0 to 2.5000e-07.
Epoch [20/25], noncvx_loss: 0.439 loss: 0.082 acc: 0.993, TEST noncvx_loss: 0.452 loss: 0.356 acc: 0.537
Epoch [24/25], noncvx_loss: 0.433 loss: 0.079 acc: 0.995, TEST noncvx_loss: 0.452 loss: 0.355 acc: 0.538
Pruning Round [ 1/25]
weight_v             | nonzeros = 80530642 / 125829120 ( 64.00%) | total_pruned = 45298478 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 80530642 / 125829120 ( 64.00%) | total_pruned = 45298478 | shape = (4096, 3072, 10)
alive: 161061284, pruned : 90596956, total: 251658

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], noncvx_loss: 0.441 loss: 0.091 acc: 0.989, TEST noncvx_loss: 0.451 loss: 0.355 acc: 0.538
Epoch [ 5/25], noncvx_loss: 0.436 loss: 0.085 acc: 0.994, TEST noncvx_loss: 0.451 loss: 0.354 acc: 0.537
Epoch [10/25], noncvx_loss: 0.440 loss: 0.083 acc: 0.995, TEST noncvx_loss: 0.451 loss: 0.354 acc: 0.538
Epoch    15: reducing learning rate of group 0 to 2.5000e-07.
Epoch [15/25], noncvx_loss: 0.438 loss: 0.085 acc: 0.992, TEST noncvx_loss: 0.450 loss: 0.354 acc: 0.537
Epoch [20/25], noncvx_loss: 0.435 loss: 0.084 acc: 0.994, TEST noncvx_loss: 0.450 loss: 0.354 acc: 0.540
Epoch [24/25], noncvx_loss: 0.436 loss: 0.083 acc: 0.995, TEST noncvx_loss: 0.451 loss: 0.354 acc: 0.538
Pruning Round [ 2/25]
weight_v             | nonzeros = 64424516 / 125829120 ( 51.20%) | total_pruned = 61404604 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 64424516 / 125829120 ( 51.20%) | total_pruned = 61404604 | shape = (4096, 3072, 10)
alive: 128849032, pruned : 122809208, total: 25165

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], noncvx_loss: 0.436 loss: 0.089 acc: 0.996, TEST noncvx_loss: 0.450 loss: 0.352 acc: 0.538
Epoch [ 5/25], noncvx_loss: 0.438 loss: 0.089 acc: 0.998, TEST noncvx_loss: 0.449 loss: 0.351 acc: 0.538
Epoch [10/25], noncvx_loss: 0.433 loss: 0.085 acc: 0.996, TEST noncvx_loss: 0.449 loss: 0.351 acc: 0.539
Epoch [15/25], noncvx_loss: 0.432 loss: 0.093 acc: 0.988, TEST noncvx_loss: 0.448 loss: 0.351 acc: 0.540
Epoch [20/25], noncvx_loss: 0.434 loss: 0.087 acc: 0.996, TEST noncvx_loss: 0.448 loss: 0.351 acc: 0.538
Epoch    22: reducing learning rate of group 0 to 2.5000e-07.
Epoch [24/25], noncvx_loss: 0.431 loss: 0.089 acc: 0.991, TEST noncvx_loss: 0.448 loss: 0.351 acc: 0.538
Pruning Round [ 3/25]
weight_v             | nonzeros = 51539613 / 125829120 ( 40.96%) | total_pruned = 74289507 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 51539615 / 125829120 ( 40.96%) | total_pruned = 74289505 | shape = (4096, 3072, 10)
alive: 103079228, pruned : 148579012, total: 25165

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], noncvx_loss: 0.438 loss: 0.100 acc: 0.994, TEST noncvx_loss: 0.447 loss: 0.349 acc: 0.540
Epoch [ 5/25], noncvx_loss: 0.433 loss: 0.099 acc: 0.996, TEST noncvx_loss: 0.446 loss: 0.348 acc: 0.540
Epoch [10/25], noncvx_loss: 0.433 loss: 0.092 acc: 0.997, TEST noncvx_loss: 0.446 loss: 0.348 acc: 0.539
Epoch [15/25], noncvx_loss: 0.431 loss: 0.097 acc: 0.990, TEST noncvx_loss: 0.445 loss: 0.348 acc: 0.540
Epoch [20/25], noncvx_loss: 0.427 loss: 0.096 acc: 0.992, TEST noncvx_loss: 0.444 loss: 0.348 acc: 0.541
Epoch    22: reducing learning rate of group 0 to 2.5000e-07.
Epoch [24/25], noncvx_loss: 0.426 loss: 0.095 acc: 0.993, TEST noncvx_loss: 0.444 loss: 0.348 acc: 0.539
Pruning Round [ 4/25]
weight_v             | nonzeros = 41231691 / 125829120 ( 32.77%) | total_pruned = 84597429 | shape = (4096, 3072, 10)
weight_w             | nonzeros = 41231695 / 125829120 ( 32.77%) | total_pruned = 84597425 | shape = (4096, 3072, 10)
alive: 82463386, pruned : 169194854, total: 251658

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], noncvx_loss: 0.432 loss: 0.113 acc: 0.993, TEST noncvx_loss: 0.447 loss: 0.346 acc: 0.540
Epoch [ 5/25], noncvx_loss: 0.429 loss: 0.107 acc: 0.998, TEST noncvx_loss: 0.445 loss: 0.345 acc: 0.541
Epoch [10/25], noncvx_loss: 0.433 loss: 0.108 acc: 0.993, TEST noncvx_loss: 0.444 loss: 0.345 acc: 0.543
Epoch [15/25], noncvx_loss: 0.430 loss: 0.110 acc: 0.991, TEST noncvx_loss: 0.443 loss: 0.345 acc: 0.543
Epoch [20/25], noncvx_loss: 0.430 loss: 0.107 acc: 0.993, TEST noncvx_loss: 0.443 loss: 0.344 acc: 0.545
Epoch    21: reducing learning rate of group 0 to 2.5000e-07.
Epoch [24/25], noncvx_loss: 0.429 loss: 0.106 acc: 0.990, TEST noncvx_loss: 0.442 loss: 0.344 acc: 0.544


In [11]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}.pth".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

with open(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_MASK.pkl".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Save Parameters 

In [None]:
param_save_loc = 'models/nn{:}_solver{:}_lr_PARAMS.json'.format(P['num_neurons'],P['cvx_solver'])
save_params(P,param_save_loc)

# Lottery Ticket 
Borrows from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py

In [None]:
from importlib import reload
import helperfunctions
reload(helperfunctions)
from helperfunctions import *

# TODO: 
 * Better plotting