To execute a cell with code, select it & press Shift+Enter

# Registering stitched data

## Define file-paths:

Your files should be saved with the following structure and `<name>` and `<channel>` should be the actual name and channel: 
('*' can be any characters)
```
base_dir
├───round1
│   ├───*A1*
│   │   └───<name>_<channel1>.ome.tif
│   │   ├───<name>_<channel2>.ome.tif
|   |   .
|   |   .
├───round2
│   ├───*A1*
│   │   └───<name>_<channel1>.ome.tif
|   |   .
|   |   .
```

In the cell below, change the following file-paths:
* `reference_channel`: the reference channel, which is used for registration (ususally 'DAPI')
* `reference_round`: the reference round to which everything is aligned
* `base_dir`: the directory where all rounds are in
* `save_dir`: the directory where the registered images will be saved

In [None]:
# INPUT REQUIRED
reference_channel = "DAPI"
reference_round = "1"
base_dir = r"Z:\zmbstaff\9780\Processed_Data\MD_1\stitched"
save_dir = r"Z:\zmbstaff\9780\Processed_Data\MD_1\registered"

## Import modules and functions:

In [None]:
import glob
import os
import re
import time

import matplotlib.pyplot as plt
import napari
import numpy as np
import ome_types
import pandas as pd
import tifffile
from skimage.measure import block_reduce
from skimage.registration import phase_cross_correlation

In [None]:
def register(img1, img2, upsample_factor):
    min_z, min_y, min_x = np.array([img1.shape, img2.shape]).min(axis=0)
    shift, error, phasediff = phase_cross_correlation(
        img1[:min_z, :min_y, :min_x],
        img2[:min_z, :min_y, :min_x],
        disambiguate=False,
        upsample_factor=upsample_factor,
    )
    return shift

In [None]:
def block_reduce_seq(data, block_size, seq_size=256):
    fragments_ds = []
    for i in np.arange(0, data.shape[-1], (block_size[-1] * seq_size)):
        fragment = data[
            :, :, i : np.min([i + (block_size[-1] * seq_size), data.shape[-1]])
        ]
        fragment_ds = block_reduce(
            fragment,
            block_size=block_size,
            func=np.mean,
            func_kwargs={"dtype": data.dtype},
        )
        fragments_ds.append(fragment_ds)
    return np.concatenate(fragments_ds, axis=-1)

In [None]:
# for registration, the images will be downsampled with these factors (can be adjusted, if needed)
downsample_factors = np.array((2, 6, 6))

## Locate files and generate dataframe

In [None]:
# load files into dataframe
os.makedirs(save_dir, exist_ok=True)
fns = glob.glob(os.path.join(base_dir, "round*", "*", "*.ome.tif"))
pattern = r".*[\/\\]round(?P<round>\d+)[\/\\](?P<well>[A-Z]\d+)[\/\\](?P<name>.*)_(?P<channel>.*).ome.tif"
files = []
for fn in fns:
    match = re.fullmatch(pattern, fn)
    row = match.groupdict()
    row["path"] = fn
    files.append(row)
files = pd.DataFrame(files)
wells = files["well"].unique()
rounds = files["round"].unique()

# add some metadata to dataframe
for index, file in files.iterrows():
    ome_dict = ome_types.to_dict(ome_types.from_tiff(file.path))
    (dx, dy, dz) = [
        ome_dict["images"][0]["pixels"][key]
        for key in ["physical_size_x", "physical_size_y", "physical_size_z"]
    ]
    files.loc[index, ["dx", "dy", "dz"]] = (dx, dy, dz)
    (dim_x, dim_y, dim_z) = [
        ome_dict["images"][0]["pixels"][key] for key in ["size_x", "size_y", "size_z"]
    ]
    files.loc[index, ["dim_x", "dim_y", "dim_z"]] = (dim_x, dim_y, dim_z)

In [None]:
files

## process all wells:

### load and save channels as individual files

In [None]:
for well in wells:
    well_files = files.query("well==@well")
    for name in well_files["name"].unique():
        print("\nPROCESSING WELL " + well + ", " + name)

        name_files = well_files.query("name==@name")
        reg_files = name_files.query("channel==@reference_channel")

        print("\nloading and registering reference channel...")
        start_time = time.time()

        fix_file = reg_files.query("round==@reference_round").iloc[0]
        fix_img = tifffile.imread(fix_file["path"])
        fix_img_ds = block_reduce_seq(
            fix_img, block_size=tuple(downsample_factors), seq_size=256
        )
        files.loc[fix_file.name, ["shift_x", "shift_y", "shift_z"]] = 0
        for index, file in reg_files.query("round!=@reference_round").iterrows():
            mov_img = tifffile.imread(file["path"])
            mov_img_ds = block_reduce_seq(
                mov_img, block_size=tuple(downsample_factors), seq_size=256
            )
            shift_px_ds = register(fix_img_ds, mov_img_ds, np.max(downsample_factors))
            shift_px = np.round(shift_px_ds * downsample_factors).astype(int)
            files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px
        for index, file in name_files.iterrows():
            round = file["round"]
            shift_px = files.query(
                "well==@well and channel==@reference_channel and round==@round"
            ).iloc[0][["shift_z", "shift_y", "shift_x"]]
            files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px
        shifts = files.query("well==@well")[["shift_z", "shift_y", "shift_x"]].to_numpy(
            dtype=int
        )
        shifts = shifts - shifts.min(axis=0)
        files.loc[
            files.query("well==@well").index, ["shift_z", "shift_y", "shift_x"]
        ] = shifts
        dims = files.query("well==@well")[["dim_z", "dim_y", "dim_x"]].to_numpy(
            dtype=int
        )
        nz, ny, nx = np.max(shifts + dims, axis=0)

        print(f"took {(time.time() - start_time):.1f}s")

        print("\nloading and aligning all channels...")
        start_time = time.time()

        for index, _ in name_files.iterrows():
            file = files.loc[index]
            round = file["round"]
            channel = file["channel"]
            channel_name = f"round{round}_{channel}"
            print(f"Processing {channel_name}")
            print("Loading...")
            img = tifffile.imread(file["path"])
            slc_z = slice(int(file["shift_z"]), int(file["shift_z"] + file["dim_z"]))
            slc_y = slice(int(file["shift_y"]), int(file["shift_y"] + file["dim_y"]))
            slc_x = slice(int(file["shift_x"]), int(file["shift_x"] + file["dim_x"]))
            img_reg = np.zeros((nz, ny, nx), dtype=img.dtype)
            img_reg[slc_z, slc_y, slc_x] = img
            os.makedirs(os.path.join(save_dir, name + "_" + well), exist_ok=True)
            print("Saving...")
            with tifffile.TiffWriter(
                os.path.join(save_dir, name + "_" + well, channel_name + ".ome.tif"),
                bigtiff=True,
            ) as tif:
                metadata = {
                    "axes": "ZYX",
                    "PhysicalSizeX": dx,
                    "PhysicalSizeXUnit": "µm",
                    "PhysicalSizeY": dy,
                    "PhysicalSizeYUnit": "µm",
                    "PhysicalSizeZ": dz,
                    "PhysicalSizeZUnit": "µm",
                    "Channel": {"Name": channel_name},
                }
                options = dict(
                    photometric="minisblack",
                )
                tif.write(img_reg, metadata=metadata, **options)

        print(f"took {(time.time() - start_time):.1f}s")

### load and save all channels together

In [None]:
for well in wells:
    well_files = files.query("well==@well")
    for name in well_files["name"].unique():
        print("\nPROCESSING WELL " + well + ", " + name)

        name_files = well_files.query("name==@name")
        reg_files = name_files.query("channel==@reference_channel")

        print("\nloading and registering reference channel...")
        start_time = time.time()

        fix_file = reg_files.query("round==@reference_round").iloc[0]
        fix_img = tifffile.imread(fix_file["path"])
        fix_img_ds = block_reduce_seq(
            fix_img, block_size=tuple(downsample_factors), seq_size=256
        )
        files.loc[fix_file.name, ["shift_x", "shift_y", "shift_z"]] = 0
        for index, file in reg_files.query("round!=@reference_round").iterrows():
            mov_img = tifffile.imread(file["path"])
            mov_img_ds = block_reduce_seq(
                mov_img, block_size=tuple(downsample_factors), seq_size=256
            )
            shift_px_ds = register(fix_img_ds, mov_img_ds, np.max(downsample_factors))
            shift_px = np.round(shift_px_ds * downsample_factors).astype(int)
            files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px
        for index, file in name_files.iterrows():
            round = file["round"]
            shift_px = files.query(
                "well==@well and channel==@reference_channel and round==@round"
            ).iloc[0][["shift_z", "shift_y", "shift_x"]]
            files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px
        shifts = files.query("well==@well")[["shift_z", "shift_y", "shift_x"]].to_numpy(
            dtype=int
        )
        shifts = shifts - shifts.min(axis=0)
        files.loc[
            files.query("well==@well").index, ["shift_z", "shift_y", "shift_x"]
        ] = shifts
        dims = files.query("well==@well")[["dim_z", "dim_y", "dim_x"]].to_numpy(
            dtype=int
        )
        nz, ny, nx = np.max(shifts + dims, axis=0)

        print(f"took {(time.time() - start_time):.1f}s")

        print("\nloading and aligning all channels...")
        start_time = time.time()

        channel_names = []
        imgs_reg = np.zeros((len(name_files), nz, ny, nx), dtype=fix_img.dtype)
        for n, (index, _) in enumerate(name_files.iterrows()):
            file = files.loc[index]
            round = file["round"]
            channel = file["channel"]
            channel_name = f"round{round}_{channel}"
            channel_names.append(channel_name)
            print(f"Loading {channel_name}")
            img = tifffile.imread(file["path"])
            slc_z = slice(int(file["shift_z"]), int(file["shift_z"] + file["dim_z"]))
            slc_y = slice(int(file["shift_y"]), int(file["shift_y"] + file["dim_y"]))
            slc_x = slice(int(file["shift_x"]), int(file["shift_x"] + file["dim_x"]))
            imgs_reg[n, slc_z, slc_y, slc_x] = img

        print(f"took {(time.time() - start_time):.1f}s")

        print("\nsaving data...")
        start_time = time.time()

        subresolutions = 2
        with tifffile.TiffWriter(
            os.path.join(save_dir, name + "_" + well + ".ome.tif"), bigtiff=True
        ) as tif:
            metadata = {
                "axes": "CZYX",
                "PhysicalSizeX": dx,
                "PhysicalSizeXUnit": "µm",
                "PhysicalSizeY": dy,
                "PhysicalSizeYUnit": "µm",
                "PhysicalSizeZ": dz,
                "PhysicalSizeZUnit": "µm",
                "Channel": {"Name": channel_names},
            }
            options = dict(
                photometric="minisblack",
                tile=(128, 128),
                resolutionunit="CENTIMETER",
                maxworkers=32,
            )
            tif.write(
                imgs_reg,
                subifds=subresolutions,
                resolution=(1e4 / dy, 1e4 / dx),
                metadata=metadata,
                **options,
            )
            # write pyramid levels to the two subifds
            # TODO: in production use resampling to generate sub-resolution images
            for level in range(subresolutions):
                mag = 2 ** (level + 1)
                tif.write(
                    imgs_reg[..., ::mag, ::mag],
                    subfiletype=1,
                    resolution=(1e4 / mag / dy, 1e4 / mag / dx),
                    **options,
                )

        print(f"took {(time.time() - start_time):.1f}s")

## process wells individually:

In [None]:
# INPUT REQUIRED
well = "A1"
name = "stitched"

In [None]:
well_files = files.query("well==@well")
name_files = well_files.query("name==@name")
reg_files = name_files.query("channel==@reference_channel")

In [None]:
%%time
fix_file = reg_files.query("round==@reference_round").iloc[0]
fix_img = tifffile.imread(fix_file["path"])
fix_img_ds = block_reduce_seq(
    fix_img, block_size=tuple(downsample_factors), seq_size=256
)
files.loc[fix_file.name, ["shift_x", "shift_y", "shift_z"]] = 0

In [None]:
%%time
for index, file in reg_files.query("round!=@reference_round").iterrows():
    mov_img = tifffile.imread(file["path"])
    mov_img_ds = block_reduce_seq(
        mov_img, block_size=tuple(downsample_factors), seq_size=256
    )
    shift_px_ds = register(fix_img_ds, mov_img_ds, np.max(downsample_factors))
    shift_px = np.round(shift_px_ds * downsample_factors).astype(int)
    files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px

In [None]:
for index, file in name_files.iterrows():
    round = file["round"]
    shift_px = files.query(
        "well==@well and channel==@reference_channel and round==@round"
    ).iloc[0][["shift_z", "shift_y", "shift_x"]]
    files.loc[index, ["shift_z", "shift_y", "shift_x"]] = shift_px
shifts = files.query("well==@well")[["shift_z", "shift_y", "shift_x"]].to_numpy(
    dtype=int
)
shifts = shifts - shifts.min(axis=0)
files.loc[files.query("well==@well").index, ["shift_z", "shift_y", "shift_x"]] = shifts
dims = files.query("well==@well")[["dim_z", "dim_y", "dim_x"]].to_numpy(dtype=int)
nz, ny, nx = np.max(shifts + dims, axis=0)

In [None]:
files

### load and save channels as individual files

In [None]:
%%time
for index, _ in name_files.iterrows():
    file = files.loc[index]
    round = file["round"]
    channel = file["channel"]
    channel_name = f"round{round}_{channel}"
    print(f"Processing {channel_name}")
    print("Loading...")
    img = tifffile.imread(file["path"])
    slc_z = slice(int(file["shift_z"]), int(file["shift_z"] + file["dim_z"]))
    slc_y = slice(int(file["shift_y"]), int(file["shift_y"] + file["dim_y"]))
    slc_x = slice(int(file["shift_x"]), int(file["shift_x"] + file["dim_x"]))
    img_reg = np.zeros((nz, ny, nx), dtype=img.dtype)
    img_reg[slc_z, slc_y, slc_x] = img
    os.makedirs(os.path.join(save_dir, name + "_" + well), exist_ok=True)
    print("Saving...")
    with tifffile.TiffWriter(
        os.path.join(save_dir, name + "_" + well, channel_name + ".ome.tif"),
        bigtiff=True,
    ) as tif:
        metadata = {
            "axes": "ZYX",
            "PhysicalSizeX": dx,
            "PhysicalSizeXUnit": "µm",
            "PhysicalSizeY": dy,
            "PhysicalSizeYUnit": "µm",
            "PhysicalSizeZ": dz,
            "PhysicalSizeZUnit": "µm",
            "Channel": {"Name": channel_name},
        }
        options = dict(
            photometric="minisblack",
        )
        tif.write(img_reg, metadata=metadata, **options)

### load and save all channels together

In [None]:
%%time
channel_names = []
imgs_reg = np.zeros((len(name_files), nz, ny, nx), dtype=fix_img.dtype)
for n, (index, _) in enumerate(name_files.iterrows()):
    file = files.loc[index]
    round = file["round"]
    channel = file["channel"]
    channel_name = f"round{round}_{channel}"
    channel_names.append(channel_name)
    print(f"Loading {channel_name}")
    img = tifffile.imread(file["path"])
    slc_z = slice(int(file["shift_z"]), int(file["shift_z"] + file["dim_z"]))
    slc_y = slice(int(file["shift_y"]), int(file["shift_y"] + file["dim_y"]))
    slc_x = slice(int(file["shift_x"]), int(file["shift_x"] + file["dim_x"]))
    imgs_reg[n, slc_z, slc_y, slc_x] = img

In [None]:
viewer = napari.Viewer()
viewer.add_image(
    imgs_reg,
    blending="additive",
    channel_axis=0,
    name=channel_names,
    contrast_limits=(0, 2000),
    scale=(dz, dy, dx),
)

In [None]:
%%time
subresolutions = 2
with tifffile.TiffWriter(
    os.path.join(save_dir, name + "_" + well + ".ome.tif"), bigtiff=True
) as tif:
    metadata = {
        "axes": "CZYX",
        "PhysicalSizeX": dx,
        "PhysicalSizeXUnit": "µm",
        "PhysicalSizeY": dy,
        "PhysicalSizeYUnit": "µm",
        "PhysicalSizeZ": dz,
        "PhysicalSizeZUnit": "µm",
        "Channel": {"Name": channel_names},
    }
    options = dict(
        photometric="minisblack",
        tile=(128, 128),
        resolutionunit="CENTIMETER",
        maxworkers=32,
    )
    tif.write(
        imgs_reg,
        subifds=subresolutions,
        resolution=(1e4 / dy, 1e4 / dx),
        metadata=metadata,
        **options
    )
    # write pyramid levels to the two subifds
    # TODO: in production use resampling to generate sub-resolution images
    for level in range(subresolutions):
        mag = 2 ** (level + 1)
        tif.write(
            imgs_reg[..., ::mag, ::mag],
            subfiletype=1,
            resolution=(1e4 / mag / dy, 1e4 / mag / dx),
            **options
        )