Calculating distance transforms in 3D is computationally intensive. So it's a good idea to calculate weight maps beforehand so that they're not created on the fly with other augmentations. This helps bring training time down. 

In [1]:
import os
os.chdir("/Users/ctromans/image-analysis/UNet_3D_C_elegans/")
import pandas as pd
import skimage
import pathlib
from unet.utils.data_utils import calculate_weight_map
from unet.augmentations.augmentations import edges_and_centroids

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_data = pd.read_csv("./patch_data/load_data_training.csv")

save_dir = "./patch_data/weight_maps"

os.makedirs(save_dir, exist_ok=True)

weight_map_paths = []

for msk in load_data.iloc[:, 1].values:
    mask = skimage.io.imread(msk)
    # Convert mask to one-hot background, edges, and centroids
    converted_mask = edges_and_centroids(mask)
    converted_mask = calculate_weight_map(gt_array=converted_mask, labels=mask, centroid_class_index=2, edge_class_index=1, w0=10)

    filename = pathlib.Path(os.path.basename(msk))
    save_filename = os.path.join(save_dir, filename.stem + "_weight_map" + filename.suffix)

    skimage.io.imsave(save_filename, converted_mask, compression=('zlib', 1))

    weight_map_paths.append(save_filename)

load_data["weight_maps"] = weight_map_paths



In [4]:
load_data.to_csv("./patch_data/load_data_training.csv", index=False)