In [None]:
import torch
import os
import FISHPainter
import numpy as np
from pathlib import Path 
from FISHPainter.src.preprocess import get_cell_background
from FISHPainter.src.signals import create_FISH
from FISHPainter.src.utils import create_dataset
from CellPatchExtraction import extract_patches

#debug
from cellplot.patches import gridPlot, draw_boxes_on_patch

FISHPainter.__version__

In [None]:
background = get_cell_background("./testdata/IF_RGB.TIFF", normalize=False)[:700, :700]
print(background.shape)

In [None]:
return_dict = extract_patches(background, "CP_TU", patch_size=128, return_all=True, device=torch.device("cuda:3"))
return_dict.keys()

In [None]:
patches = return_dict["image_patches"]
masks = return_dict["mask_patches"]

len(patches), len(masks)

In [None]:
patches_with_boxes = []
bbox_dataset = []
for patch, mask in zip(patches, masks):
    n_green = np.random.randint(2, 8, 1)[0]
    n_green_cluster = np.random.choice([0, 1, 2, 3], p=[0.7, 0.15, 0.1, 0.05])
    fish_dict = create_FISH(patch.copy(), mask.copy(), num_red=2, num_green=n_green, num_green_cluster=n_green_cluster, green_cluster_size=4, return_as_dict=True)
    bbox_img = draw_boxes_on_patch(fish_dict["patch"].copy(), fish_dict["bboxes"], fish_dict["labels"])
    patches_with_boxes.append(bbox_img)
    bbox_dataset.append(fish_dict)
    
gridPlot(patches_with_boxes, grid_size=(8, 8), plot_size=(12, 12), hspace=0.05, vspace=0.05)

In [None]:
#create whole dataset
background = get_cell_background("./testdata/IF_RGB.TIFF", normalize=False)[:2000, :2000]

In [None]:
return_dict = extract_patches(background, "CP_TU", patch_size=128, return_all=True, device=torch.device("cuda:3"))
return_dict.keys()

In [None]:
patches = return_dict["image_patches"]
masks = return_dict["mask_patches"]

len(patches), len(masks)

In [101]:
from tqdm import tqdm

def create_fish_data(patches, masks, total_iterations):
    
    created = {
        "patches": [],
        "labels": [],
        "bboxes": []
    }
    
    with tqdm(total=total_iterations) as pbar:
        for patch, mask in zip(patches, masks):
            n_green = np.random.randint(2, 10, 1)[0]
            n_green_cluster = np.random.choice([0, 1, 2, 3], p=[0.5, 0.25, 0.15, 0.1])
            signal_size = np.random.uniform(1.5, 2.5)
            FISH_dict = create_FISH(patch.copy(), mask.copy(), num_red=2, num_green=n_green, num_green_cluster=n_green_cluster, signal_size=signal_size, return_as_dict=True)
            created["patches"].append(FISH_dict["patch"])
            created["labels"].append(FISH_dict["labels"])
            created["bboxes"].append(FISH_dict["bboxes"])
            
            pbar.update(1)  # Update the progress bar

            if pbar.n == total_iterations:
                return created  # Exit the function (and all loops) immediately

            pbar.set_description(f"Processing {len(created)}/{total_iterations}")

    return created  # Return created list if the loop completes without meeting the condition

created = create_fish_data(patches, masks, total_iterations=250)

Processing 3/250: 100%|██████████| 250/250 [00:09<00:00, 25.70it/s]


In [102]:
out_dir = Path("/home/simon_g/src/FISH-Painter_package/created")

if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    
create_dataset(dataset=created, filepath=out_dir.joinpath("created_dataset.h5"))

dict_keys(['patches', 'labels', 'bboxes'])


In [103]:
#how to read the dataset
import h5py

patches, labels, bboxes = [], [], []
with h5py.File(out_dir.joinpath("created_dataset.h5"), 'r') as f:
    
    print(len(f.keys()))
    for i in range(len(f.keys())):
        group = f[str(i)]
        patches.append(group["patches"][()])
        labels.append(group["labels"][()])
        bboxes.append(group["bboxes"][()])


250


In [104]:
len(patches), len(labels), len(bboxes)

(250, 250, 250)