In [None]:
import numpy as np
import torch
import torch.nn as nn
from tedeous.model import Model

In [None]:
device = 'cuda'

In [4]:
x_grid = np.linspace(0, 1, 51)
t_grid = np.linspace(0, 1, 51)

x = torch.from_numpy(x_grid)
t = torch.from_numpy(t_grid)

grid = torch.cartesian_prod(x, t).float().to(device)

def nn_autograd_simple(model, points, order,axis=0):
    points.requires_grad=True
    f = model(points).sum()
    for i in range(order):
        grads, = torch.autograd.grad(f, points, create_graph=True)
        f = grads[:,axis].sum()
    return grads[:,axis]

func_bnd1 = lambda x: 10 ** 4 * torch.sin((1 / 10) * x * (x - 1)) ** 2
bnd1 = torch.cartesian_prod(x, torch.from_numpy(np.array([0], dtype=np.float64))).float().to(device)
bndval1 = func_bnd1(bnd1[:, 0])

# du/dx (x,0) = 1e3*sin^2(x(x-1)/10)
func_bnd2 = lambda x: 10 ** 3 * torch.sin((1 / 10) * x * (x - 1)) ** 2
bnd2 = torch.cartesian_prod(x, torch.from_numpy(np.array([0], dtype=np.float64))).float().to(device)
bop2 = {
    'du/dt':
        {
            'coeff': 1,
            'du/dt': [1],
            'pow': 1,
            'var': 0
        }
}
bndval2 = func_bnd2(bnd2[:, 0])

# u(0,t) = u(1,t)
bnd3_left = torch.cartesian_prod(torch.from_numpy(np.array([0], dtype=np.float64)), t).float().to(device)
bnd3_right = torch.cartesian_prod(torch.from_numpy(np.array([1], dtype=np.float64)), t).float().to(device)
bnd3 = [bnd3_left, bnd3_right]

# du/dt(0,t) = du/dt(1,t)
bnd4_left = torch.cartesian_prod(torch.from_numpy(np.array([0], dtype=np.float64)), t).float().to(device)
bnd4_right = torch.cartesian_prod(torch.from_numpy(np.array([1], dtype=np.float64)), t).float().to(device)
bnd4 = [bnd4_left, bnd4_right]

bop4 = {
    'du/dx':
        {
            'coeff': 1,
            'du/dx': [0],
            'pow': 1,
            'var': 0
        }
}
bcond_type = 'periodic'

bconds = [[bnd1, bndval1, 'dirichlet'],
          [bnd2, bop2, bndval2, 'operator'],
          [bnd3, bcond_type],
          [bnd4, bop4, bcond_type]]

def wave_op(model, grid):
    u_xx = nn_autograd_simple(model, grid, order=2, axis=0)
    u_tt = nn_autograd_simple(model, grid, order=2, axis=1)
    a = -(1 / 4)

    op = u_tt + a * u_xx

    return op

def op_loss(operator):
    return torch.mean(torch.square(operator))

def bcs_loss(model):
    bc1 = model(bnd1)
    bc2 = nn_autograd_simple(model, bnd2, order=1, axis=1)
    bc3 = model(bnd3_left) - model(bnd3_right)
    bc4 = nn_autograd_simple(model, bnd4_left, order=1, axis=0) - nn_autograd_simple(model, bnd4_right, order=1, axis=0)
    
    loss_bc1 = torch.mean(torch.square(bc1.reshape(-1) - bndval1))
    loss_bc2 = torch.mean(torch.square(bc2.reshape(-1) - bndval2))
    loss_bc3 = torch.mean(torch.square(bc3))
    loss_bc4 = torch.mean(torch.square(bc4))
    
    loss = loss_bc1 + loss_bc2 + loss_bc3 + loss_bc4
    return loss



In [5]:
def loss_fn(model):
    # model.load_state_dict(params)
    operator = wave_op(model, grid)
    loss = op_loss(operator) + 1000 * bcs_loss(model)
    return loss

In [6]:
model = torch.nn.Sequential(
        nn.Linear(2, 100),
        nn.Tanh(),
        nn.Linear(100, 100),
        nn.Tanh(),
        nn.Linear(100, 100),
        nn.Tanh(),
        nn.Linear(100, 1)).to(device)

In [7]:
from copy import deepcopy

In [8]:
N_samples = 10
input_size = len(grid)
d = 2

In [9]:
gradient_mode = 'central'

In [10]:
class ZO_AdaMM(torch.optim.Optimizer):
    def __init__(self, params, input_size, gradient_mode = 'central', sampler = 'uniform', N_samples = 10, dim = 2, lr=1e-03, betas=(0.9, 0.999), mu=1e-03, eps=1e-12):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: (} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: (} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0l".format(betas[1]))
        if not 0.0 <= mu < 1.0:
            raise ValueError("Invalid mu parameter: {} - should be in [0.0, 1.0l".format(mu))

        defaults = dict(lr=lr, betas=betas, mu=mu, eps=eps)
        super().__init__(params, defaults)
        self.input_size = input_size
        self.gradient_mode = gradient_mode
        self.sampler = sampler
        self.N_samples = N_samples
        self.dim = dim
        
        self.size_params = 0
        for group in self.param_groups:
            for p in group['params']:
                self.size_params += torch.numel(p)
                
                
    # def step(self, closure):
    #     for group in self.param_groups:
    #         lr = group['lr']
    #         
    #         # grad_est = self._grads(group['params'], fd_eps)
    #         # for i, param in enumerate(group['params']):
    #         #     param.data.add_(-lr * torch.sign(grad_est[i]))
    #         for i, param in enumerate(group['params']):
    #             grad_est = closure(self.size_params, group["mu"], N_samples, input_size, d)
    #             param.data.add_(-lr * torch.sign(grad_est[i]))
                
    def step(self, closure):

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            # Closure return the approximation for the gradient
            grad_est = closure(self.size_params, group["mu"], N_samples, input_size, d)
            print(grad_est)
            for p, grad in zip(group['params'], grad_est):
                state = self.state[p]

                # Lazy state initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                # Do the AdaMM updates
                state['exp_avg'].mul_(beta1).add_(grad, alpha=(1.0 - beta1))
                state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2))
                state['max_exp_avg_sq'] = torch.maximum(state['max_exp_avg_sq'],
                                                        state['exp_avg_sq'])

                p.data.addcdiv_(state['exp_avg'], state['exp_avg_sq'].sqrt().add_(group['eps']), value=(-group['lr']))


In [11]:
def closure(size_params, mu, N_samples, input_size, d):
    init_model_parameters = deepcopy(dict(model.state_dict()))
    model_parameters = dict(model.state_dict()).values()
    
    def parameter_perturbation(eps):
        start_idx = 0            
        for param in model_parameters:
            end_idx = start_idx + param.view(-1).size()[0]
            param.add_(eps[start_idx : end_idx].view(param.size()).float(), alpha=np.sqrt(mu))
            start_idx = end_idx
    
    def grads_multiplication(grads, u):
        start_idx = 0
        grad_est = []
        for param in model_parameters:
            end_idx = start_idx + param.view(-1).size()[0]
            grad_est.append(grads * u[start_idx:end_idx].view(param.size()))
            start_idx = end_idx
        return grad_est
    
    grads = [torch.zeros_like(param) for param in model_parameters]
    loss = loss_fn(model)
    
    for _ in range(N_samples):
        with torch.no_grad():
            if self.sampler == 'uniform':
                u = 2 * (torch.rand(size_params) - 0.5)
                u.div_(torch.norm(u, "fro"))
                u = u.to(device)
            elif self.sampler == 'normal':
                u = torch.randn(size_params)
                u.to(device)
                
        # param + mu * eps
            parameter_perturbation(u) 
        loss_add = loss_fn(model)
        
        # param - mu * eps
        with torch.no_grad():
            parameter_perturbation(-2 * u)
        loss_sub = loss_fn(model)
        
        
        with torch.no_grad():
            if gradient_mode == 'central':
                # (1/ inp_size * q) * d * [f(x+mu*eps) - f(x-mu*eps)] / 2*mu
                grad_coeff = (1 / (input_size * N_samples)) * d * (loss_add - loss_sub) / (2 * mu)
            elif gradient_mode == 'forward':
                # d * [f(x+mu*eps) - f(x)] / mu
                grad_coeff = (1 / (input_size * N_samples)) * d * (loss_add - loss) / mu
            elif gradient_mode == 'backward':
                # d * [f(x) - f(x-mu*eps)] / mu
                grad_coeff = (1 / (input_size * N_samples)) * d * (loss - loss_sub) / mu
            
            # coeff * u, i.e. constant multiplied by infinitely small perturbation.
            current_grad = grads_multiplication(grad_coeff, u)
        
            grads = [grad_past + cur_grad for grad_past, cur_grad in zip(grads, current_grad)]
            
        #load initial model parameters
        model.load_state_dict(init_model_parameters)
      
        assert loss == loss_fn(model)
    
    return grads
    
    
    

In [12]:
opt = ZO_AdaMM(model.parameters(), len(grid), lr = 1e-3, N_samples=1)

In [13]:
for i in range(1000):
    opt.zero_grad()
    opt.step(closure)
    loss = loss_fn(model)
    print(loss.item())

[tensor([[ 2.6331e-03, -3.7557e-03],
        [-1.0465e-03, -3.7178e-03],
        [-4.4855e-03,  1.5280e-03],
        [ 2.6061e-03,  3.6839e-03],
        [-2.2262e-03,  1.4313e-03],
        [ 3.6663e-03,  2.8909e-03],
        [-1.9063e-03, -2.1543e-03],
        [ 1.8190e-04,  4.6285e-03],
        [-4.1299e-03,  5.4359e-04],
        [ 8.4455e-05, -1.2584e-04],
        [-1.0551e-04,  1.2600e-03],
        [-4.0652e-04, -9.2010e-04],
        [ 3.2965e-03, -6.2381e-04],
        [ 1.0024e-03,  5.4078e-03],
        [ 1.9707e-03, -2.4203e-03],
        [ 6.0872e-04,  5.2855e-04],
        [-1.7187e-03,  1.0463e-03],
        [ 9.6192e-04,  1.9741e-03],
        [-3.2916e-03, -1.2429e-04],
        [ 1.4155e-03,  1.8269e-03],
        [-4.0269e-03, -3.0151e-03],
        [ 5.4094e-04, -1.3637e-03],
        [ 9.9771e-04, -8.5297e-04],
        [ 3.8085e-03,  1.9054e-03],
        [-1.6993e-03, -3.8599e-03],
        [ 2.0115e-03, -9.1846e-04],
        [-1.0263e-03,  1.3000e-03],
        [ 1.6488e-03,  3.98

KeyboardInterrupt: 

In [14]:
opt.size_params

751

In [38]:
u = torch.randn(opt.size_params) * np.sqrt(1e-3)


In [39]:
torch.var(u)

tensor(0.0010)

In [34]:
torch.sqrt(torch.tensor(1e-3))

tensor(0.0316)

In [37]:
np.sqrt(1e-3)

0.03162277660168379

In [7]:
opti = torch.optim.Adam(model.parameters(), lr= 1e-3)

In [13]:
opti.param_groups[0]['lr']

0.001