In [1]:
import numpy as np
from fibsem_tools.io import read_xarray
from napari import Viewer
from xarray_ome_ngff import read_multiscale_group
from tifffile import imwrite
from skimage.transform import resize
import glob
from pathlib import Path
import csv
import os
import re
from tifffile import imread, imwrite

### collect 3D crop from cellmap annotated data

In [2]:
#go through all metadata and find closest res
#assuming data is saved as z, y, x
def find_closest_res_match_x_y(em_parent_path, resolution_2D):
    res_dict = {}
    em_parent_path = Path(em_parent_path)
    em_xarray = read_xarray(em_parent_path, storage_options={'anon': True})
    for dict in em_xarray.attrs["multiscales"][0]["datasets"]:
        resolution_level = dict["path"]
        scale = dict["coordinateTransformations"][0]["scale"]
        scale_2D = (scale[1], scale[2])
        res_diff = abs(np.array(resolution_2D) - scale_2D)
        res_dict[resolution_level] = res_diff
    smallest_diff = min(res_dict, key=lambda k: res_dict[k].sum())
    return smallest_diff

In [3]:
def get_em_crops_of_specific_resolution_2d(resolution, base_path, em_out_path, gt_out_path, organelle, count = 1):
    em_parent_path = f'{base_path}/em/fibsem-uint8/'
    print(em_parent_path)
    #find closest match in resolution and create desired path to em
    resolution_s = find_closest_res_match_x_y(em_parent_path, resolution)
    em_path = f'{base_path}/em/fibsem-uint8/{resolution_s}'
    
    # Read the EM data as an xarray DataArray
    em_parent = read_xarray(em_parent_path, storage_options={'anon': True})
    em_data = read_xarray(em_path, storage_options={'anon': True})
    
    # get ground truth
    path_to_gt = f'{base_path}/labels/groundtruth/'
    path = Path(path_to_gt)
    subfolders = [p for p in path.iterdir() if p.is_dir()]
    sorted_subfolders = sorted(subfolders, key=lambda x: int(re.findall(r'\d+', str(x))[-1]))
    print(path_to_gt)

    
    # iterate through gt crops
    crop_dict = {}
    for crop_path in sorted_subfolders:
        crop_number = str(crop_path).split("/")[-1]
        print(f'processing {crop_number}')
        labels_parent_path = f'{crop_path}/{organelle}/'
        labels_path = f'{crop_path}/{organelle}/{resolution_s}'
    
        # get metadata info (except if there is no gt for the organelle of choice)
        try:
            labels_parent = read_xarray(labels_parent_path, storage_options={'anon': True})
        except:
            print(f'{organelle} not in ground truth for this dataset')
            continue
        res_index = int(resolution_s.split("s")[-1])
        print(f"{res_index=}")
        crop_offset_world = np.array(labels_parent.attrs["multiscales"][0]["datasets"][res_index]['coordinateTransformations'][1]['translation'])
        em_offset_world = np.array(em_parent.attrs["multiscales"][0]["datasets"][res_index]['coordinateTransformations'][1]['translation'])
        crop_resolution = np.array(labels_parent.attrs["multiscales"][0]["datasets"][res_index]['coordinateTransformations'][0]['scale'])
        em_resolution = np.array(em_parent.attrs["multiscales"][0]["datasets"][res_index]['coordinateTransformations'][0]['scale'])
        em_resolution_str = "[" + ",".join(map(str, em_resolution)) + "]"
        # Get label data and shape
        labels_data = read_xarray(labels_path, storage_options={'anon': True})
        labels_shape = labels_data.shape  # e.g., (200, 200, 200)

        # check if there are any labels of interest in this crop
        labels = np.unique(labels_data)
        if len(list(labels)) <= 1 and np.unique(labels_data) == 0:
            print(f"{crop_number} has no nuclear annotation")
            print(count)
            continue
        print(count)
        #calculate difference in resolution between EM iage and annotations
        ratio_resolution = crop_resolution/em_resolution
            
        # relative_offset_world = crop_offset_world #- em_offset_world
        offset_voxels = np.round(crop_offset_world / em_resolution).astype(int)
        
        # starting point of the crop
        z0, y0, x0 = map(int, offset_voxels)
        #size of the crop (resolution difference taken into account)
        dz, dy, dx = map(int,(labels_shape * ratio_resolution))
        #ez, ey, ex = em_data.shape
        
        # Extract matching crop from EM starting from top left
        em_crop = em_data[z0:z0+dz, y0:y0+dy, x0:x0+dx]
    
        #if resolution is different, scale up gt data to  match em data
        if not em_crop.shape[0]*em_crop.shape[1]*em_crop.shape[2] > 32000000000:
            print("smaller than 32GB --> processing image crop")
            if (ratio_resolution != np.zeros_like(ratio_resolution)).all():
                print("different resolutions - rescaling gt data")
                resized_mask = resize(
                labels_data,
                output_shape=em_crop.shape,  # height, width
                order=0,  # nearest-neighbor for masks
                preserve_range=True,
                anti_aliasing=False
                ).astype(labels_data.dtype)
    
            else:
                print("same resolution")
        
            # make sure the mask is binary
            resized_mask[resized_mask > 0] = 1
            print(type(em_crop))
            print(em_crop.dtype)
            print(em_crop.shape)
            
            #save EM crop and annotation as tiff in seperate folders
            imwrite(f'{em_out_path}/em_image_sample{count}.tif' , em_crop)
            
            imwrite(f'{gt_out_path}/ground_truth_sample{count}.tif', resized_mask)
            crop_dict[count] = (crop_path, resolution_s, em_resolution_str)
            count += 1

    return count, crop_dict

In [6]:
#define resolution level and path to EM data
#resolution_s = 2
organelle = 'nuc'
em_out_path = '/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/em_161225_s_adapted_13_4nm/'
gt_out_path = '/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/gt_161225_s_adapted_13_4nm/'
if not os.path.exists(em_out_path):
    os.mkdir(em_out_path)
    os.mkdir(gt_out_path)
    
path_to_data = '/Users/gloof/Desktop/code/cellmap-segmentation-challenge/data/'
path = Path(path_to_data)
subfolders = [p for p in path.iterdir() if p.is_dir()]

csv_file = '/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/crop_summary_161225_s_adapted_13_4.csv'
fieldnames = ['count', 'path_to_crop', 'resolution_level', 'scale']

# Write header only if file doesn't exist
if not os.path.exists(csv_file):
    with open(csv_file, mode='w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

resolution = (13.4,13.4)   # x and y      
count = 1
for folder in subfolders:
    name = str(folder).split('/')[-1]
    base_path = f'{path_to_data}{name}/{name}.zarr/recon-1'
    count, crop_dict = get_em_crops_of_specific_resolution_2d(resolution, base_path, em_out_path, gt_out_path, organelle, count)
    with open(csv_file, mode='a', newline='') as f:
        for key in crop_dict.keys():
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writerow({'count': key, 'path_to_crop': crop_dict[key][0], "resolution_level": crop_dict[key][1], "scale": crop_dict[key][2]})

/Users/gloof/Desktop/code/cellmap-segmentation-challenge/data/jrc_hela-3/jrc_hela-3.zarr/recon-1/em/fibsem-uint8/
/Users/gloof/Desktop/code/cellmap-segmentation-challenge/data/jrc_hela-3/jrc_hela-3.zarr/recon-1/labels/groundtruth/
processing crop27
res_index=2
1
smaller than 32GB --> processing image crop
different resolutions - rescaling gt data
<class 'xarray.core.dataarray.DataArray'>
uint8
(50, 50, 50)
processing crop33
res_index=2
2
smaller than 32GB --> processing image crop
different resolutions - rescaling gt data
<class 'xarray.core.dataarray.DataArray'>
uint8
(50, 50, 50)
processing crop34
res_index=2
3
smaller than 32GB --> processing image crop
different resolutions - rescaling gt data
<class 'xarray.core.dataarray.DataArray'>
uint8
(50, 50, 50)
processing crop50
res_index=2
crop50 has no nuclear annotation
4
processing crop51
res_index=2
crop51 has no nuclear annotation
4
processing crop60
res_index=2
crop60 has no nuclear annotation
4
processing crop61
res_index=2
crop61 

### split 3D crops and annotations into 2D images for training

In [7]:
def split_into_planes(input_path):
    """
    Loads a 3D .tif and returns list of 2D planes.
    """
    volume = imread(input_path)
    if volume.ndim != 3:
        raise ValueError("The input image is not 3D (Z, Y, X).")
    return [volume[z, :, :] for z in range(volume.shape[0])]


In [8]:
count = 1

em_dir = "/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/em_161225_s_adapted_13_4nm/"
gt_dir = "/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/gt_161225_s_adapted_13_4nm/"
em_out_dir = "/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/em_2D_161225_s_adapted_13_4nm/"
gt_out_dir = "/Users/gloof/Desktop/data/cellmap_2d_training_data_nuc/gt_2D_161225_s_adapted_13_4nm/"

os.makedirs(em_out_dir, exist_ok=True)
os.makedirs(gt_out_dir, exist_ok=True)

for em_file in glob.glob(os.path.join(em_dir, "*.tif")):
    name = em_file.split("_")[-1].split(".")[0]
    print(f"Processing {name=}")

    gt_files = glob.glob(os.path.join(gt_dir, f"*{name}.tif"))
    if not gt_files:
        print(f"No GT match for {name}, skipping.")
        continue

    gt_file = gt_files[0]

    em_planes = split_into_planes(em_file)
    gt_planes = split_into_planes(gt_file)

    if len(em_planes) != len(gt_planes):
        print(f"Mismatch in planes for {name}: EM={len(em_planes)}, GT={len(gt_planes)}. Skipping.")
        continue

    for em_plane, gt_plane in zip(em_planes, gt_planes):
        if np.all(em_plane == 0) or np.all(gt_plane == 0):
            print(f"Skipping image_{count}.tif (all-zero detected)")
        else:
            em_out_path = os.path.join(em_out_dir, f"image_{count}.tif")
            gt_out_path = os.path.join(gt_out_dir, f"image_{count}.tif")
            imwrite(em_out_path, em_plane)
            imwrite(gt_out_path, gt_plane)
        count += 1

print("Done.")


Processing name='sample38'
Skipping image_1.tif (all-zero detected)
Skipping image_2.tif (all-zero detected)
Skipping image_3.tif (all-zero detected)
Skipping image_4.tif (all-zero detected)
Skipping image_5.tif (all-zero detected)
Skipping image_6.tif (all-zero detected)
Skipping image_7.tif (all-zero detected)
Skipping image_8.tif (all-zero detected)
Skipping image_9.tif (all-zero detected)
Skipping image_10.tif (all-zero detected)
Skipping image_11.tif (all-zero detected)
Skipping image_12.tif (all-zero detected)
Skipping image_13.tif (all-zero detected)
Skipping image_14.tif (all-zero detected)
Skipping image_15.tif (all-zero detected)
Skipping image_16.tif (all-zero detected)
Skipping image_17.tif (all-zero detected)
Skipping image_18.tif (all-zero detected)
Processing name='sample10'
Skipping image_387.tif (all-zero detected)
Skipping image_388.tif (all-zero detected)
Skipping image_389.tif (all-zero detected)
Skipping image_390.tif (all-zero detected)
Skipping image_391.tif (all