In [2]:
from delay_optimizer.delays.utils import ParamHistoryBuffer
from delay_optimizer.delays.distributions import Stochastic
import torch

shape = (10,20,30)
max_L = 3

delay = Stochastic(max_L=max_L)
param = torch.rand(shape)
buffer = ParamHistoryBuffer(param, buffer_size=max_L)

In [7]:
%%timeit
num_iters = 1000
for _ in range(num_iters):
    new_param = torch.rand(shape)
    buffer.update(new_param)

54.1 ms ± 2.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit
num_iters = 1000
for i in range(num_iters):
    delay.sample(shape, i)

92.6 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%%timeit
num_iters = 1000
max_L = 100
temp_buffer = ParamHistoryBuffer(torch.rand(shape), buffer_size=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    temp_buffer.update(new_param)

33.8 ms ± 790 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%%timeit
num_iters = 1000
max_L = 100
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.cat([new_param.detach().unsqueeze(0), temp_buffer], dim=0)
    temp_buffer = full_param_state[:-1]

125 ms ± 25.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%%timeit
num_iters = 1000
max_L = 3
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.cat([new_param.detach().unsqueeze(0), temp_buffer], dim=0)
    D = delay.sample(shape, i)
    delayed_param = full_param_state.gather(0, D.unsqueeze(0)).squeeze(0)
    temp_buffer = full_param_state[:-1]

122 ms ± 4.36 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
num_iters = 1000
max_L = 3
temp_param = torch.rand(shape)  
temp_buffer = ParamHistoryBuffer(temp_param, buffer_size=max_L)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    temp_param, temp_buffer = delay(new_param, temp_buffer, i)

15.4 s ± 1.02 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
num_iters = 1000
max_L = 3
temp_buffer = torch.stack([torch.rand(shape).clone().detach() for _ in range(max_L)], dim=0)
delay = Stochastic(max_L=max_L)
for i in range(num_iters):
    new_param = torch.rand(shape)
    full_param_state = torch.stack([new_param.detach()] + list(temp_buffer), dim=0)
    D = delay.sample(shape, i)
    delayed_param = full_param_state.gather(0, D.unsqueeze(0)).squeeze(0)
    temp_buffer = full_param_state[:-1]

130 ms ± 391 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [40]:
import torch
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(5, 6)
        self.linear.weight.data.fill_(1.0)
        self.linear.bias.data.fill_(0.0)
    
    def forward(self, x):
        print(self.linear.weight)
        print(self.linear.bias)
        return self.linear(x)
    

In [41]:
class ParamHistoryBuffer:
    """Holds the parameter history and manages delays for a single parameter."""
    def __init__(self, param, buffer_size):
        self.buffer_size = buffer_size      # Should be max_L+1
        self._buffer = param.repeat(self.buffer_size, *(1,)*param.ndim).detach()
        self._current_idx = 0

    @property
    def parameter(self):
        return self._buffer[self._current_idx]

    def __getattr__(self, name):
        return getattr(self._buffer, name)

    def __repr__(self):
        return ("ParamHistoryBuffer(\n"
                f"{self._buffer}, current_idx={self._current_idx})")

    def _delay_to_idx(self, delay):
        if any(delay < 0) or any(delay > (self.buffer_size-1)):
            raise IndexError(f"Delay must be in [0, {self.buffer_size-1}]")
        return (self._current_idx + delay) % self.buffer_size

    def __getitem__(self, args):
        delay, *args = args if isinstance(args, tuple) else (args,)
        if isinstance(delay, slice):
             raise ValueError("Slicing is not supported over the delay dimension.")
        return self._buffer[self._delay_to_idx(delay), *args]

    def update(self, new_param):
        self._current_idx = (self._current_idx - 1) % self.buffer_size
        self._buffer[self._current_idx] = new_param
        
        

In [42]:
model = SimpleModel()
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], requires_grad=True)
Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)


In [None]:
b

for i in range(3):
    random_input = torch.randn(1,5)
    random_out = model(random_input)
    for buffer in buffers:
        buffer.update(torch.randint(0, 4, buffer.parameter.shape))

        print(buffer)

    

# # leaf variable with requires_grad=True can not used inplace operation
# for name, params in model.named_parameters():
#     params.data.copy_(old_params[name])

NameError: name 'old_params' is not defined