In [8]:
import sys
sys.path.insert(0, "../..")

In [9]:
import torch
from torch import nn
import numpy as np
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD
from nn_utils import *

In [10]:
torch.cuda.is_available()

True

In [11]:
trainer = Trainer(
    device="cuda",
    loss_fn=nn.NLLLoss(reduction="mean"),
    val_loss_fn=nn.NLLLoss(reduction="sum"),
)

In [12]:
dl = load_dataset(
    dataset="mnist", 
    root="../../../MNIST", 
    download=False, 
    validation_split=0.2,
    batch_size=64, 
    num_workers=2,
)

In [13]:
def make_model():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 32, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.Conv2d(64, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(4*4*64, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
        nn.LogSoftmax(-1),
    )

## Epoch

In [17]:
model = make_model()
model.to(trainer.device)

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [18]:
optimizer = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-5)
logger = Logger("SGD_momentum_adam.txt")

In [19]:
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████| 750/750 [00:13<00:00, 54.91it/s, loss=0.1697]
Epoch 1 | Training loss: 0.1697, validation accuracy: 0.9838, validation loss: 0.0576
100%|██████████| 750/750 [00:13<00:00, 56.90it/s, loss=0.0461]
Epoch 2 | Training loss: 0.0461, validation accuracy: 0.9871, validation loss: 0.0445
100%|██████████| 750/750 [00:13<00:00, 56.62it/s, loss=0.0319]
Epoch 3 | Training loss: 0.0319, validation accuracy: 0.9892, validation loss: 0.0375
100%|██████████| 750/750 [00:13<00:00, 54.73it/s, loss=0.0252]
Epoch 4 | Training loss: 0.0252, validation accuracy: 0.9872, validation loss: 0.0402
100%|██████████| 750/750 [00:13<00:00, 56.45it/s, loss=0.0197]
Epoch 5 | Training loss: 0.0197, validation accuracy: 0.9890, validation loss: 0.0383
100%|██████████| 750/750 [00:13<00:00, 56.72it/s, loss=0.0157]
Epoch 6 | Training loss: 0.0157, validation accuracy: 0.9891, validation loss: 0.0414
100%|██████████| 750/750 [00:13<00:00, 55.29it/s, loss=0.0134]
Epoch 7 | Training loss: 0.0134, validation

In [20]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

Train: (0.9985833333333334, 0.004241814148521134)
Valid: (0.9895833333333334, 0.06418128799728837)
