# Zarr creation using centre positions for multiSEM

**NOTE**: I am trusting on the stitching parameters given by the adquisition software. This notebook only takes care of creating the last zarr for which I have to fuse the tiles.

# Single hexagon

## Loading centre position metadata from CVS file using pandas

In [None]:
import multisemzarr as msz

import pandas as pd
from pathlib import Path
import numpy as np
import skimage.io as skio
#from tqdm.notebook import tqdm

In [None]:
dataset_path = Path('/PROJECTS/CCI/BRAIN/Multibeam/Test_20Sections_20230217_15-17-11/Test_20Sections_20230217_15-17-11/')

id = '024'
region = 'Region4'
section = id + '_' + region

section_path = dataset_path.joinpath(section)

csv_p = section_path.joinpath(region + "_stitched_imagepositions.txt")

image_positions = msz.read_stitched_imagepositions(csv_p)
image_positions

We get information of each tile based on the naming convention of multiSEM

In [None]:
image_positions = msz.get_info_from_path(image_positions, section_path=section_path)

image_positions.sample(5)

I assume that al tiles have same size

In [None]:
# we fill tile dimentions and positions based on first tile, we assume all othes have same size
image_positions = msz.get_info_from_image(image_positions)
# trasnlating to 0,0
image_positions = msz.translation00(image_positions)

image_positions.sample(5)

## intensity corrections

In [None]:
testing = False
if testing:

    # this is for testing
    hex_pos = image_positions[image_positions["hexagon"].isin(['000011', 
                                                                '000012',
                                                                '000013',
                                                                '000014',
                                                                '000015' ])].copy()
    
    hex_pos.reset_index(inplace=True)
else:
    # this is for full image
    hex_pos = image_positions.copy()
hex_pos

In [None]:
hex_pos = msz.get_intensity_correction(hex_pos, method='q30')

## Check size of array 
This is to make it compatible with downscaling later own

In [None]:
hex_pos = msz.translation00(hex_pos)

x_size_tmp = (hex_pos['corner_x']+hex_pos["size_x"]).max()
total_x = int(msz.optimal_size(x_size_tmp, 5))

y_size_tmp = (hex_pos['corner_y']+hex_pos["size_y"]).max()
total_y = int(msz.optimal_size(y_size_tmp, 5))

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
#define Matplotlib figure and axis
fig, ax = plt.subplots()

#create simple line plot
ax.scatter(hex_pos["centre_x"].to_numpy(), hex_pos["centre_y"].to_numpy())

#add rectangle to plot
for index, row in hex_pos.iterrows():
    c_x = row["corner_x"]
    c_y = row["corner_y"]
    
    ax.add_patch(Rectangle((c_x, c_y), row["size_x"], row["size_y"],
             edgecolor = 'red',
             fill=False))

ax.add_patch(Rectangle((0, 0), total_x, total_y,
             edgecolor = 'blue',
             fill=False))

# axis as in image
ax.set_ylim(ax.get_ylim()[::-1])  
ax.xaxis.tick_top() 
ax.yaxis.tick_left()  

#display plot
plt.title("Tile distribution")
plt.show()

## Init ZARR array

In [None]:
import zarr
import skimage.io as skio

In [None]:
def rm_tree(pth):
    pth = Path(pth)
    for child in pth.glob('*'):
        if child.is_file():
            child.unlink()
        else:
            rm_tree(child)
    pth.rmdir()

In [None]:
z0_str = "./data/"+section+".zarr"
z0_path = Path(z0_str)

if z0_path.exists():
  rm_tree(z0_path)
  
store = zarr.DirectoryStore(z0_path)
img_tile = skio.imread(hex_pos['abs_path'][0])
chunk_size = np.max(img_tile.shape)
print(f'Chunk size: {chunk_size},{chunk_size}')
z = zarr.creation.open_array(store=store, mode='a', shape=(total_y, total_x), chunks=(chunk_size,chunk_size), dtype=img_tile.dtype)
z

## Dynamically fill in values

In [None]:
from tqdm.auto import tqdm
from functools import partial
from multiprocess import Pool

In [None]:
def correct_write_tile3(tile_info, zarr_array):
    from skimage.io import imread
    from numpy import multiply, transpose, median
    from multisemzarr import flat_field_correction
    #if type(tile_info) is tuple:
    #    tile_info = tile_info[1]

    tile = imread(tile_info['abs_path'])
    corr_tile = flat_field_correction(tile)

    original_med = tile_info['median_int']
    corr_med = median(corr_tile)

    corr_tile = multiply(corr_tile, original_med/corr_med).astype(tile.dtype)
    
    x1 = tile_info["corner_x"]
    x2 = x1+tile_info["size_x"]
    y1 = tile_info["corner_y"]
    y2 = y1+tile_info["size_y"]

    corr_factor = tile_info['int_corr']

    zarr_array[y1:y2,x1:x2] = multiply(transpose(corr_tile).astype(float), corr_factor).astype(zarr_array.dtype)

def correct_write_tile2(tile_info, zarr_array):
    from skimage.io import imread
    from numpy import multiply, transpose, median
    from skimage.exposure import equalize_adapthist
    #if type(tile_info) is tuple:
    #    tile_info = tile_info[1]

    tile = imread(tile_info['abs_path'])

    original_med = tile_info['median_int']
    img_adapteq = equalize_adapthist(tile, clip_limit=0.00)
    # img_adapteq.shape
    adapted_med = median(img_adapteq)
    img_adapteq = multiply(img_adapteq, original_med/adapted_med).astype(tile.dtype)
    
    x1 = tile_info["corner_x"]
    x2 = x1+tile_info["size_x"]
    y1 = tile_info["corner_y"]
    y2 = y1+tile_info["size_y"]

    corr_factor = tile_info['int_corr']

    zarr_array[y1:y2,x1:x2] = multiply(transpose(img_adapteq).astype(float), corr_factor).astype(zarr_array.dtype)

def correct_write_tile(tile_info, zarr_array):
    from skimage.io import imread
    from numpy import multiply, transpose
    #if type(tile_info) is tuple:
    #    tile_info = tile_info[1]

    tile = imread(tile_info['abs_path'])
    
    x1 = tile_info["corner_x"]
    x2 = x1+tile_info["size_x"]
    y1 = tile_info["corner_y"]
    y2 = y1+tile_info["size_y"]

    corr_factor = tile_info['int_corr']

    zarr_array[y1:y2,x1:x2] = multiply(transpose(tile).astype(float), corr_factor).astype(zarr_array.dtype)

np.unique(hex_pos['tile_number'])
chunks = []
for tn in np.unique(hex_pos['tile_number']):
    tmp = hex_pos[hex_pos['tile_number'].isin([tn])].copy()
    chunks.append(tmp)

print(f'found {len(chunks)} unique tile ids')

for current in tqdm(chunks):
    hex_list = []
    for tile_idx, row in current.iterrows():
        hex_list.append(row) 

    with Pool(20) as pool:
        seq = [row for row in hex_list]
        pool.map(partial(correct_write_tile3, zarr_array=z), seq)
        #pool.imap(partial(correct_write_tile, zarr_array=z), seq)
        #pool.close()
        #pool.join()

    #for row in hex_list:
    #    print(row['int_corr'])
    #    correct_write_tile(zarr_array=z, tile_info=row)



## To open in napari

This image can be now opened in Napari by drag a drop and using ```napari builtins```

## Changing now to ome-zarr

However, I want to add ome-zarr support. For that I need some minimal metadata, and optionally some resolution levels

For downsampling I will use ```dask-array``` as suggested in [this discussion](https://forum.image.sc/t/creating-an-ome-zarr-dynamically-from-tiles-stored-as-a-series-of-images-list-of-centre-positions-using-python/81657/12?u=camachodejay) 

In [None]:
import dask.array as da
# like numpy.mean, but maintains dtype, helper function
def mean_dtype(arr, **kwargs):
    return np.mean(arr, **kwargs).astype(arr.dtype)

In [None]:
# it is still not quite clear to me why, but we need to rechunk de data at this stage
# if not zarr writting later on will fail
d0 = da.from_zarr(store).rechunk(img_tile.shape[1],img_tile.shape[0])
d0

In [None]:
d1 = da.coarsen(mean_dtype, d0, {0:2,1:2}).rechunk(int(img_tile.shape[1]/2),int(img_tile.shape[0]/2))

d2 = da.coarsen(mean_dtype, d0, {0:4,1:4}).rechunk(int(img_tile.shape[1]/2),int(img_tile.shape[0]/2))

d3 = da.coarsen(mean_dtype, d0, {0:8,1:8}).rechunk(int(img_tile.shape[1]/2),int(img_tile.shape[0]/2))

d4 = da.coarsen(mean_dtype, d0, {0:16,1:16}).rechunk(int(img_tile.shape[1]/2),int(img_tile.shape[0]/2))

d5 = da.coarsen(mean_dtype, d0, {0:32,1:32}).rechunk(int(img_tile.shape[1]/2),int(img_tile.shape[0]/2))
d5

In [None]:
from ome_zarr.io import parse_url
from ome_zarr.writer import write_multiscale
from ome_zarr.writer import write_multiscales_metadata

In [None]:
# I can probably build this programmatically, for the moment I take a shortcut. 
# This assumes an image with full resolution and one downscale by 2x2
initial_pix_size = 4
initial_pix_unit = 'nanometer'
coordtfs = [
        [{'type': 'scale', 'scale': [initial_pix_size,initial_pix_size]},
         {'type': 'translation', 'translation': [0, 0]}],
        [{'type': 'scale', 'scale': [initial_pix_size*2,initial_pix_size*2]},
         {'type': 'translation', 'translation': [0, 0]}],
        [{'type': 'scale', 'scale': [initial_pix_size*4,initial_pix_size*4]},
         {'type': 'translation', 'translation': [0, 0]}],
        [{'type': 'scale', 'scale': [initial_pix_size*8,initial_pix_size*8]},
         {'type': 'translation', 'translation': [0, 0]}],
        [{'type': 'scale', 'scale': [initial_pix_size*16,initial_pix_size*16]},
         {'type': 'translation', 'translation': [0, 0]}],
        [{'type': 'scale', 'scale': [initial_pix_size*32,initial_pix_size*32]},
         {'type': 'translation', 'translation': [0, 0]}],
        ]
axes = [{'name': 'y', 'type': 'space', 'unit': initial_pix_unit},
        {'name': 'x', 'type': 'space', 'unit': initial_pix_unit}]

# Open the zarr group manually
path_str = "./data/"+section+"corr-ome.zarr"
path = Path(path_str)

if path.exists():
  rm_tree(path)

store = parse_url(path, mode='w').store
root = zarr.group(store=store)

# Use OME write multiscale;
write_multiscale([d0, d1, d2, d3, d4, d5],
        group=root, axes=axes, coordinate_transformations=coordtfs
        )
# add omero metadata: the napari ome-zarr plugin uses this to pass rendering
# options to napari.
root.attrs['omero'] = {
        'channels': [{
                'color': 'ffffff',
                'label': region,
                'active': True,
                }]
        }

In [None]:
if z0_path.exists():
  rm_tree(z0_path)