# Example of using STORM optimizer

## Load the libraries as usual and import the STORM optimizer

In [27]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from storm import STORM

torch.manual_seed(0)

<torch._C.Generator at 0x107f56230>

## Define the model, the objective function as usual and initialize the STORM optimizer

In [28]:
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)
optimizer = STORM(net.parameters(), k=1, w=1, c=30, foreach=True)
criterion = nn.CrossEntropyLoss()

datapath = "datasets"
dataset = datasets.MNIST(datapath, True, transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset, batch_size=100, shuffle=True)

## Run the STORM optimizer

- Initialize the `cur_batch` to `None`.
- Loop over the batches.
    1. If `cur_batch` is `None`, skip the first batch.
    2. If `cur_batch` is not `None`, update the model with the current batch.
    3. Compute the future batch gradients on the current model.
    4. Compute the current batch gradients and update the model.
    5. Update the `cur_batch`.

In [29]:
# initial cur_batch to None
cur_batch = None

for e in range(5):
    for i, fur_batch in enumerate(train_loader):
        # skip first batch as current batch is None
        if cur_batch is None:
            cur_batch = fur_batch
            continue

        # future gradient step
        future_output = net(fur_batch[0])
        future_loss = criterion(future_output, fur_batch[1])
        future_loss.backward()
        optimizer.store_next_grad()
        optimizer.zero_grad()

        # update main model
        output = net(cur_batch[0])
        loss = criterion(output, cur_batch[1])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        acc = (output.argmax(dim=1) == cur_batch[1]).float().mean()

        # update current batch
        cur_batch = fur_batch

        if i % 100 == 0:
            print(
                f"Epoch {e}, step {i:3d}, Loss: {loss.item():.3f}, Acc: {acc.item():.3f}, "
                f"lr: {optimizer.param_groups[0]['lr']:.3f}, momentum: {optimizer.param_groups[0]['momentum']:.3f}"
            )

Epoch 0, step 100, Loss: 0.297, Acc: 0.940, lr: 0.193, momentum: 0.000
Epoch 0, step 200, Loss: 0.421, Acc: 0.920, lr: 0.160, momentum: 0.233
Epoch 0, step 300, Loss: 0.282, Acc: 0.940, lr: 0.143, momentum: 0.390
Epoch 0, step 400, Loss: 0.258, Acc: 0.900, lr: 0.131, momentum: 0.481
Epoch 0, step 500, Loss: 0.308, Acc: 0.910, lr: 0.123, momentum: 0.543
Epoch 1, step   0, Loss: 0.273, Acc: 0.910, lr: 0.117, momentum: 0.589
Epoch 1, step 100, Loss: 0.168, Acc: 0.950, lr: 0.112, momentum: 0.624
Epoch 1, step 200, Loss: 0.126, Acc: 0.980, lr: 0.108, momentum: 0.652
Epoch 1, step 300, Loss: 0.152, Acc: 0.960, lr: 0.104, momentum: 0.674
Epoch 1, step 400, Loss: 0.282, Acc: 0.940, lr: 0.101, momentum: 0.693
Epoch 1, step 500, Loss: 0.367, Acc: 0.900, lr: 0.098, momentum: 0.710
Epoch 2, step   0, Loss: 0.124, Acc: 0.950, lr: 0.096, momentum: 0.724
Epoch 2, step 100, Loss: 0.126, Acc: 0.950, lr: 0.094, momentum: 0.736
Epoch 2, step 200, Loss: 0.133, Acc: 0.950, lr: 0.092, momentum: 0.747
Epoch 