In [1]:
import sys
sys.path.insert(0, "../..")

In [2]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD

In [3]:
torch.cuda.is_available()

True

In [4]:
val_loss_fn = nn.NLLLoss(reduction="sum")

def validation(model, loader):
    ok = 0
    loss_sum = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            out = model(x)
            loss_sum += val_loss_fn(out, y)
            preds = out.argmax(1)
            ok += (y == preds).sum()
            total += len(y)
    return ok / total, loss_sum / total

def train_epoch(loss_log):
    model.train()
    for x, y in tqdm(train_loader):
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss_log += list(loss.flatten().cpu().detach().numpy())
        loss.backward()
        optimizer.step()

In [5]:
train_ds = datasets.MNIST("../../../MNIST", download=True, train=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST("../../../MNIST", download=True, train=False, transform=transforms.ToTensor())
valid_size = int(0.2 * len(train_ds))
train_ds, valid_ds = data.random_split(train_ds, [len(train_ds) - valid_size, valid_size])

train_loader = data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2)
valid_loader = data.DataLoader(valid_ds, batch_size=64, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

In [6]:
def make_model():
    return nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 32, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.Conv2d(64, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten(),
        nn.Linear(4*4*64, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
        nn.LogSoftmax(-1),
    )

In [7]:
model = make_model()
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [8]:
log_file = open("SGD_momentum-short-cont.txt", "w")

In [9]:
optimizer = AcceleratedSGD(model.parameters(), 1e-3, k=10, momentum=0.5, weight_decay=1e-5, mode="epoch")
loss_fn = nn.NLLLoss()

## Epoch

In [10]:
epochs = 20

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.3027


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1212, validation loss: 2.3020
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.3006


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1017, validation loss: 2.2998
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.2981


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1017, validation loss: 2.2966
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.2938


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1018, validation loss: 2.2907
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.2837


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.3359, validation loss: 2.2735
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 2.2234


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.4958, validation loss: 2.0727
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 1.0449


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8308, validation loss: 0.5286
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.4413


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8741, validation loss: 0.4101
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.3488


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9030, validation loss: 0.3206
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.2878


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9218, validation loss: 0.2641
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.2443


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9343, validation loss: 0.2214
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.2097


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9287, validation loss: 0.2229
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1839


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9457, validation loss: 0.1766
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1623


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9511, validation loss: 0.1630
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1453


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9571, validation loss: 0.1384
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1319


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9604, validation loss: 0.1317
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1220


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9625, validation loss: 0.1242
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1134


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9610, validation loss: 0.1242
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.1060


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9653, validation loss: 0.1141
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0994


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9654, validation loss: 0.1098


In [11]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Train: (tensor(0.9690, device='cuda:0'), tensor(0.0997, device='cuda:0'))
Valid: (tensor(0.9654, device='cuda:0'), tensor(0.1098, device='cuda:0'))


In [12]:
%%time
optimizer.accelerate()

CPU times: user 97.1 ms, sys: 33 µs, total: 97.1 ms
Wall time: 33.6 ms


In [13]:
model_acc = deepcopy(model)
optimizer.store_parameters([model_acc.parameters()])
model_acc.cuda()
None

In [14]:
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)
model_acc.cpu()
None

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Train: (tensor(0.9739, device='cuda:0'), tensor(0.0867, device='cuda:0'))
Valid: (tensor(0.9688, device='cuda:0'), tensor(0.0979, device='cuda:0'))


In [15]:
epochs = 10

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0944


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9706, validation loss: 0.0970
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0905


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9707, validation loss: 0.0942
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0853


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9707, validation loss: 0.0902
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0818


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9700, validation loss: 0.0973
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0786


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9743, validation loss: 0.0828
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0755


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9727, validation loss: 0.0888
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0738


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9751, validation loss: 0.0799
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0698


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9735, validation loss: 0.0834
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0680


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9757, validation loss: 0.0760
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0654


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9750, validation loss: 0.0780


In [16]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Train: (tensor(0.9796, device='cuda:0'), tensor(0.0656, device='cuda:0'))
Valid: (tensor(0.9750, device='cuda:0'), tensor(0.0780, device='cuda:0'))


In [18]:
optimizer.accelerate()
model_acc = deepcopy(model)
optimizer.store_parameters([model_acc.parameters()])
model_acc.cuda()
None

In [19]:
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)
model_acc.cpu()
None

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Train: (tensor(0.9818, device='cuda:0'), tensor(0.0580, device='cuda:0'))
Valid: (tensor(0.9762, device='cuda:0'), tensor(0.0724, device='cuda:0'))


In [20]:
epochs = 5

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0633


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9756, validation loss: 0.0733
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0606


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9754, validation loss: 0.0760
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0589


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9782, validation loss: 0.0681
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0570


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9767, validation loss: 0.0711
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


Training loss: 0.0556


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9779, validation loss: 0.0667
