In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

In [None]:
import ttab.configs.utils as configs_utils
import ttab.loads.define_dataset as define_dataset
from ttab.benchmark import Benchmark
from ttab.loads.define_model import define_model, load_pretrained_model
from ttab.model_adaptation import get_model_adaptation_method
from ttab.model_selection import get_model_selection_method
from argparse import Namespace

In [None]:
cfg = Namespace(
    job_name="rotta",
    job_id=None,
    timestamp=None,
    root_path="../logs",
    data_path="../datasets",
    ckpt_path="../resnet26_bn_ssh_cifar10.pth",  # path to the pretrained model, can be downloaded from Google Drive
    seed=2022,
    device="cuda:0",
    num_cpus=2,
    model_adaptation_method="rotta",
    model_selection_method="last_iterate",
    task="classification",
    test_scenario="PTTA_cifar10",
    base_data_name="cifar10",
    data_wise="batch_wise",
    batch_size=64,
    lr=0.001,
    n_train_steps=1,
    offline_pre_adapt=False,
    episodic=False,
    domain_sampling_name="uniform",
    domain_sampling_ratio=1.0,
    non_iid_pattern="class_wise_over_domain",
    non_iid_ness=0.1,
    label_shift_param=None,
    data_size=None,
    fishers=False,
    record_preadapted_perf=False,
    grad_checkpoint=False,
    debug=False,
    data_names='',
    entry_of_shared_layers=None,
    group_norm_num_groups=None
)

In [None]:
config, scenario = configs_utils.config_hparams(config=cfg)

# Dataset
test_data_cls = define_dataset.ConstructTestDataset(config=config)
test_loader = test_data_cls.construct_test_loader(scenario=scenario)

# Base model.
model = define_model(config=config)
load_pretrained_model(config=config, model=model)

# Algorithms.
model_adaptation_cls = get_model_adaptation_method(
    adaptation_name=scenario.model_adaptation_method
)(meta_conf=config, model=model)
model_selection_cls = get_model_selection_method(selection_name=scenario.model_selection_method)(
    meta_conf=config, model_adaptation_method=model_adaptation_cls
)

# Evaluate.
benchmark = Benchmark(
    scenario=scenario,
    model_adaptation_cls=model_adaptation_cls,
    model_selection_cls=model_selection_cls,
    test_loader=test_loader,
    meta_conf=config,
)
benchmark.eval()