[https://walkwithfastai.com/Segmentation]

In [1]:
import torch
from fastai.data.all import *
from fastai.vision.all import *
from patchify import patchify
from PIL import Image
import optuna
from optuna.integration import FastAIPruningCallback
import optuna.visualization as vs
import monai.losses as mdlss #(med loss)

# NON-OPTIMIZED HYPERPARAMS (cause my GPU can't handle it :P)
# batch size
BATCH_SZ = 8 # has to be small since it's a unet
# patch size
PATCH_SZ = 256 # must be a multiple of the image size, obviously. This is appropriate for BATCH_SZ.

# Set these low for testing hyper-optimizer setup. May be hyperparams later.
FREEZE_EPOCHS = 1
EPOCHS = 2
NUM_TRIALS = 2



### Pre-pipeline processing with patchify

In [2]:
# simpler implementation, more readable if perhaps less dry...
def extract_patches(full_arr, sz=PATCH_SZ):
    patch_list = []
    if len(full_arr.shape) < 3 : # account for no channel dim (masks)
        patch_arr = patchify(full_arr, (sz, sz), sz)
        for i in range(patch_arr.shape[0]):
            for j in range(patch_arr.shape[1]):
                patch_list.append(patch_arr[i,j,:,:])
    
    else:
        channels = full_arr.shape[-1]
        patch_arr = patchify(full_arr, (sz, sz, channels), sz)
        for i in range(patch_arr.shape[0]):
            for j in range(patch_arr.shape[1]):
                patch_list.append(patch_arr[i,j,:,:,:])


    return patch_list

In [3]:
#TODO - screw around calculating class imbalance

def get_all_patches(path):
    """
    Given an input data directory,
    returns a list with tuples of form (img_patch, msk_patch)
    """
    
    def get_arrays(path):
        #This NEEDS to be sorted or everything else will be messed up...
        paths = sorted(path.glob("*"))
        return [np.array(Image.open(path)) for path in paths]

    img_arrs = get_arrays(path/"images")
    msk_arrs = get_arrays(path/"targets")

    img_patches = [extract_patches(img_arr) for img_arr in img_arrs]
    msk_patches = [extract_patches(msk_arr) for msk_arr in msk_arrs]

    img_patches = [patch for patches in img_patches for patch in patches]
    msk_patches = [patch for patches in msk_patches for patch in patches]

    # extra processing-- not efficient, but necessary!
    img_patches = [patch.squeeze() for patch in img_patches]
    msk_patches = [patch.squeeze() for patch in msk_patches]

    all_patches = list(zip(img_patches, msk_patches))

    return all_patches

# don't optimize prematurely! >:3

def save_patches(patches, output_dir):
    """
    given patches and an output dir, save all patches
    """
    total = len(patches)
    for i, (img, msk) in enumerate(patches):
        Image.fromarray(img).save(output_dir/"images"/f"{i}.png")
        Image.fromarray(msk).save(output_dir/"targets"/f"{i}.png")
        print(f"Saved tuple {i}/{total}", end="\r", flush=True)
    print(end="\r", flush=True)


In [4]:

base_path = Path("../data/")
data_dir = base_path/"full"/"post-disaster"
patch_dir = base_path/f"{PATCH_SZ}_patches"


if not patch_dir.is_dir():
    # make necessary directories
    patch_dir.mkdir()
    (patch_dir/"images").mkdir()
    (patch_dir/"targets").mkdir()

    print("extracting patches...")
    patches = get_all_patches(data_dir)
    print("saving patches...")
    save_patches(patches, patch_dir)
    print("all patches saved!")
else:
    print("patches already extracted! skipping.")


patches already extracted! skipping.


### Make Optimizer

In [1]:
def objective(trial):

    #things the optimizer does...
    loss_fn = trial.suggest_categorical("loss_fn",
                                        [CrossEntropyLossFlat(axis=1),
                                         DiceLoss()]
                                       )
    pretrained = trial.suggest_categorical("pretrained", True, False)
    
    path = patch_dir
    codes = ["Background", "NoDamage", "MinorDamage", "MajorDamage", "Destroyed"]
    
    dls = SegmentationDataLoaders.from_label_func(path, bs=BATCH_SZ,
        fnames = get_image_files(path/"images"), 
        label_func = lambda o: path/"targets"/f"{o.stem}{o.suffix}",                                     
        codes = codes,
        # batch_tfms=[*aug_transforms(size=(360,480)), Normalize.from_stats(*imagenet_stats)]
        )

    learn = unet_learner(
        dls, 
        resnet18, 
        metrics=DiceMulti(axis=1),
        self_attention=True, 
        act_cls=Mish,
        loss_func = loss_fn_dict[loss_fn],
        pretrained=True,
        n_out = len(codes) # set codes implicitly later
    )

    model_cbs = [
    # EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2), # detect overfitting
    # EarlyStoppingCallback(monitor='train_loss', min_delta=0.1, patience=3), # decect stalled training
    # ActivationStats(with_hist=True)], # too slow
    FastAIPruningCallback(trial, monitor="dice_multi")
    # set this to `train_loss` to purposely overfit?
    # ! Optimizer may lose information on overfitting I need to look at... make sure to log everything.
    # TRY USING FP.16!!
    ]

    lr = learn.lr_find()
    lr = lr[0]; lr

    # See https://forums.fast.ai/t/how-to-diable-progress-bar-completely/65249/3
    # to disable progress bar and logging info.
    with learn.no_bar():
        with learn.no_logging():
            learn.fine_tune(epochs=EPOCHS,
                    base_lr=lr,
                    freeze_epochs=FREEZE_EPOCHS,
                    cbs=model_cbs
                   )

    return learn.recorder.metrics[0].value.item() # only one metric to worry about

### Optimize

In [None]:
study = optuna.create_study(direction="maximize") # use default pruner
study.optimize(objective, n_trials=NUM_TRIALS, timeout=600)

In [None]:
# print a bunch of junk

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))