In [3]:
import sys
sys.path.insert(0, "../..")

In [4]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms, models
from tqdm.notebook import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD

In [5]:
device = "cuda:1"
val_loss_fn = nn.NLLLoss(reduction="sum")

def validation(model, loader):
    ok = 0
    loss_sum = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            loss_sum += val_loss_fn(out, y).item()
            preds = out.argmax(1)
            ok += (preds == y).sum().item()
            total += len(y)
    return ok / total, loss_sum / total

def train_epoch(loss_log):
    model.train()
    for x, y in tqdm(train_loader):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss_log += list(loss.flatten().cpu().detach().numpy())
        loss.backward()
        optimizer.step()

In [6]:
augmentation = transforms.RandomAffine(10, scale=(0.9, 1.1), translate=(0.2, 0.2))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
rescale = transforms.Resize((224, 224))

transform = transforms.Compose([
    transforms.ToTensor(),
#     rescale,
    normalize,
])

train_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", 
                            download=True, train=True, transform=transforms.Compose([augmentation, transform]))
valid_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", download=True, train=True, transform=transform)
test_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", download=True, train=False, transform=transform)

train_indices, valid_indices = train_test_split(np.arange(len(train_ds)), test_size=0.2)
train_ds = data.Subset(train_ds, train_indices)
valid_ds = data.Subset(valid_ds, valid_indices)

train_loader = data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=10)
valid_loader = data.DataLoader(valid_ds, batch_size=128, shuffle=False, num_workers=10)
test_loader = data.DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=10)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [7]:
model = models.resnet34(pretrained=False)
model.fc = nn.Sequential(
    nn.Linear(512, 10),
    nn.LogSoftmax(-1)
)
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
sum(param.numel() for param in model.parameters())

21289802

In [None]:
validation(model, valid_loader)

In [6]:
log_file = open("resnet34_log_augmentation_no_momentum.txt.no_resizing", "w")

## No momentum

In [7]:
optimizer = AcceleratedSGD(model.parameters(), 1e-1, k=10, momentum=0, weight_decay=0, lambda_=1e-8)
loss_fn = nn.NLLLoss()

In [8]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 2.0391


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.3764, validation loss: 1.8502
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.6344


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.3575, validation loss: 3.0827
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.5015


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.4537, validation loss: 1.8319
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.3960


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.4053, validation loss: 2.3077
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.3093


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5393, validation loss: 1.3663
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.2286


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5712, validation loss: 1.2252
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.1531


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6048, validation loss: 1.1132
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0920


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5662, validation loss: 1.3065
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0451


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6151, validation loss: 1.0958
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0031


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6589, validation loss: 0.9514
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.9664


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6316, validation loss: 1.0214
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.9305


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6378, validation loss: 1.0774
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.8999


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6664, validation loss: 0.9496
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.8706


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6970, validation loss: 0.8577
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.8495


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7107, validation loss: 0.8323
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.8200


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6264, validation loss: 1.1343
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7942


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7161, validation loss: 0.8261
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7805


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6118, validation loss: 1.3972
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7587


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7162, validation loss: 0.8270
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7437


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7146, validation loss: 0.8365
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7202


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7078, validation loss: 0.8819
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7113


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7122, validation loss: 0.8458
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.6897


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7561, validation loss: 0.7158
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.6733


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6978, validation loss: 0.9203
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.6628


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7167, validation loss: 0.8620


In [9]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.734725, 0.7532561800003051)
Valid: (0.7167, 0.8619575891494751)


In [10]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [11]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.7419, 0.7232592183113098)
Valid: (0.7327, 0.7703360885620117)


In [12]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [13]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.744375, 0.7231218473434449)
Valid: (0.7327, 0.7703351764678955)


In [14]:
optimizer.param_groups[0]["lr"] = 1e-2

In [15]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5733


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7881, validation loss: 0.6241
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5467


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7913, validation loss: 0.6162
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5324


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7862, validation loss: 0.6231
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5286


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7935, validation loss: 0.6105
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5211


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7943, validation loss: 0.6093
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5185


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7926, validation loss: 0.6091
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5102


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7933, validation loss: 0.6205
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5083


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7927, validation loss: 0.6178
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5076


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7970, validation loss: 0.6178
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5007


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7981, validation loss: 0.6113
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4979


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7966, validation loss: 0.6132
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4886


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7968, validation loss: 0.6146
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4869


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7987, validation loss: 0.6163
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4846


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7961, validation loss: 0.6159
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4815


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7944, validation loss: 0.6212
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4777


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7957, validation loss: 0.6229
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4777


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8002, validation loss: 0.6104
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4699


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7969, validation loss: 0.6193
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4655


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8002, validation loss: 0.6107
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4672


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7980, validation loss: 0.6143
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4661


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7998, validation loss: 0.6203
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4616


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7991, validation loss: 0.6167
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4601


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7996, validation loss: 0.6209
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4566


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7946, validation loss: 0.6311
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4531


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8011, validation loss: 0.6217


In [16]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [17]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.8465, 0.43225871195793153)
Valid: (0.7954, 0.6391768536567688)


In [18]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [19]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.8477, 0.4333727992534637)
Valid: (0.7954, 0.6391767944335938)


In [20]:
optimizer.param_groups[0]["lr"] = 1e-3

In [21]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4374


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8028, validation loss: 0.6144
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4334


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8016, validation loss: 0.6174
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4385


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8024, validation loss: 0.6153
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4316


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8017, validation loss: 0.6156
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4364


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8027, validation loss: 0.6130
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4308


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8024, validation loss: 0.6124
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4271


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8005, validation loss: 0.6178
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4341


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8027, validation loss: 0.6207
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4286


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8027, validation loss: 0.6157
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4311


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8037, validation loss: 0.6166
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4279


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8022, validation loss: 0.6192
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4353


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8018, validation loss: 0.6215
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4308


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8027, validation loss: 0.6243
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4326


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8017, validation loss: 0.6220
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4307


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8041, validation loss: 0.6207
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4260


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8034, validation loss: 0.6225
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4335


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8043, validation loss: 0.6180
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4302


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8018, validation loss: 0.6213
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4268


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8011, validation loss: 0.6172
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4242


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8027, validation loss: 0.6170
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4298


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8016, validation loss: 0.6232
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4274


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8024, validation loss: 0.6221
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4208


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8018, validation loss: 0.6186
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4232


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8005, validation loss: 0.6230
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4204


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8017, validation loss: 0.6187


In [22]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.857325, 0.4013046481609345)
Valid: (0.8017, 0.6187197321891784)


In [23]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [24]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.8563, 0.40048055052757264)
Valid: (0.8018, 0.6194939519882202)


In [25]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [26]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.856875, 0.40265166931152346)
Valid: (0.8018, 0.6194941596984863)


## Momentum

In [10]:
model = models.resnet34(pretrained=False)
model.fc = nn.Sequential(
    nn.Linear(512, 10),
    nn.LogSoftmax(-1)
)
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
validation(model, valid_loader)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))




(0.0986, 3.7226751861572267)

In [12]:
log_file = open("resnet34_log_augmentation.txt.no_resizing", "w")

In [13]:
optimizer = AcceleratedSGD(model.parameters(), 1e-1, k=10, momentum=0.9, weight_decay=1e-5, lambda_=1e-8)
loss_fn = nn.NLLLoss()

In [None]:
epochs = 35

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 2.5956


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.2522, validation loss: 2.2859
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.9045


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.3145, validation loss: 1.8309
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.7471


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.3803, validation loss: 1.6533
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.6266


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.4345, validation loss: 1.5411
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.5416


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.4782, validation loss: 1.4209
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.4611


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5176, validation loss: 1.3581
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.4033


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5167, validation loss: 1.3725
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.3405


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5533, validation loss: 1.2685
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.2868


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5859, validation loss: 1.2477
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.2297


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5848, validation loss: 1.2309
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.1716


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6161, validation loss: 1.0746
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.1262


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6163, validation loss: 1.1169
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0895


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6207, validation loss: 1.0869
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0772


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6361, validation loss: 1.0371
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.0126


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6545, validation loss: 0.9986
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.9895


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6768, validation loss: 0.9241
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))

In [None]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
optimizer.param_groups[0]["lr"] = 1e-2

In [None]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

In [None]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
optimizer.param_groups[0]["lr"] = 1e-3

In [None]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

In [None]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.to(device)
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.to(device)

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)