# Setup WTS-Experiment

In [None]:
from pathlib import Path

In [None]:
# lets assume we have X mb for the whole memory
AVAILABLE_MEM_STORAGE = 10*1024*1024
BYTES_FOR_FLOAT = 4
BYTES_FOR_INT = 1
SH_FILE = Path('..', 'scripts', 'wts_core50_resnet34.sh')
LOG_DIR = '/home/marwei/Code/encodedgdumb/logs/'
DATA_DIR = '/home/marwei/pytorch/'
DATASET = 'CORe50'

SEEDS = [0, 1, 2]

Recall the output-sizes for the cifar-dataset
| cut...                    | output shape | output numel |
|---------------------------|--------------|--------------|
| original (0)              | 32x32x3      |  3072        |
| after Block 1             | 8x8x64       |  4096        |
| after Block 2             | 4x4x128      |  2048        |
| after Block 3             | 2x2x256      |  1024        |
| after Block 4             | 1x1x512      |   512        |


In [None]:
# encoding_block: output_numel
# 0: 3072 represents the case, where we skip the encoding and train the whole resnet
output_sizes = {0: 3072, 1: 4096, 2: 2048, 3: 1024}


match DATASET:
    case 'CIFAR10':
        num_classes_per_task = 2
        num_tasks = 5
        backbone = 'resnet'
        encoder = 'cutr'
    case 'CIFAR100':
        num_classes_per_task = 5
        num_tasks = 20
        backbone = 'resnet34'
        encoder = 'cutr34'
    case 'CORe50':
        num_classes_per_task = 2
        num_tasks = 5
        backbone = 'resnet34'
        encoder = 'cutr34'
    case _:
        raise ValueError(f'unknown dataset: {DATASET}')

In [None]:
shell_scripts = []

for this_seed in SEEDS:
    for this_block in output_sizes.keys():

        if this_block == 0:
            memory_slots = AVAILABLE_MEM_STORAGE // (output_sizes[this_block] * BYTES_FOR_INT)
            this_encoder = 'none'
            encoding_opts = ''
        else:
            memory_slots = AVAILABLE_MEM_STORAGE // (output_sizes[this_block] * BYTES_FOR_FLOAT)
            this_encoder = encoder
            encoding_opts = f"--encoding_block {this_block} "
        
        this_name = f"{DATASET}_m-{memory_slots}_{this_encoder}-{this_block}_none_{backbone}-{this_block}_s{this_seed}"
        
        this_shell_script = f"python3 src/main.py " \
            f"--dataset {DATASET} " \
            f"--num_classes_per_task {num_classes_per_task} " \
            f"--num_tasks {num_tasks} " \
            f"--seed {this_seed} " \
            f"--memory_size {memory_slots} " \
            f"--num_passes 128 " \
            f"--sampler greedy_sampler " \
            f"--encoder {this_encoder} " \
            f"{encoding_opts}" \
            f"--compressor none " \
            f"--backbone {backbone} " \
            f"--backbone_block {this_block} " \
            f"--data_dir {DATA_DIR} " \
            f"--log_dir {LOG_DIR} " \
            f"--exp_name \"{this_name}\""
        shell_scripts.append(this_shell_script)


In [None]:
assert not SH_FILE.exists()
with open (SH_FILE, 'w') as fp:
    fp.write('\n'.join(shell_scripts)+'\n')