In [1]:
import os
import pickle
import itertools
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(0)
np.random.seed(42)

In [2]:
def normalize_data(x, y, dataset):
    pairs = list(zip(x, y))
    np.random.shuffle(pairs)
    x_shuffled = tuple([x[0] for x in pairs])
    y_shuffled = tuple([x[1] for x in pairs])
    
    if dataset == 'MNIST':
        x_shuffled = np.array(x_shuffled).reshape(-1, 784)
    elif dataset == 'CIFAR10':
        x_shuffled = np.array(x_shuffled).reshape(-1, 3, 32, 32)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    mean_x = np.mean(x_shuffled, axis=0).reshape(1, -1)
    std_x = np.std(x_shuffled)

    x_shuffled = (x_shuffled - mean_x) / std_x
    return x_shuffled, y_shuffled

def reshaper(torch_dataset, dataset):
    xs = []
    ys = []
    for x, y in torch_dataset:
        if dataset == 'MNIST':
            xs.append(np.array(x, dtype=np.float32).reshape(1, -1))
        else:
            xs.append(np.array(x, dtype=np.float32).reshape(3, 32, 32))
        try:
            ys.append(y.item())
        except:
            ys.append(y)
    return xs, ys

def get_normalized_data(dataset):
    if dataset == 'MNIST':
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
    elif dataset == 'CIFAR10':
        trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    
    train_x, train_y = reshaper(trainset, dataset)
    train_x_normalized, train_y_normalized = normalize_data(train_x, train_y, dataset)
    test_x, test_y = reshaper(testset, dataset)
    test_x_normalized, test_y_normalized = normalize_data(test_x, test_y, dataset)
    
    return train_x_normalized, train_y_normalized, test_x_normalized, test_y_normalized

# MNIST example
DATASET = 'MNIST'
train_x, train_y, test_x, test_y = get_normalized_data(DATASET)
with open('./to_reproduce/shuffled_idx.pickle', 'rb') as f:
    shuffled_idxs = pickle.load(f)
N_SAMPLES = len(train_y)

In [3]:
# all functions to optimize

class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(28*28, 10, bias=True)
    
    def forward(self, x):
        x = self.linear(x)
        return x
    
class FullyConnected(nn.Module):
    def __init__(self):
        super(FullyConnected, self).__init__()
        self.linear1 = nn.Linear(28*28, 1000, bias=True)
        self.linear2 = nn.Linear(1000, 10, bias=True)
        
    def forward(self, x):
        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        x = torch.sigmoid(x)
        return x
    
class FullyConnectedRELU(nn.Module):
    def __init__(self):
        super(FullyConnectedRELU, self).__init__()
        self.linear1 = nn.Linear(28*28, 1000, bias=True)
        self.linear2 = nn.Linear(1000, 10, bias=True)
        
    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        return x

In [4]:
# batch-related methods

def get_batch_size(D_0=False, L_k=False, epsilon=False, convex=False, fast=False, alpha=False, bs=False):
    # batch size in different cases
    if fast and alpha:
        batch_size = min(max(int(D_0*alpha / (epsilon)), 1), N_SAMPLES)
    elif bs:
        batch_size = bs
    else:
        if convex:
            batch_size = min(max(int(D_0 / (L_k * epsilon)), 1), N_SAMPLES)
        else:
            batch_size = min(max(int(8 * D_0 / (epsilon**2)), 1), N_SAMPLES)
    return batch_size

def get_batch(idx, train_x, train_y):
    batch_x = []
    batch_y = []
    for i in idx:
        if DATASET == 'MNIST':
            batch_x.append(train_x[i, :])
        elif DATASET == 'CIFAR10':
            batch_x.append(train_x[i, :, :, :].reshape(1, 3, 32, 32))
        batch_y.append(train_y[i])
    return torch.FloatTensor(np.vstack(tuple(batch_x))), torch.LongTensor(batch_y)

def evaluate_function(algo, x, y, crit):
    with torch.no_grad():
        out = algo(torch.FloatTensor(x))
        loss = crit(out, torch.LongTensor(y))
    return loss.item()

# Evaluate average performance of the algorithms

In [5]:
# Investigate search in logarithmic parameters space (i. e. params are vary in 2-10 times)
# Our methods grid (size of meaningful combinations min_L_k*epsilon <= D_0 is 38)
D_0s = [0.1, 0.01, 0.001, 0.0001]
epsilons = [0.01, 0.001, 0.0001, 0.00001]
L_ks = [1000, 10000]
min_L_ks = [101, 11, 2]
# Adam and Adagrad grid (size = 30)
learning_rates = [0.00001, 0.0001, 0.001, 0.01, 0.1]
batch_sizes = [32, 64, 128, 256, 512, 1024]
functions = [LogisticRegression, FullyConnected]

#to investigate iteration performance simply evaluate iteration loss not after epoch, but after each iteration (delete if epoch_done condition)

In [None]:
# we evaluate each algorithm starting from 5 different starting points during 10 epochs
n_epochs = 10
opt_methods = ['accelerated_adaptive', 'adaptive_convex', 'adaptive_nonconvex']
criterion = nn.CrossEntropyLoss(reduction='mean')
params_grid = list(itertools.product(D_0s, epsilons, L_ks, min_L_ks))

for opt_method in opt_methods:
    print(opt_method)
    for func in functions:
        print(func.__name__)
        for starting_point_num in range(1, 6):
            print(starting_point_num)
            
            if opt_method == 'adaptive_convex':
                convex = True
            else:
                convex = False
                
            for params in params_grid:
                
                D_0, epsilon, L_k, min_L_k = params
                if min_L_k*epsilon > D_0:
                    continue
                print(params)
                
                if func == LogisticRegression:
                    net = LogisticRegression()
                    net.load_state_dict(torch.load('./to_reproduce/lr_starting_points/lr_starting_point_{}'.format(starting_point_num)))
                else:
                    net = func()
                    net.load_state_dict(torch.load('./to_reproduce/fc_starting_points/fc_starting_point_{}'.format(starting_point_num)))
                
                epoch_losses = []    
                epoch_test_accs = []
                
                if opt_method == 'accelerated_adaptive':
                    x_k = list(net.parameters())
                    y_k = list(net.parameters())
                    u_k = list(net.parameters())
                    alpha_k = 0.
                    A_k = 0.
                
                epoch_done = False
                for epoch in range(n_epochs):
                    start_idx = 0
                    while True:
                        i_k = 0
                        while True:
                            new_L_k = (2**(i_k - 1))*L_k
                            if new_L_k < min_L_k:
                                L_k_1 = min_L_k
                            else:
                                L_k_1 = new_L_k
                                
                            if opt_method == 'accelerated_adaptive':
                                alpha_k_1 = (1 + np.sqrt(1 + 4*A_k*L_k_1)) / (2*L_k_1)
                                A_k_1 = A_k + alpha_k_1
                                bs_to_check = get_batch_size(D_0=D_0, L_k=L_k_1, epsilon=epsilon, fast=True, alpha=alpha_k_1)
                            else:
                                bs_to_check = get_batch_size(D_0=D_0, L_k=L_k_1, epsilon=epsilon, convex=convex)

                            end_idx = start_idx + bs_to_check
                            if start_idx == 0 and bs_to_check >= N_SAMPLES:
                                epoch_done = True
                                x_to_check = torch.FloatTensor(train_x)
                                y_to_check = torch.LongTensor(train_y)
                            elif start_idx != 0 and end_idx > N_SAMPLES:
                                epoch_done = True
                                idx_1 = shuffled_idxs[start_idx: N_SAMPLES]
                                idx_2 = shuffled_idxs[: bs_to_check - (N_SAMPLES - start_idx)]
                                idx_to_check = idx_1 + idx_2
                                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                            elif end_idx <= N_SAMPLES:
                                idx_to_check = shuffled_idxs[start_idx: end_idx]
                                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                            
                            if opt_method == 'accelerated_adaptive':
                                #with torch.no_grad():
                                y_k_1 = []
                                for u, x, param in zip(u_k, x_k, list(net.parameters())):
                                    new_value = (alpha_k_1*u + A_k*x) / A_k_1
                                    y_k_1.append(new_value)
                                    param.data = new_value
                                
                                out = net(x_to_check)
                                loss_y_k_1 = criterion(out, y_to_check)
                                net.zero_grad()
                                loss_y_k_1.backward()
                                
                                dot_prods = []
                                norms = []
                                u_k_1 = []
                                x_k_1 = []
                                for x, u, y_k_1_value, param in zip(x_k, u_k, y_k_1, list(net.parameters())):
                                    u_k_1_value = u - alpha_k_1*param.grad.data
                                    u_k_1.append(u_k_1_value)
                                    x_k_1_value = (alpha_k_1*u_k_1_value + A_k*x) / A_k_1
                                    x_k_1.append(x_k_1_value)

                                    dot_prods.append(torch.sum((param.grad.data*(x_k_1_value-y_k_1_value)).view(-1)))
                                    norms.append(torch.sum(((x_k_1_value-y_k_1_value)*(x_k_1_value-y_k_1_value)).view(-1)))

                                    param.data = x_k_1_value
                                
                                with torch.no_grad():
                                    out = net(x_to_check)
                                    loss_x_k_1 = criterion(out, y_to_check)
                                
                                if loss_x_k_1 < loss_y_k_1 + sum(dot_prods) + (L_k_1)*sum(norms) + (alpha_k_1*epsilon)/(2*A_k_1):
                                    alpha_k = alpha_k_1
                                    A_k = A_k_1
                                    u_k = u_k_1
                                    x_k = x_k_1
                                    L_k = L_k_1

                                    start_idx = end_idx
                                    if epoch_done:
                                        with torch.no_grad():
                                            out_train = net(torch.from_numpy(train_x))
                                            out_test = net(torch.from_numpy(test_x))
                                            iter_loss = criterion(out_train, torch.from_numpy(np.array(train_y))).item()

                                            pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                                            ground_test = np.array(test_y)
                                            l = float(len(ground_test))
                                            iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l

                                    break
                                else:
                                    for x, f in zip(x_k, list(net.parameters())):
                                        f.data = x    
                                    i_k += 1  
                                    
                            elif opt_method in ['adaptive_convex', 'adaptive_nonconvex']:
                                out = net(x_to_check)
                                loss = criterion(out, y_to_check)
                                loss_before = loss.item()
                                net.zero_grad()
                                loss.backward()
                                
                                right_hands = []
                                for f in net.parameters():
                                    right_hands.append((-1./(8.*L_k_1))*f.grad.data.norm(2)**2)
                                    f.data.sub_(f.grad.data / (2.*L_k_1))
                                right = sum(right_hands)
                                loss_after = evaluate_function(net, x_to_check, y_to_check, criterion)
                                if convex:
                                    condition = loss_before + right + epsilon/(2.)
                                else:
                                    condition = loss_before + right + epsilon**2/(32*L_k)
                                
                                if loss_after < condition:
                                    start_idx = end_idx
                                    L_k = L_k_1
                                    if epoch_done:
                                        with torch.no_grad():
                                            out_train = net(torch.from_numpy(train_x))
                                            out_test = net(torch.from_numpy(test_x))
                                            iter_loss = criterion(out_train, torch.from_numpy(np.array(train_y))).item()

                                            pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                                            ground_test = np.array(test_y)
                                            l = float(len(ground_test))
                                            iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l

                                    break
                                else:
                                    for f in net.parameters():
                                        f.data.add_(f.grad.data / (2.*L_k))   
                                    i_k += 1  
                        if epoch_done:
                            epoch_losses.append(iter_loss)
                            epoch_test_accs.append(iter_test_acc)
                            print('epoch {} done'.format(epoch))
                            print(iter_loss, bs_to_check, 1/(2*L_k))
                            epoch_done = False
                            break
                                    
                                    
                                    
                                    
                dir_name = './to_reproduce/results/' + '_'.join([opt_method, func.__name__, str(starting_point_num)] + [str(x) for x in params])
                os.system('mkdir ' + dir_name)
                with open(dir_name + '/epoch_losses.pickle', 'wb') as out:
                    pickle.dump(epoch_losses, out)
                with open(dir_name + '/epoch_test_accs.pickle', 'wb') as out:
                    pickle.dump(epoch_test_accs, out)

# Adam and Adagrad

In [6]:
# we evaluate each algorithm starting from 5 different starting points during 10 epochs
n_epochs = 10
opt_methods = ['adam', 'adagrad']
criterion = nn.CrossEntropyLoss(reduction='mean')
params_grid = list(itertools.product(learning_rates, batch_sizes))

for opt_method in opt_methods:
    print(opt_method)
    for func in functions:
        print(func.__name__)
        for starting_point_num in range(1, 6):
            print(starting_point_num)
                
            for params in params_grid:
                if func == LogisticRegression:
                    net = LogisticRegression()
                    net.load_state_dict(torch.load('./to_reproduce/lr_starting_points/lr_starting_point_{}'.format(starting_point_num)))
                else:
                    net = func()
                    net.load_state_dict(torch.load('./to_reproduce/fc_starting_points/fc_starting_point_{}'.format(starting_point_num)))
                
                lr, bs = params
                print(params)
                if opt_method == 'adam':
                    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
                elif opt_method == 'adagrad':
                    optimizer = torch.optim.Adagrad(net.parameters(), lr=lr)
                
                epoch_losses = []    
                epoch_test_accs = []
                
                epoch_done = False
                for epoch in range(n_epochs):
                    start_idx = 0
                    while True:
                        bs_to_check = get_batch_size(bs=bs)

                        end_idx = start_idx + bs_to_check
                        if start_idx == 0 and bs_to_check >= N_SAMPLES:
                            epoch_done = True
                            x_to_check = torch.FloatTensor(train_x)
                            y_to_check = torch.LongTensor(train_y)
                        elif start_idx != 0 and end_idx > N_SAMPLES:
                            epoch_done = True
                            idx_1 = shuffled_idxs[start_idx: N_SAMPLES]
                            idx_2 = shuffled_idxs[: bs_to_check - (N_SAMPLES - start_idx)]
                            idx_to_check = idx_1 + idx_2
                            x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                        elif end_idx <= N_SAMPLES:
                            idx_to_check = shuffled_idxs[start_idx: end_idx]
                            x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                            
                        optimizer.zero_grad()
                        out = net(x_to_check)
                        loss = criterion(out, y_to_check)
                        loss.backward()
                        optimizer.step()
                        
                        start_idx = end_idx
                        if epoch_done:
                            with torch.no_grad():
                                out_train = net(torch.from_numpy(train_x))
                                out_test = net(torch.from_numpy(test_x))
                                iter_loss = criterion(out_train, torch.from_numpy(np.array(train_y))).item()

                                pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                                ground_test = np.array(test_y)
                                l = float(len(ground_test))
                                iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l
                        if epoch_done:
                            epoch_losses.append(iter_loss)
                            epoch_test_accs.append(iter_test_acc)
                            print('epoch {} done'.format(epoch))
                            print(iter_loss)
                            epoch_done = False
                            break
                dir_name = './to_reproduce/results/' + '_'.join([opt_method, func.__name__, str(starting_point_num)] + [str(x) for x in params])
                os.system('mkdir ' + dir_name)
                with open(dir_name + '/epoch_losses.pickle', 'wb') as out:
                    pickle.dump(epoch_losses, out)
                with open(dir_name + '/epoch_test_accs.pickle', 'wb') as out:
                    pickle.dump(epoch_test_accs, out)

adam
FullyConnected
4
(1e-05, 32)
epoch 0 done
1.869897484779358
epoch 1 done
1.7382358312606812
epoch 2 done
1.6773236989974976
epoch 3 done
1.6422603130340576
epoch 4 done
1.6197983026504517
epoch 5 done
1.60427987575531
epoch 6 done
1.593002438545227
epoch 7 done
1.5844480991363525
epoch 8 done
1.5776914358139038
epoch 9 done
1.5722122192382812
(1e-05, 64)
epoch 0 done
1.982283592224121
epoch 1 done
1.832443118095398
epoch 2 done
1.7547630071640015
epoch 3 done
1.706875205039978
epoch 4 done
1.6741596460342407
epoch 5 done
1.6504762172698975
epoch 6 done
1.6326842308044434
epoch 7 done
1.6189221143722534
epoch 8 done
1.6080466508865356
epoch 9 done
1.5992746353149414
(1e-05, 128)
epoch 0 done
2.0941474437713623
epoch 1 done
1.9518237113952637
epoch 2 done
1.8614999055862427
epoch 3 done
1.8014895915985107
epoch 4 done
1.758839726448059
epoch 5 done
1.7267564535140991
epoch 6 done
1.7016111612319946
epoch 7 done
1.6813561916351318
epoch 8 done
1.6646794080734253
epoch 9 done
1.650759

epoch 8 done
2.4208102226257324
epoch 9 done
2.4208102226257324
(0.1, 64)
epoch 0 done
2.4256632328033447
epoch 1 done
2.4256632328033447
epoch 2 done
2.4256632328033447
epoch 3 done
2.4256632328033447
epoch 4 done
2.4256632328033447
epoch 5 done
2.4256632328033447
epoch 6 done
2.4256632328033447
epoch 7 done
2.4256632328033447
epoch 8 done
2.4256632328033447
epoch 9 done
2.4256632328033447
(0.1, 128)
epoch 0 done
2.4256632328033447
epoch 1 done
2.4256632328033447
epoch 2 done
2.4256632328033447
epoch 3 done
2.4256632328033447
epoch 4 done
2.4256632328033447
epoch 5 done
2.4256632328033447
epoch 6 done
2.4256632328033447
epoch 7 done
2.4256632328033447
epoch 8 done
2.4256632328033447
epoch 9 done
2.4256632328033447
(0.1, 256)
epoch 0 done
2.267894744873047
epoch 1 done
2.2646007537841797
epoch 2 done
2.2611262798309326
epoch 3 done
2.274233818054199
epoch 4 done
2.261445999145508
epoch 5 done
2.266448974609375
epoch 6 done
2.2629690170288086
epoch 7 done
2.266322374343872
epoch 8 done


epoch 7 done
2.0604143142700195
epoch 8 done
2.0604982376098633
epoch 9 done
2.0619447231292725
(0.01, 128)
epoch 0 done
1.8184527158737183
epoch 1 done
1.8100998401641846
epoch 2 done
1.807093858718872
epoch 3 done
1.803218126296997
epoch 4 done
1.8027589321136475
epoch 5 done
1.801184892654419
epoch 6 done
1.703366756439209
epoch 7 done
1.6962116956710815
epoch 8 done
1.6984140872955322
epoch 9 done
1.692416787147522
(0.01, 256)
epoch 0 done
1.5061640739440918
epoch 1 done
1.4931225776672363
epoch 2 done
1.4855453968048096
epoch 3 done
1.483593225479126
epoch 4 done
1.480364203453064
epoch 5 done
1.4779953956604004
epoch 6 done
1.4766696691513062
epoch 7 done
1.4761041402816772
epoch 8 done
1.4793243408203125
epoch 9 done
1.476371169090271
(0.01, 512)
epoch 0 done
1.5130863189697266
epoch 1 done
1.496537208557129
epoch 2 done
1.486517310142517
epoch 3 done
1.481703758239746
epoch 4 done
1.4786068201065063
epoch 5 done
1.4762316942214966
epoch 6 done
1.4740430116653442
epoch 7 done
1.

epoch 6 done
1.5795615911483765
epoch 7 done
1.57565176486969
epoch 8 done
1.5723921060562134
epoch 9 done
1.5695945024490356
(0.001, 256)
epoch 0 done
1.703014612197876
epoch 1 done
1.6530821323394775
epoch 2 done
1.6293087005615234
epoch 3 done
1.6146695613861084
epoch 4 done
1.6045317649841309
epoch 5 done
1.596990942955017
epoch 6 done
1.5910738706588745
epoch 7 done
1.586303949356079
epoch 8 done
1.5823321342468262
epoch 9 done
1.5789530277252197
(0.001, 512)
epoch 0 done
1.7476491928100586
epoch 1 done
1.6865534782409668
epoch 2 done
1.6570100784301758
epoch 3 done
1.6386377811431885
epoch 4 done
1.6258350610733032
epoch 5 done
1.6162381172180176
epoch 6 done
1.6087137460708618
epoch 7 done
1.6026312112808228
epoch 8 done
1.5975865125656128
epoch 9 done
1.5933040380477905
(0.001, 1024)
epoch 0 done
1.8088618516921997
epoch 1 done
1.7344714403152466
epoch 2 done
1.6976531744003296
epoch 3 done
1.6745086908340454
epoch 4 done
1.6581997871398926
epoch 5 done
1.6458946466445923
epoch

epoch 6 done
1.9636536836624146
epoch 7 done
1.9478932619094849
epoch 8 done
1.9339513778686523
epoch 9 done
1.9214669466018677
(0.0001, 512)
epoch 0 done
2.1912901401519775
epoch 1 done
2.1450278759002686
epoch 2 done
2.1113898754119873
epoch 3 done
2.0845203399658203
epoch 4 done
2.0620503425598145
epoch 5 done
2.042675018310547
epoch 6 done
2.0256831645965576
epoch 7 done
2.0105350017547607
epoch 8 done
1.9969053268432617
epoch 9 done
1.9845446348190308
(0.0001, 1024)
epoch 0 done
2.221885919570923
epoch 1 done
2.1863973140716553
epoch 2 done
2.159839630126953
epoch 3 done
2.1380462646484375
epoch 4 done
2.1193971633911133
epoch 5 done
2.1030025482177734
epoch 6 done
2.088341236114502
epoch 7 done
2.0750577449798584
epoch 8 done
2.0629079341888428
epoch 9 done
2.051699161529541
(0.001, 32)
epoch 0 done
1.6379541158676147
epoch 1 done
1.6071749925613403
epoch 2 done
1.5925859212875366
epoch 3 done
1.5836272239685059
epoch 4 done
1.5773502588272095
epoch 5 done
1.5726118087768555
epoc

KeyboardInterrupt: 