# Rewrite of `convexnn_pytorch_stepsize_fig.py` with Lottery
Borrows from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py

In [1]:
import copy
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import random
import time
from tqdm.auto import tqdm, trange

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.init as init

from helperfunctions import *

In [2]:
from importlib import reload
import helperfunctions
reload(helperfunctions)
from helperfunctions import *

# Parameters and Args
I'm not using argparse in a notebook, it's gross. 

In [5]:
P = dict()
P['seed'] = 42        # Well we can tell who read Hitchhiker's Guide to the Galaxy lol
P['device'] = 'cuda'  # Or 'cpu'
P['verbose'] = True
P['P'] = 4096         # Number of hyperplane arrangements and number of neurons
P['num_neurons'] = P['P']
P["num_classes"] = 10
P["dim_in"] = 3*32*32
P['batch_size'] = 1000
P['beta'] = 1e-3      # Regularization parameter (in loss)
P['dir'] = os.path.abspath('')
P["print_freq"] = 5
P['device'] = 'cuda'

# Nonconvex (Regular) Args:
P['ncvx_solver'] = 'sgd'       # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['ncvx_schedule'] = 0         # learning rate schedule (0: Nothing, 1: ReduceLROnPlateau, 2: ExponentialLR)
P['ncvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['ncvx_num_epochs'] = 100
P["ncvx_learning_rate"] = 1e-3
P["ncvx_train_len"] = 50000
P["ncvx_test_len"] = 10000

P["ncvx_prune_epochs"] = 25
P["ncvx_prune_rounds"] = 5
P["ncvx_prune_perc"] = 20

# Convex Args:
P['cvx_solver'] = 'sgd'   # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
P['cvx_LBFGS_param'] = (10,4) # params for solver LBFGS
P['cvx_num_epochs'] = 100
P['cvx_learning_rate'] = 5e-7
P['cvx_rho'] = 1e-2
P['cvx_test_len'] = 10000

P["cvx_prune_epochs"] = 25
P["cvx_prune_rounds"] = 5
P["cvx_prune_perc"] = 20

# Set seed
random.seed(a=P['seed'])
np.random.seed(seed=P['seed'])
torch.manual_seed(seed=P['seed'])

<torch._C.Generator at 0x2190711e110>

# Load Data
Downloads CIFAR10 if not already downloaded.  

In [3]:
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])

train_dataset = datasets.CIFAR10(P['dir'], train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

test_dataset = datasets.CIFAR10(P['dir'], train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]))

# Extract the data via a dummy loader (dumps entire dataset at once)
dummy_loader= torch.utils.data.DataLoader(train_dataset, batch_size=50000, shuffle=False, pin_memory=True, sampler=None)
for A, y in dummy_loader:
    pass
Apatch=A.detach().clone() # Detaches from graph

A = A.view(A.shape[0], -1)
n,dim_in=A.size()

P["cvx_n"] = n

print("Apatch (Detached A) Shape:",Apatch.shape)
print("A shape:", A.shape)

Files already downloaded and verified
Files already downloaded and verified
Apatch (Detached A) Shape: torch.Size([50000, 3, 32, 32])
A shape: torch.Size([50000, 3072])


# Standard Non-Convex Network
Consists of typical 2-layer network definition, training and test loss, as well as training loop. 

In [10]:
class FCNetwork(nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(FCNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(input_dim, num_neurons, bias=False), nn.ReLU())
        self.layer2 = nn.Linear(num_neurons, num_classes, bias=False)

    def forward(self, x):
        x = x.reshape(x.size(0), -1)
        out = self.layer2(self.layer1(x))
        return out
    
def save_model(model,path):
    torch.save(model.state_dict(),path)
    
def load_fc_model(path,P):
    model = FCNetwork(P["num_neurons"],P["num_classes"],P["dim_in"])
    model.load_state_dict(torch.load(path))
    return model

def loss_func_primal(yhat, y, model, beta):
    loss = 0.5 * torch.norm(yhat - y)**2
    # l2 norm on first layer weights, l1 squared norm on second layer
    for layer, p in enumerate(model.parameters()):
        if layer == 0:
            loss += beta/2 * torch.norm(p)**2
        else:
            loss += beta/2 * sum([torch.norm(p[:, j], 1)**2 for j in range(p.shape[1])])
    return loss

def validation_primal(model, testloader, beta, device):
    test_loss = 0
    test_correct = 0
    for ix, (_x, _y) in enumerate(testloader):
        _x = Variable(_x).float().to(device)
        _y = Variable(_y).float().to(device)
        #output = model.forward(_x) # Does this do anything?
        yhat = model(_x).float()
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, beta)
        test_loss += loss.item()
        test_correct += torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()
    return test_loss, test_correct.item()

def ncvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-6
    device = P["device"]
    for ix, (_x, _y) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device) # shape 1000,3,32,32
        _y = Variable(_y).to(device) # shape 1000
        yhat = model(_x).float()
        
        loss = loss_func_primal(yhat, one_hot(_y).to(device), model, P["beta"])/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        d_out["times"].append(time.time())
    return ix

def ncvx_train(model, ds, ds_test, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["ncvx_prune_epochs"] if prune else P["ncvx_num_epochs"]
    rounds = P["ncvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
    scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
    
    d_out = {"losses":[], "accs":[], "losses_test":[],"accs_test":[], "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["ncvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["ncvx_solver"],P["ncvx_learning_rate"],P["ncvx_LBFGS_param"])
            scheduler = get_scheduler(P["ncvx_schedule"],optimizer,P["verbose"])
            
            print("\nPruning Round [{:>2}/{:}]".format(p,rounds))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = ncvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at = validation_primal(model, ds_test, P["beta"], device)
            d_out["losses_test"] += [lt/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["ncvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], loss: {:.3f} acc: {:.3f}, TEST loss: {:.3f} test acc: {:.3f}".format(
                       i,num_epochs,d_out["losses"][-1],d_out["accs"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))

            if P["ncvx_schedule"] > 0:
                scheduler.step(d_out["losses"][-1])
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Nonconvex (Regular) Training

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=P["batch_size"], shuffle=True, pin_memory=True, sampler=None)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)
print_params(P,True,True)

Parameter          : Value
seed               : 42
device             : cuda
verbose            : True
P                  : 4096
num_neurons        : 4096
num_classes        : 10
dim_in             : 3072
batch_size         : 1000
beta               : 0.001
dir                : C:\Users\trevo\Documents\repos\spring22\convex_nn
print_freq         : 5
ncvx_solver        : sgd
ncvx_schedule      : 0
ncvx_LBFGS_param   : (10, 4)
ncvx_num_epochs    : 100
ncvx_learning_rate : 0.001
ncvx_train_len     : 50000
ncvx_test_len      : 10000
ncvx_prune_epochs  : 25
ncvx_prune_rounds  : 5
ncvx_prune_perc    : 20


In [36]:
ncvx_save_loc = "models/ncvx_nn{:}_solver{:}_l1e-3".format(P['num_neurons'],P['cvx_solver'])
model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
model.apply(weight_init)
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,ncvx_save_loc+"_INITIAL.pth")

# Initial training,
results_ncvx = ncvx_train(model, train_loader, test_loader, P, prune=False)
results_ncvx.to_csv(ncvx_save_loc+"_EPOCHS{:}_Results".format(P["ncvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,ncvx_save_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], loss: 0.770 acc: 0.240, TEST loss: 0.787 test acc: 0.239
Epoch [ 5/100], loss: 0.503 acc: 0.397, TEST loss: 0.553 test acc: 0.344
Epoch [10/100], loss: 0.440 acc: 0.455, TEST loss: 0.504 test acc: 0.378
Epoch [15/100], loss: 0.398 acc: 0.507, TEST loss: 0.477 test acc: 0.398
Epoch [20/100], loss: 0.379 acc: 0.531, TEST loss: 0.460 test acc: 0.415
Epoch [25/100], loss: 0.342 acc: 0.629, TEST loss: 0.449 test acc: 0.425
Epoch [30/100], loss: 0.336 acc: 0.628, TEST loss: 0.439 test acc: 0.433
Epoch [35/100], loss: 0.322 acc: 0.660, TEST loss: 0.432 test acc: 0.438
Epoch [40/100], loss: 0.315 acc: 0.646, TEST loss: 0.427 test acc: 0.442
Epoch [45/100], loss: 0.308 acc: 0.678, TEST loss: 0.423 test acc: 0.449
Epoch [50/100], loss: 0.302 acc: 0.689, TEST loss: 0.420 test acc: 0.456
Epoch [55/100], loss: 0.296 acc: 0.694, TEST loss: 0.417 test acc: 0.454
Epoch [60/100], loss: 0.286 acc: 0.738, TEST loss: 0.414 test acc: 0.463
Epoch [65/100], loss: 0.273 acc: 0.749, TEST loss: 

### Pruning

In [11]:
# Load from file just in case
model_init = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(ncvx_save_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = FCNetwork(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(ncvx_save_loc+"_EPOCHS{:}.pth".format(P["ncvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
prune_results_ncvx = ncvx_train(model, train_loader, test_loader, P, 
                                prune=True, re_init=False, init_state_dict=initial_state_dict, mask=mask)
prune_results_ncvx.to_csv(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_Results.csv".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))

Pruning Round [ 0/5]
layer1.0.weight      | nonzeros = 10066329 / 12582912 ( 80.00%) | total_pruned = 2516583 | shape = (4096, 3072)
layer2.weight        | nonzeros =   32768 /   40960 ( 80.00%) | total_pruned =    8192 | shape = (10, 4096)
alive: 10099097, pruned : 2524775, total: 12623872, Compression rate :       1.25x  ( 20.00% pruned)


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], loss: 0.695 acc: 0.288, TEST loss: 0.774 test acc: 0.257
Epoch [ 5/25], loss: 0.526 acc: 0.398, TEST loss: 0.599 test acc: 0.327
Epoch [10/25], loss: 0.470 acc: 0.461, TEST loss: 0.559 test acc: 0.357
Epoch [15/25], loss: 0.447 acc: 0.475, TEST loss: 0.538 test acc: 0.373
Epoch [20/25], loss: 0.423 acc: 0.499, TEST loss: 0.524 test acc: 0.385
Epoch [24/25], loss: 0.403 acc: 0.551, TEST loss: 0.516 test acc: 0.387
Pruning Round [ 1/5]
layer1.0.weight      | nonzeros = 8053063 / 12582912 ( 64.00%) | total_pruned = 4529849 | shape = (4096, 3072)
layer2.weight        | nonzeros =   26214 /   40960 ( 64.00%) | total_pruned =   14746 | shape = (10, 4096)
alive: 8079277, pruned : 4544595, total: 12623872, Compression rate :       1.56x  ( 36.00% pruned)


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], loss: 0.712 acc: 0.262, TEST loss: 0.730 test acc: 0.256
Epoch [ 5/25], loss: 0.548 acc: 0.357, TEST loss: 0.598 test acc: 0.319
Epoch [10/25], loss: 0.503 acc: 0.435, TEST loss: 0.564 test acc: 0.348
Epoch [15/25], loss: 0.472 acc: 0.427, TEST loss: 0.543 test acc: 0.358
Epoch [20/25], loss: 0.448 acc: 0.479, TEST loss: 0.530 test acc: 0.366
Epoch [24/25], loss: 0.439 acc: 0.484, TEST loss: 0.522 test acc: 0.374
Pruning Round [ 2/5]
layer1.0.weight      | nonzeros = 6442450 / 12582912 ( 51.20%) | total_pruned = 6140462 | shape = (4096, 3072)
layer2.weight        | nonzeros =   20971 /   40960 ( 51.20%) | total_pruned =   19989 | shape = (10, 4096)
alive: 6463421, pruned : 6160451, total: 12623872, Compression rate :       1.95x  ( 48.80% pruned)


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], loss: 0.633 acc: 0.318, TEST loss: 0.686 test acc: 0.275
Epoch [ 5/25], loss: 0.547 acc: 0.367, TEST loss: 0.580 test acc: 0.326
Epoch [10/25], loss: 0.489 acc: 0.430, TEST loss: 0.551 test acc: 0.347
Epoch [15/25], loss: 0.470 acc: 0.429, TEST loss: 0.535 test acc: 0.360
Epoch [20/25], loss: 0.442 acc: 0.465, TEST loss: 0.524 test acc: 0.362
Epoch [24/25], loss: 0.431 acc: 0.477, TEST loss: 0.516 test acc: 0.375
Pruning Round [ 3/5]
layer1.0.weight      | nonzeros = 5153960 / 12582912 ( 40.96%) | total_pruned = 7428952 | shape = (4096, 3072)
layer2.weight        | nonzeros =   16777 /   40960 ( 40.96%) | total_pruned =   24183 | shape = (10, 4096)
alive: 5170737, pruned : 7453135, total: 12623872, Compression rate :       2.44x  ( 59.04% pruned)


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], loss: 0.601 acc: 0.323, TEST loss: 0.653 test acc: 0.272
Epoch [ 5/25], loss: 0.507 acc: 0.374, TEST loss: 0.564 test acc: 0.328
Epoch [10/25], loss: 0.481 acc: 0.414, TEST loss: 0.539 test acc: 0.342
Epoch [15/25], loss: 0.464 acc: 0.434, TEST loss: 0.524 test acc: 0.359
Epoch [20/25], loss: 0.456 acc: 0.433, TEST loss: 0.514 test acc: 0.365
Epoch [24/25], loss: 0.444 acc: 0.442, TEST loss: 0.508 test acc: 0.371
Pruning Round [ 4/5]
layer1.0.weight      | nonzeros = 4123169 / 12582912 ( 32.77%) | total_pruned = 8459743 | shape = (4096, 3072)
layer2.weight        | nonzeros =   13421 /   40960 ( 32.77%) | total_pruned =   27539 | shape = (10, 4096)
alive: 4136590, pruned : 8487282, total: 12623872, Compression rate :       3.05x  ( 67.23% pruned)


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [ 0/25], loss: 0.596 acc: 0.319, TEST loss: 0.630 test acc: 0.290
Epoch [ 5/25], loss: 0.511 acc: 0.370, TEST loss: 0.548 test acc: 0.337
Epoch [10/25], loss: 0.465 acc: 0.418, TEST loss: 0.525 test acc: 0.350
Epoch [15/25], loss: 0.469 acc: 0.403, TEST loss: 0.512 test acc: 0.363
Epoch [20/25], loss: 0.452 acc: 0.427, TEST loss: 0.503 test acc: 0.369
Epoch [24/25], loss: 0.432 acc: 0.464, TEST loss: 0.497 test acc: 0.378


In [17]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}.pth".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]))

with open(ncvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_MASK.pkl".format(P["ncvx_prune_rounds"],P["ncvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Convex Network


In [6]:
class custom_cvx_layer(torch.nn.Module):
    def __init__(self, num_neurons=4096, num_classes=10, input_dim=3072):
        self.num_classes = num_classes
        super(custom_cvx_layer, self).__init__()
        
        # (num_neurons) P x (input_dim) d x (num_classes) C
        self.weight_v = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)
        self.weight_w = torch.nn.Parameter(data=torch.zeros(num_neurons, input_dim, num_classes), requires_grad=True)

    def forward(self, x, sign_patterns):
        sign_patterns = sign_patterns.unsqueeze(2)
        x = x.view(x.shape[0], -1) # n x d
        
        Xv_w = torch.matmul(x, self.weight_v - self.weight_w) # P x N x C
        
        # for some reason, the permutation is necessary. not sure why
        DXv_w = torch.mul(sign_patterns, Xv_w.permute(1, 0, 2)) #  N x P x C
        y_pred = torch.sum(DXv_w, dim=1, keepdim=False) # N x C
        
        return y_pred
    
def get_nonconvex_cost(y, model, _x, beta, device):
    _x = _x.view(_x.shape[0], -1)
    Xv = torch.matmul(_x, model.weight_v)
    Xw = torch.matmul(_x, model.weight_w)
    Xv_relu = torch.max(Xv, torch.Tensor([0]).to(device))
    Xw_relu = torch.max(Xw, torch.Tensor([0]).to(device))
    
    prediction_w_relu = torch.sum(Xv_relu - Xw_relu, dim=0, keepdim=False)
    prediction_cost = 0.5 * torch.norm(prediction_w_relu - y)**2
    regularization_cost = beta * (torch.sum(torch.norm(model.weight_v, dim=1)**2) + torch.sum(torch.norm(model.weight_w, p=1, dim=1)**2))
    return prediction_cost + regularization_cost

def loss_func_cvxproblem(yhat, y, model, _x, sign_patterns, beta, rho, device):
    _x = _x.view(_x.shape[0], -1)
    # term 1
    loss = 0.5 * torch.norm(yhat - y)**2
    # term 2
    loss = loss + beta * torch.sum(torch.norm(model.weight_v, dim=1))
    loss = loss + beta * torch.sum(torch.norm(model.weight_w, dim=1))
    # term 3
    sign_patterns = sign_patterns.unsqueeze(2) # N x P x 1
    
    Xv = torch.matmul(_x, torch.sum(model.weight_v, dim=2, keepdim=True)) # N x d times P x d x 1 -> P x N x 1
    DXv = torch.mul(sign_patterns, Xv.permute(1, 0, 2)) # P x N x 1
    relu_term_v = torch.max(-2*DXv + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_v)
    
    Xw = torch.matmul(_x, torch.sum(model.weight_w, dim=2, keepdim=True))
    DXw = torch.mul(sign_patterns, Xw.permute(1, 0, 2))
    relu_term_w = torch.max(-2*DXw + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))
    loss = loss + rho * torch.sum(relu_term_w)
    return loss

def validation_cvxproblem(model, testloader, u_vectors, beta, rho, device):
    test_loss = 0
    test_correct = 0
    test_noncvx_cost = 0

    with torch.no_grad():
        for ix, (_x, _y) in enumerate(testloader):
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            _x = _x.view(_x.shape[0], -1)
            _z = (torch.matmul(_x, torch.from_numpy(u_vectors).float().to(device)) >= 0)

            output = model.forward(_x, _z)
            yhat = model(_x, _z).float()

            loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x, _z, beta, rho, device)

            test_loss += loss.item()
            test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()

            test_noncvx_cost += get_nonconvex_cost(one_hot(_y).to(device), model, _x, beta, device)

    return test_loss, test_correct.item(), test_noncvx_cost.item()

In [7]:
def cvx_train_step(model, ds, optimizer, P, d_out, freeze=True):
    EPS = 1e-12
    device = P["device"]
    for ix, (_x, _y, _z) in enumerate(ds):
        optimizer.zero_grad()
        # Make input differentiable
        _x = Variable(_x).to(device)
        _y = Variable(_y).to(device)
        _z = Variable(_z).to(device)
        yhat = model(_x, _z).float()
        
        loss = loss_func_cvxproblem(yhat, one_hot(_y).to(device), model, _x,_z, P["beta"], P["cvx_rho"], device)/len(_y)
        correct = torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()/len(_y)
        
        loss.backward()
        # Freezing Pruned weights by making their gradients Zero (if zero stay zero)
        if freeze:
            for name, p in model.named_parameters():
                if 'weight' in name:
                    tensor = p.data.cpu().numpy()
                    grad_tensor = p.grad.data.cpu().numpy()
                    grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                    p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        d_out["losses"].append(loss.item())
        d_out["accs"].append(correct.item())
        
        noncvx_loss = get_nonconvex_cost(one_hot(_y).to(device), model, _x, P["beta"], device)/len(_y)
        d_out["noncvx_losses"].append(noncvx_loss.item())
        d_out["times"].append(time.time())
    return ix


def cvx_train(model, ds, ds_test, u_vectors, P, prune=True, re_init=False, init_state_dict=None, mask=None):
    # Runs training loop
    num_epochs = P["cvx_prune_epochs"] if prune else P["cvx_num_epochs"]
    rounds = P["cvx_prune_rounds"] if prune else 1

    device = torch.device(P["device"])
    model.to(device)
    optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
    
    d_out = {"losses":[], "accs":[], "noncvx_losses": [], "losses_test":[],"accs_test":[], "noncvx_losses_test":[], 
             "times":[time.time()], "epoch": [], "round": []}
    if prune:
        d_out["nonzero_perc"] = []

    for p in range(rounds):
        if prune:
            prune_by_percentile(model,mask,P["cvx_prune_perc"])
            _ = re_init(model,mask) if re_init else og_init(model,mask,init_state_dict)
            optimizer = get_optimizer(model,P["cvx_solver"],P["cvx_learning_rate"],P["cvx_LBFGS_param"])
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=P["verbose"], factor=0.5, eps=1e-12)
            
            print("Pruning Round [{:>2}/{:}]".format(p,num_epochs))
            nonzero_pc = print_nonzeros(model)
            
        iter_no = 0
        for i in tqdm(range(num_epochs)):
            model.train()
            train_iters = cvx_train_step(model, ds, optimizer, P, d_out, freeze=prune)

            model.eval()
            lt,at,nlt = validation_cvxproblem(model, ds_test, u_vectors, P["beta"], P["cvx_rho"], device)
            d_out["losses_test"] += [lt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["accs_test"] += [at/P["cvx_test_len"]]*(train_iters + 1)
            d_out["noncvx_losses_test"] += [nlt/P["cvx_test_len"]]*(train_iters + 1)
            d_out["epoch"] += [i]*(train_iters + 1)
            d_out["round"] += [p]*(train_iters + 1)
            
            if prune:
                d_out["nonzero_perc"] += [nonzero_pc]*(train_iters + 1)
            iter_no += train_iters + 1

            if i % P["print_freq"] == 0 or i == num_epochs - 1:
                print("Epoch [{:>2}/{:}], noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}, TEST noncvx_loss: {:.3f} loss: {:.3f} acc: {:.3f}".format(
                       i,num_epochs,d_out["noncvx_losses"][-1],d_out["losses"][-1],d_out["accs"][-1],
                       d_out["noncvx_losses_test"][-1],d_out["losses_test"][-1],d_out["accs_test"][-1]))
            scheduler.step(d_out["losses"][-1])
            
    d_out["times"] = np.diff(d_out["times"])
    return pd.DataFrame.from_dict(d_out)

# Convex Training

In [8]:
def generate_conv_sign_patterns(A2, P, verbose=False): 
    # generate convolutional sign patterns
    n, c, p1, p2 = A2.shape
    A = A2.reshape(n,int(c*p1*p2))
    fsize=9*c
    d=c*p1*p2;
    fs=int(np.sqrt(9))
    unique_sign_pattern_list = []  
    u_vector_list = []             

    for i in range(P): 
        # obtain a sign pattern
        ind1=np.random.randint(0,p1-fs+1)
        ind2=np.random.randint(0,p2-fs+1)
        u1p= np.zeros((c,p1,p2))
        u1p[:,ind1:ind1+fs,ind2:ind2+fs]=np.random.normal(0, 1, (fsize,1)).reshape(c,fs,fs)
        u1=u1p.reshape(d,1)
        sampled_sign_pattern = (np.matmul(A, u1) >= 0)[:,0]
        unique_sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(u1)

    if verbose:
        print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list)))
    return len(unique_sign_pattern_list),unique_sign_pattern_list, u_vector_list

def generate_sign_patterns(A, P, verbose=False):
    # generate sign patterns
    n, d = A.shape
    sign_pattern_list = []  # sign patterns
    u_vector_list = []             # random vectors used to generate the sign paterns
    umat = np.random.normal(0, 1, (d,P))
    sampled_sign_pattern_mat = (np.matmul(A, umat) >= 0)
    for i in range(P):
        sampled_sign_pattern = sampled_sign_pattern_mat[:,i]
        sign_pattern_list.append(sampled_sign_pattern)
        u_vector_list.append(umat[:,i])
    if verbose:
        print("Number of sign patterns generated: " + str(len(sign_pattern_list)))
    return len(sign_pattern_list),sign_pattern_list, u_vector_list

# Generate sign patterns for convex network
num_neurons,sign_pattern_list, u_vector_list = generate_sign_patterns(A, P["P"], P["verbose"])
sign_patterns = np.array([sign_pattern_list[i].int().data.numpy() for i in range(num_neurons)])
u_vectors = np.asarray(u_vector_list).reshape((num_neurons, A.shape[1])).T

ds_train = PrepareData3D(X=A, y=y, z=sign_patterns.T)
ds_train = DataLoader(ds_train, batch_size=P["batch_size"], shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=P["batch_size"], shuffle=False, pin_memory=True)

print_params(P,True,False,True)

Number of sign patterns generated: 4096
Parameter         : Value
seed              : 42
device            : cuda
verbose           : True
P                 : 4096
num_neurons       : 4096
num_classes       : 10
dim_in            : 3072
batch_size        : 1000
beta              : 0.001
dir               : C:\Users\trevo\Documents\repos\spring22\convex_nn
print_freq        : 5
cvx_solver        : sgd
cvx_LBFGS_param   : (10, 4)
cvx_num_epochs    : 100
cvx_learning_rate : 5e-07
cvx_rho           : 0.01
cvx_test_len      : 10000
cvx_prune_epochs  : 25
cvx_prune_rounds  : 5
cvx_prune_perc    : 20


In [None]:
cvx_save_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])

# Save initial model
initial_state_dict = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict,cvx_save_loc+"_INITIAL.pth")

# Initial training,
results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, prune=False)
results_cvx.to_csv(cvx_save_loc+"_EPOCHS{:}_Results".format(P["cvx_num_epochs"]))

# Save model after 100 epochs
initial_state_dict_post = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_post,cvx_save_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"]))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [ 0/100], noncvx_loss: 0.408 loss: 0.366 acc: 0.477, TEST noncvx_loss: 0.411 loss: 0.366 acc: 0.467


### Pruning

In [None]:
cvx_save_loc = "models/cvx_nn{:}_solver{:}_lr5e-7".format(P['num_neurons'],P['cvx_solver'])
# Load from file just in case
model_init = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model_init.load_state_dict(torch.load(cvx_save_loc+"_INITIAL.pth"))
initial_state_dict = copy.deepcopy(model_init.state_dict())

model = custom_cvx_layer(P["num_neurons"], P["num_classes"], P["dim_in"])
model.load_state_dict(torch.load(cvx_save_loc+"_EPOCHS{:}.pth".format(P["cvx_num_epochs"])))
initial_state_dict_post = copy.deepcopy(model.state_dict())
                           
mask = make_mask(model)
prune_results_cvx = cvx_train(model, ds_train, test_loader, u_vectors, P, 
                              prune=True, re_init=False, init_state_dict=initial_state_dict_post, mask=mask)
prune_results_cvx.to_csv(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_Results.csv".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

In [None]:
# Save model/mask after Lottery epochs
initial_state_dict_lotto = copy.deepcopy(model.state_dict())
torch.save(initial_state_dict_lotto,cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}.pth".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]))

with open(cvx_save_loc+"_ROUNDS{:}_EPOCHS{:}_MASK.pkl".format(P["cvx_prune_rounds"],P["cvx_prune_epochs"]),'wb') as f:
    pickle.dump(mask,f)

# Save Parameters 

In [None]:
param_save_loc = 'models/nn{:}_solver{:}_lr_PARAMS.json'.format(P['num_neurons'],P['cvx_solver'])
save_params(P,param_save_loc)

# Lottery Ticket 
Borrows from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/master/main.py

In [None]:
from importlib import reload
import helperfunctions
reload(helperfunctions)
from helperfunctions import *

# TODO: 
 * Better plotting