In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Model

In [None]:
#model
from src.models.resnets_eff import ResNet34

NUM_CLASSES = 10
model = ResNet34(num_classes=NUM_CLASSES).to(device)

PATH = 'models/model_2023-01-10 15:25:15.316290_step_0.pth'
model.load_state_dict(torch.load(PATH))

for name, p in model.named_parameters():
    p.requires_grad = False

# Early Exit Wrapper

In [None]:
from src.models.methods.sdn import SDN
LD = [64, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 512, 512, 512]

config_ee = {
    'confidence_threshold': 0.02,
    'is_model_frozen': True,
    'act_name': 'relu',
    'dropout_prob': 0.1,
    'reduction_layer_weight_std': 0.02,
    'device': device
}

ee_wrapper = SDN(model=model, layers_dim=LD, num_classes=NUM_CLASSES, config_ee=config_ee).to(device)

# Optimizer

In [None]:
from src.common.common import OPTIMIZER_NAME_MAP, LOSS_NAME_MAP

optim = OPTIMIZER_NAME_MAP['adamw'](filter(lambda p: p.requires_grad, ee_wrapper.parameters()), **{'lr': 0.01, 'weight_decay': 0.001})
lr_scheduler = None

# Dataset

In [None]:
from torch.utils.data import DataLoader
from src.data.datasets import get_cifar10


train_dataset, _, test_dataset = get_cifar10('data/')

BATCH_SIZE = 64

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=4)

loaders = {
    'train': train_loader,
    'test': test_loader
}

In [None]:
%tensorboard --logdir=reports/sdn

# Trainer

In [None]:
from src.trainer.trainer_ee import EarlyExitTrainer

params_trainer = {
    'ee_wrapper': ee_wrapper,
    'loaders': loaders,
    'optim': optim,
    'lr_scheduler': lr_scheduler,
}

trainer = EarlyExitTrainer(**params_trainer)

In [None]:
from src.common.utils import AttrDict

EXP_NAME = 'sdn'

config = {
    'epoch_start_at': 0,
    'epoch_end_at': 10,
    'grad_accum_steps': 1,
    'step_multi': 50,
    'whether_clip': False,
    'clip_value': 2.0,
    'base_path': 'reports',
    'exp_name': EXP_NAME,
    'logger_name': 'tensorboard',
    'logger_config': {'api_token': "07a2cd842a6d792d578f8e6c0978efeb8dcf7638", 'project': 'early_exit', 'hyperparameters': {}},
    'random_seed': 42,
    'device': device

}
config = AttrDict(config)

trainer.run_exp(config)

In [None]:
for name, p in ee_method.named_parameters():
    print(name, p.requires_grad)