In [None]:
%run utils/experiment_chooser.ipynb
experiment_chooser.choose()

In [None]:
from collections import defaultdict
import tifffile
from multiprocessing import Pool
import numpy as np
from skimage import filters
from skimage.exposure import rescale_intensity, adjust_gamma
import os
from pystackreg import StackReg

IN_COLLECTION = "stitched"
OUT_COLLECTION = "stacked"

experiment_base = experiment_chooser.fetch_base()

os.makedirs(experiment_base / "processed_imgs" / OUT_COLLECTION, exist_ok=True)

paths = list(experiment_base.glob(f"processed_imgs/{IN_COLLECTION}/*"))

def get_vertex(path):
    label = path.name.split(".")[0]
    return label.split("-")[0]

def get_channel(path):
    label = path.name.split(".")[0]
    return label.split("-")[1]

def get_tp(path):
    label = path.name.split(".")[0]
    return int(label.split("-")[-1])

def handle_stack(args):
    vertex_channel, paths = args
    vertex, channel = vertex_channel
    paths = sorted(paths, key=get_tp)
    stack = np.array([tifffile.imread(path) for path in paths])
    sr = StackReg(StackReg.RIGID_BODY)
    registered = sr.register_transform_stack(stack, reference='first')
    rl, rh = np.percentile(registered.flatten(), (0.1, 99.9))
    rescaled = np.array([rescale_intensity(img, in_range=(rl, rh), out_range=np.uint16) for img in registered])
    out_path = experiment_base / "processed_imgs" / OUT_COLLECTION / f"{vertex}-{channel}.tif"
    tifffile.imwrite(out_path, rescaled)

In [None]:
groups = defaultdict(list) 
for path in paths:
    vertex = get_vertex(path)
    channel = get_channel(path)
    groups[(vertex, channel)].append(path)

try:
    from notebooks.config.prod import CPUS
except:
    CPUS = os.cpu_count()    

with Pool() as p:
    for idx, _ in enumerate(p.imap(handle_stack, groups.items())):
        print("\r", end="")
        print(f"Stacked {idx + 1} / {len(groups)}", end="") 