In [1]:
import celldetection as cd
import numpy as np



## Define your training schedule

In [2]:
schedule = cd.Schedule(
    base_lr=dict(loc=.0008, scale=.0001),
    base_lr_sync=True,
    base_bs=16,
    base_lr_scale='sqrt',  # 'linear' or 'sqrt'
    batch_size=10,
    lr=None,  # set None to use dynamic lr

    # CPN
    checkpoint=None,
    in_channels=3,
    model='CpnResNeXt101UNet',
    pretrained=False,
    encoder_kwargs={},
    backbone_kwargs=dict(pyr_pool=dict(out_channels=256, scales=[1, 2, 3, 6], method='ppm')),
    sync_batch_norm=True,
    bg_fg_dists=[(0.75, 0.75)],
    order=25,
    nms_thresh=np.pi * .1,
    score_thresh=.8,
    certainty_thresh=None,
    samples=224,
    refinement=True,
    refinement_iterations=4,
    refinement_margin=3.,
    refinement_buckets=6,
    contour_features='1',
    refinement_features='0',
    contour_head_stride=1,
    order_weights=True,
    refinement_head_stride=1,
    refinement_interpolation='bilinear',
    uncertainty_head=True,
    uncertainty_factor=7.,
    uncertainty_nms=True,
    decoder_activation='LeakyReLU',
    decoder_group_norm=True,
    decoder_block='ResBlock',

    # Losses
    loss_fourier='l1',
    loss_regression='l1',
    loss_iou='giou',
    loss_classification='bce',

    # Data & Augmentation
    pseudo_labels=None,  # e.g. 'inputs/inputs/pseudo_labels0/*.h5'
    pseudo_labels_num=None,
    aug_plan='mild',
    dyn_resize_min=4,
    dyn_resize_max=128 + 64,
    dyn_resize_aspect_std=.05,
    data_norm_method='prov',  # 'prov', 'cstm', 'rand-mix'
    classes=1,
    sampler_seed=None,  # None to pick random seed
    sampler_seed_sync=True,

    # Training
    epochs=600,  # 256 + 128
    crop_size=[(512, 512)],
    image_std=[[0.229, 0.224, 0.225]],
    image_mean=[[0.485, 0.456, 0.406]],
    amp=False,
    prefetch_factor=4,  # only has effect if num_workers > 0
    pin_memory=True,
    shuffle=True,
    show_gpu_stats=True,
    show_progress=True,
    save_frequency=15,
    save_min_epoch=30,
    writer_kwargs=None,  # set e.g. dict(log_dir='tb') to enable
    neurips_reps=5,

    # Optimizer & Scheduler
    optimizer={'Adam': {'betas': (0.9, 0.999), 'weight_decay': 0.00002}},
    scheduler={'WarmupMultiStepLR': dict(warmup_factor=.001, warmup_steps=1000, warmup_method='linear', gamma=.666)},
    scheduler_on_step=True,
    scheduler_milestones_as_fractions=([0.1, 0.2, 0.3, 0.4, 0.6, 0.8],),
)
schedule

Schedule(
  (base_lr): ({'loc': 0.0008, 'scale': 0.0001},)
  (base_lr_sync): (True,)
  (base_bs): (16,)
  (base_lr_scale): ('sqrt',)
  (batch_size): (10,)
  (lr): (None,)
  (checkpoint): (None,)
  (in_channels): (3,)
  (model): ('CpnResNeXt101UNet',)
  (pretrained): (False,)
  (encoder_kwargs): ({},)
  (backbone_kwargs): ({'pyr_pool': {'out_channels': 256, 'scales': [1, 2, 3, 6], 'method': 'ppm'}},)
  (sync_batch_norm): (True,)
  (bg_fg_dists): [(0.75, 0.75)]
  (order): (25,)
  (nms_thresh): (0.3141592653589793,)
  (score_thresh): (0.8,)
  (certainty_thresh): (None,)
  (samples): (224,)
  (refinement): (True,)
  (refinement_iterations): (4,)
  (refinement_margin): (3.0,)
  (refinement_buckets): (6,)
  (contour_features): ('1',)
  (refinement_features): ('0',)
  (contour_head_stride): (1,)
  (order_weights): (True,)
  (refinement_head_stride): (1,)
  (refinement_interpolation): ('bilinear',)
  (uncertainty_head): (True,)
  (uncertainty_factor): (7.0,)
  (uncertainty_nms): (True,)
  (dec

## Save it to disk

In [3]:
schedule.to_json('schedule.json')

## Inspect all configs defined by the schedule

In [4]:
for config in schedule:
    print(config, end='\n' + '=' * 115 + '\n')

Config(
  (amp): False
  (aug_plan): 'mild'
  (backbone_kwargs): {'pyr_pool': {'out_channels': 256, 'scales': [1, 2, 3, 6], 'method': 'ppm'}}
  (base_bs): 16
  (base_lr): {'loc': 0.0008, 'scale': 0.0001}
  (base_lr_scale): 'sqrt'
  (base_lr_sync): True
  (batch_size): 10
  (bg_fg_dists): (0.75, 0.75)
  (certainty_thresh): None
  (checkpoint): None
  (classes): 1
  (contour_features): '1'
  (contour_head_stride): 1
  (crop_size): (512, 512)
  (data_norm_method): 'prov'
  (decoder_activation): 'LeakyReLU'
  (decoder_block): 'ResBlock'
  (decoder_group_norm): True
  (dyn_resize_aspect_std): 0.05
  (dyn_resize_max): 192
  (dyn_resize_min): 4
  (encoder_kwargs): {}
  (epochs): 600
  (image_mean): [0.485, 0.456, 0.406]
  (image_std): [0.229, 0.224, 0.225]
  (in_channels): 3
  (loss_classification): 'bce'
  (loss_fourier): 'l1'
  (loss_iou): 'giou'
  (loss_regression): 'l1'
  (lr): None
  (model): 'CpnResNeXt101UNet'
  (neurips_reps): 5
  (nms_thresh): 0.3141592653589793
  (optimizer): {'Adam