In [None]:
import sys
import stlearn as st
st.settings.set_figure_params(dpi=300)
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import sys
file = Path("../stimage").resolve()
parent= file.parent
sys.path.append(str(parent))
from PIL import Image
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling
from stimage._model import CNN_NB_multiple_genes
from stimage._data_generator import DataGenerator
import tensorflow as tf
import seaborn as sns
sns.set_style("white")
from PIL import Image, ImageOps, ImageChops, ImageDraw
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
import time
from datetime import timedelta
import seaborn as sns

In [None]:
# ------------------------------------------------------------------------
# Tools for stain normalisation
# ------------------------------------------------------------------------

import numpy as np
import cv2 as cv
from PIL import Image
from staintools.preprocessing.input_validation import is_uint8_image
from staintools import ReinhardColorNormalizer, LuminosityStandardizer, StainNormalizer
from staintools.stain_extraction.macenko_stain_extractor import MacenkoStainExtractor
from staintools.stain_extraction.vahadane_stain_extractor import VahadaneStainExtractor
from staintools.miscellaneous.optical_density_conversion import convert_OD_to_RGB
from staintools.miscellaneous.get_concentrations import get_concentrations

class LuminosityStandardizerIterative(LuminosityStandardizer):
    """
    Transforms image to a standard brightness
    Modifies the luminosity channel such that a fixed percentile is saturated
    
    Standardiser can fit to source slide image and apply the same luminosity standardisation settings to all tiles generated
    from the source slide image
    """
    def __init__(self):
        super().__init__()
        self.p = None

    def fit(self, I, percentile = 95):
        assert is_uint8_image(I), "Image should be RGB uint8."
        I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB)
        L_float = I_LAB[:, :, 0].astype(float)
        self.p = np.percentile(L_float, percentile)

    def standardize_tile(self, I):
        I_LAB = cv.cvtColor(I, cv.COLOR_RGB2LAB)
        L_float = I_LAB[:, :, 0].astype(float)
        I_LAB[:, :, 0] = np.clip(255 * L_float / self.p, 0, 255).astype(np.uint8)
        I = cv.cvtColor(I_LAB, cv.COLOR_LAB2RGB)
        return I

class ReinhardColorNormalizerIterative(ReinhardColorNormalizer):
    """
    Normalise each tile from a slide to a target slide using the method of:
    E. Reinhard, M. Adhikhmin, B. Gooch, and P. Shirley,
    'Color transfer between images'
    Normaliser can fit to source slide image and apply the same normalisation settings to all tiles generated from the
    source slide image
    Attributes
    ----------
    target_means : tuple float
        means pixel value for each channel in target image
    target_stds : tuple float
        standard deviation of pixel values for each channel in target image
    source_means : tuple float
        mean pixel value for each channel in source image
    source_stds : tuple float
        standard deviation of pixel values for each channel in source image
    Methods
    -------
    fit_target(target)
        Fit normaliser to target image
    fit_source(source)
        Fit normaliser to source image
    transform(I)
        Transform an image to normalise it to the target image
    transform_tile(I)
        Transform a tile using precomputed parameters that normalise the source slide image to the target slide image
    lab_split(I)
        Convert from RGB unint8 to LAB and split into channels
    merge_back(I1, I2, I3)
        Take separate LAB channels and merge back to give RGB uint8
    get_mean_std(I)
        Get mean and standard deviation of each channel
    """
    def __init__(self):
        super().__init__()
        self.source_means = None
        self.source_stds = None

    def fit_target(self, target):
        """Fit to a target image
        Parameters
        ----------
        target : Image RGB uint8
        Returns
        -------
        None
        """
        means, stds = self.get_mean_std(target)
        self.target_means = means
        self.target_stds = stds

    def fit_source(self, source):
        """Fit to a source image
        Parameters
        ----------
        source : Image RGB uint8
        Returns
        -------
        None
        """
        means, stds = self.get_mean_std(source)
        self.source_means = means
        self.source_stds = stds

    def transform_tile(self, I):
        """Transform a tile using precomputed parameters that normalise the source slide image to the target slide image
        Parameters
        ----------
        I : Image RGB uint8
        Returns
        -------
        transformed_tile : Image RGB uint8
        """
        I1, I2, I3 = self.lab_split(I)
        norm1 = ((I1 - self.source_means[0]) * (self.target_stds[0] / self.source_stds[0])) + self.target_means[0]
        norm2 = ((I2 - self.source_means[1]) * (self.target_stds[1] / self.source_stds[1])) + self.target_means[1]
        norm3 = ((I3 - self.source_means[2]) * (self.target_stds[2] / self.source_stds[2])) + self.target_means[2]
        return self.merge_back(norm1, norm2, norm3)

class StainNormalizerIterative(StainNormalizer):
    """Normalise each tile from a slide to a target slide using the Macenko or Vahadane method
    """
    def __init__(self, method):
        super().__init__(method)
        self.maxC_source = None

    def fit_target(self, I):
        self.fit(I)

    def fit_source(self, I):
        self.stain_matrix_source = self.extractor.get_stain_matrix(I)
        source_concentrations = get_concentrations(I, self.stain_matrix_source)
        self.maxC_source = np.percentile(source_concentrations, 99, axis=0).reshape((1, 2))

    def transform_tile(self, I):
        source_concentrations = get_concentrations(I, self.stain_matrix_source)
        source_concentrations *= (self.maxC_target / self.maxC_source)
        tmp = 255 * np.exp(-1 * np.dot(source_concentrations, self.stain_matrix_target))
        return tmp.reshape(I.shape).astype(np.uint8)

class IterativeNormaliser:
    """Iterative normalise each tile from a slide to a target using a selectable method
    Normalisation methods include: 'none', 'reinhard', 'macenko' and 'vahadane'
    Luminosity standardisation is also selectable
    """
    def __init__(self, normalisation_method = 'vahadane', standardise_luminosity = True):
        self.method = normalisation_method
        self.standardise_luminosity = standardise_luminosity
        # Instantiate normaliser and luminosity standardiser
        if normalisation_method == 'none':
            pass
        elif normalisation_method == 'reinhard':
            self.normaliser = ReinhardColorNormalizerIterative()
        elif normalisation_method == 'macenko' or normalisation_method == 'vahadane':
            self.normaliser = StainNormalizerIterative(normalisation_method)
        if standardise_luminosity:
            self.lum_std = LuminosityStandardizerIterative()

    def fit_target(self, target_img):
        if self.standardise_luminosity:
            self.target_std = self.lum_std.standardize(np.array(target_img))
        else:
            self.target_std = np.array(target_img)
        if self.method != 'none':
            self.normaliser.fit_target(self.target_std)

    def fit_source(self, source_img):
        if self.standardise_luminosity:
            self.lum_std.fit(np.array(source_img))
            source_std = self.lum_std.standardize_tile(np.array(source_img))
        else:
            source_std = np.array(source_img)
        if self.method != 'none':
            self.normaliser.fit_source(source_std)

    def transform_tile(self, tile_img):
        if self.standardise_luminosity:
            tile_std = self.lum_std.standardize_tile(np.array(tile_img))
        else:
            tile_std = np.array(tile_img)
        if self.method != 'none':
            tile_norm = self.normaliser.transform_tile(tile_std)
        else:
            tile_norm = tile_std
        return Image.fromarray(tile_norm)

In [None]:
def thumbnail(img, size = (1000,1000)):
    """Converts Pillow images to a different size without modifying the original image
    """
    img_thumbnail = img.copy()
    img_thumbnail.thumbnail(size)
    return img_thumbnail

def scale_img(img, scale_f=10):
    return img.resize((img.size[0]//scale_f, img.size[1]//scale_f))

def tissue_mask_grabcut(img):
    img_cv = img[:, :, ::-1]   #Convert RGB to BGR
    mask_initial = (np.array(Image.fromarray(img).convert('L')) <250).astype(np.uint8)

    
    # Grabcut
    bgdModel = np.zeros((1,65),np.float64)
    fgdModel = np.zeros((1,65),np.float64)
    cv.grabCut(img_cv, mask_initial, None, bgdModel, fgdModel, 5, cv.GC_INIT_WITH_MASK)
    mask_final = np.where((mask_initial==2)|(mask_initial==0),0,1).astype('uint8')
    
    # Generate a rough 'filled in' mask of the tissue
    kernal_64 = cv.getStructuringElement(cv.MORPH_ELLIPSE, (64,64))
    mask_closed = cv.morphologyEx(mask_final, cv.MORPH_CLOSE, kernal_64)
    mask_opened = cv.morphologyEx(mask_closed, cv.MORPH_OPEN, kernal_64)
    
    # Use rough mask to remove small debris in grabcut mask
    mask_cleaned = cv.bitwise_and(mask_final, mask_final, mask = mask_opened)
    mask_cleaned_pil = Image.fromarray(mask_cleaned.astype(np.bool))
    return mask_cleaned_pil

def filter_green(img, g_thresh = 240):
    """Replaces green pixels greater than threshold with white pixels
    
    Used to remove background from tissue images
    """
    img = img.convert('RGB')
    r, g, b = img.split()
    green_mask = (np.array(g) > 240)*255
    green_mask_img = Image.fromarray(green_mask.astype(np.uint8), 'L')
    white_image = Image.new('RGB', img.size, (255,255,255))
    img_filtered = img.copy()
    img_filtered.paste(white_image, mask = green_mask_img)
    return img_filtered

def filter_grays(img, tolerance = 3):
    """Replaces gray pixels greater than threshold with white pixels
    
    Used to remove background from tissue images
    """
    img = img.convert('RGB')
    r, g, b = img.split()
    rg_diff = np.array(ImageChops.difference(r,g)) <= tolerance
    rb_diff = np.array(ImageChops.difference(r,b)) <= tolerance
    gb_diff = np.array(ImageChops.difference(g,b)) <= tolerance
    grays = (rg_diff & rb_diff & gb_diff)*255
    grays_mask = Image.fromarray(grays.astype(np.uint8), 'L')
    white_image = Image.new('RGB', img.size, (255,255,255))
    img_filtered = img.copy()
    img_filtered.paste(white_image, mask = grays_mask)
    return img_filtered

In [None]:
from typing import Optional, Union
from anndata import AnnData
from pathlib import Path

# Test progress bar
from tqdm import tqdm
import numpy as np
import os


def tiling(
        adata: AnnData,
        out_path: Union[Path, str] = "./tiling",
        library_id: str = None,
        crop_size: int = 40,
        target_size: int = 299,
        stain_normaliser = None,
        image_select = "HE",
        verbose: bool = False,
        copy: bool = False,
        save_name: str = "tile_path",
) -> Optional[AnnData]:
    """\
    Tiling H&E images to small tiles based on spot spatial location

    Parameters
    ----------
    adata
        Annotated data matrix.
    out_path
        Path to save spot image tiles
    library_id
        Library id stored in AnnData.
    crop_size
        Size of tiles
    verbose
        Verbose output
    copy
        Return a copy instead of writing to adata.
    target_size
        Input size for convolutional neuron network
    Returns
    -------
    Depending on `copy`, returns or updates `adata` with the following fields.
    **tile_path** : `adata.obs` field
        Saved path for each spot image tiles
    """

    if library_id is None:
        library_id = list(adata.uns["spatial"].keys())[0]

    # Check the exist of out_path
    if not os.path.isdir(out_path):
        os.mkdir(out_path)
    if image_select == "HE":
        image = adata.uns["spatial"][library_id]["images"][adata.uns["spatial"][library_id]["use_quality"]]
    else:
        image = image_select
    if image.dtype == np.float32 or image.dtype == np.float64:
        image = (image * 255).astype(np.uint8)
    img_pillow = Image.fromarray(image)
    
    if stain_normaliser:
        stain_normaliser.fit_source(scale_img(img_pillow))
    tile_names = []

    with tqdm(
            total=len(adata),
            desc="Tiling image",
            bar_format="{l_bar}{bar} [ time left: {remaining} ]",
    ) as pbar:
        for imagerow, imagecol in zip(adata.obs["imagerow"], adata.obs["imagecol"]):
            
            imagerow_down = imagerow - crop_size / 2
            imagerow_up = imagerow + crop_size / 2
            imagecol_left = imagecol - crop_size / 2
            imagecol_right = imagecol + crop_size / 2
            tile = img_pillow.crop(
                (imagecol_left, imagerow_down, imagecol_right, imagerow_up)
            )
            if stain_normaliser:
                tile = stain_normaliser.transform_tile(tile)
            # tile.thumbnail((target_size, target_size), Image.ANTIALIAS)
            tile = tile.resize((target_size, target_size))
            tile_name = library_id + "-" + str(imagecol) + "-" + str(imagerow) + "-" + str(crop_size)
            out_tile = Path(out_path) / (tile_name + ".jpeg")
            tile_names.append(str(out_tile))
            if verbose:
                print(
                    "generate tile at location ({}, {})".format(
                        str(imagecol), str(imagerow)
                    )
                )
            tile.save(out_tile, "JPEG")

            pbar.update(1)

    adata.obs[save_name] = tile_names
    return adata if copy else None


def calculate_bg(
    adata: AnnData,
    copy: bool = False,
) -> Optional[AnnData]:
    tissue_area_list = []
    for img_path in adata.obs["tile_tissue_mask_path"]:
        tile_mask = plt.imread(img_path, 0)
        tissue_area = (tile_mask > 200).sum() / tile_mask.size
        tissue_area_list.append(tissue_area)
    adata.obs["tissue_area"] = np.array(tissue_area_list)
    return adata if copy else None

In [None]:
import warnings
from typing import Optional, Union
from anndata import AnnData
from matplotlib import pyplot as plt
# from .utils import get_img_from_fig, checkType


def tissue_area_plot(
        adata: AnnData,
        threshold: float = None,
        library_id: str = None,
        data_alpha: float = 1.0,
        tissue_alpha: float = 1.0,
        vmin: float = None,
        vmax: float = None,
        cmap: str = "Spectral_r",
        spot_size: Union[float, int] = 6.5,
        show_legend: bool = False,
        show_color_bar: bool = True,
        show_axis: bool = False,
        cropped: bool = True,
        margin: int = 100,
        name: str = None,
        output: str = None,
        copy: bool = False,
) -> Optional[AnnData]:
    
    colors = adata.obs["tissue_area"]

    if threshold is not None:
        colors = colors[colors > threshold]

    index_filter = colors.index

    filter_obs = adata.obs.loc[index_filter]

    imagecol = filter_obs["imagecol"]
    imagerow = filter_obs["imagerow"]

    # Option for turning off showing figure
    plt.ioff()

    # Initialize matplotlib
    fig, a = plt.subplots()
    if vmin:
        vmin = vmin
    else:
        vmin = min(colors)
    if vmax:
        vmax = vmax
    else:
        vmax = max(colors)
    # Plot scatter plot based on pixel of spots
    plot = a.scatter(imagecol, imagerow, edgecolor="none", alpha=data_alpha, s=spot_size, marker="o",
                     vmin=vmin, vmax=vmax, cmap=plt.get_cmap(cmap), c=colors)
    plot.set_clim(vmin=vmin, vmax=vmax)
    if show_color_bar:
        cb = plt.colorbar(plot, cax=fig.add_axes(
            [0.9, 0.3, 0.03, 0.38]), cmap=cmap)
        cb.outline.set_visible(False)

    if not show_axis:
        a.axis('off')

    if library_id is None:
        library_id = list(adata.uns["spatial"].keys())[0]

    image = adata.uns["spatial"][library_id]["images"][adata.uns["spatial"][library_id]["use_quality"]]
    # Overlay the tissue image
    a.imshow(image, alpha=tissue_alpha, zorder=-1, )

    if cropped:
        imagecol = adata.obs["imagecol"]
        imagerow = adata.obs["imagerow"]

        a.set_xlim(imagecol.min() - margin,
                   imagecol.max() + margin)

        a.set_ylim(imagerow.min() - margin,
                   imagerow.max() + margin)

        a.set_ylim(a.get_ylim()[::-1])

    if name is None:
        name = "tissue_area_plot"
    if output is not None:
        fig.savefig(output + "/" + name, dpi=plt.figure().dpi,
                    bbox_inches='tight', pad_inches=0)

    plt.show()



In [None]:
BASE_PATH = Path("/clusterdata/uqxtan9/Xiao/STimage/dataset/breast_cancer_10x_visium")
TILE_PATH = Path("/tmp") / "tiles"
TILE_PATH.mkdir(parents=True, exist_ok=True)

SAMPLE = "block1"
Sample1 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif", 0)
Sample1.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "block2"
Sample2 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_2_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_2_image.tif", 0)
Sample2.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "FFPE"
Sample3 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="Visium_FFPE_Human_Breast_Cancer_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"Visium_FFPE_Human_Breast_Cancer_image.tif", 0)
Sample3.uns["spatial"][SAMPLE]['images']["fulres"] = img

In [None]:
gene_list=["SLITRK6", "PGM5", "LINC00645", 
           "TTLL12", "COX6C", "CPB1",
           "KRT5", "MALAT1"]
gene_list

In [None]:
template_img = Sample1.uns["spatial"]["block1"]['images']["fulres"]
template_img = Image.fromarray(template_img.astype("uint8"))

In [None]:
start = time.time()

template_img = Sample1.uns["spatial"]["block1"]['images']["fulres"]
template_img = Image.fromarray(template_img.astype("uint8"))

normaliser = IterativeNormaliser(normalisation_method = 'vahadane', standardise_luminosity = True)
normaliser.fit_target(scale_img(template_img))

end = time.time()
print(time.strftime('%H:%M:%S', time.gmtime(end - start)))

In [None]:
thumbnail(template_img)

# standard tiling

In [None]:
start = time.time()

for adata in [
#     Sample1,
#     Sample2,
    Sample3
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
    st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
    st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image

    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299)

end = time.time()
print(time.strftime('%H:%M:%S', time.gmtime(end - start)))

# tiling + stain normalisation

In [None]:
Sample3_stain_norm = Sample3.copy()

In [None]:
start = time.time()

for adata in [
#     Sample1,
#     Sample2,
    Sample3_stain_norm
]:

    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299, stain_normaliser=normaliser)

end = time.time()
print(time.strftime('%H:%M:%S', time.gmtime(end - start)))

In [None]:
target_img = Image.fromarray(Sample3.uns["spatial"]["FFPE"]['images']["fulres"])

In [None]:
target_img_norm = normaliser.transform_tile(scale_img(target_img))

In [None]:
target_img_norm_filtered = filter_green(target_img_norm, g_thresh = 250)
target_img_norm_filtered = filter_grays(target_img_norm_filtered, tolerance=3)

In [None]:
tissue_mask = tissue_mask_grabcut(np.array(target_img_norm_filtered))

In [None]:
thumbnail(target_img)

In [None]:
target_img_norm

In [None]:
tissue_mask

In [None]:
tissue_mask_up_scale = tissue_mask.resize(target_img.size, Image.ANTIALIAS)

In [None]:
tissue_mask_up_scale

In [None]:
for adata in [
#     Sample1,
#     Sample2,
    Sample3_stain_norm
]:

    TILE_PATH_ = TILE_PATH / (list(adata.uns["spatial"].keys())[0] + "_tissue_mask")
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299, image_select=np.array(tissue_mask_up_scale), 
           save_name="tile_tissue_mask_path")

In [None]:
f, axarr = plt.subplots(10,2, figsize=(5, 15))
for i in range(10):
    norm_tile = plt.imread(Sample3_stain_norm.obs["tile_path"][i], 0)
    tile_mask = plt.imread(Sample3_stain_norm.obs["tile_tissue_mask_path"][i], 0)
    axarr[i,0].imshow(norm_tile)
    axarr[i,1].imshow(tile_mask, cmap="gray_r")
plt.axis('off')
plt.tight_layout()

In [None]:
calculate_bg(Sample3_stain_norm)

In [None]:
Sample3_stain_norm.obs

In [None]:
tissue_area_plot(Sample3_stain_norm, vmin=0, vmax=1)

In [None]:
sns.histplot(Sample3_stain_norm.obs, x="tissue_area")
plt.axvline(Sample3_stain_norm.obs["tissue_area"].quantile(0.15), color='r')