In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from accelerate import Accelerator
from transformers import get_cosine_schedule_with_warmup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
EPOCHS = 5
BATCH_SIZE = 32
GRAD_ACCUM_STEPS = 128 // BATCH_SIZE

In [None]:
from datasets import get_cifar10
from torch.utils.data import DataLoader
train_set, val_set, test_set = get_cifar10()

train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, num_workers=8)
val_loader = DataLoader(dataset=val_set, shuffle=False, batch_size=BATCH_SIZE, pin_memory=True, num_workers=8)
test_loader = DataLoader(dataset=test_set, shuffle=False, batch_size=BATCH_SIZE, pin_memory=True, num_workers=8)

NUM_TRAINING_STEPS = (len(train_loader) // GRAD_ACCUM_STEPS) * EPOCHS

In [None]:
from trainer import EarlyExitTrainer
from pee_method import PEE
from resnets import ResNet18

In [None]:
backbone_model = ResNet18(10)
ee_method = PEE(backbone_model, [64, 64, 128, 128, 256, 256, 512, 512], 0.2, device).to(device)

In [None]:
import numpy as np

fg = backbone_model.forward_generator(torch.randn(1,3,32,32).to(device))
x = None

x = fg.send(x)
print(x.shape, np.prod(x.shape))

backbone_model.adjust_repr(x).shape

In [None]:
from utils import configure_optimizer

accelerator = Accelerator()
criterion = torch.nn.CrossEntropyLoss().to(device)
optim = torch.optim.AdamW(ee_method.parameters(), **{'lr': 0.05, 'weight_decay': 0.001})
lr_scheduler = None
# optim = configure_optimizer(optim_wrapper, backbone_model, ee_method, lr_backbone=5e-4, lr_head=5e-3, weight_decay=1e-3)
# lr_scheduler = get_cosine_schedule_with_warmup(
#             optimizer=optim,
#             num_cycles=EPOCHS,
#             num_warmup_steps=int(0.2 * NUM_TRAINING_STEPS),
#             num_training_steps=NUM_TRAINING_STEPS)

In [None]:
train_loader, val_loader, test_loader, ee_method, optim, lr_scheduler = accelerator.prepare(
        train_loader, val_loader, test_loader, ee_method, optim, lr_scheduler)

loaders = {'train': train_loader, 'test': test_loader, 'val': val_loader}

In [None]:
args_trainer = {
    'ee_method': ee_method,
    'criterion': criterion,
    'optim': optim,
    'accelerator': accelerator,
    'lr_scheduler': lr_scheduler,
    'loaders': loaders,
    'device': device
}

trainer = EarlyExitTrainer(**args_trainer)

In [None]:
import collections
config_run_epoch = collections.namedtuple('RE', ['save_interval', 'grad_accum_steps', 'running_step_mult'])(110000,
                                                                                                       GRAD_ACCUM_STEPS,
                                                                                                       4)
params_run = {
    'epoch_start': 0,
    'epoch_end': EPOCHS,
    'exp_name': f'gpee',
    'config_run_epoch': config_run_epoch,
    'random_seed': 42
}
trainer.run_exp(**params_run)