In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
import matplotlib.pyplot as plt
import numpy as np
import time
import pickle
import copy
torch.manual_seed(0)
np.random.seed(42)

In [2]:
def normalize_data(x, y, dataset):
    pairs = list(zip(x, y))
    np.random.shuffle(pairs)
    x_shuffled = tuple([x[0] for x in pairs])
    y_shuffled = tuple([x[1] for x in pairs])
    
    if dataset == 'MNIST':
        x_shuffled = np.array(x_shuffled).reshape(-1, 784)
    elif dataset == 'CIFAR10':
        x_shuffled = np.array(x_shuffled).reshape(-1, 3, 32, 32)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    mean_x = np.mean(x_shuffled, axis=0).reshape(1, -1)
    std_x = np.std(x_shuffled)

    x_shuffled = (x_shuffled - mean_x) / std_x
    return x_shuffled, y_shuffled

def reshaper(torch_dataset, dataset):
    xs = []
    ys = []
    for x, y in torch_dataset:
        if dataset == 'MNIST':
            xs.append(np.array(x, dtype=np.float32).reshape(1, -1))
        else:
            xs.append(np.array(x, dtype=np.float32).reshape(3, 32, 32))
        try:
            ys.append(y.item())
        except:
            ys.append(y)
    return xs, ys

def get_normalized_data(dataset):
    if dataset == 'MNIST':
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
    elif dataset == 'CIFAR10':
        trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    
    train_x, train_y = reshaper(trainset, dataset)
    train_x_normalized, train_y_normalized = normalize_data(train_x, train_y, dataset)
    test_x, test_y = reshaper(testset, dataset)
    test_x_normalized, test_y_normalized = normalize_data(test_x, test_y, dataset)
    
    return train_x_normalized, train_y_normalized, test_x_normalized, test_y_normalized

# MNIST example
DATASET = 'MNIST'
train_x, train_y, test_x, test_y = get_normalized_data(DATASET)
with open('./to_reproduce/shuffled_idx.pickle', 'rb') as f:
    shuffled_idxs = pickle.load(f)
N_SAMPLES = len(train_y)

In [3]:
def get_batch_size(D_0=False, L_k=False, epsilon=False, convex=False, fast=False, alpha=False, bs=False):
    # batch size in different cases
    if fast and alpha:
        batch_size = min(max(int(D_0*alpha / (epsilon)), 1), N_SAMPLES)
    elif bs:
        batch_size = bs
    else:
        if convex:
            batch_size = min(max(int(D_0 / (L_k * epsilon)), 1), N_SAMPLES)
        else:
            batch_size = min(max(int(8 * D_0 / (epsilon**2)), 1), N_SAMPLES)
    return batch_size

def get_batch(idx, train_x, train_y):
    batch_x = []
    batch_y = []
    for i in idx:
        if DATASET == 'MNIST':
            batch_x.append(train_x[i, :])
        elif DATASET == 'CIFAR10':
            batch_x.append(train_x[i, :, :, :].reshape(1, 3, 32, 32))
        batch_y.append(train_y[i])
    return torch.FloatTensor(np.vstack(tuple(batch_x))), torch.LongTensor(batch_y)

def evaluate_function(algo, x, y, crit, reduction='mean'):
    crit = crit(reduction=reduction)
    with torch.no_grad():
        out = algo(torch.FloatTensor(x))
        if reduction != 'none':
            loss = crit(out, torch.LongTensor(y))
        else:
            return crit(out, torch.LongTensor(y))
    return loss.item()

In [4]:
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(28*28, 10, bias=True)
    
    def forward(self, x):
        x = self.linear(x)
        return x

# Algorithm with first choice of L_k and r_k after

In [6]:
n_epoch = 1000
N_SAMPLES = len(train_y)
net = LogisticRegression()
N_PARAMS = sum([x.numel() for x in net.parameters()])
net.load_state_dict(torch.load('./to_reproduce/lr_starting_points/lr_starting_point_1'))
convex = True # True is for Alg. 2 in paper, False is for Alg. 5 in paper
criterion = nn.CrossEntropyLoss
criterion_ = nn.CrossEntropyLoss(reduction='mean')
print(evaluate_function(net, train_x, train_y, criterion))
D_0, epsilon, L_k, min_L_k, tau = (0.001, 0.0001, 10000., 1., 0.000001)

iter_losses = []
iter_train_accs = []
iter_test_accs = []
iter_times = []
iter_batch_size = []
iter_steps = []

epoch_losses = []
epoch_train_accs = []
epoch_test_accs = []
epoch_times = []
epoch_iterations = []

epoch_done = False
for epoch in range(n_epoch):
    start_idx = 0
    while True:
        i_k = 0
        #print(L_k)
        while True:
            new_L_k = (2**(i_k - 1))*L_k
            if new_L_k < min_L_k:
                L_k_1 = min_L_k
            else:
                L_k_1 = new_L_k

            
            bs_to_check = get_batch_size(D_0=D_0, L_k=L_k_1, epsilon=epsilon, convex=convex)

            end_idx = start_idx + bs_to_check
            if start_idx == 0 and bs_to_check >= N_SAMPLES:
                epoch_done = True
                x_to_check = torch.FloatTensor(train_x)
                y_to_check = torch.LongTensor(train_y)
            elif start_idx != 0 and end_idx > N_SAMPLES:
                epoch_done = True
                idx_1 = shuffled_idxs[start_idx: N_SAMPLES]
                idx_2 = shuffled_idxs[: bs_to_check - (N_SAMPLES - start_idx)]
                idx_to_check = idx_1 + idx_2
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
            elif end_idx <= N_SAMPLES:
                idx_to_check = shuffled_idxs[start_idx: end_idx]
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                
            
            loss_before = evaluate_function(net, x_to_check, y_to_check, criterion)
            with torch.no_grad():
                old_params = list(net.parameters())
                
                choosen_param = np.random.choice(range(len(old_params)))
                #print(choosen_param)
                choosen_coordinate = np.random.choice(range(old_params[choosen_param].numel()))
                #print(choosen_coordinate)

                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    old_p = p.data
                    flatten = p.data.view(-1)
                    flatten[choosen_coordinate] += tau
                    p.data = flatten.view(p.size())
                    loss_for_grad_estim = evaluate_function(net, x_to_check, y_to_check, criterion)
                    #print(loss_before, loss_for_grad_estim)
                    p.data = old_p
                    
                    new_flatten = p.data.view(-1)
                    coordinate_grad = (1 / tau) * (loss_for_grad_estim - loss_before)
                    new_flatten[choosen_coordinate] -= (1/(2*L_k_1))*coordinate_grad
                    p.data = new_flatten.view(p.size())
            loss_after = evaluate_function(net, x_to_check, y_to_check, criterion)
            
            if loss_after <= loss_before - (1/(4*L_k_1))*coordinate_grad**2 + epsilon/2:
                #print(loss_after, loss_before)
                #print()
                #print(evaluate_function(net, train_x, train_y, criterion), 1/(2*L_k), bs_to_check)
                start_idx = end_idx
                L_k = L_k_1
                if epoch_done:
                    with torch.no_grad():
                        out_train = net(torch.from_numpy(train_x))
                        out_test = net(torch.from_numpy(test_x))
                        iter_loss = criterion_(out_train, torch.from_numpy(np.array(train_y))).item()

                        pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                        ground_test = np.array(test_y)
                        l = float(len(ground_test))
                        iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l

                break
            else:
                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    p.data = old_p
                i_k += 1  
        if epoch_done:
            epoch_losses.append(iter_loss)
            epoch_test_accs.append(iter_test_acc)
            print('epoch {} done'.format(epoch))
            print(iter_loss, bs_to_check, 1/(2*L_k))
            epoch_done = False
            break

2.50454044342041
epoch 0 done
1.57681405544281 10 0.5
epoch 1 done
1.1900826692581177 10 0.5
epoch 2 done
1.0925655364990234 10 0.5
epoch 3 done
1.047715663909912 10 0.5
epoch 4 done
1.0616923570632935 10 0.5
epoch 5 done
0.9943814873695374 10 0.5
epoch 6 done
0.9368994832038879 10 0.5
epoch 7 done
0.9057785272598267 10 0.5
epoch 8 done
0.8992491364479065 10 0.5
epoch 9 done
0.8852349519729614 10 0.5
epoch 10 done
0.8666549324989319 10 0.5
epoch 11 done
0.8907153606414795 10 0.5
epoch 12 done
0.8911431431770325 5 0.25
epoch 13 done
0.8765305876731873 10 0.5
epoch 14 done
0.8518108129501343 10 0.5
epoch 15 done
0.8617379069328308 10 0.5
epoch 16 done
0.8875938653945923 10 0.5
epoch 17 done
0.9044677019119263 10 0.5
epoch 18 done
0.9305834174156189 10 0.5
epoch 19 done
0.9438886046409607 10 0.5
epoch 20 done
0.9495401978492737 10 0.5
epoch 21 done
0.9689264893531799 10 0.5
epoch 22 done
0.9618505835533142 10 0.5
epoch 23 done
0.9983253479003906 10 0.5
epoch 24 done
0.9981207847595215 10 

epoch 204 done
1.4873372316360474 10 0.5
epoch 205 done
1.5027165412902832 10 0.5
epoch 206 done
1.4770244359970093 10 0.5
epoch 207 done
1.471001386642456 10 0.5
epoch 208 done
1.480560064315796 10 0.5
epoch 209 done
1.4741239547729492 10 0.5
epoch 210 done
1.4893995523452759 10 0.5
epoch 211 done
1.4685766696929932 10 0.5
epoch 212 done
1.459059238433838 10 0.5
epoch 213 done
1.468521237373352 10 0.5
epoch 214 done
1.459505319595337 10 0.5
epoch 215 done
1.4624929428100586 10 0.5
epoch 216 done
1.4498425722122192 10 0.5
epoch 217 done
1.4474676847457886 10 0.5
epoch 218 done
1.4548194408416748 10 0.5
epoch 219 done
1.4648936986923218 10 0.5
epoch 220 done
1.4772231578826904 10 0.5
epoch 221 done
1.4794502258300781 10 0.5
epoch 222 done
1.4705603122711182 10 0.5
epoch 223 done
1.4636856317520142 10 0.5
epoch 224 done
1.4425731897354126 10 0.5


KeyboardInterrupt: 

In [62]:
L_k

2.0

In [67]:
bs_to_check

50

# Algorithm with first choice of r_k and L_k after

In [31]:
n_epoch = 1000
N_SAMPLES = len(train_y)
net = LogisticRegression()
N_PARAMS = sum([x.numel() for x in net.parameters()])
net.load_state_dict(torch.load('./to_reproduce/lr_starting_points/lr_starting_point_1'))
convex = True # True is for Alg. 2 in paper, False is for Alg. 5 in paper
criterion = nn.CrossEntropyLoss
criterion_ = nn.CrossEntropyLoss(reduction='mean')
print(evaluate_function(net, train_x, train_y, criterion))
D_0, epsilon, r_k, min_r_k, tau = (0.001, 0.0001, 100., 2., 0.000001)

iter_losses = []
iter_train_accs = []
iter_test_accs = []
iter_times = []
iter_batch_size = []
iter_steps = []

epoch_losses = []
epoch_train_accs = []
epoch_test_accs = []
epoch_times = []
epoch_iterations = []

epoch_done = False
for epoch in range(n_epoch):
    start_idx = 0
    while True:
        i_k = 0
        while True:
            
            new_r_k = (2**(i_k - 1))*r_k
            if new_r_k < min_r_k:
                r_k_1 = min_r_k
            else:
                r_k_1 = new_r_k

            
            bs_to_check = int(r_k_1)

            end_idx = start_idx + bs_to_check
            if start_idx == 0 and bs_to_check >= N_SAMPLES:
                epoch_done = True
                x_to_check = torch.FloatTensor(train_x)
                y_to_check = torch.LongTensor(train_y)
            elif start_idx != 0 and end_idx > N_SAMPLES:
                epoch_done = True
                idx_1 = shuffled_idxs[start_idx: N_SAMPLES]
                idx_2 = shuffled_idxs[: bs_to_check - (N_SAMPLES - start_idx)]
                idx_to_check = idx_1 + idx_2
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
            elif end_idx <= N_SAMPLES:
                idx_to_check = shuffled_idxs[start_idx: end_idx]
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                
            
            loss_before = evaluate_function(net, x_to_check, y_to_check, criterion, reduction='none')
            
            with torch.no_grad():
                old_params = list(net.parameters())
                
                choosen_param = np.random.choice(range(len(old_params)))
                #print(choosen_param)
                choosen_coordinate = np.random.choice(range(old_params[choosen_param].numel()))
                #print(choosen_coordinate)

                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    old_p = p.data
                    flatten = p.data.view(-1)
                    flatten[choosen_coordinate] += tau
                    p.data = flatten.view(p.size())
                    loss_for_grad_estim = evaluate_function(net, x_to_check, y_to_check, criterion, reduction='none')
                    
                    p.data = old_p
                    
                    new_flatten = p.data.view(-1)
                    variance_estim = (1 / tau) * (loss_for_grad_estim - loss_before)
                    D_k_1 = (1/(r_k_1 - 1))*torch.sum((variance_estim - torch.mean(variance_estim))**2).item()
                    if D_k_1 < D_0:
                        D_k_1 = D_0
                    L_k_1 = D_k_1/(r_k_1*epsilon)
                    coordinate_grad = torch.mean(variance_estim)
                    new_flatten[choosen_coordinate] -= (1/(2*L_k_1))*coordinate_grad
                    p.data = new_flatten.view(p.size())
            loss_after = evaluate_function(net, x_to_check, y_to_check, criterion)
            
            if loss_after <= torch.mean(loss_before) - (1/(4*L_k_1))*coordinate_grad**2 + epsilon/2:
                #print(loss_after, loss_before)
                #print()
                #print(evaluate_function(net, train_x, train_y, criterion), 1/(2*L_k), bs_to_check)
                start_idx = end_idx
                r_k = r_k_1
                if epoch_done:
                    with torch.no_grad():
                        out_train = net(torch.from_numpy(train_x))
                        out_test = net(torch.from_numpy(test_x))
                        iter_loss = criterion_(out_train, torch.from_numpy(np.array(train_y))).item()

                        pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                        ground_test = np.array(test_y)
                        l = float(len(ground_test))
                        iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l

                break
            else:
                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    p.data = old_p
                i_k += 1  
        if epoch_done:
            epoch_losses.append(iter_loss)
            epoch_test_accs.append(iter_test_acc)
            print('epoch {} done'.format(epoch))
            print(iter_loss, bs_to_check, 1/(2*L_k_1))
            epoch_done = False
            break

2.5045392513275146
epoch 0 done
2.380889415740967 2 0.0035184372232947188
epoch 1 done
2.316734790802002 2 0.0035184372232947188
epoch 2 done
2.2482550144195557 2 0.1
epoch 3 done
2.1223816871643066 2 9.773436517870239e-05
epoch 4 done
2.0320193767547607 2 0.1
epoch 5 done
1.9769314527511597 2 0.1
epoch 6 done
1.8672014474868774 2 0.00021990232645591992
epoch 7 done
1.8124260902404785 2 0.1
epoch 8 done
1.7780168056488037 2 0.1
epoch 9 done
1.7117562294006348 2 0.1
epoch 10 done
1.676803469657898 2 0.1
epoch 11 done
1.7015706300735474 2 0.00021990232645591992
epoch 12 done
1.6983686685562134 2 0.1
epoch 13 done
1.6741758584976196 2 0.1
epoch 14 done
1.6275966167449951 2 0.1
epoch 15 done
1.5963523387908936 2 0.1
epoch 16 done
1.5417555570602417 2 0.1
epoch 17 done
1.5662568807601929 2 0.1
epoch 18 done
1.52008056640625 2 0.1
epoch 19 done
1.5124461650848389 2 0.1
epoch 20 done
1.487147569656372 2 0.1
epoch 21 done
1.4297336339950562 2 0.1
epoch 22 done
1.439789056777954 2 0.00021990232

KeyboardInterrupt: 