In [None]:
import sys
# set CUDA_LAUNCH_BLOCKING=1
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
import numpy as np
from skimage.measure import label, regionprops
from skimage.transform import rotate
import matplotlib.pyplot as plt

def extract_objects(mask):
    """Extract individual objects from segmentation mask"""
    labeled = label(mask)
    return [mask[minr:maxr, minc:maxc] 
            for region in regionprops(labeled) 
            for (minr, minc, maxr, maxc) in [region.bbox]]

def rotate_object(obj, angle):
    """Rotate object by arbitrary angle with padding"""
    return rotate(obj, angle, resize=True, order=0, 
                 preserve_range=True).astype(bool)

def random_placement(canvas, objects):
    """Place objects randomly on empty canvas"""
    h, w = canvas.shape
    for obj in objects:
        oh, ow = obj.shape
        if (max_x := w - ow) >= 0 and (max_y := h - oh) >= 0:
            x, y = np.random.randint(0, max_x+1), np.random.randint(0, max_y+1)
            canvas[y:y+oh, x:x+ow] |= obj
    return canvas

# Example usage
mask = np.zeros((128, 128), dtype=bool)
mask[10:31, 10:31] = True  # Square object
rr, cc = np.ogrid[:128, :128]
mask[(rr - 80)**2 + (cc - 80)**2 <= 225] = True  # Circular object (radius 15)

# Processing pipeline
objects = extract_objects(mask)
rotated = [rotate_object(obj, np.random.uniform(0, 360)) for obj in objects]
result = random_placement(np.zeros_like(mask), rotated)

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(mask, cmap='gray'), ax1.set_title('Original Mask')
# ax2.imshow(np.hstack([rotate_object(obj, 45) for obj in objects]),  # Example 45° rotation
#            cmap='gray'), ax2.set_title('Rotated Objects')
ax2.imshow(result, cmap='gray'), ax2.set_title('Final Composition')
plt.show()


In [None]:
# import imageio.v3 as imageio
# import napari

# from micro_sam import instance_segmentation, util
# from micro_sam.multi_dimensional_segmentation import automatic_3d_segmentation
# from disentangle.core.tiff_reader import load_tiff

# def cell_segmentation():
#     """Run the instance segmentation functionality from micro_sam for segmentation of
#     HeLA cells. You need to run examples/annotator_2d.py:hela_2d_annotator once before
#     running this script so that all required data is downloaded and pre-computed.
#     """
#     image_path = "/home/ashesh.ashesh/code/Disentangle/disentangle/notebooks/test_img.tiff"
#     embedding_path = "../embeddings/embeddings-hela2d.zarr"

#     # Load the image, the SAM Model, and the pre-computed embeddings.
#     image = load_tiff(image_path)
#     predictor = util.get_sam_model()
#     embeddings = util.precompute_image_embeddings(predictor, image, save_path=embedding_path)

#     # Use the instance segmentation logic of Segment Anything.
#     # This works by covering the image with a grid of points, getting the masks for all the poitns
#     # and only keeping the plausible ones (according to the model predictions).
#     # While the functionality here does the same as the implementation from Segment Anything,
#     # we enable changing the hyperparameters, e.g. 'pred_iou_thresh', without recomputing masks and embeddings,
#     # to support (interactive) evaluation of different hyperparameters.

#     # Create the automatic mask generator class.
#     amg = instance_segmentation.AutomaticMaskGenerator(predictor)

#     # Initialize the mask generator with the image and the pre-computed embeddings.
#     amg.initialize(image, embeddings, verbose=True)

#     # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
#     # without having to call initialize again.
#     instances = amg.generate(pred_iou_thresh=0.88)
#     # instances = instance_segmentation.mask_data_to_segmentation(
#     #     instances, with_background=True
#     # )

#     # # instances = instance_segmentation.mask_data_to_segmentation(
#     # #     instances, shape=image.shape, with_background=True
#     # # )

#     # # Show the results.
#     # v = napari.Viewer()
#     # v.add_image(image)
#     # v.add_labels(instances)
#     # napari.run()
#     return image, instances


# def cell_segmentation_with_tiling():
#     """Run the instance segmentation functionality from micro_sam for segmentation of
#     cells in a large image. You need to run examples/annotator_2d.py:wholeslide_annotator once before
#     running this script so that all required data is downloaded and pre-computed.
#     """
#     image_path = "../data/whole-slide-example-image.tif"
#     embedding_path = "../embeddings/whole-slide-embeddings.zarr"

#     # Load the image, the SAM Model, and the pre-computed embeddings.
#     image = imageio.imread(image_path)
#     predictor = util.get_sam_model()
#     embeddings = util.precompute_image_embeddings(
#         predictor, image, save_path=embedding_path, tile_shape=(1024, 1024), halo=(256, 256)
#     )

#     # Use the instance segmentation logic of Segment Anything.
#     # This works by covering the image with a grid of points, getting the masks for all the poitns
#     # and only keeping the plausible ones (according to the model predictions).
#     # The functionality here is similar to the instance segmentation in Segment Anything,
#     # but uses the pre-computed tiled embeddings.

#     # Create the automatic mask generator class.
#     amg = instance_segmentation.TiledAutomaticMaskGenerator(predictor)

#     # Initialize the mask generator with the image and the pre-computed embeddings.
#     amg.initialize(image, embeddings, verbose=True)

#     # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
#     # without having to call initialize again.
#     instances = amg.generate(pred_iou_thresh=0.88)
#     instances = instance_segmentation.mask_data_to_segmentation(
#         instances, shape=image.shape, with_background=True
#     )

#     # Show the results.
#     v = napari.Viewer()
#     v.add_image(image)
#     v.add_labels(instances)
#     v.add_labels(instances)
#     napari.run()


In [None]:
import numpy as np
# import matplotlib.pyplot as plt
from cellpose import models, io
from cellpose.io import imread

io.logger_setup()

# model_type='cyto' or 'nuclei' or 'cyto2' or 'cyto3'
model = models.Cellpose(model_type='nuclei', gpu=True)

# list of files
# PUT PATH TO YOUR FILES HERE!
files = ['/group/jug/ashesh/data/sox2_one_img.tiff']

imgs = [imread(f) for f in files]
nimg = len(imgs)

# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]
# if NUCLEUS channel does not exist, set the second channel to 0
channels = [[0,0]]
# IF ALL YOUR IMAGES ARE THE SAME TYPE, you can give a list with 2 elements
# channels = [0,0] # IF YOU HAVE GRAYSCALE
# channels = [2,3] # IF YOU HAVE G=cytoplasm and B=nucleus
# channels = [2,1] # IF YOU HAVE G=cytoplasm and R=nucleus

# if diameter is set to None, the size of the cells is estimated on a per image basis
# you can set the average cell `diameter` in pixels yourself (recommended)
# diameter can be a list or a single number for all images

masks, flows, styles, diams = model.eval(imgs, diameter=None, channels=channels)


### or to run one of the other models, or a custom model, specify a CellposeModel
# model = models.CellposeModel(model_type='livecell_cp3')

# masks, flows, styles = model.eval(imgs, diameter=30, channels=[0,0])

In [None]:
import matplotlib.pyplot as plt
_,ax = plt.subplots(figsize=(10,5),ncols=2)
ax[0].imshow(masks[0])
ax[1].imshow(imgs[0])

In [None]:
BACKGROUND_PATCHES= [ imgs[0][:200,:500].copy(), 
                     imgs[0][200:400,:400].copy(),
                     imgs[0][100:300,:400].copy(),
                     imgs[0][400:600,:300].copy()
                     ]

In [None]:
def find_objects(mask, img):
    id_list = np.unique(mask)
    # remove 0 
    id_list = id_list[id_list != 0]
    objects = []
    for id in id_list:
        obj_mask = mask == id
        # find a bounding box around the object
        y, x = np.where(obj_mask)
        min_x, max_x = np.min(x), np.max(x)
        min_y, max_y = np.min(y), np.max(y)
        # extract the object
        obj = img[min_y:max_y, min_x:max_x]
        obj_mask = obj_mask[min_y:max_y, min_x:max_x]
        # concaatenate the object and the mask
        obj = np.concatenate((obj[None], obj_mask[None]), axis=0)
        objects.append(obj)
    return objects

In [None]:
objects = find_objects(masks[0], imgs[0])

In [None]:
_,ax = plt.subplots(figsize=(8,4),ncols=2)
idx = 5
ax[0].imshow(objects[idx][0])
ax[1].imshow(objects[idx][1])

## Rotate them.

In [None]:
from deepinv.transform.projective import Homography
import torch
import kornia
import torch.nn.functional as F
import itertools


def rotate(tensor, angle):
    return kornia.geometry.transform.rotate(tensor, angle, 
                                            center=None, 
                                            mode='bilinear', 
                                            padding_mode='zeros', 
                                            align_corners=True)

def random_flip(tensor, p = 0.5):
    tensor = kornia.augmentation.RandomHorizontalFlip(p=p)(tensor)
    tensor = kornia.augmentation.RandomVerticalFlip(p=p)(tensor)
    return tensor

def transform_object(tensor, max_angle=180):
    assert tensor.ndim == 3, "Tensor should be of shape (C, H, W)"
    # rotation.
    angle = np.random.uniform(0, max_angle)*1.0
    h,w = tensor.shape[-2:]
    sz = int(np.ceil(np.sqrt(h**2 + w**2)))
    tensor =torch.Tensor(tensor*1.0)
    tensor = F.pad(tensor, ((sz-w)//2,(sz-w)//2,(sz-h)//2,(sz-h)//2))
    object_rotated = rotate(tensor, torch.Tensor([angle]))
    # torch.where(object_rotated != 0)
    # crop it. 
    idx = object_rotated.nonzero()
    # print(idx.shape)
    x_min = idx[:, -2].min()
    x_max = idx[:, -2].max()
    y_min = idx[:, -1].min()
    y_max = idx[:, -1].max()
    print(x_min, x_max, y_min, y_max, object_rotated.shape)
    object_rotated = object_rotated[...,x_min:x_max,y_min:y_max]
    # randomly flip, it adds one more dimension.
    object_rotated = random_flip(object_rotated)[0]
    return object_rotated

def get_object_ordering(objects, perm, square_size):
    new_row_loc = []
    next_pos = None
    for i, idx in enumerate(perm):
        if next_pos is None:
            next_pos = objects[idx].shape[1]
        else:
            if next_pos > square_size:
                new_row_loc.append(i)
                next_pos = None
    return new_row_loc        

def get_combined_frame_dims(objects, perm, ordering):
    combined_h = 0
    combined_w = 0
    row_h = 0
    row_w = 0
    i_start = 0
    ordering_full = [x for x in ordering] + [len(objects)]
    for i_end in ordering_full:
        for i in range(i_start, i_end):
            row_h = max(row_h, objects[perm[i]].shape[0])
            row_w += objects[perm[i]].shape[1]
        combined_h += row_h
        combined_w = max(combined_w, row_w)
        row_h = 0
        row_w = 0
        i_start = i_end
    return combined_h, combined_w

def get_background(size):
    # create a white background
    idx_list = np.random.permutation(len(BACKGROUND_PATCHES))
    for idx in idx_list:
        patch = BACKGROUND_PATCHES[idx]
        h, w = patch.shape
        if h >= size[0] and w >= size[1]:
            # 
            patch = random_flip(patch * 1.0)[0,0]
            print('after random_flip', patch.shape)
            # crop it.
            x_min = np.random.randint(0, h - size[0])
            x_max = x_min + size[0]
            y_min = np.random.randint(0, w - size[1])
            y_max = y_min + size[1]
            return patch[x_min:x_max, y_min:y_max] * 1.0
        elif h >= size[1] and w >= size[0]:
            # rotate by 90 degrees
            patch = rotate(patch, 90) if np.random.rand() > 0.5 else rotate(patch, -90)
            patch = random_flip(patch* 1.0)[0,0]
            print('after random_flip', patch.shape)
            h, w = patch.shape
            # crop it.
            x_min = np.random.randint(0, h - size[1])
            x_max = x_min + size[1]
            y_min = np.random.randint(0, w - size[0])
            y_max = y_min + size[0]
            return patch[x_min:x_max, y_min:y_max] * 1.0
        
    raise ValueError(f"No background patch found that fits the size")

def get_rectrangle_ratio(objects, perm, ordering):
    combined_h, combined_w = get_combined_frame_dims(objects, perm, ordering)
    return max(combined_h / combined_w, combined_w / combined_h)

def render_objects(objects, perm, ordering):
    combined_h, combined_w = get_combined_frame_dims(objects, perm, ordering)
    final_frame = get_background((combined_h, combined_w))
    ordering_full = [x for x in ordering] + [len(objects)]

    combined_h = 0
    # combined_w = 0
    row_h = 0
    row_w = 0
    i_start = 0
    for i_end in ordering_full:
        for i in range(i_start, i_end):
            h,w = objects[perm[i]].shape
            mask = objects[perm[i]] > 0
            final_frame[combined_h:combined_h+h, row_w:row_w+w] = final_frame[combined_h:combined_h+h, row_w:row_w+w] + mask*objects[perm[i]]
            row_h = max(row_h, objects[perm[i]].shape[0])
            row_w += objects[perm[i]].shape[1]
        combined_h += row_h
        # combined_w = max(combined_w, row_w)
        row_h = 0
        row_w = 0
        i_start = i_end
    return final_frame


def combine_objects(objects):
    area = 0
    for obj in objects:
        h,w = obj.shape
        area += h*w

    square_size = int(np.ceil(np.sqrt(area)))
    h_max = max([x.shape[0] for x in objects])
    w_max = max([x.shape[1] for x in objects])
    square_size = max(max(square_size, h_max), w_max)
    # find a generator for all permutations from 0 to n-1

    n = len(objects)  # Change as needed
    best_perm = None
    best_ratio = None
    for perm in itertools.permutations(range(n)):
        ordering = get_object_ordering(objects, perm, square_size)
        ratio = get_rectrangle_ratio(objects, ordering)
        if best_ratio is None or ratio < best_ratio:
            best_perm = perm
            best_ratio = ratio

    combined_img = render_objects(objects, best_perm, square_size)
    return combined_img


In [None]:
_,ax = plt.subplots(figsize=(18,3),ncols=6)
for i in range(6):
    ax[i].imshow(get_background((160, 160)))

In [None]:
rendered_outputs = []
idx_list = np.random.randint(0, len(objects), 4)
for i in range(10):
    trans_objects = []
    for i in idx_list:
        tensor =torch.Tensor(objects[i]*1.0)
        new_obj = transform_object((tensor[0]*tensor[1])[None], max_angle=180)
        trans_objects.append(new_obj.squeeze())
        # print(obj.shape)

    output = render_objects(trans_objects, [0,1,2,3], [2])
    rendered_outputs.append(output)

In [None]:
_,ax = plt.subplots(figsize=(15,6),ncols=5,nrows=2)
for i in range(10):
    ax[i//5][i%5].imshow(rendered_outputs[i])
    # ax[i//3][i%3].axis('off')