In [3]:
import torch
from torch.optim.optimizer import Optimizer

In [4]:
def closure(size_params,mu):
    grad_est = []
    
    u = torch.normal(mean = torch.zeros(size_params),std = 1)
    u = torch.div(u,torch.norm(u,"fro"))
    
    # initial evaluation
    output = model(input)
    loss_init = criterion.forward(output)
    
    # save the state of the model 
    model_init = dict(model.state_dict())
    
    start_ind = 0
    for param_tensor in model.state_dict():
        end_ind = start + model.state_dict()[param_tensor].view(-1).size()[0]
        model.state_dict()[param_tensor].add_(u[start_ind:end_ind].view(model.state_dict()[param_tensor].size()), value = mu)
    
    # random evaluation
    output2 = model(input)
    loss_random = criterion.forward(output2)
    
    
    # load initial state
    model.load_state_dict(model_init)
    
    # compute the gradient
    
    grad_norm = size_params*(loss_random-loss_init)/mu
    grad_est = []
    
    start_ind = 0
    for param_tensor in model_init:
        end_ind = start + model_init[param_tensor].view(-1).size()[0]
        grad_est.apppend(grad_norm*u[start_ind:end_ind].view(model_init[param_tensor].size()))
    
    return grad_est
     

In [5]:
class ZO_AdaMM(Optimizer):
    
    def __init__(self,params,lr = 1e-03,betas = (0.9,0.999), mu = 1e-05, eps = 1e-12):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: (} - should be >= 0.0". format (lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError ("Invalid beta parameter: (} - should be in [0.0, 1.0[". format (betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0l". format (betas [1]))
        if not 0.0 <= mu < 1.0:
            raise ValueError("Invalid mu parameter: {} - should be in [0.0, 1.0l". format (mu))
            
        defaults = dict(lr=lr, betas=betas, mu=mu, eps = eps)
        super(ZO_AdaMM,self).__init__(params,defaults)
        
    def step(self, closure):
        
         for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            beta1, beta2 = group['betas']
            
            size_params = 0
            
            for p in group['params']:
                size_params += p.view(-1).size()[0]
            
            # closure return the approximation for the gradient, we have to add some "option" to this function 
            grad_est = closure(size_params,group["mu"])
            
            i = 0
            for p in group['params']:    
                #grads.append(grad_est[i])
                state = self.state[p]
                # Lazy state initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    #if group['amsgrad']:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avgs.append(state['exp_avg'])
                exp_avg_sqs.append(state['exp_avg_sq'])

                #if group['amsgrad']:
                max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                # update the steps for each param group update
                state['step'] += 1
                # record the step after step update
                state_steps.append(state['step'])
                    
                
                beta1, beta2 = group['betas']
                state['exp_avg'].mul_(beta1).add_(grad_est[i],alpha = (1.0 - beta1))
                state['exp_avg_sq'].mul_(beta2).addcmul_(grad_est[i], grad_est[i],value = (1.0 - beta2))
                state['max_exp_avg_sq'] = torch.max(state['max_exp_avg_sq'],state['exp_avg_sq'])# vérifier max ou maximum
                
                p.data.addcdiv_(state['exp_avg'], state['exp_avg_sq'].sqrt().add_(group['eps']),value = (-group['lr']))
                i +=1
           
        

In [6]:
import torch.nn.functional as F
import torchvision


import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn

import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (14,12)

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [67]:
class SmallModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(3, 9, 3)
        #self.fc1 = nn.Linear(9 * 5 * 5, 15)
        #self.fc2 = nn.Linear(15, 13)
        self.fc3 = nn.Linear(9 * 5 * 5, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        #x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [42]:
transform = transforms.Compose([transforms.ToTensor()])

mnist_dataset_train = torchvision.datasets.MNIST('data/mnist/', download=True, train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_dataset_train, batch_size=1)

mnist_dataset_test = torchvision.datasets.MNIST('data/mnist/', download=True, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_dataset_train, batch_size=1)

criterion = nn.CrossEntropyLoss()

nb_epochs = 4


In [76]:
def train(model, optimizer, criterion, nb_epochs, train_loader, test_loader):
    # Heavily inspired from PyTorch tutorial
    train_losses = []
    test_accuracies = []

    running_loss = 0

    for e in range(nb_epochs):

        for i, data in enumerate(train_loader):
            inputs, labels = data
            
            #optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            
                
            def closure(size_params,mu):
                # initial evaluation
                output = outputs
                loss_init = loss 
                # save the state of the model 
                model_init = dict(model.state_dict())
                #u = torch.normal(mean = torch.zeros(size_params),std = 1)
                u = 2*(torch.rand(size_params)-0.5)
                u.div_(torch.norm(u,"fro"))

                start_ind = 0
                for param_tensor in model.state_dict():
                    end_ind = start_ind + model.state_dict()[param_tensor].view(-1).size()[0]
                    model.state_dict()[param_tensor].add_(u[start_ind:end_ind].view(model.state_dict()[param_tensor].size()), alpha = mu)
                    start_ind = end_ind
                print(type(model.state_dict()))
                # random evaluation
                output2 = model(inputs)
                loss_random = criterion(output2,labels)


                # load initial state
                model.load_state_dict(model_init)
                # compute the gradient
                
                # when u is uniform random variable
                grad_norm = size_params*(loss_random-loss_init)/mu
                # when u is Gaussian random variable
                #grad_norm = (loss_random-loss_init)/mu
                grad_est = []

                start_ind = 0
                for param_tensor in model_init:
                    end_ind = start_ind + model_init[param_tensor].view(-1).size()[0]
                    grad_est.append(grad_norm*u[start_ind:end_ind].view(model_init[param_tensor].size()))
                    start_ind = end_ind
                return grad_est
     
            
            optimizer.step(closure)

            if i % 2000 == 1999:
                train_losses.append(running_loss / 2000)
                print(f'epoch : {e + 1}/{nb_epochs} | train loss : {train_losses[-1]:.4f}')
                running_loss = 0.0

        with torch.no_grad():
            correct_preds = 0
            total_preds = 0

            for inputs, labels in test_loader:
                outputs = model(inputs)

                predictions = torch.argmax(outputs, 1)
                total_preds += labels.size(0)
                correct_preds += (predictions == labels).sum().item()

            test_accuracies.append(correct_preds / total_preds)


    return train_losses, test_accuracies

In [81]:
for i in range(5):
    for param_tensor in model.state_dict().values(): 
        print(param_tensor.size())

torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([9, 3, 3, 3])
torch.Size([9])
torch.Size([10, 225])
torch.Size([10])
torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([9, 3, 3, 3])
torch.Size([9])
torch.Size([10, 225])
torch.Size([10])
torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([9, 3, 3, 3])
torch.Size([9])
torch.Size([10, 225])
torch.Size([10])
torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([9, 3, 3, 3])
torch.Size([9])
torch.Size([10, 225])
torch.Size([10])
torch.Size([3, 1, 5, 5])
torch.Size([3])
torch.Size([9, 3, 3, 3])
torch.Size([9])
torch.Size([10, 225])
torch.Size([10])


In [78]:
model = SmallModel()
with torch.no_grad():
    optimizer_pt = ZO_AdaMM(model.parameters(), lr=1e-3, betas=(0.3, 0.9),mu =1e-03, eps=1e-10)

    train_losses_pt, test_acc_pt = train(model, optimizer_pt, criterion, nb_epochs, train_loader, test_loader)

<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collec

<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collections.OrderedDict'>
<class 'collec

KeyboardInterrupt: 

In [73]:
def train(model, optimizer, criterion, nb_epochs, train_loader, test_loader):
    # Heavily inspired from PyTorch tutorial
    train_losses = []
    test_accuracies = []

    running_loss = 0

    for e in range(nb_epochs):

        for i, data in enumerate(train_loader):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = model(inputs)
            #print(outputs)
            loss = criterion(outputs, labels)
            #print(loss)
            loss.backward()

            optimizer.step()
            #print(loss)
            #print('-----------------')

            running_loss += loss.item()

            if i % 2000 == 1999:
                train_losses.append(running_loss / 1000)
                print(f'epoch : {e + 1}/{nb_epochs} | train loss : {train_losses[-1]:.4f}')
                running_loss = 0.0

        with torch.no_grad():
            correct_preds = 0
            total_preds = 0
            for inputs, labels in test_loader:
                outputs = model(inputs)

                predictions = torch.argmax(outputs, 1)
                total_preds += labels.size(0)
                correct_preds += (predictions == labels).sum().item()

            test_accuracies.append(correct_preds / total_preds)

    return train_losses, test_accuracies

In [74]:
small_model = SmallModel()

optimizer_pt = optim.Adam(small_model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, amsgrad=True)

train_losses_pt, test_acc_pt = train(small_model, optimizer_pt, criterion, nb_epochs, train_loader, test_loader)

epoch : 1/4 | train loss : 1.4947
epoch : 1/4 | train loss : 0.5242
epoch : 1/4 | train loss : 0.4342
epoch : 1/4 | train loss : 0.3548
epoch : 1/4 | train loss : 0.3857
epoch : 1/4 | train loss : 0.2623
epoch : 1/4 | train loss : 0.3025
epoch : 1/4 | train loss : 0.2995
epoch : 1/4 | train loss : 0.2553
epoch : 1/4 | train loss : 0.2043
epoch : 1/4 | train loss : 0.2413
epoch : 1/4 | train loss : 0.2041
epoch : 1/4 | train loss : 0.1924
epoch : 1/4 | train loss : 0.2067
epoch : 1/4 | train loss : 0.1955
epoch : 1/4 | train loss : 0.2156
epoch : 1/4 | train loss : 0.2010
epoch : 1/4 | train loss : 0.1986
epoch : 1/4 | train loss : 0.1985
epoch : 1/4 | train loss : 0.1803
epoch : 1/4 | train loss : 0.2019
epoch : 1/4 | train loss : 0.2035
epoch : 1/4 | train loss : 0.1844
epoch : 1/4 | train loss : 0.2311
epoch : 1/4 | train loss : 0.1943
epoch : 1/4 | train loss : 0.1611
epoch : 1/4 | train loss : 0.2041
epoch : 1/4 | train loss : 0.1393
epoch : 1/4 | train loss : 0.1210
epoch : 1/4 | 

epoch : 3/4 | train loss : 0.1726
epoch : 3/4 | train loss : 0.1304
epoch : 3/4 | train loss : 0.1199
epoch : 3/4 | train loss : 0.1297
epoch : 3/4 | train loss : 0.1362
epoch : 3/4 | train loss : 0.1054
epoch : 3/4 | train loss : 0.1185
epoch : 3/4 | train loss : 0.1301
epoch : 3/4 | train loss : 0.1107
epoch : 3/4 | train loss : 0.0904
epoch : 3/4 | train loss : 0.1172
epoch : 3/4 | train loss : 0.1214
epoch : 3/4 | train loss : 0.1180
epoch : 3/4 | train loss : 0.1560
epoch : 3/4 | train loss : 0.0894
epoch : 3/4 | train loss : 0.1153
epoch : 3/4 | train loss : 0.1170
epoch : 3/4 | train loss : 0.1265
epoch : 3/4 | train loss : 0.1283
epoch : 3/4 | train loss : 0.1150
epoch : 3/4 | train loss : 0.1178
epoch : 3/4 | train loss : 0.1413
epoch : 3/4 | train loss : 0.1132
epoch : 3/4 | train loss : 0.1359
epoch : 3/4 | train loss : 0.1341
epoch : 3/4 | train loss : 0.1023
epoch : 3/4 | train loss : 0.1182
epoch : 3/4 | train loss : 0.0929
epoch : 3/4 | train loss : 0.0802
epoch : 3/4 | 