In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
import matplotlib.pyplot as plt
import numpy as np
import time
import pickle
import copy
torch.manual_seed(0)
np.random.seed(42)

In [5]:
def normalize_data(x, y, dataset):
    pairs = list(zip(x, y))
    np.random.shuffle(pairs)
    x_shuffled = tuple([x[0] for x in pairs])
    y_shuffled = tuple([x[1] for x in pairs])
    
    if dataset == 'MNIST':
        x_shuffled = np.array(x_shuffled).reshape(-1, 784)
    elif dataset == 'CIFAR10':
        x_shuffled = np.array(x_shuffled).reshape(-1, 3, 32, 32)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    mean_x = np.mean(x_shuffled, axis=0).reshape(1, -1)
    std_x = np.std(x_shuffled)

    x_shuffled = (x_shuffled - mean_x) / std_x
    return x_shuffled, y_shuffled

def reshaper(torch_dataset, dataset):
    xs = []
    ys = []
    for x, y in torch_dataset:
        if dataset == 'MNIST':
            xs.append(np.array(x, dtype=np.float32).reshape(1, -1))
        else:
            xs.append(np.array(x, dtype=np.float32).reshape(3, 32, 32))
        try:
            ys.append(y.item())
        except:
            ys.append(y)
    return xs, ys

def get_normalized_data(dataset):
    if dataset == 'MNIST':
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
    elif dataset == 'CIFAR10':
        trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    else:
        print('dataset parameter should be MNIST or CIFAR10')
        return
    
    train_x, train_y = reshaper(trainset, dataset)
    train_x_normalized, train_y_normalized = normalize_data(train_x, train_y, dataset)
    test_x, test_y = reshaper(testset, dataset)
    test_x_normalized, test_y_normalized = normalize_data(test_x, test_y, dataset)
    
    return train_x_normalized, train_y_normalized, test_x_normalized, test_y_normalized

# MNIST example
DATASET = 'MNIST'
train_x, train_y, test_x, test_y = get_normalized_data(DATASET)
with open('./to_reproduce/shuffled_idx.pickle', 'rb') as f:
    shuffled_idxs = pickle.load(f)
N_SAMPLES = len(train_y)

In [6]:
def get_batch_size(D_0=False, L_k=False, epsilon=False, convex=False, fast=False, alpha=False, bs=False):
    # batch size in different cases
    if fast and alpha:
        batch_size = min(max(int(D_0*alpha / (epsilon)), 1), N_SAMPLES)
    elif bs:
        batch_size = bs
    else:
        if convex:
            batch_size = min(max(int(D_0 / (L_k * epsilon)), 1), N_SAMPLES)
        else:
            batch_size = min(max(int(8 * D_0 / (epsilon**2)), 1), N_SAMPLES)
    return batch_size

def get_batch(idx, train_x, train_y):
    batch_x = []
    batch_y = []
    for i in idx:
        if DATASET == 'MNIST':
            batch_x.append(train_x[i, :])
        elif DATASET == 'CIFAR10':
            batch_x.append(train_x[i, :, :, :].reshape(1, 3, 32, 32))
        batch_y.append(train_y[i])
    return torch.FloatTensor(np.vstack(tuple(batch_x))), torch.LongTensor(batch_y)

def evaluate_function(algo, x, y, crit):
    with torch.no_grad():
        out = algo(torch.FloatTensor(x))
        loss = crit(out, torch.LongTensor(y))
    return loss.item()

In [7]:
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(28*28, 10, bias=True)
    
    def forward(self, x):
        x = self.linear(x)
        return x

In [72]:
n_epoch = 1000
N_SAMPLES = len(train_y)
net = LogisticRegression()
N_PARAMS = sum([x.numel() for x in net.parameters()])
net.load_state_dict(torch.load('./to_reproduce/lr_starting_points/lr_starting_point_1'))
convex = True # True is for Alg. 2 in paper, False is for Alg. 5 in paper
criterion = nn.CrossEntropyLoss(reduction='mean')
print(evaluate_function(net, train_x, train_y, criterion))
D_0, epsilon, L_k, min_L_k, tau = (0.001, 0.0001, 10000., 1., 0.000001)

iter_losses = []
iter_train_accs = []
iter_test_accs = []
iter_times = []
iter_batch_size = []
iter_steps = []

epoch_losses = []
epoch_train_accs = []
epoch_test_accs = []
epoch_times = []
epoch_iterations = []

epoch_done = False
for epoch in range(n_epoch):
    start_idx = 0
    while True:
        i_k = 0
        #print(L_k)
        while True:
            new_L_k = (2**(i_k - 1))*L_k
            if new_L_k < min_L_k:
                L_k_1 = min_L_k
            else:
                L_k_1 = new_L_k

            
            bs_to_check = get_batch_size(D_0=D_0, L_k=L_k_1, epsilon=epsilon, convex=convex)

            end_idx = start_idx + bs_to_check
            if start_idx == 0 and bs_to_check >= N_SAMPLES:
                epoch_done = True
                x_to_check = torch.FloatTensor(train_x)
                y_to_check = torch.LongTensor(train_y)
            elif start_idx != 0 and end_idx > N_SAMPLES:
                epoch_done = True
                idx_1 = shuffled_idxs[start_idx: N_SAMPLES]
                idx_2 = shuffled_idxs[: bs_to_check - (N_SAMPLES - start_idx)]
                idx_to_check = idx_1 + idx_2
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
            elif end_idx <= N_SAMPLES:
                idx_to_check = shuffled_idxs[start_idx: end_idx]
                x_to_check, y_to_check = get_batch(idx_to_check, train_x, train_y)
                
            
            loss_before = evaluate_function(net, x_to_check, y_to_check, criterion)
            with torch.no_grad():
                old_params = list(net.parameters())
                
                choosen_param = np.random.choice(range(len(old_params)))
                #print(choosen_param)
                choosen_coordinate = np.random.choice(range(old_params[choosen_param].numel()))
                #print(choosen_coordinate)

                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    old_p = p.data
                    flatten = p.data.view(-1)
                    flatten[choosen_coordinate] += tau
                    p.data = flatten.view(p.size())
                    loss_for_grad_estim = evaluate_function(net, x_to_check, y_to_check, criterion)
                    #print(loss_before, loss_for_grad_estim)
                    p.data = old_p
                    
                    new_flatten = p.data.view(-1)
                    coordinate_grad = (1 / tau) * (loss_for_grad_estim - loss_before)
                    new_flatten[choosen_coordinate] -= (1/(2*L_k_1))*coordinate_grad
                    p.data = new_flatten.view(p.size())
            loss_after = evaluate_function(net, x_to_check, y_to_check, criterion)
            
            if loss_after <= loss_before - (1/(4*L_k_1))*coordinate_grad**2 + epsilon/2:
                #print(loss_after, loss_before)
                #print()
                #print(evaluate_function(net, train_x, train_y, criterion), 1/(2*L_k), bs_to_check)
                start_idx = end_idx
                L_k = L_k_1
                if epoch_done:
                    with torch.no_grad():
                        out_train = net(torch.from_numpy(train_x))
                        out_test = net(torch.from_numpy(test_x))
                        iter_loss = criterion(out_train, torch.from_numpy(np.array(train_y))).item()

                        pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                        ground_test = np.array(test_y)
                        l = float(len(ground_test))
                        iter_test_acc = len(np.where(pred_test == ground_test)[0]) / l

                break
            else:
                for i, p in enumerate(list(net.parameters())):
                    if i != choosen_param:
                        continue
                    p.data = old_p
                i_k += 1  
        if epoch_done:
            epoch_losses.append(iter_loss)
            epoch_test_accs.append(iter_test_acc)
            print('epoch {} done'.format(epoch))
            print(iter_loss, bs_to_check, 1/(2*L_k))
            epoch_done = False
            break

2.50454044342041
epoch 0 done
1.5877803564071655 10 0.5
epoch 1 done
1.4368382692337036 5 0.25
epoch 2 done
1.2743124961853027 10 0.5
epoch 3 done
1.1793426275253296 10 0.5
epoch 4 done
1.1036754846572876 10 0.5
epoch 5 done
1.0946258306503296 10 0.5
epoch 6 done
1.0790493488311768 10 0.5
epoch 7 done
1.0499826669692993 10 0.5
epoch 8 done
1.0385204553604126 10 0.5
epoch 9 done
0.9992709159851074 10 0.5
epoch 10 done
0.9850336313247681 10 0.5
epoch 11 done
0.9884915947914124 10 0.5
epoch 12 done
0.9892107248306274 10 0.5
epoch 13 done
0.9364581108093262 10 0.5
epoch 14 done
0.9255140423774719 10 0.5
epoch 15 done
0.9423025846481323 10 0.5
epoch 16 done
0.9702706933021545 10 0.5
epoch 17 done
0.9591780304908752 10 0.5
epoch 18 done
0.9329747557640076 10 0.5
epoch 19 done
0.9309083223342896 10 0.5
epoch 20 done
0.9382079243659973 10 0.5
epoch 21 done
0.9188967943191528 10 0.5
epoch 22 done
0.9387637376785278 10 0.5
epoch 23 done
0.9386370182037354 10 0.5
epoch 24 done
0.9217556118965149 

KeyboardInterrupt: 

In [62]:
L_k

2.0

In [67]:
bs_to_check

50