In [None]:
import cv2 as cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sp
import tqdm as tqdm
import pathml.ml.hovernet as hn
import squidpy.im as sp_im
import spatialdata as sd
import spatialdata.models as sd_m
import spatialdata_io as sd_io
import multiscale_spatial_image as msi
import spatialdata.transformations as sd_t
import xenium_utils as xu
from PIL import Image

In [None]:
image_filename = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/images/0_311.png'
mask_filename = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/masks/0_311.npy'
tif_input = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/output_image.tif'
h5_annotated = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1-output/imputed_annotated.h5ad'
transform_file = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/alignments/MPS-1-matrix.csv'
xenium_file = '/scratch/project_mnt/S0010/Andrew_N/XeniumData/MPS-1'
obj_threshold = 10

In [None]:
full_image_parsed = xu.load_registered_image(tif_input)

In [None]:
adata_annotated = sp.read_h5ad(h5_annotated)
adata_annotated

In [None]:
sdata = sd_io.xenium(xenium_file, n_jobs=8, cells_as_shapes=True)
sdata

In [None]:
sdata.table.obs[["celltype_major"]] = adata_annotated.obs.reset_index()[['predicted.id']]
sdata

In [None]:
merged = sd.SpatialData(
    images={
        "he": full_image_parsed,
    },
    shapes={
        "cell_circles": sdata.shapes["cell_circles"], # Required for bbox queries otherwise adata table disappears
        "cell_boundaries": sdata.shapes["cell_boundaries"],
        "nucleus_boundaries": sdata.shapes["nucleus_boundaries"],
    },
    table=sdata.table,
)

In [None]:
A = pd.read_csv(transform_file, header=None).to_numpy()
if A.shape[0] == 2:
    A = np.append(A, [[0,0,1]], axis=0)
affineT = sd_t.Affine(A, input_axes=("x", "y"), output_axes=("x", "y"))

In [None]:
height, width = full_image_parsed['scale0']['image'].shape[-2:]
coords = [[[0, 0],[height, width]]]
coords

In [None]:
# img_sep, mask_sep = xu.sdata_load_img_mask(merged, affineT=affineT, img_key='he', expand_px=3, return_sep=True)

In [None]:
#mask_sep.shape

In [None]:
img_key = 'he'
shape_key = 'nucleus_boundaries'
label_key = 'celltype_major'
t_shapes = sd_t.Sequence([
  sd_t.get_transformation(merged.shapes[shape_key]),
  sd_t.get_transformation(merged.images[img_key]).inverse()])


In [None]:
shapes = sd.transform(merged.shapes[shape_key], t_shapes)
shapes = sd.transform(shapes, affineT)

In [None]:
shapes.index = shapes.index.astype(int)

In [None]:
labels = merged.table.obs[label_key].to_frame()
labels.index = shapes.index.astype(int)
shapes_df = shapes.merge(labels, how = 'inner', right_index = True, left_index = True)
shapes_df['label'] = shapes_df[label_key].cat.codes

In [None]:
shapes_df_dict = {k: v for k, v in shapes_df.groupby(label_key)}      
shapes_df_dict               

In [None]:
img = merged.images[img_key]
if isinstance(img, msi.multiscale_spatial_image.MultiscaleSpatialImage):
    # Note that this clears any transformation attribute
    img = sd_m.Image2DModel.parse(img["scale0"].ds.to_array().squeeze(axis=0))
img = img.values

In [None]:
def new_mask_for_polygons(polygons, im_size, vals):
  print(im_size)
  if not isinstance(vals, (list, tuple, np.ndarray)):
      vals = np.ones_like(polygons)
  img_mask = np.zeros(im_size, np.float64)
  if not polygons:
      print("Not polys")
      return img_mask
  int_coords = lambda x: np.array(x).round().astype(np.int32)
  exteriors = [int_coords(poly.exterior.coords) if poly.geom_type == 'Polygon' 
            else int_coords(poly.convex_hull.exterior.coords)
            for poly in polygons]
  interiors = [poly.interiors if poly.geom_type == 'Polygon'
                else poly.convex_hull.interiors
                for poly in polygons]
  interiors = [int_coords(pi.coords) for poly in interiors for pi in poly] # interiors should be [] anyway
  print(f"Exteriors: {type(exteriors[0])} {exteriors[0]}")
  print(f"Interiors: {interiors}")
  print(f"Vals {vals}")
  cv2.fillPoly(img_mask, [exteriors[0]], vals[0])
  for i in range(len(exteriors)):
      cv2.fillPoly(img_mask, [exteriors[i]], vals[i])
  for i in range(len(interiors)):
      cv2.fillPoly(img_mask, [interiors[i]], 0)
  print(f"Mask: {img_mask} {np.all(img_mask == 0),}")
  return img_mask

In [None]:
from itertools import islice

first_five = islice(shapes_df_dict.items(), 2)

In [None]:
masks = [
    new_mask_for_polygons(
        v['geometry'].tolist(),
        img.shape[-2:],
        # Add 1 here in case val is 0 (background)
        vals=(v.index.to_numpy() +1).tolist()
    )
    # https://stackoverflow.com/questions/60484383/typeerror-scalar-value-for-argument-color-is-not-numeric-when-using-opencv
    for k, v in first_five
]


In [None]:
masks

In [None]:
print(
  np.all(masks[0] == 0),
  np.all(masks[1] == 0)
)

In [None]:
masks = np.stack(masks)
mask_bg = (np.sum(masks, axis=0) == 0)*1.
mask = np.concatenate((masks, np.expand_dims(mask_bg, axis=0)))
mask

In [None]:
print(
  np.all(masks[0] == 0),
  np.all(masks[1] == 0),
  np.all(masks[2] == 0)
)

In [None]:
masks.shape

In [None]:
img_mask = xu.sdata_load_img_mask(merged, affineT=affineT, img_key='he', expand_px=3)
img_mask

In [None]:
imgc = sp_im.ImageContainer(img_mask)
gen = imgc.generate_equal_crops(size=256, as_array='image', squeeze=True)

In [None]:
for i, tile in enumerate(tqdm.tqdm(gen)):
  if i == 311 or i == 312:
    mask = tile[:,:,3:-1]
    # mask = np.moveaxis(tile[:,:,3:], 2, 0)
    print(mask)
    image = Image.fromarray(tile[:,:,:3].astype(np.uint8))

In [None]:
mask

In [None]:
plt.imshow(plt.imread(image_filename))
plt.axis('off')  # Optional: Remove the axis
plt.show()

In [None]:
mask = np.load(str(mask_filename))
np.unique(mask)
mask


In [None]:
plt.imshow(mask.transpose(1,2,0))

In [None]:
mask.transpose(1,2,0).shape

In [None]:
nucleus_mask = (mask[-1] == 0).astype(np.uint8) # invert the bg mask
filter_mask = (hn.remove_small_objs(nucleus_mask, obj_threshold) != 0)
filter_mask_bg = (filter_mask == 0)
mask[:-1] = np.multiply(filter_mask, mask[:-1])
mask[-1] = np.multiply(filter_mask_bg, mask[-1])

In [None]:
mask.shape