In [1]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from copy import deepcopy
from collections import deque

from extrapolation import *
from experiments import difference_matrix

In [2]:
torch.cuda.is_available()

True

In [3]:
train_ds = datasets.MNIST("../MNIST", download=True, train=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST("../MNIST", download=True, train=False, transform=transforms.ToTensor())
valid_size = int(0.2 * len(train_ds))
train_ds, valid_ds = data.random_split(train_ds, [len(train_ds) - valid_size, valid_size])

train_loader = data.DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2)
valid_loader = data.DataLoader(valid_ds, batch_size=64, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

In [4]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 3),
    nn.ReLU(),
    nn.Conv2d(32, 32, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 3),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4*4*64, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(-1),
)
model.cuda()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Flatten(start_dim=1, end_dim=-1)
  (11): Linear(in_features=1024, out_features=128, bias=True)
  (12): ReLU()
  (13): Linear(in_features=128, out_features=10, bias=True)
  (14): LogSoftmax(dim=-1)
)

In [5]:
val_loss_fn = nn.NLLLoss(reduction="sum")

def validation(model, loader):
    ok = 0
    loss_sum = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            out = model(x)
            loss_sum += val_loss_fn(out, y)
            preds = out.argmax(1)
            ok += (y == preds).sum()
            total += len(y)
    return ok / total, loss_sum / total

def train_epoch(loss_log):
    model.train()
    for x, y in tqdm(train_loader):
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        loss_log.append(loss.item())
        optimizer.step()

In [6]:
validation(model, valid_loader)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))




(tensor(0.0660, device='cuda:0'), tensor(2.3024, device='cuda:0'))

## Without acceleration

In [7]:
initial_state = deepcopy(model.state_dict())
optimizer = torch.optim.SGD(model.parameters(), 3e-3)
loss_fn = nn.NLLLoss()

In [8]:
epochs = 30
without_acc = {"train_loss": [], "val_loss": [], "val_acc": []}

for epoch in range(epochs):
    print("Epoch", epoch+1)
    train_epoch(without_acc["train_loss"])
    val_acc, val_loss = validation(model, valid_loader)
    without_acc["val_loss"].append(val_loss)
    without_acc["val_acc"].append(val_acc)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.3003
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2979
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2926
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.3968, validation loss: 2.2704
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.7433, validation loss: 0.9036
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.8701, validation loss: 0.4250
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9067, validation loss: 0.3104
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9255, validation loss: 0.2424
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9421, validation loss: 0.1855
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9492, validation loss: 0.1672
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9557, validation loss: 0.1426
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9588, validation loss: 0.1308
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9599, validation loss: 0.1265
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9670, validation loss: 0.1083
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9688, validation loss: 0.0981
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9695, validation loss: 0.0992
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9721, validation loss: 0.0894
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9743, validation loss: 0.0852
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9747, validation loss: 0.0829
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9764, validation loss: 0.0743
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9763, validation loss: 0.0765
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9766, validation loss: 0.0753
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9787, validation loss: 0.0709
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9793, validation loss: 0.0674
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9787, validation loss: 0.0723
Epoch 26


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9785, validation loss: 0.0689
Epoch 27


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9806, validation loss: 0.0652
Epoch 28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9815, validation loss: 0.0637
Epoch 29


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9803, validation loss: 0.0682
Epoch 30


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9812, validation loss: 0.0621


In [9]:
final_without_acc = deepcopy(model.state_dict())

## With acceleration, online scheme

In [10]:
def params_to_vector(parameters):
    param_vectors = []
    for param in parameters:
        param_vectors.append(param.data.flatten().cpu())
    return torch.hstack(param_vectors)

def params_from_vector(parameters, x):
    idx = 0
    for param in parameters:
        n = param.data.numel()
        param.data[:] = x[idx:idx+n].view(param.data.shape)
        idx += n

In [11]:
model.load_state_dict(initial_state)
optimizer = torch.optim.SGD(model.parameters(), 3e-3)

In [12]:
epochs = 30
k = 5
with_acc = {"train_loss": [], "val_loss": [], "val_acc": []}
model_hist = deque(maxlen=k)

for epoch in range(epochs):
    print("Epoch", epoch+1)
    train_epoch(with_acc["train_loss"])

    x = params_to_vector(model.parameters())
    if len(model_hist) >= k:
        U = difference_matrix(list(model_hist) + [x])
        x = regularized_RRE(torch.vstack(list(model_hist)), U, 1e-5)
        params_from_vector(model.parameters(), x)
    model_hist.append(x)
    
    val_acc, val_loss = validation(model, valid_loader)
    with_acc["val_loss"].append(val_loss)
    with_acc["val_acc"].append(val_acc)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.3003
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2979
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2926
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.4131, validation loss: 2.2701
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.7546, validation loss: 0.8733
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.3012
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.4416, validation loss: 2.2642
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5419, validation loss: 2.2266
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5213, validation loss: 2.2414
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5068, validation loss: 2.2483
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5165, validation loss: 2.2449
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5165, validation loss: 2.2448
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5168, validation loss: 2.2448
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5149, validation loss: 2.2452
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5146, validation loss: 2.2454
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5162, validation loss: 2.2449
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2450
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5154, validation loss: 2.2451
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5151, validation loss: 2.2451
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5151, validation loss: 2.2451
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5153, validation loss: 2.2450
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 26


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 27


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 29


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451
Epoch 30


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.5152, validation loss: 2.2451


## With acceleration, offline scheme

In [15]:
model.load_state_dict(initial_state)
optimizer = torch.optim.SGD(model.parameters(), 3e-3)

In [16]:
epochs = 30
k = 5
with_acc_offline = {"train_loss": [], "val_loss": [], "val_acc": []}
model_hist = deque(maxlen=k)

for epoch in range(epochs):
    print("Epoch", epoch+1)
    train_epoch(with_acc_offline["train_loss"])
    
    model_acc = None
    x = params_to_vector(model.parameters())
    if len(model_hist) >= k:
        U = difference_matrix(list(model_hist) + [x])
        y = regularized_RRE(torch.vstack(list(model_hist)), U, 1e-5)
        model_acc = deepcopy(model)
        params_from_vector(model_acc.parameters(), y)
    model_hist.append(x)
        
    if model_acc is None:
        val_acc, val_loss = validation(model, valid_loader)
    else:
        val_acc, val_loss = validation(model_acc, valid_loader)
    with_acc_offline["val_loss"].append(val_loss)
    with_acc_offline["val_acc"].append(val_acc)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.3003
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2979
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2926
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.4125, validation loss: 2.2704
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.7243, validation loss: 0.9239
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.3012
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.1144, validation loss: 2.2989
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.2669, validation loss: 2.2818
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9163, validation loss: 0.2655
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9369, validation loss: 0.1991
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9467, validation loss: 0.1724
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9631, validation loss: 0.1234
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9632, validation loss: 0.1167
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9657, validation loss: 0.1088
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9667, validation loss: 0.1008
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9682, validation loss: 0.0994
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9723, validation loss: 0.0880
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9743, validation loss: 0.0831
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9748, validation loss: 0.0828
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9756, validation loss: 0.0806
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9767, validation loss: 0.0761
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9778, validation loss: 0.0715
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9790, validation loss: 0.0694
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9787, validation loss: 0.0688
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9791, validation loss: 0.0677
Epoch 26


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9795, validation loss: 0.0657
Epoch 27


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9799, validation loss: 0.0645
Epoch 28


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9807, validation loss: 0.0637
Epoch 29


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9816, validation loss: 0.0627
Epoch 30


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=188.0), HTML(value='')))


Validation accuracy: 0.9812, validation loss: 0.0618


In [17]:
results = {
    "without_acc": without_acc,
    "with_acc_online": with_acc,
    "with_acc_offline": without_acc,
}
torch.save(results, "first_nn_results.p")