# Training

For training, we will store the relative path information for each image and its masks within a dataframe with the following structure:

|   images                                                                    |   masks                                               |   train  |
|-----------------------------------------------------------------------------|-------------------------------------------------------|----------|
|   ./data/images_stacked_channels/mx85-nd-acqusition-0-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-0-ground-truth.tif  |   TRUE   |
|   ./data/images_stacked_channels/mx85-nd-acqusition-1-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-1-ground-truth.tif  |   TRUE   |
|   ./data/images_stacked_channels/mx85-nd-acqusition-8-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-8-ground-truth.tif  |   FALSE  |

In [1]:
import os
os.chdir("../") # Root of repo
import unet.utils.data_utils as utils
from unet.utils.load_data import CElegansDataset
from unet.networks.unet3d import UNet3D
from unet.networks.unet3d import SingleConv
import sklearn.model_selection
from torch.utils.data import DataLoader
from unet.utils.loss import WeightedBCELoss, WeightedBCEDiceLoss
import unet.augmentations.augmentations as aug
import pandas as pd
import torch
from unet.utils.trainer import RunTraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set parameters
# Here we can easily change data parameters and model hyperparameters

params = {
    "Normalize": {"per_channel": True},
    "RandomContrastBrightness": {"p": 0.5},
    "Flip": {"p": 0.5},
    "RandomRot90": {"p": 0.5, "channel_axis": 0},
    "RandomGuassianBlur": {"p": 0.5},
    "RandomGaussianNoise": {"p": 0.5},
    "RandomPoissonNoise": {"p": 0.5},
    "ElasticDeform": {"sigma":10, "p":0.5, "channel_axis": 0, "mode":"mirror"},
    "LabelsToEdges": {"connectivity": 2, "mode":"thick"},
    "EdgeMaskWmap": {"edge_multiplier":2, "wmap_multiplier":1, "invert_wmap":True},
    "BlurMasks": {"sigma": 2},
    "ToTensor": {},
    "batch_size": 1,
    "epochs": 100,
    "val_split": 0.2,
    "patch_size": (24, 200, 200),
    "create_wmap": True, ##
    "lr": 1e-2,
    "weight_decay": 1e-5,
    "in_channels": 2,
    "out_channels": 1,
    "scheduler_factor": 0.2,
    "scheduler_patience": 20,
    "scheduler_mode": "min",
    "loss_function": WeightedBCEDiceLoss,
    # "targets": [["image"], ["mask"]]
    "targets": [["image"], ["mask"], ["weight_map"]]
}

In [None]:
source_data = pd.read_csv("./data/data_stacked_channels_training.csv")[0:2]

# Create patches of the dataset
# This is performed prior to training since weight-map generation online at train time is 
# computationally slow
utils.create_patch_dataset(source_data, patch_size=params["patch_size"], create_wmap=params["create_wmap"])

# Load the patch dataframe
training_data = pd.read_csv("training_data.csv")

# Create the train/val split
train_dataset, val_dataset = sklearn.model_selection.train_test_split(
        training_data, test_size=params["val_split"]
        )

# Define the augmentations for training and validation
train_transforms = [
    aug.Normalize(**params["Normalize"]),
    aug.RandomContrastBrightness(**params["RandomContrastBrightness"]),
    aug.Flip(**params["Flip"]),
    aug.RandomRot90(**params["RandomRot90"]),
    aug.RandomGuassianBlur(**params["RandomGuassianBlur"]),
    aug.RandomGaussianNoise(**params["RandomGaussianNoise"]),
    aug.RandomPoissonNoise(**params["RandomPoissonNoise"]),
    aug.ElasticDeform(**params["ElasticDeform"]),
    aug.LabelsToEdges(**params["LabelsToEdges"]),
    aug.EdgeMaskWmap(**params["EdgeMaskWmap"]),
    aug.BlurMasks(**params["BlurMasks"]),
    aug.ToTensor()
]
val_transforms = [
    aug.Normalize(**params["Normalize"]),
    aug.LabelsToEdges(**params["LabelsToEdges"]),
    aug.EdgeMaskWmap(**params["EdgeMaskWmap"]),
    aug.BlurMasks(**params["BlurMasks"]),
    aug.ToTensor()
]

train_ds = CElegansDataset(data_csv=train_dataset, transforms=train_transforms, targets=params["targets"], train_val="train")

val_ds = CElegansDataset(data_csv=val_dataset, transforms=val_transforms, targets=params["targets"], train_val="val")

if torch.cuda.is_available():
    # Find fastest conv
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Create the train and validation data loaders
train_loader = DataLoader(
    train_ds,
    batch_size=params["batch_size"],
    shuffle=True,
    pin_memory=True if device == "cuda" else False,
    num_workers=1,
)

# Don't shuffle validation so you can see how predictions improve over time
val_loader = DataLoader(
    val_ds,
    batch_size=params["batch_size"],
    shuffle=False,
    pin_memory=True if device == "cuda" else False,
    num_workers=1,
)

data_loader = {"train": train_loader, "val": val_loader}

# Load the model
model = UNet3D(
    in_channels=params["in_channels"], out_channels=params["out_channels"], f_maps=32
)

model = utils.load_weights(
    model, 
    weights_path="./data/pretrained_model/best_checkpoint.pytorch", 
    device="cpu", # Load to CPU and convert to GPU later
    dict_key="state_dict"
)

# If you are intending to change the number of input or output channels, or wish to
# freeze the UNet encoder/decoder
model = utils.set_parameter_requires_grad(model, trainable=True, trainable_layer_name=None) # could be eg. "encoder"

# Change the number of input channels
model.encoders[0].basic_module.SingleConv1 = SingleConv(params["in_channels"], 16)

params_to_update = utils.find_parameter_requires_grad(model)

# Send the model to the device
model.to(device)

# Get the loss function, optimizer and scheduler
loss_function = params["loss_function"]()
optimizer = torch.optim.Adam(params_to_update, lr=params["lr"], weight_decay=params["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode=params["scheduler_mode"], factor=params["scheduler_factor"], patience=params["scheduler_patience"]
)

# Instantiate the trainer class
trainer = RunTraining(
    model,
    device,
    data_loader,
    loss_function,
    optimizer,
    scheduler,
    num_epochs=params["epochs"],
)

# Run training/validation
trainer.fit()

# Inference

For running inference we rely on Monai's sliding_window_inference, which enables us to use a model that has been trained with patches to generate predictions for a larger image.

The dataframe should be like this:

|   images                                                                    |   masks                                               |   train  |
|-----------------------------------------------------------------------------|-------------------------------------------------------|----------|
|   ./data/images_stacked_channels/mx85-nd-acqusition-x-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-x-ground-truth.tif  |   TRUE   |
|   ./data/images_stacked_channels/mx85-nd-acqusition-y-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-y-ground-truth.tif  |   TRUE   |
|   ./data/images_stacked_channels/mx85-nd-acqusition-z-stacked-channels.tif  |   ./data/masks/mx85-nd-acqusition-z-ground-truth.tif  |   FALSE  |
|   ./data/images_stacked_channels/mx85-nd-acqusition-a-stacked-channels.tif  |                                                       |          |
|   ./data/images_stacked_channels/mx85-nd-acqusition-b-stacked-channels.tif  |                                                       |          |

Only images marked with train FALSE will be used for inference. If an image has an associated mask file, inference statistics will be calculated (eg. F1-IoU at multiple thresholds).

Below, `predict_from_csv` will get the edge prediction and then perform instance segmentation. The output images will be saved into a folder named `output`. Paths to these images will also be saved in `./output/inference_data.csv`, along with performance test statistics if required.

In [None]:
load_data_inference = pd.read_csv("data/data_test_stacked_channels.csv")

model = UNet3D(
    in_channels=2, out_channels=1, f_maps=32
)

model = utils.load_weights(
    model, 
    weights_path="best_checkpoint.pytorch", 
    device="cpu", # Load to CPU and convert to GPU later
    dict_key="state_dict"
)

model.to("cuda")

# Instantiate the inferer class
infer = Inferer(
    model=model, 
    patch_size=params["patch_size"],
    )

# Run inference on csv images
infer.predict_from_csv(load_data_inference)