In [2]:
import sys
sys.path.insert(0, "../..")

In [3]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms, models
from tqdm.notebook import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from copy import deepcopy

from nn_extrapolation import AcceleratedSGD

In [4]:
val_loss_fn = nn.NLLLoss(reduction="sum")

def validation(model, loader):
    ok = 0
    loss_sum = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.cuda()
            y = y.cuda()
            out = model(x)
            loss_sum += val_loss_fn(out, y).item()
            preds = out.argmax(1)
            ok += (preds == y).sum().item()
            total += len(y)
    return ok / total, loss_sum / total

def train_epoch(loss_log):
    model.train()
    for x, y in tqdm(train_loader):
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss_log += list(loss.flatten().cpu().detach().numpy())
        loss.backward()
        optimizer.step()

In [6]:
augmentation = transforms.RandomAffine(10, scale=(0.9, 1.1), translate=(0.2, 0.2))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
rescale = transforms.Resize((224, 224))

transform = transforms.Compose([
    transforms.ToTensor(),
#     rescale,
    normalize,
])

train_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", 
                            download=True, train=True, transform=transforms.Compose([augmentation, transform]))
valid_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", download=True, train=True, transform=transform)
test_ds = datasets.CIFAR10("/tmp/i291318/CIFAR", download=True, train=False, transform=transform)

train_indices, valid_indices = train_test_split(np.arange(len(train_ds)), test_size=0.2)
train_ds = data.Subset(train_ds, train_indices)
valid_ds = data.Subset(valid_ds, valid_indices)

train_loader = data.DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=10)
valid_loader = data.DataLoader(valid_ds, batch_size=128, shuffle=False, num_workers=10)
test_loader = data.DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=10)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = models.resnet18(pretrained=False)
model.fc = nn.Sequential(
    nn.Linear(512, 10),
    nn.LogSoftmax(-1)
)
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
validation(model, valid_loader)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))




(0.0976, 2.4530729885101317)

In [6]:
log_file = open("resnet_log_augmentation_averaging.txt", "w")

## Momentum

In [7]:
optimizer = AcceleratedSGD(model.parameters(), 1e-1, k=10, momentum=0.9, weight_decay=1e-5, lambda_=1e-8, mode="epoch_avg")
loss_fn = nn.NLLLoss()

In [8]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

Epoch 1


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 2.0710


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.3430, validation loss: 1.7220
Epoch 2


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.6113


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.4539, validation loss: 1.5009
Epoch 3


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.3298


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5376, validation loss: 1.3062
Epoch 4


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 1.1100


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.5886, validation loss: 1.3332
Epoch 5


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.9500


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6342, validation loss: 1.0602
Epoch 6


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.8271


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.6876, validation loss: 0.9402
Epoch 7


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.7212


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7228, validation loss: 0.8282
Epoch 8


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.6527


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7505, validation loss: 0.7912
Epoch 9


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5919


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7653, validation loss: 0.7142
Epoch 10


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5468


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7859, validation loss: 0.6676
Epoch 11


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.5053


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7957, validation loss: 0.6466
Epoch 12


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4699


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.7988, validation loss: 0.6334
Epoch 13


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4379


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8214, validation loss: 0.5571
Epoch 14


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.4056


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8257, validation loss: 0.5521
Epoch 15


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.3756


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8267, validation loss: 0.5531
Epoch 16


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.3521


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8420, validation loss: 0.5011
Epoch 17


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.3328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8132, validation loss: 0.6395
Epoch 18


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.3095


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8464, validation loss: 0.5138
Epoch 19


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2874


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8539, validation loss: 0.4981
Epoch 20


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2575


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8401, validation loss: 0.5823
Epoch 21


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8405, validation loss: 0.5806
Epoch 22


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2310


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8640, validation loss: 0.4726
Epoch 23


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2158


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8534, validation loss: 0.5380
Epoch 24


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.2022


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8593, validation loss: 0.5433
Epoch 25


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))


Training loss: 0.1878


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Validation accuracy: 0.8538, validation loss: 0.5626


In [9]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.928875, 0.2028727339744568)
Valid: (0.8538, 0.5626181404113769)


In [10]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [11]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=79.0), HTML(value='')))


Train: (0.135875, 4.399873812866211)
Valid: (0.2294, 3.0517888343811035)


In [12]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=313.0), HTML(value='')))

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
optimizer.param_groups[0]["lr"] = 1e-2

In [None]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

In [None]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
optimizer.param_groups[0]["lr"] = 1e-3

In [None]:
epochs = 25

for epoch in range(epochs):
    print("Epoch", epoch+1)
    loss_log = []
    train_epoch(loss_log)
    print(f"Training loss: {np.mean(loss_log):.4f}")
    optimizer.finish_epoch()
    val_acc, val_loss = validation(model, valid_loader)
    print(f"Validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}")
    print("Epoch", epoch+1, 
          f"Training loss: {np.mean(loss_log):.4f}, validation accuracy: {val_acc:.4f}, validation loss: {val_loss:.4f}",
          file=log_file, flush=True
         )

In [None]:
train_score = validation(model, train_loader)
valid_score = validation(model, valid_loader)
print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RNA"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
optimizer.param_groups[0]["method"] = "RRE"
print(optimizer.param_groups[0]["method"], flush=True, file=log_file)
model.cpu()
model_acc = deepcopy(model)
optimizer.accelerate()
optimizer.store_parameters([model_acc.parameters()])

In [None]:
model_acc.cuda()
train_score = validation(model_acc, train_loader)
valid_score = validation(model_acc, valid_loader)
model_acc.cpu()
model.cuda()

print("Train:", train_score)
print("Valid:", valid_score)
print("Train:", train_score, flush=True, file=log_file)
print("Valid:", valid_score, flush=True, file=log_file)

In [None]:
import sys
sys.exit(0)