In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
import matplotlib.pyplot as plt
import numpy as np
import time
import copy
torch.manual_seed(0)

<torch._C.Generator at 0x114267b70>

In [None]:
# download datasets and normalize
dataset = 'MNIST'

if dataset == 'MNIST':
    mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True)
    print(len(mnist_trainset))
    print(len(mnist_testset))

    train_x = []
    train_y = []
    for x, y in mnist_trainset:
        train_x.append(np.array(x, dtype=np.float32).reshape(1, -1))
        train_y.append(y)

    train_x = np.array(train_x).reshape(-1, 784)
    mean_x = np.mean(train_x, axis=0).reshape(1, -1)
    std_x = np.std(train_x)

    train_x = (train_x - mean_x) / std_x


    test_x = []
    test_y = []
    for x, y in mnist_testset:
        test_x.append(np.array(x, dtype=np.float32).reshape(1, -1))
        test_y.append(y)

    test_x = np.array(test_x).reshape(-1, 784)
    mean_x = np.mean(test_x, axis=0).reshape(1, -1)
    std_x = np.std(test_x)

    test_x = (test_x - mean_x) / std_x
elif dataset == 'CIFAR10':
    cifar_trainset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True)
    cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    print(len(cifar_trainset))
    print(len(cifar_testset))

    train_x = []
    train_y = []
    for x, y in cifar_trainset:
        train_x.append(np.array(x, dtype=np.float32).reshape(3, 32, 32))
        try:
            train_y.append(y.item())
        except:
            train_y.append(y)

    train_x = np.array(train_x).reshape(-1, 3, 32, 32)
    mean_x = np.mean(train_x, axis=0)
    std_x = np.std(train_x)

    train_x = (train_x - mean_x) / std_x


    test_x = []
    test_y = []
    for x, y in cifar_testset:
        test_x.append(np.array(x, dtype=np.float32).reshape(3, 32, 32))
        try:
            test_y.append(y.item())
        except:
            test_y.append(y)

    test_x = np.array(test_x).reshape(-1, 3, 32, 32)
    mean_x = np.mean(test_x, axis=0)
    std_x = np.std(test_x)

    test_x = (test_x - mean_x) / std_x

In [7]:
# get batch based on current L_k
def get_batch(dataset, train_x, train_y, up_const, L_k, convex, fast=False, alpha=None, bs=False):
    global epoch_done
    global available_idx
    if fast and alpha:
        batch_size = max(int(up_const*D_0*alpha / (epsilon)), 1)
    else:
        if convex:
            batch_size = max(int(D_0 / (L_k * epsilon)), 1)
        else:
            batch_size = max(int(8 * D_0 / (epsilon**2)), 1)
    try:
        if bs:
            batch_size = bs
        ids = np.random.choice(available_idx, size=batch_size, replace=False)
    except ValueError:
        ids = available_idx[:]
        epoch_done = True
    
    batch_x = []
    batch_y = []
    for n, i in enumerate(ids):
        if dataset == 'MNIST':
            batch_x.append(train_x[i, :])
        elif dataset == 'CIFAR10':
            batch_x.append(train_x[i, :, :, :].reshape(1, 3, 32, 32))
        batch_y.append(train_y[i])
        available_idx.remove(i)
    if not available_idx:
        epoch_done = True
        available_idx = list(range(len(train_y)))
    return batch_size, torch.FloatTensor(np.vstack(tuple(batch_x))), torch.LongTensor(batch_y)

In [4]:
# all tasks
class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes, bias=True)
    
    def forward(self, x):
        x = self.linear(x)
        return x
    
class FullyConnected(nn.Module):
    def __init__(self):
        super(FullyConnected, self).__init__()
        self.linear1 = nn.Linear(28*28, 1000, bias=True)
        self.linear2 = nn.Linear(1000, 10, bias=True)
        
    def forward(self, x):
        x = self.linear1(x)
        x = torch.sigmoid(x)
        x = self.linear2(x)
        x = torch.sigmoid(x)
        return x
    
class FullyConnectedRELU(nn.Module):
    def __init__(self):
        super(FullyConnectedRELU, self).__init__()
        self.linear1 = nn.Linear(28*28, 1000, bias=True)
        self.linear2 = nn.Linear(1000, 10, bias=True)
        
    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        return x
    
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Example algorithms

## Adaptive method for FullyConnected on MNIST

In [None]:
available_idx = list(range(len(train_x)))
epoch_done = False
n_epoch = 1000
net = FullyConnectedRELU()
convex = True # True is for Alg. 2 in paper, False is for Alg. 5 in paper
criterion = nn.CrossEntropyLoss(reduction='mean')

D_0 = 0.1
epsilon = 0.002
L_k = 1.

iter_losses = []
iter_train_accs = []
iter_test_accs = []

for epoch in range(n_epoch):
    print(len(available_idx))
    print(epoch_done)
    
    while True:
        bs_to_check, x_to_check, y_to_check = get_batch(dataset, train_x, train_y, 2, L_k, convex)

        out = net(x_to_check)
        loss = criterion(out, y_to_check)
        loss_before = loss
        net.zero_grad()
        loss.backward()

        i_k = 0
        while True:

            # params update

            L_k = (2**(i_k - 1))*L_k

            right_hands = []
            for f in net.parameters():
                right_hands.append((-1./(8.*L_k))*f.grad.data.norm(2)**2)
                f.data.sub_(f.grad.data / (2.*L_k))

            right = sum(right_hands)

            out = net(x_to_check)
            loss_after = criterion(out, y_to_check)

            # check condition
            if convex:
                condition = loss_before + right + epsilon/(2.)
            else:
                condition = loss_before + right + epsilon**2/(32*L_k)
            if loss_after <= condition:
                
                with torch.no_grad():
                    out_train = net(torch.from_numpy(train_x))
                    out_test = net(torch.from_numpy(test_x))
                    iter_losses.append(criterion(out_train, torch.from_numpy(np.array(train_y))).item())

                    pred_train = np.argmax(out_train.detach().numpy(), axis=1)
                    ground_train = np.array(train_y)
                    l = float(len(ground_train))
                    iter_train_accs.append(len(np.where(pred_train == ground_train)[0]) / l)

                    pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                    ground_test = np.array(test_y)
                    l = float(len(ground_test))
                    iter_test_accs.append(len(np.where(pred_test == ground_test)[0]) / l)
                
                break
            else:
                for f in net.parameters():
                    f.data.add_(f.grad.data / (2.*L_k))

                i_k += 1
        
        if epoch_done:
            print('epoch {} done'.format(epoch))
            print(len(iter_losses))
            print(iter_losses[-10:])
            epoch_done = False
            break

## Fast Adaptive method for FullyConnected on MNIST

In [None]:
# alg 3 in paper
available_idx = list(range(len(mnist_trainset)))
epoch_done = False
n_epoch = 1000
net = FullyConnectedRELU()
convex = False
criterion = nn.CrossEntropyLoss(reduction='mean')

iter_losses = []
iter_train_accs = []
iter_test_accs = []

D_0 = 0.1
epsilon = 0.002
L_old = 1.

x_old = list(net.parameters())
y_old = list(net.parameters())
u_old = list(net.parameters())
alpha_old = 0.
A_old = 0.

for epoch in range(n_epoch):
    print(len(available_idx))
    print(epoch_done)
    
    while True:
        
        i_k = 0
        alpha_batch = (1 + np.sqrt(1 + 4*A_old*L_old)) / (2*L_old)
        bs_to_check, x_to_check, y_to_check = get_batch(dataset, train_x, train_y, 3., False, convex, True, alpha_batch)
        
        while True:

            L_k = (2**(i_k - 1))*L_old
            alpha_k = (1 + np.sqrt(1 + 4*A_old*L_k)) / (2*L_k)
            A_k = A_old + alpha_k

            with torch.no_grad():
                y_k = []
                for u, x, param in zip(u_old, x_old, list(net.parameters())):
                    new_value = (alpha_k*u + A_old*x) / A_k
                    y_k.append(new_value)
                    param.data = new_value

            net.zero_grad()
            out = net(x_to_check)
            loss_y_k = criterion(out, y_to_check)
            loss_y_k.backward()

            dot_prods = []
            norms = []
            u_k = []
            x_k = []
            for x, u, y_k_value, param in zip(x_old, u_old, y_k, list(net.parameters())):
                u_k_value = u - alpha_k*param.grad.data
                u_k.append(u_k_value)
                x_k_value = (alpha_k*u_k_value + A_old*x) / A_k
                x_k.append(x_k_value)

                dot_prods.append(torch.sum((param.grad.data*(x_k_value-y_k_value)).view(-1)))
                norms.append(torch.sum(((x_k_value-y_k_value)*(x_k_value-y_k_value)).view(-1)))

                param.data = x_k_value
                
            with torch.no_grad():
                out = net(x_to_check)
                loss_x_k = criterion(out, y_to_check)
                
            i_k += 1    
            if loss_x_k <= loss_y_k + sum(dot_prods) + (L_k/2.)*sum(norms) + epsilon/(L_k*alpha_k):
                alpha_old = alpha_k
                A_old = A_k
                u_old = u_k
                x_old = x_k
                L_old = L_k
                
                with torch.no_grad():
                    out_train = net(torch.from_numpy(train_x))
                    out_test = net(torch.from_numpy(test_x))
                    iter_losses.append(criterion(out_train, torch.from_numpy(np.array(train_y))).item())

                    pred_train = np.argmax(out_train.detach().numpy(), axis=1)
                    ground_train = np.array(train_y)
                    l = float(len(ground_train))
                    iter_train_accs.append(len(np.where(pred_train == ground_train)[0]) / l)

                    pred_test = np.argmax(out_test.detach().numpy(),axis=1)
                    ground_test = np.array(test_y)
                    l = float(len(ground_test))
                    iter_test_accs.append(len(np.where(pred_test == ground_test)[0]) / l)

                break
            else:
                for x, f in zip(x_old, list(net.parameters())):
                    f.data = x        
        
        if epoch_done:
            print('epoch {} done'.format(epoch))
            print(len(iter_losses))
            epoch_done = False
            break

# Adam or Adagrad

In [None]:
available_idx = list(range(len(train_x)))
epoch_done = False
n_epoch = 10
net = FullyConnectedRELU()
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.Adam(net.parameters())

iter_losses = []
iter_train_accs = []
iter_test_accs = []

for epoch in range(n_epoch):
    print(len(available_idx))
    print(epoch_done)
    
    while True:
        optimizer.zero_grad()
        bs_to_check, x_to_check, y_to_check = get_batch(dataset, train_x, train_y, 2, L_k, False, bs=128)
        out = net(x_to_check)
        loss = criterion(out, y_to_check)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            out_train = net(torch.from_numpy(train_x))
            out_test = net(torch.from_numpy(test_x))
            iter_losses.append(criterion(out_train, torch.from_numpy(np.array(train_y))).item())

            pred_train = np.argmax(out_train.detach().numpy(), axis=1)
            ground_train = np.array(train_y)
            l = float(len(ground_train))
            iter_train_accs.append(len(np.where(pred_train == ground_train)[0]) / l)

            pred_test = np.argmax(out_test.detach().numpy(),axis=1)
            ground_test = np.array(test_y)
            l = float(len(ground_test))
            iter_test_accs.append(len(np.where(pred_test == ground_test)[0]) / l)
            
        if epoch_done:
            print('epoch {} done'.format(epoch))
            print(len(iter_losses))
            epoch_done = False
            break