In [21]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import json
import geopandas as gpd
from shapely.geometry import Polygon
import ast
import spatialdata as sd
import spatialdata_io as sdio
import spatialdata_plot as sdplot
import sopa
from anndata import read_h5ad
from matplotlib import pyplot as plt
import rasterio
from rasterio.features import rasterize
from spatialdata import transform
import anndata as ad
import cv2
import random
from squidpy.im import ImageContainer 
from spatialdata.transformations import (
    Affine,
    Identity,
    MapAxis,
    Scale,
    Sequence,
    Translation,
    get_transformation,
    get_transformation_between_coordinate_systems,
    set_transformation,
)
import gc
from matplotlib.colors import TABLEAU_COLORS

In [2]:
import subprocess
from pathlib import Path
from datetime import datetime
import os
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
PROJECT = "Pete"
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")
system = subprocess.check_output(["hostname", "-s"]).decode("utf-8").strip()
BASE_PATH_ = Path()
if "bun" in system:
    BASE_PATH_ = Path("/QRISdata/Q1851/Xiao/")
elif "imb-quan-gpu" in system:
    BASE_PATH_ = Path("/home/uqxtan9/Q1851/Xiao/")
elif "gpunode" in system:
    BASE_PATH_ = Path("/scratch/imb/Xiao/Q1851/Xiao/")


TMP_PATH = Path("/scratch/temp/") / os.environ["SLURM_JOB_ID"]
SCRATCH_PATH = Path("/scratch/project_mnt/S0010/Xiao")
BASE_PATH = BASE_PATH_ / "Working_project" / PROJECT
DATA_PATH = BASE_PATH / "DATA/"
PROCESSED = BASE_PATH / "PROCESSED"
PROCESSED.mkdir(exist_ok=True, parents=True)
OUT_PATH = BASE_PATH / "OUT"
OUT_PATH.mkdir(exist_ok=True, parents=True)
QC_PATH = OUT_PATH / "QC"
QC_PATH.mkdir(exist_ok=True, parents=True)
CELL_TYPE_PATH = OUT_PATH / "CELL_TYPE"
CELL_TYPE_PATH.mkdir(exist_ok=True, parents=True)
NICHE_PATH = OUT_PATH / "NICHE"
NICHE_PATH.mkdir(exist_ok=True, parents=True)
CCI_PATH = OUT_PATH / "CCI"
CCI_PATH.mkdir(exist_ok=True, parents=True)
PRED_PATH = OUT_PATH / "PRED"
PRED_PATH.mkdir(exist_ok=True, parents=True)

In [24]:
TMP_PATH

PosixPath('/scratch/temp/12548117')

In [3]:
!tar -xf {BASE_PATH_ / "Hovernet_results" / "hovernet_out_test_QMDL02.tar"} -C {TMP_PATH}

In [12]:
def segmentation_lines(mask_in):
    """
    Generate coords of points bordering segmentations from a given mask.
    Useful for plotting results of tissue detection or other segmentation.
    """
    assert (
        mask_in.dtype == np.uint8
    ), f"Input mask dtype {mask_in.dtype} must be np.uint8"
    kernel = np.ones((3, 3), np.uint8)
    dilated = cv2.dilate(mask_in, kernel)
    diff = np.logical_xor(dilated.astype(bool), mask_in.astype(bool))
    y, x = np.nonzero(diff)
    return x, y


def plot_segmentation(masks, ax, palette=None, markersize=5):
    """
    Plot segmentation contours. Supports multi-class masks.

    Args:
        ax: matplotlib axis
        masks (np.ndarray): Mask array of shape (n_masks, H, W). Zeroes are background pixels.
        palette: color palette to use. if None, defaults to matplotlib.colors.TABLEAU_COLORS
        markersize (int): Size of markers used on plot. Defaults to 5
    """
    assert masks.ndim == 3
    n_channels = masks.shape[0]

    if palette is None:
        palette = list(TABLEAU_COLORS.values())

    nucleus_labels = list(np.unique(masks))
    if 0 in nucleus_labels:
        nucleus_labels.remove(0)  # background
    # plot each individual nucleus
    for label in nucleus_labels:
        for i in range(n_channels):
            nuclei_mask = masks[i, ...] == label
            x, y = segmentation_lines(nuclei_mask.astype(np.uint8))
            ax.scatter(x, y, color=palette[i], marker=".", s=markersize)

In [4]:
def plot_confusion_matrix(cm, categories, sample_id, save_path):
    """Plot confusion matrix."""
    cm_normalized = cm / cm.sum(axis=1, keepdims=True) * 100
    annotations = np.empty_like(cm[1:5, 1:5]).astype(str)

    for i in range(annotations.shape[0]):
        for j in range(annotations.shape[1]):
            count = cm[1:5, 1:5][i, j]
            percentage = cm_normalized[1:5, 1:5][i, j]
            annotations[i, j] = f"{int(count)}\n({percentage:.1f}%)"

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm_normalized[1:5, 1:5], annot=annotations, fmt='',
                cmap='Blues', xticklabels=categories, yticklabels=categories)

    plt.title(f'Confusion Matrix ({sample_id})')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

    plt.savefig(save_path / f"cm_{sample_id}.png", dpi=300, bbox_inches='tight')
    plt.close()

In [5]:
sample_id_list = ["QMDL02",
                    "QMDL04",
                    "QMDL05",
                    "QMDL01",
                    "QMDL03",]
cell_type_mapping = {
    'Neoplastic': 'Cancer Epithelial',
    'Connective': 'Stromal',
    'Epithelial': 'Cancer Epithelial',  # or 'Stromal', depending on the context
    'Inflammatory': 'Immune',
    'Dead': None,  # No direct mapping
}

In [26]:
sample_id = sample_id_list[0]
result_path = SCRATCH_PATH / f"hovernet_out_test_{sample_id}"
img_true = np.load(result_path / 'truths_mapped.npy')
img_pred = np.load(result_path / 'preds_mapped.npy')

In [6]:
img_path = PROCESSED / "images" / f"{sample_id}.ome.tif"
df = pd.read_csv(PROCESSED / "CellViT" / f"df_{sample_id}.csv")
df['contour'] = df['contour'].apply(lambda x: ast.literal_eval(x))
df['geometry'] = df['contour'].apply(Polygon)
gdf = gpd.GeoDataFrame(df, geometry='geometry')

In [7]:
gdf["prediction_mapped"] = gdf["predicted_cell_type"].map(cell_type_mapping)

In [8]:
colors = list(TABLEAU_COLORS.values())

In [19]:
img = ImageContainer(img_path)

size = 256
size = img._get_size(size)
size = img._convert_to_pixel_space(size)

y, x = img.shape
ys, xs = size

unique_ycoord = np.arange(start=0, stop=(y // ys + (y % ys != 0)) * ys, step=ys)
unique_xcoord = np.arange(start=0, stop=(x // xs + (x % xs != 0)) * xs, step=xs)

ycoords = np.repeat(unique_ycoord, len(unique_xcoord))
xcoords = np.tile(unique_xcoord, len(unique_ycoord))

mapping = dict(enumerate(zip(ycoords, xcoords)))

In [10]:
# Get unique labels (excluding None/null values)
unique_labels = gdf['prediction_mapped'].dropna().unique()

# Initialize empty array to store all masks
masks = np.zeros((len(unique_labels), y, x), dtype=np.uint8)

# Create a mask for each label
for idx, label in enumerate(unique_labels):
    # Filter geodataframe for current label
    label_gdf = gdf[gdf['prediction_mapped'] == label]
    
    # Create mask for this label
    mask = rasterize(
        label_gdf.geometry,
        out_shape=(y, x),
        fill=0,
        default_value=1,
        dtype=np.uint8
    )
    
    masks[idx] = mask

print(f"Generated {len(unique_labels)} masks with shape {masks.shape}")

Generated 3 masks with shape (3, 29753, 26735)


In [31]:
size[0]

256

In [39]:
# Create a figure with 3 columns (one for each segmentation) and rows for each item
selected_keys = random.sample(list(mapping.keys()), 50)
selected_items = {key: mapping[key] for key in selected_keys}
crop_size = size[0]
PRED_TILE_PATH = PRED_PATH / "PRED_TILE"
PRED_TILE_PATH.mkdir(exist_ok=True, parents=True)
for idx, (key, value) in enumerate(selected_items.items()):
    crop_x = value[0]
    crop_y = value[1]
    crop = img.crop_corner(crop_x,crop_y, size=crop_size)
    mpatch = img_true[:-1,crop_x:crop_x+crop_size,crop_y:crop_y+crop_size]
    mpatch2 = img_pred[:-1,crop_x:crop_x+crop_size,crop_y:crop_y+crop_size]
    mpatch3 = masks[:,crop_x:crop_x+crop_size,crop_y:crop_y+crop_size]
    
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))

    # First column - STimage Classification
    crop.show(ax=axes[0])
    plot_segmentation(masks=mpatch2, ax=axes[0])
    axes[0].set_title('', fontsize=14)
    
    # Second column - CellViT
    crop.show(ax=axes[1]) 
    plot_segmentation(masks=mpatch3, ax=axes[1])
    axes[1].set_title('', fontsize=14)
    
    # Third column - Ground truth
    crop.show(ax=axes[2])
    plot_segmentation(masks=mpatch, ax=axes[2])
    axes[2].set_title('', fontsize=14)
    plt.tight_layout()
    
    fig.savefig(PRED_TILE_PATH /f"tile_plot_{idx}.png", dpi=300, bbox_inches='tight')
    plt.close()

In [38]:
PRED_TILE_PATH

PosixPath('/QRISdata/Q1851/Xiao/Working_project/Pete/OUT/PRED/PRED_TILE')