In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sp
import pathml.ml.hovernet as hn
import spatialdata as sd
import spatialdata_io as sd_io
import spatialdata.transformations as sd_t
import xenium_utils as xu

In [None]:
image_filename = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/images/0_311.png'
mask_filename = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/masks/0_311.npy'
tif_input = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/output_image.tif'
h5_annotated = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/imputed_annotated.h5ad'
transform_file = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/alignments/MPS-1-matrix.csv'
xenium_file = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1'
obj_threshold = 10

In [None]:
full_image_parsed = xu.load_registered_image(tif_input)

In [None]:
adata_annotated = sp.read_h5ad(h5_annotated)
adata_annotated

In [None]:
sdata = sd_io.xenium(xenium_file, n_jobs=8, cells_as_shapes=True)
sdata

In [None]:
sdata.table.obs[["celltype_major"]] = adata_annotated.obs.reset_index()[['predicted.id']]
sdata

In [None]:
merged = sd.SpatialData(
    images={
        "he": full_image_parsed,
    },
    shapes={
        "cell_circles": sdata.shapes["cell_circles"], # Required for bbox queries otherwise adata table disappears
        "cell_boundaries": sdata.shapes["cell_boundaries"],
        "nucleus_boundaries": sdata.shapes["nucleus_boundaries"],
    },
    table=sdata.table,
)

In [None]:
A = pd.read_csv(transform_file, header=None).to_numpy()
if A.shape[0] == 2:
    A = np.append(A, [[0,0,1]], axis=0)
affineT = sd_t.Affine(A, input_axes=("x", "y"), output_axes=("x", "y"))

In [None]:
height, width = full_image_parsed['scale0']['image'].shape[-2:]
coords = [[[0, 0],[height, width]]]
coords

In [None]:
xu.sdata_load_img_mask(merged, affineT=affineT, img_key='he', expand_px=3)

In [None]:
plt.imshow(plt.imread(image_filename))
plt.axis('off')  # Optional: Remove the axis
plt.show()

In [None]:
mask = np.load(str(mask_filename))
np.unique(mask)
mask


In [None]:
plt.imshow(mask.transpose(1,2,0))

In [None]:
mask.transpose(1,2,0).shape

In [None]:
nucleus_mask = (mask[-1] == 0).astype(np.uint8) # invert the bg mask
filter_mask = (hn.remove_small_objs(nucleus_mask, obj_threshold) != 0)
filter_mask_bg = (filter_mask == 0)
mask[:-1] = np.multiply(filter_mask, mask[:-1])
mask[-1] = np.multiply(filter_mask_bg, mask[-1])

In [None]:
mask.shape