In [1]:
import sys
sys.path.insert(0, "../..")

In [2]:
import torch
from torch import nn
import numpy as np
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD
from nn_utils import *

In [3]:
torch.cuda.is_available()

True

In [4]:
trainer = Trainer(
    device="cuda",
    loss_fn=nn.NLLLoss(reduction="mean"),
    val_loss_fn=nn.NLLLoss(reduction="sum"),
)

In [5]:
dl = load_dataset(
    dataset="mnist", 
    root="../../../MNIST", 
    download=False, 
    validation_split=0.2,
    batch_size=64, 
    num_workers=12,
)

In [6]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 3),
    nn.ReLU(),
    nn.Conv2d(32, 32, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4*4*64, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(-1),
)
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [7]:
initial_state = deepcopy(model.state_dict())

In [8]:
optimizer = AcceleratedSGD(model.parameters(), 1e-3, k=10, mode="epoch", method="RNA")
logger = Logger("SGD.txt")

## Epoch

In [9]:
torch.manual_seed(2020)
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    optimizer.finish_epoch()
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████| 750/750 [00:07<00:00, 100.08it/s, loss=2.3001]
Epoch 1 | Training loss: 2.3001, validation accuracy: 0.0973, validation loss: 2.2968
100%|██████████| 750/750 [00:07<00:00, 99.85it/s, loss=2.2890] 
Epoch 2 | Training loss: 2.2890, validation accuracy: 0.2014, validation loss: 2.2806
100%|██████████| 750/750 [00:07<00:00, 103.60it/s, loss=2.2543]
Epoch 3 | Training loss: 2.2543, validation accuracy: 0.4482, validation loss: 2.2046
100%|██████████| 750/750 [00:07<00:00, 102.61it/s, loss=1.7666]
Epoch 4 | Training loss: 1.7666, validation accuracy: 0.7831, validation loss: 0.8356
100%|██████████| 750/750 [00:07<00:00, 103.56it/s, loss=0.5810]
Epoch 5 | Training loss: 0.5810, validation accuracy: 0.8716, validation loss: 0.4288
100%|██████████| 750/750 [00:07<00:00, 102.40it/s, loss=0.4092]
Epoch 6 | Training loss: 0.4092, validation accuracy: 0.9009, validation loss: 0.3326
100%|██████████| 750/750 [00:07<00:00, 104.90it/s, loss=0.3325]
Epoch 7 | Training loss: 0.3325, val

In [10]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.98025, 0.06479943061837305)
Valid: (0.9749166666666667, 0.0810826615827779)


In [11]:
optimizer.accelerate()

In [12]:
optimizer.store_parameters()
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [13]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9812083333333333, 0.06093290462344885)
Valid: (0.97575, 0.07702405818800132)


## Epoch average

In [14]:
optimizer = AcceleratedSGD(model.parameters(), 1e-3, k=10, mode="epoch_avg", method="RNA")
logger = Logger("SGD-avg.txt")
model.load_state_dict(initial_state)

<All keys matched successfully>

In [15]:
torch.manual_seed(2020)
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    optimizer.finish_epoch()
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████| 750/750 [00:07<00:00, 98.64it/s, loss=2.3001] 
Epoch 1 | Training loss: 2.3001, validation accuracy: 0.0973, validation loss: 2.2968
100%|██████████| 750/750 [00:07<00:00, 94.93it/s, loss=2.2890] 
Epoch 2 | Training loss: 2.2890, validation accuracy: 0.2014, validation loss: 2.2806
100%|██████████| 750/750 [00:07<00:00, 98.46it/s, loss=2.2543] 
Epoch 3 | Training loss: 2.2543, validation accuracy: 0.4482, validation loss: 2.2046
100%|██████████| 750/750 [00:07<00:00, 100.03it/s, loss=1.7666]
Epoch 4 | Training loss: 1.7666, validation accuracy: 0.7831, validation loss: 0.8357
100%|██████████| 750/750 [00:07<00:00, 99.74it/s, loss=0.5811] 
Epoch 5 | Training loss: 0.5811, validation accuracy: 0.8718, validation loss: 0.4288
100%|██████████| 750/750 [00:07<00:00, 96.74it/s, loss=0.4091] 
Epoch 6 | Training loss: 0.4091, validation accuracy: 0.9010, validation loss: 0.3324
100%|██████████| 750/750 [00:07<00:00, 99.45it/s, loss=0.3325] 
Epoch 7 | Training loss: 0.3325, val

In [16]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9803958333333334, 0.06431959083303809)
Valid: (0.97525, 0.08072126029680173)


In [17]:
optimizer.accelerate()

In [18]:
optimizer.store_parameters()
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [19]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9751041666666667, 0.08772928147763014)
Valid: (0.9718333333333333, 0.09713368124514818)


## Split + epoch

In [32]:
logger = Logger("SGD-split.txt")
model.load_state_dict(initial_state)
groups = [{"params": [param]} for param in model.parameters()]
optimizer = AcceleratedSGD(groups, 1e-3, k=10, mode="epoch")

In [33]:
torch.manual_seed(2020)
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    optimizer.finish_epoch()
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████| 750/750 [00:07<00:00, 99.17it/s, loss=2.3001] 
Epoch 1 | Training loss: 2.3001, validation accuracy: 0.0973, validation loss: 2.2968
100%|██████████| 750/750 [00:07<00:00, 100.42it/s, loss=2.2890]
Epoch 2 | Training loss: 2.2890, validation accuracy: 0.2015, validation loss: 2.2806
100%|██████████| 750/750 [00:07<00:00, 103.58it/s, loss=2.2543]
Epoch 3 | Training loss: 2.2543, validation accuracy: 0.4482, validation loss: 2.2046
100%|██████████| 750/750 [00:07<00:00, 99.06it/s, loss=1.7664] 
Epoch 4 | Training loss: 1.7664, validation accuracy: 0.7831, validation loss: 0.8354
100%|██████████| 750/750 [00:07<00:00, 101.39it/s, loss=0.5811]
Epoch 5 | Training loss: 0.5811, validation accuracy: 0.8718, validation loss: 0.4288
100%|██████████| 750/750 [00:07<00:00, 99.29it/s, loss=0.4092] 
Epoch 6 | Training loss: 0.4092, validation accuracy: 0.9010, validation loss: 0.3324
100%|██████████| 750/750 [00:07<00:00, 100.14it/s, loss=0.3326]
Epoch 7 | Training loss: 0.3326, val

In [34]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9800625, 0.06498881517040232)
Valid: (0.975, 0.08112968947614232)


In [35]:
optimizer.accelerate()

In [36]:
optimizer.store_parameters([[param] for param in model.parameters()])
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [37]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9785, 0.06886476942043131)
Valid: (0.9743333333333334, 0.08699206533686568)


## Split + epoch average

In [26]:
logger = Logger("SGD-split-avg.txt")
model.load_state_dict(initial_state)
groups = [{"params": [param]} for param in model.parameters()]
optimizer = AcceleratedSGD(groups, 1e-3, k=10, mode="epoch_avg")

In [27]:
torch.manual_seed(2020)
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    optimizer.finish_epoch()
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████| 750/750 [00:09<00:00, 80.04it/s, loss=2.3001]
Epoch 1 | Training loss: 2.3001, validation accuracy: 0.0973, validation loss: 2.2968
100%|██████████| 750/750 [00:09<00:00, 81.56it/s, loss=2.2890]
Epoch 2 | Training loss: 2.2890, validation accuracy: 0.2014, validation loss: 2.2806
100%|██████████| 750/750 [00:09<00:00, 80.98it/s, loss=2.2543]
Epoch 3 | Training loss: 2.2543, validation accuracy: 0.4482, validation loss: 2.2046
100%|██████████| 750/750 [00:09<00:00, 82.98it/s, loss=1.7665]
Epoch 4 | Training loss: 1.7665, validation accuracy: 0.7831, validation loss: 0.8355
100%|██████████| 750/750 [00:09<00:00, 80.27it/s, loss=0.5811]
Epoch 5 | Training loss: 0.5811, validation accuracy: 0.8721, validation loss: 0.4288
100%|██████████| 750/750 [00:09<00:00, 80.74it/s, loss=0.4092]
Epoch 6 | Training loss: 0.4092, validation accuracy: 0.9010, validation loss: 0.3326
100%|██████████| 750/750 [00:09<00:00, 81.51it/s, loss=0.3326]
Epoch 7 | Training loss: 0.3326, validation

In [28]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.98, 0.064899891360042)
Valid: (0.9750833333333333, 0.08100792063275973)


In [29]:
optimizer.accelerate()

In [30]:
optimizer.store_parameters([[param] for param in model.parameters()])
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [31]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9740625, 0.08288766444909076)
Valid: (0.9721666666666666, 0.09380900078018506)
