In [1]:
import sys
sys.path.insert(0, "../..")

In [2]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD

In [3]:
torch.cuda.is_available()

True

In [4]:
val_loss_fn = nn.NLLLoss(reduction="sum")

def validation(model, loader):
    ok = 0
    loss_sum = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            out = model(x)
            loss_sum += val_loss_fn(out, y)
            preds = out.argmax(1)
            ok += (y == preds).sum()
            total += len(y)
    return ok / total, loss_sum / total

def train_epoch():
    model.train()
    param_hist = []
    for i, (x, y) in enumerate(tqdm(train_loader)):
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
    return torch.tensor(np.mean(param_hist, 0))

In [5]:
train_ds = datasets.MNIST("../../../MNIST", download=True, train=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST("../../../MNIST", download=True, train=False, transform=transforms.ToTensor())
valid_size = int(0.2 * len(train_ds))
train_ds, valid_ds = data.random_split(train_ds, [len(train_ds) - valid_size, valid_size])

train_loader = data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2)
valid_loader = data.DataLoader(valid_ds, batch_size=64, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

In [6]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 3),
    nn.ReLU(),
    nn.Conv2d(32, 32, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4*4*64, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(-1),
)
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [7]:
initial_state = deepcopy(model.state_dict())

In [8]:
optimizer = AcceleratedSGD(model.parameters(), 1e-3, k=10, momentum=0.5, weight_decay=1e-5, mode="epoch")
loss_fn = nn.NLLLoss()

## Epoch

In [9]:
epochs = 30

for epoch in range(epochs):
    print("Epoch", epoch+1)
    train_epoch()
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.0999, validation loss: 2.3000
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.2002, validation loss: 2.2963
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1841, validation loss: 2.2886
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.3984, validation loss: 2.2606
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.7039, validation loss: 1.2695
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8671, validation loss: 0.4441
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8942, validation loss: 0.3493
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9143, validation loss: 0.2826
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9250, validation loss: 0.2467
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9394, validation loss: 0.2051
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9442, validation loss: 0.1832
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9538, validation loss: 0.1552
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9580, validation loss: 0.1439
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9597, validation loss: 0.1348
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9630, validation loss: 0.1288
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9632, validation loss: 0.1191
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9672, validation loss: 0.1139
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9701, validation loss: 0.1031
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9731, validation loss: 0.0940
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9692, validation loss: 0.1074
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9734, validation loss: 0.0918
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9754, validation loss: 0.0874
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9751, validation loss: 0.0881
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9753, validation loss: 0.0840
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9765, validation loss: 0.0792
Epoch 26


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9770, validation loss: 0.0812
Epoch 27


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9776, validation loss: 0.0792
Epoch 28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9784, validation loss: 0.0756
Epoch 29


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9761, validation loss: 0.0805
Epoch 30


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9757, validation loss: 0.0832


In [10]:
print("Train:", validation(model, train_loader))
print("Valid:", validation(model, valid_loader))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Train: (tensor(0.9815, device='cuda:0'), tensor(0.0598, device='cuda:0'))


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Valid: (tensor(0.9757, device='cuda:0'), tensor(0.0832, device='cuda:0'))


In [11]:
optimizer.accelerate()

In [12]:
optimizer.store_parameters()

In [13]:
print("Train:", validation(model, train_loader))
print("Valid:", validation(model, valid_loader))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Train: (tensor(0.9856, device='cuda:0'), tensor(0.0491, device='cuda:0'))


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Valid: (tensor(0.9794, device='cuda:0'), tensor(0.0719, device='cuda:0'))


## Epoch average

In [14]:
model.load_state_dict(initial_state)
optimizer = AcceleratedSGD(model.parameters(), 1e-3, k=10, momentum=0.5, weight_decay=1e-5, mode="epoch_avg")

In [15]:
epochs = 30

for epoch in range(epochs):
    print("Epoch", epoch+1)
    train_epoch()
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.0999, validation loss: 2.3000
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.2008, validation loss: 2.2963
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1816, validation loss: 2.2887
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.4075, validation loss: 2.2608
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.7058, validation loss: 1.2917
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8565, validation loss: 0.4638
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8925, validation loss: 0.3501
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9116, validation loss: 0.2893
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9296, validation loss: 0.2350
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9403, validation loss: 0.2015
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9441, validation loss: 0.1816
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9536, validation loss: 0.1590
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9606, validation loss: 0.1406
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9608, validation loss: 0.1338
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9613, validation loss: 0.1279
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9646, validation loss: 0.1176
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9691, validation loss: 0.1082
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9697, validation loss: 0.1014
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9719, validation loss: 0.0957
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9703, validation loss: 0.1032
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9703, validation loss: 0.1029
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9737, validation loss: 0.0896
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9729, validation loss: 0.0933
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9742, validation loss: 0.0855
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9762, validation loss: 0.0808
Epoch 26


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9758, validation loss: 0.0816
Epoch 27


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9770, validation loss: 0.0779
Epoch 28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9768, validation loss: 0.0769
Epoch 29


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9778, validation loss: 0.0769
Epoch 30


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9778, validation loss: 0.0772


In [16]:
print("Train:", validation(model, train_loader))
print("Valid:", validation(model, valid_loader))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Train: (tensor(0.9829, device='cuda:0'), tensor(0.0580, device='cuda:0'))


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Valid: (tensor(0.9778, device='cuda:0'), tensor(0.0772, device='cuda:0'))


In [17]:
optimizer.accelerate()

In [18]:
optimizer.store_parameters()

In [19]:
print("Train:", validation(model, train_loader))
print("Valid:", validation(model, valid_loader))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Train: (tensor(0.9866, device='cuda:0'), tensor(0.0464, device='cuda:0'))


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Valid: (tensor(0.9803, device='cuda:0'), tensor(0.0682, device='cuda:0'))
