In [1]:
from torch import nn

class TestModel(nn.Module):
    def __init__(self, input_size=3, intermediate_size=5, output_size=2):
        super(TestModel, self).__init__()
        self.fc1 = nn.Linear(input_size, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, output_size)

    def forward(self, x):
        return self.fc2(self.fc1(x))

model = TestModel()

In [6]:
from torch.optim import Adam
import torch
from pytorch_optimizer import DelayedOptimizationWrapper
from delay_optimizer.delays.distributions import Stochastic, Undelayed

bias_params = [p for n, p in model.named_parameters() if "bias" in n]
non_bias_params = [p for n, p in model.named_parameters() if "bias" not in n]
optimizer = DelayedOptimizationWrapper(Adam([{"params": bias_params, "delay": Undelayed()},{"params": non_bias_params, "delay":Stochastic(3, 1000)}], lr=0.01))

In [7]:
optimizer

DelayedOptimizationWrapper (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    delay: Undelayed
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    history: [tensor([], size=(0, 5)), tensor([], size=(0, 2))]
    lr: 0.01
    max_L: 0
    maximize: False
    weight_decay: 0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    delay: Stochastic
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    history: [tensor([[[ 5.3105e-01,  6.8636e-02, -5.2133e-01],
         [ 2.3278e-01,  5.1172e-01,  3.4246e-01],
         [-6.9846e-02,  1.7337e-05,  2.4849e-01],
         [-2.2138e-01,  1.1057e-01,  4.9824e-01],
         [-5.3136e-02,  4.3500e-01, -5.1018e-01]],

        [[ 5.3105e-01,  6.8636e-02, -5.2133e-01],
         [ 2.3278e-01,  5.1172e-01,  3.4246e-01],
         [-6.9846e-02,  1.7337e-05,  2.4849e-01],
         [-2.2138e-01,  1.1057e-01,  4.9824e-01],
         [-5.31

In [26]:
optimizer.zero_grad()
optimizer.apply_delays()

x = torch.randn(1, 3)
y = model(x)
loss = torch.square(y).sum()
loss.backward()
optimizer.step()

In [27]:
optimizer.state[bias_params[0]].get("step", torch.tensor(0.0))

tensor(15.)

In [9]:
param_history = torch.randint(0, 100, (1, 3))
param = torch.tensor([1.0, 2.0, 3.0])
L =1

In [27]:
import torch
from collections import deque

shape = (4,3)
L = 2
param_state = torch.full(shape, 1.)
# param_history = deque([torch.full(shape, i+L, dtype=float) 
#                         for i in range(L)], maxlen=L)
param_history = torch.stack([torch.full(shape, i+L, dtype=float) for i in range(L)], dim=0)
D = torch.randint(0, L+1, shape)

print(param_state)
print(param_history)
print(D)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]], dtype=torch.float64)
tensor([[2, 1, 2],
        [1, 1, 0],
        [0, 0, 1],
        [2, 1, 1]])


In [25]:
# Create a mask where D > 0
mask = D > 0
print(mask)

# Directly extract only the required delayed values from param_history
initial_param_state = param_state.clone().detach()
print(initial_param_state)

# delayed_values = [param_history[d][i, j].item() for d, i, j in 
                    # zip(D[mask]-1, *mask.nonzero(as_tuple=True))]
delayed_values = param_history[D[mask]-1, *mask.nonzero(as_tuple=True)]
print(delayed_values)

# Perform the update in-place on param_state
# param_state[mask] = torch.tensor(delayed_values, dtype=param_state.dtype)
print(param_state)

# Prepend the current param_state to param_history and drop the oldest
# param_history.appendleft(initial_param_state)   # maxlen=L so oldest is dropped
print(param_history)


tensor([[ True,  True,  True],
        [ True,  True,  True],
        [False, False, False],
        [ True,  True, False]])
tensor([[2., 2., 2.],
        [3., 3., 3.],
        [1., 1., 1.],
        [3., 3., 1.]])
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [3., 3., 3.],
        [3., 3., 3.]], dtype=torch.float64)
tensor([[2., 2., 2.],
        [3., 3., 3.],
        [1., 1., 1.],
        [3., 3., 1.]])
tensor([[[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]]], dtype=torch.float64)


In [169]:
class ParamHistoryBuffer:
    """Holds the parameter history and manages delays for a single parameter."""
    def __init__(
        self,
        param,
        buffer_size,
        device = None,
    ):
        self.buffer_size = buffer_size
        self.device = device or param.device

        self._initialize_history(param)

    def __repr__(self):
        return (f"ParamHistoryBuffer({self._buffer}, "
                f"current_idx={self._current_idx})")

    def _initialize_history(self, param):
        """Initialize the history with L copies of the current parameter."""
        self._buffer = param.repeat(self.buffer_size, *(1,)*param.ndim)
        self._current_idx = 0
        
    def _delay_to_idx(self, delay):
        if delay <= 0 or delay > self.buffer_size:
            raise IndexError(f"Delay must be in [1, {self.buffer_size}]")
        return (self._current_idx + delay - 1) % self.buffer_size

    def __getitem__(self, args):
        delay, *args = args if isinstance(args, tuple) else (args,)
        if isinstance(delay, slice):
             raise ValueError("Slicing over delays in parameter history "
                                "is not supported.")
        return self._buffer[self._delay_to_idx(delay), *args[1:]]

    def update(self, new_param):
        self._current_idx = (self._current_idx - 1) % self.buffer_size
        self._buffer[self._current_idx] = new_param

In [172]:
param_history = ParamHistoryBuffer(param_state, 4)
param_history.update(torch.full(param_state.shape, 4.0))  # delay of 3
param_history.update(torch.full(param_state.shape, 3.0))  # delay of 2
param_history.update(torch.full(param_state.shape, 2.0))  # delay of 1

In [173]:
print(param_history)

ParamHistoryBuffer(tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.],
         [3., 3., 3.]],

        [[4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.],
         [4., 4., 4.]]]), current_idx=1)


In [40]:
# Create a mask where D > 0
nonzero_mask = D > 0
print(mask)

# Directly extract only the required delayed values from param_history
initial_param_state = param_state.clone().detach()
print(initial_param_state)

# delayed_values = [param_history[d][i, j].item() for d, i, j in 
                    # zip(D[mask]-1, *mask.nonzero(as_tuple=True))]
delayed_values = param_history[D[mask], *nonzero_mask.nonzero(as_tuple=True)]
print(delayed_values)

# Perform the update in-place on param_state
param_state[mask] = torch.tensor(delayed_values, dtype=param_state.dtype)
print(param_state)

# Prepend the current param_state to param_history and drop the oldest
# param_history.appendleft(initial_param_state)   # maxlen=L so oldest is dropped
print(param_history)

tensor([[ True,  True,  True],
        [ True,  True,  True],
        [False, False, False],
        [ True,  True, False]])
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
(tensor([2, 1, 2, 1, 1, 0, 2, 1]), tensor([0, 0, 0, 1, 1, 2, 3, 3, 3]), tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]))
((tensor([2, 1, 2, 1, 1, 0, 2, 1]), tensor([0, 0, 0, 1, 1, 2, 3, 3, 3]), tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])),)


TypeError: '<=' not supported between instances of 'tuple' and 'int'

In [8]:
mask = D > 0
indices = mask.nonzero(as_tuple=True)  # Get the indices where the mask is True

# Extract the indices for the first dimension of param_history
time_indices = (D[mask] - 1).long()

# Use gather to extract the delayed values from param_history
# param_history is assumed to have shape [L, *param.shape]
delayed_values = param_history.gather(0, time_indices.unsqueeze(1).unsqueeze(2).expand(-1, *param_state.shape[1:]))

# Assign the delayed values back to param_state
param_state[mask] = delayed_values.to(param_state.dtype)


AttributeError: 'collections.deque' object has no attribute 'gather'

In [20]:
param_state

tensor([[3., 3., 1.],
        [2., 1., 3.],
        [1., 2., 3.],
        [2., 1., 2.]])

In [129]:
def delay_old(param, param_history, D):
    full_param_state = torch.cat([param.detach().unsqueeze(0), param_history], dim=0)
    param.data.copy_(full_param_state.gather(0, D.unsqueeze(0)).squeeze(0))
    param_history.data.copy_(full_param_state[:-1])
    return param, param_history

def delay_inplace(param, param_history, D):
    full_param_state = torch.cat([param.detach().unsqueeze(0), param_history], dim=0)
    param.copy_(full_param_state.gather(0, D).squeeze(0))
    param_history.copy_(full_param_state[:-1])
    return param, param_history


In [130]:
shape = (32,64)
L = 5

In [131]:
%%timeit
param = torch.rand(shape)
param_history = deque([torch.rand(shape) for i in range(L)], maxlen=L)
param_history_tensor = torch.stack(list(param_history))

for i in range(50):
    D = torch.randint(0, L, shape)
    param, param_history_tensor = delay_old(param, param_history_tensor, D)

5.1 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [132]:
%%timeit
param = torch.rand(shape)
param_history = deque([torch.rand(shape) for i in range(L)], maxlen=L)
param_history_tensor = torch.stack(list(param_history))

for i in range(50):
    D = torch.randint(0, L, shape)
    param, param_history_tensor = delay_inplace(param, param_history_tensor, D)

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [124]:
shape = (2,3)
L = 2

param = torch.rand(shape)
param_history = deque([torch.rand(shape) for i in range(L)], maxlen=L)
param_history_tensor = torch.stack(list(param_history))
D = torch.tensor([[1, 2, 0], 
                  [0, 1, 0]])

print(param)
print(param_history_tensor)

delay_inplace(param, param_history_tensor, D)

print(param)
print(param_history_tensor)

tensor([[0.0847, 0.0033, 0.5999],
        [0.6128, 0.8855, 0.4476]])
tensor([[[0.9649, 0.7076, 0.5187],
         [0.0799, 0.9483, 0.1927]],

        [[0.6909, 0.8019, 0.1471],
         [0.9378, 0.4082, 0.7502]]])
tensor([[0.9649, 0.8019, 0.5999],
        [0.6128, 0.9483, 0.4476]])
tensor([[[0.0847, 0.0033, 0.5999],
         [0.6128, 0.8855, 0.4476]],

        [[0.9649, 0.7076, 0.5187],
         [0.0799, 0.9483, 0.1927]]])


In [7]:
import os
from torch import nn

def get_test_params(filename="test_params.pt"):
    class XOR:
        def __init__(self):
            self.fc1 = nn.Linear(2,10)
            self.fc2 = nn.Linear(10,2)
            self.relu = nn.ReLU()

        def forward(self, x):
            return self.fc2(self.relu(self.fc1(x)))

    if filename in os.listdir():
        return torch.load(filename)
    else:
        
        model.

In [8]:
get_test_params()

AttributeError: 'Sequential' object has no attribute 'save'