In [6]:
from pathlib import Path
from matplotlib import pyplot as plt
# from cellpose import plot, models
from numcodecs import blosc, Blosc
from typing import  Union, Set, Dict, Tuple, List, Any, Optional, Callable
from collections.abc import Iterator

import os
import numpy as np
import pandas as pd

def _database_to_prefix(database: Tuple[str, str]):
    return f'{database[0]}_{database[1]}'

class Well:
    def __init__(self, name: str, row: int, column: int, well_path: Path, ions_path: Callable[[Tuple[str, str]], Path], annotation_path: Callable[[Tuple[str, str]], Path]) -> None:
        self.name = name
        self.row = row
        self.column = column
        self.well_path = well_path
        self.ions_path = ions_path
        self.annotation_path = annotation_path
    
    def __str__(self) -> str:
        return self.name

class Tile:
    def __init__(self, name: str, well_name: str, row: int, column: int, grid_column: int, grid_row: int, pre_maldi: Path, post_maldi: Path, cellpose: Path, nuclei: Path, gridfit: Path) -> None:
        self.name = name
        self.well_name = well_name
        self.row = row
        self.column = column
        self.grid_column = grid_column
        self.grid_row = grid_row
        self.pre_maldi = pre_maldi
        self.post_maldi = post_maldi
        self.cellpose = cellpose
        self.nuclei = nuclei
        self.gridfit = gridfit
    
    def __str__(self) -> str:
        return self.name


class Experiment:
    def __init__(self, zarr_path: Path, no_rows: int, no_columns: int, no_grids_pr_row: int, no_grids_pr_column: int) -> None:
        blosc.init()

        self.zarr_path = zarr_path
        self.no_rows = no_rows
        self.no_columns = no_columns
        self.no_grids_pr_row = no_grids_pr_row
        self.no_grids_pr_column = no_grids_pr_column
        self.no_grids = no_grids_pr_row * no_grids_pr_column
    
    def well_from_row_column(self, row: int, column: int):
        row_name:str  = chr(ord('A')+row)
        col_name: str = str(column+1)
        well_name: str = row_name + col_name
        return self.well(well_name=well_name)
    
    def well(self, well_name: str) -> Well:
        zarr_path = self.zarr_path
        row_name:str  = well_name[0]
        col_name: str = well_name[1]
        row = ord(row_name)-ord('A')
        col = int(col_name)-1
        well_path: Path = zarr_path.joinpath(row_name, col_name)
        ions_path = lambda database: well_path.joinpath(f'{_database_to_prefix(database)}_ions.h5ad')
        annotation_path = lambda database: zarr_path.joinpath(f'{_database_to_prefix(database)}_{well_name}.h5ad')
        return Well(name=well_name, row=row, column=col, well_path=well_path, ions_path=ions_path, annotation_path=annotation_path)
    
    def tile(self, well: Well, grid_column: int, grid_row: int) -> Tile:
        well_path = well.well_path
        pre_path: Path = well_path.joinpath('pre')
        post_path: Path = well_path.joinpath('post')
        cellpose_path: Path = pre_path.joinpath('labels', 'cellpose')
        nuclei_path: Path = pre_path.joinpath('labels', 'nuclei')
        gridfit_path: Path = post_path.joinpath('labels', 'gridfit')
        grid_name = well.name + '_' + str(grid_column+1) + 'x' + str(grid_row+1)
        return Tile(name=grid_name, well_name=well.name, row=well.row, column=well.column, grid_column=grid_column, grid_row=grid_row, pre_maldi=pre_path, post_maldi=post_path, cellpose=cellpose_path, nuclei=nuclei_path, gridfit=gridfit_path)
    
    def wells(self) -> Iterator:
        for row in range(self.no_rows):
            for column in range(self.no_columns):
                yield self.well_from_row_column(row=row, column=column)
    
    def tiles_in_well(self, well: Well) -> Iterator:
        for grid_column in range(self.no_grids_pr_row):
            for grid_row in range(self.no_grids_pr_column):
                yield self.tile(well=well, grid_column=grid_column, grid_row=grid_row)

    def tiles(self) -> Iterator:
        for well in self.wells():
            for tile in self.tiles_in_well(well):
                yield tile
    
    def store_image(self, tile: Tile, channel: int, path: Path, data: np.ndarray) -> None:
        image_dir: Path = path.joinpath('0', str(channel), str(tile.grid_row))
        image_path: Path = image_dir.joinpath(str(tile.grid_column))
        os.makedirs(image_dir, exist_ok=True)
        with open(image_path, 'wb') as f:
            f.write(Blosc(cname='lz4', clevel=5, shuffle=1, blocksize=1).encode(data.tobytes(order='C')))
    
    def load_image(self, tile: Tile, channel: int, path: Path) -> np.ndarray:
        image_dir: Path = path.joinpath('0', str(channel), str(tile.grid_row))
        image_path: Path = image_dir.joinpath(str(tile.grid_column))
        dtype = np.uint16
        if 'gridfit' in str(path):
            dtype = np.uint8
        if image_path.exists:
            try:
                with open(image_path, 'rb') as f:
                    return np.frombuffer(buffer=blosc.decompress(f.read()), dtype=dtype).reshape((2048, 2048))
            except:
                pass
        return np.zeros((2048, 2048), dtype=dtype)

In [16]:
from tifffile import imwrite

experiment = Experiment(Path('/Users/alberto-mac/Documents/DA_ESPORTARE/LOCAL_EMBL_FILES/scratch/bailoni/projects/BII_hackathon_2023/data/data.zarr'), 2, 4, 3, 3)
out_dir = Path("/Users/alberto-mac/Documents/DA_ESPORTARE/LOCAL_EMBL_FILES/scratch/bailoni/projects/BII_hackathon_2023/data/input_tiles_tif")

for tile in experiment.tiles():
    imwrite(
        out_dir / f"tile_{tile}_BF.tif",
        experiment.load_image(tile, 0, tile.pre_maldi),
        imagej=True,
        metadata={
            "axes": "YX",
        },
    )
    imwrite(
        out_dir / f"tile_{tile}_DAPI.tif",
        experiment.load_image(tile, 1, tile.pre_maldi),
        imagej=True,
        metadata={
            "axes": "YX",
        },
    )
    imwrite(
        out_dir / f"tile_{tile}_masks.tif",
        experiment.load_image(tile, 0, tile.cellpose),
        imagej=True,
        metadata={
            "axes": "YX",
        },
    )



(2048, 2048)