In [None]:
from dataclasses import dataclass

@dataclass
class Location:
    h: int
    w: int
    sz: int
    frame_idx: int = -1
    SgI_fpath: str = ""
    SgGT_fpath: str = ""
    SgP_fpath: str = ""
    GT_fpath: str = ""
    reviewer: str = ""

import numpy as np
def patchify_one_frame(SgI: np.ndarray, SgGT: np.ndarray, SgP: np.ndarray, GT: np.ndarray, patch_size=128):    
    """
    Inputs:
        SgI: Segmentation output for the superimposed input frame.
        SgGT: Segmentation output for the GT frame.
        SgP: Segmentation output for the Prediction.
        I: Superimposed input frame.
        G: GT frame.
        P: Prediction frame.
    Outputs:
        A list of patch tuples (G patch, SgI patch, SgGT patch, SgP patch, patch_location)
    """
    assert len(SgI.shape) == 2
    assert SgI.shape == SgGT.shape == SgP.shape == GT.shape, f"Shape Mismatch: SgI{SgI.shape}, SgGT{SgGT.shape}, SgP{SgP.shape}, I{I.shape}, G{G.shape}, P{P.shape}"
    h = np.random.randint(0, SgI.shape[1]-patch_size)
    w = np.random.randint(0, SgI.shape[2]-patch_size)
    output = {
        'GT': GT[:, h:h+patch_size, w:w+patch_size],
        'SgI': SgI[h:h+patch_size, w:w+patch_size],
        'SgGT': SgGT[h:h+patch_size, w:w+patch_size],
        'SgP': SgP[h:h+patch_size, w:w+patch_size],
        'patch_location': Location(h, w, patch_size)
    }
    return output


def patchify_one_stack(SgI, SgGT, SgP, GT, patch_size=128):
    assert len(SgI.shape) == 3
    assert SgGT.shape == SgP.shape == GT.shape, f"Shape Mismatch: SgGT{SgGT.shape}, SgP{SgP.shape}, GT{GT.shape}"
    num_patches_per_frame = (GT.shape[-1] // patch_size)**2

    assert SgI.shape[0] == SgGT.shape[0] == SgP.shape[0] == GT.shape[0]
    assert SgI.shape[1] == SgGT.shape[1] == SgP.shape[1] == GT.shape[1]
    assert SgI.shape[2] == SgGT.shape[2] == SgP.shape[2] == GT.shape[2]
    patches = {
        'GT': [],
        'SgI': [],
        'SgGT': [],
        'SgP': [],
        'patch_location': []
    }
    for frame_idx in range(SgI.shape[0]):
        for _ in range(num_patches_per_frame):
            one_patch = patchify_one_frame(SgI[frame_idx], SgGT[frame_idx], SgP[frame_idx], GT[frame_idx], patch_size)
            patches['GT'].append(one_patch['GT'])
            patches['SgI'].append(one_patch['SgI'])
            patches['SgGT'].append(one_patch['SgGT'])
            patches['SgP'].append(one_patch['SgP'])
            location = one_patch['patch_location']
            location.frame_idx = frame_idx
            patches['patch_location'].append(location)
    return patches



In [None]:
import os
from disentangle.core.tiff_reader import load_tiff

rootdir = '/facility/imganfacusers/Ashesh/NatureMethodsSegmentationOutputs/Analysis_2405_D18-M3-S0-L8_13_1/'
GT_fpath = '/facility/imganfacusers/Ashesh/NatureMethodsSegmentation/2405_D18-M3-S0-L8_13/GT.tif'
GT = load_tiff(GT_fpath)
reviewers =['DDN','JB', 'EC']
for reviewer in reviewers:
    subdir = os.path.join(rootdir, reviewer)
    SgGT_dir = os.path.join(subdir, 'GT')
    SgI_dir = os.path.join(subdir, 'input')
    SgP_dir = os.path.join(subdir, 'pred')
    for frame_idx in [0,1,2,3,4]:
        SgGT_fpath = os.path.join(SgGT_dir, f'seg_{frame_idx}.tif')
        SgI_fpath = os.path.join(SgI_dir, f'seg_{frame_idx}.tif')
        SgP_fpath = os.path.join(SgP_dir, f'seg_{frame_idx}.tif')

        SgGT = load_tiff(SgGT_fpath)
        SgI = load_tiff(SgI_fpath)
        SgP = load_tiff(SgP_fpath)

        patches = patchify_one_stack(SgI, SgGT, SgP, GT, patch_size=128)
        location.SgI_fpath = SgI_fpath
        location.SgGT_fpath = SgGT_fpath
        location.SgP_fpath = SgP_fpath
        location.GT_fpath = GT_fpath

        break
    break

In [None]:
GT.shape

In [None]:
SgP.shape