In [1]:
import sys
sys.path.append("../..")

In [2]:
import torch
from torch import nn
from torchvision import models
from copy import deepcopy
import os

from nn_extrapolation import AcceleratedSGD
from nn_utils import *

In [3]:
trainer = Trainer(
    device="cuda",
    loss_fn=nn.NLLLoss(reduction="mean"),
    val_loss_fn=nn.NLLLoss(reduction="sum"),
)

In [4]:
dl = load_dataset(
    dataset="CIFAR10",
    root=os.path.join("/tmp", os.environ["USER"], "CIFAR"),
    augmentation=transforms.RandomAffine(10, scale=(0.9, 1.1), translate=(0.2, 0.2)),
    validation_split=0.2,
    batch_size=128,
    num_workers=10,
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
l = nn.Conv2d(3, 5, (3, 3))

In [6]:
model = models.vgg16(pretrained=False)
model.classifier[6] = nn.Linear(4096, 10)
model.classifier.add_module("7", nn.LogSoftmax(-1))

for name, mod in model.features.named_children():
    idx = int(name)
    if isinstance(mod, nn.Conv2d):
        model.features.add_module(str(idx + 1), nn.BatchNorm2d(mod.out_channels))

model.to(trainer.device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

In [7]:
sum(param.numel() for param in model.parameters())

134309962

In [8]:
trainer.validation(model, dl["valid"])

(0.0952, 2.572225959777832)

## Momentum

In [9]:
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
logger = Logger("vgg_log_augmentation_adam_bn.txt.no_resizing")

In [10]:
epochs = 45

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

100%|██████████████████████████████████████████████████████████| 313/313 [00:27<00:00, 11.56it/s, loss=1.8243]
Epoch 1 | Training loss: 1.8243, validation accuracy: 0.4576, validation loss: 1.4929
100%|██████████████████████████████████████████████████████████| 313/313 [00:27<00:00, 11.55it/s, loss=1.4246]
Epoch 2 | Training loss: 1.4246, validation accuracy: 0.5654, validation loss: 1.2182
100%|██████████████████████████████████████████████████████████| 313/313 [00:27<00:00, 11.53it/s, loss=1.2898]
Epoch 3 | Training loss: 1.2898, validation accuracy: 0.6025, validation loss: 1.1355
100%|██████████████████████████████████████████████████████████| 313/313 [00:27<00:00, 11.51it/s, loss=1.1922]
Epoch 4 | Training loss: 1.1922, validation accuracy: 0.6515, validation loss: 0.9769
  6%|███▊                                                       | 20/313 [00:02<00:30,  9.71it/s, loss=1.1727]


KeyboardInterrupt: 

In [None]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)

In [None]:
optimizer.param_groups[0]["lr"] = 1e-4

In [None]:
epochs = 30

for epoch in range(epochs):
    train_loss = trainer.train_epoch(model, optimizer, dl["train"])
    optimizer.finish_epoch()
    val_acc, val_loss = trainer.validation(model, dl["valid"])
    logger.log("Epoch", epoch+1, "|", 
          f"Training loss: {train_loss:.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

In [None]:
train_score = trainer.validation(model, dl["train"])
valid_score = trainer.validation(model, dl["valid"])
logger.log("Train:", train_score)
logger.log("Valid:", valid_score)