In [None]:
"""
Train neural networks on a synthetic classification dataset using convex optimization.
"""

import sys
sys.path.append("..")

In [None]:
# !cd .. && pip install -e .

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from convex_nn.private.utils.data import gen_classification_data
from convex_nn.optimize import optimize
from convex_nn.regularizers import L2

In [None]:
# Generate realizable synthetic classification problem (ie. Figure 1)
n_train = 10000
n_test = 10000
d = 25
hidden_units = 100
kappa = 1000  # condition number

# (X_train, y_train), (X_test, y_test) = gen_classification_data(123, n_train, n_test, d, hidden_units, kappa)

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import os

dataset = 'CIFAR10'
normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
save_path = os.path.abspath('')

train_dataset = datasets.CIFAR10(save_path, train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize]),
    target_transform=lambda x: float(x >= 5))                            

test_dataset = datasets.CIFAR10(save_path, train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor(), normalize,]),
    target_transform=lambda x: float(x >= 5))

In [None]:
# Extract the data via a dummy loader (dumps entire dataset at once)
dummy_loader= torch.utils.data.DataLoader(train_dataset, batch_size=50000, shuffle=True, pin_memory=True, sampler=None)
for X_train, y_train in dummy_loader:
    pass

X_train = X_train.reshape((X_train.shape[0], -1))
X_train = X_train[:n_train]
y_train = y_train[:n_train]
print(X_train.shape, y_train.shape)

In [None]:
# Extract the data via a dummy loader (dumps entire dataset at once)
dummy_loader= torch.utils.data.DataLoader(test_dataset, batch_size=10000, shuffle=False, pin_memory=True, sampler=None)
for X_test, y_test in dummy_loader:
    pass

X_test = X_test.reshape((X_test.shape[0], -1))
y_test = y_test[:n_test]
print(X_test.shape, y_test.shape)

In [None]:
def accuracy(logits, y):
    return np.sum((np.sign(logits) == y)) / len(y)

In [None]:
# # cast data
# tX_train, ty_train, tX_test, ty_test = [torch.tensor(z, dtype=torch.float) for z in [X_train, y_train, X_test, y_test]]

# loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(tX_train, ty_train), batch_size=32, shuffle=True)

In [None]:
max_epochs = 1000
tol = 1e-6    
lam = 0.001

## Non-Convex Model

In [None]:
# lr = 0.00001

# # create model
# nc_model = torch.nn.Sequential(
#     torch.nn.Linear(in_features=d, out_features=hidden_units, bias=False), 
#     torch.nn.ReLU(), 
#     torch.nn.Linear(in_features=hidden_units, out_features=1, bias=False))

# # Acc Before Training
# print("Test Accuracy:", accuracy(nc_model(tX_test).detach().numpy(), y_test))

# sgd = torch.optim.SGD(nc_model.parameters(), lr=lr)

# for i in range(max_epochs):
#     for X, y in loader:
#         nc_model.zero_grad()
#         l2_penalty = sum([torch.sum(param ** 2) for param in nc_model.parameters()])
#         obj = torch.sum((nc_model(X) - y) ** 2) / (2 * len(y)) + lam * l2_penalty
#         obj.backward()
        
#         sgd.step()

#     # check for convergence
    
#     nc_model.zero_grad()
#     l2_penalty = sum([torch.sum(param ** 2) for param in nc_model.parameters()])
#     obj = torch.sum((nc_model(tX_train) - ty_train) ** 2) / (2 * len(y_train)) + lam * l2_penalty
#     obj.backward()    
#     grad_norm = sum([torch.sum(param.grad ** 2) for param in nc_model.parameters()])

#     if grad_norm <= tol:
#         print(f"Converged at {i}/{max_epochs}")
#         break

#     if i % 25 == 0:
#         print(f"{i}/{max_epochs}: Obj - {obj}, Grad - {grad_norm}")

# # Acc After Training
# print("Test Accuracy:", accuracy(nc_model(tX_test).detach().numpy(), y_test))

# Convex Reformulation

In [None]:
# cvx_model, metrics = optimize("relu", 
#                               max_neurons,
#                               X_train=X_train[:10], 
#                               y_train=y_train[:10], 
#                               X_test=X_test.numpy(), 
#                               y_test=y_test.numpy(), 
#                               verbose=True,  
#                               device="cpu")

In [None]:
import tqdm

models = {}

In [None]:
num_layers = 3

# number of activation patterns to use.
max_neurons = 1000

for num_examples in [100, 500, 1000]:
    layers = []
    for index in tqdm.tqdm(range(num_layers)):
        if len(layers):
            print(current_X_train.shape, layers[-1].shape)
            current_X_train = np.maximum(current_X_train @ layers[-1].T, 0)
            current_X_test = np.maximum(current_X_test @ layers[-1].T, 0)
        else:
            current_X_train = X_train[:num_examples]
            current_X_test = X_test[:num_examples]

        # train model
        cvx_model, metrics = optimize("relu", 
                                      max_neurons,
                                      X_train=current_X_train[:num_examples], 
                                      y_train=y_train[:num_examples], 
                                      X_test=current_X_test.numpy(), 
                                      y_test=y_test.numpy(), 
                                      verbose=True,
                                      regularizer=L2(1e-3),
                                      device="cpu")
        layers.append(cvx_model.parameters[0])
    layers.append(cvx_model.parameters[-1])
    models[num_examples] = layers

print([x.shape for x in layers])

In [None]:
# Acc After Training
print("\n \n")
print("Test Accuracy:", accuracy(cvx_model(X_test.numpy()), y_test.numpy()))
print(f"Hidden Layer Size: {cvx_model.parameters[0].shape[0]}")

In [None]:
import json

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)
    
for k, v in models.items():
    with open(f"model_{k}.json", "w") as fp:
        json.dump(v, fp, cls=NumpyEncoder)