In [None]:
import neurite_sandbox as nes
import numpy as np
from ionpy.experiment.util import fix_seed

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml synth_gen_cfg 

gen_opts:
    seed: 42
    # Controls the number of images to generate.
    num_to_gen: 6 
    # Controls the number of images to generate.
    img_res: 
        - 256 
        - 256
    # Controls the number of "regions" in the image, each is a label.
    num_labels_range: 
        - 10 
        - 10 
    # Controls the relative size of the regions in the image. Big = Bigger regions
    shapes_im_scales: 
        - 50 
    # Unclear what this does.
    shapes_warp_scales: 
        - 16 
        - 32 
        - 64 
        - 128
    warp_res: 
        - 8 
        - 16 
        - 32 
        - 64

augmentations:
    shapes_im_max_std_range: 
        - 2 
        - 2 
    shapes_warp_max_std_range: 
        - 8.0 
        - 8.0
    std_min_range: 
        - 0.1 
        - 0.1
    std_max_range: 
        - 0.2
        - 0.2 
    lab_int_interimage_std_range: 
        - 0.1
        - 0.1
    warp_std_range:
        - 8 
        - 8 
    bias_res_range: 
        - 32 
        - 32 
    bias_std_range: 
        - 0.1 
        - 0.1
    blur_std_range: 
        - 0.5 
        - 0.5

In [None]:
def perlin_generation(
    synth_cfg: dict
):
    gen_opts_cfg = synth_cfg['gen_opts']
    aug_cfg = synth_cfg['augmentations']

    fix_seed(gen_opts_cfg["seed"])

    # Gen parameters
    if gen_opts_cfg['num_labels_range'][0] == gen_opts_cfg['num_labels_range'][1]:
        num_labels = gen_opts_cfg['num_labels_range'][0]
    else:
        num_labels = np.random.randint(low=gen_opts_cfg['num_labels_range'][0], high=gen_opts_cfg['num_labels_range'][1])

    # Set the augmentation parameters.
    if aug_cfg['shapes_im_max_std_range'][0] == aug_cfg['shapes_im_max_std_range'][1]:
        shapes_im_max_std = aug_cfg['shapes_im_max_std_range'][0]
    else:
        shapes_im_max_std = np.random.uniform(aug_cfg['shapes_im_max_std_range'][0], aug_cfg['shapes_im_max_std_range'][1])
    
    if aug_cfg['shapes_warp_max_std_range'][0] == aug_cfg['shapes_warp_max_std_range'][1]:
        shapes_warp_max_std = aug_cfg['shapes_warp_max_std_range'][0]
    else:
        shapes_warp_max_std = np.random.uniform(aug_cfg['shapes_warp_max_std_range'][0], aug_cfg['shapes_warp_max_std_range'][1])
    
    if aug_cfg['std_min_range'][0] == aug_cfg['std_min_range'][1]:
        std_min = aug_cfg['std_min_range'][0]
    else:
        std_min = np.random.uniform(aug_cfg['std_min_range'][0], aug_cfg['std_min_range'][1])
        
    if aug_cfg['std_max_range'][0] == aug_cfg['std_max_range'][1]:
        std_max = aug_cfg['std_max_range'][0]
    else:
        std_max = np.random.uniform(aug_cfg['std_max_range'][0], aug_cfg['std_max_range'][1])

    if aug_cfg['lab_int_interimage_std_range'][0] == aug_cfg['lab_int_interimage_std_range'][1]:
        lab_int_interimage_std = aug_cfg['lab_int_interimage_std_range'][0]
    else:
        lab_int_interimage_std = np.random.uniform(aug_cfg['lab_int_interimage_std_range'][0], aug_cfg['lab_int_interimage_std_range'][1])

    if aug_cfg['warp_std_range'][0] == aug_cfg['warp_std_range'][1]:
        warp_std = aug_cfg['warp_std_range'][0]
    else:
        warp_std = np.random.uniform(aug_cfg['warp_std_range'][0], aug_cfg['warp_std_range'][1])

    if aug_cfg['bias_res_range'][0] == aug_cfg['bias_res_range'][1]:
        bias_res = aug_cfg['bias_res_range'][0]
    else:
        bias_res = np.random.uniform(aug_cfg['bias_res_range'][0], aug_cfg['bias_res_range'][1])

    if aug_cfg['bias_std_range'][0] == aug_cfg['bias_std_range'][1]:
        bias_std = aug_cfg['bias_std_range'][0]
    else:
        bias_std = np.random.uniform(aug_cfg['bias_std_range'][0], aug_cfg['bias_std_range'][1])

    if aug_cfg['blur_std_range'][0] == aug_cfg['blur_std_range'][1]:
        blur_std = aug_cfg['blur_std_range'][0]
    else:
        blur_std = np.random.uniform(aug_cfg['blur_std_range'][0], aug_cfg['blur_std_range'][1])

    # Gen tasks
    images, label_maps, _ = nes.tf.utils.synth.perlin_nshot_task(in_shape=gen_opts_cfg['img_res'],
                                                                  num_gen=gen_opts_cfg['num_to_gen'],
                                                                  num_label=num_labels,
                                                                  shapes_im_scales=gen_opts_cfg['shapes_im_scales'],
                                                                  shapes_warp_scales=gen_opts_cfg['shapes_warp_scales'],
                                                                  shapes_im_max_std=shapes_im_max_std,
                                                                  shapes_warp_max_std=shapes_warp_max_std,
                                                                  min_int=0,
                                                                  max_int=1,
                                                                  std_min=std_min,
                                                                  std_max=std_max,
                                                                  lab_int_interimage_std=lab_int_interimage_std,
                                                                  warp_std=warp_std,
                                                                  warp_res=gen_opts_cfg['warp_res'],
                                                                  bias_res=bias_res,
                                                                  bias_std=bias_std,
                                                                  blur_std=blur_std)
    
    return images, label_maps, _ 

In [None]:
images, label_maps, lab_og = perlin_generation(synth_gen_cfg)

In [None]:
img_tensor = np.stack(images)
lab_tensor = np.stack(label_maps).argmax(axis=-1)

In [None]:
import matplotlib.pyplot as plt

f, axarr = plt.subplots(2, 6, figsize=(24, 8))
for idx in range(6):
    im = axarr[0, idx].imshow(img_tensor[idx], cmap='gray', interpolation='none')
    lab = axarr[1, idx].imshow(lab_tensor[idx], cmap='tab10', interpolation='none')
    f.colorbar(im, ax=axarr[0, idx], shrink=0.6)
    f.colorbar(lab, ax=axarr[1, idx], shrink=0.6)
    # Turn off axis lines and labels.
    axarr[0, idx].axis('off')
    axarr[1, idx].axis('off')
plt.show()