In [2]:
from delay_optimizer.delays.utils import ParamHistoryBuffer
from delay_optimizer.delays.distributions import Stochastic
import torch

shape = (10,20,30)
max_L = 3

delay = Stochastic(max_L=max_L)
param = torch.rand(shape)
buffer = ParamHistoryBuffer(param, buffer_size=max_L)

In [7]:
%%timeit
num_iters = 1000
for _ in range(num_iters):
    new_param = torch.rand(shape)
    buffer.update(new_param)

54.1 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit
num_iters = 1000
for i in range(num_iters):
    delay.sample(shape, i)

92.6 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%%timeit
num_iters = 1000
max_L = 100
temp_buffer = ParamHistoryBuffer(torch.rand(shape), buffer_size=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    temp_buffer.update(new_param)

33.8 ms ± 790 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%%timeit
num_iters = 1000
max_L = 100
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.cat([new_param.detach().unsqueeze(0), temp_buffer], dim=0)
    temp_buffer = full_param_state[:-1]

125 ms ± 25.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%timeit
num_iters = 1000
max_L = 3
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.cat([new_param.detach().unsqueeze(0), temp_buffer], dim=0)
    D = delay.sample(shape, i)
    delayed_param = full_param_state.gather(0, D.unsqueeze(0)).squeeze(0)
    temp_buffer = full_param_state[:-1]

122 ms ± 4.36 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
num_iters = 1000
max_L = 3
temp_param = torch.rand(shape)  
temp_buffer = ParamHistoryBuffer(temp_param, buffer_size=max_L)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    temp_param, temp_buffer = delay(new_param, temp_buffer, i)

15.4 s ± 1.02 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
num_iters = 1000
max_L = 3
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.stack([new_param.detach()] + list(temp_buffer), dim=0)
    D = delay.sample(shape, i)
    delayed_param = full_param_state.gather(0, D.unsqueeze(0)).squeeze(0)
    temp_buffer = full_param_state[:-1]

130 ms ± 391 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [1]:
import torch
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(5, 6)
        self.linear.weight.data.fill_(1.0)
        self.linear.bias.data.fill_(0.0)
    
    def forward(self, x):
        return self.linear(x)
    

In [2]:
from delay_optimizer import DelayedOptimizer
from torch.optim import Adam

model = SimpleModel()
optimizer = DelayedOptimizer(model.parameters(), Adam, lr=0.01, delay=1)

In [3]:
optimizer.apply_delays()

In [4]:
out = model(torch.rand(2,5))
loss = torch.abs(out).sum()
loss.backward()
optimizer.step()

In [5]:
optimizer.state

defaultdict(dict,
            {Parameter containing:
             tensor([[0.9900, 0.9900, 0.9900, 0.9900, 0.9900],
                     [0.9900, 0.9900, 0.9900, 0.9900, 0.9900],
                     [0.9900, 0.9900, 0.9900, 0.9900, 0.9900],
                     [0.9900, 0.9900, 0.9900, 0.9900, 0.9900],
                     [0.9900, 0.9900, 0.9900, 0.9900, 0.9900],
                     [0.9900, 0.9900, 0.9900, 0.9900, 0.9900]], requires_grad=True): {'step': tensor(1.),
              'exp_avg': tensor([[0.1078, 0.1204, 0.1545, 0.0990, 0.1386],
                      [0.1078, 0.1204, 0.1545, 0.0990, 0.1386],
                      [0.1078, 0.1204, 0.1545, 0.0990, 0.1386],
                      [0.1078, 0.1204, 0.1545, 0.0990, 0.1386],
                      [0.1078, 0.1204, 0.1545, 0.0990, 0.1386],
                      [0.1078, 0.1204, 0.1545, 0.0990, 0.1386]]),
              'exp_avg_sq': tensor([[0.0012, 0.0014, 0.0024, 0.0010, 0.0019],
                      [0.0012, 0.0014, 0.0024, 0.00