In [None]:
import sys
import stlearn as st
st.settings.set_figure_params(dpi=300)
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import sys
file = Path("../stimage").resolve()
parent= file.parent
sys.path.append(str(parent))
from PIL import Image
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling
from stimage._model import CNN_NB_multiple_genes
from stimage._data_generator import DataGenerator
import tensorflow as tf
import seaborn as sns
sns.set_style("white")
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
import anndata

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from typing import Optional, Union
from anndata import AnnData


def BB_plot(
    adata: AnnData,
    library_id: str = None,
    gene: str = None,
    name: str = None,
    data_alpha: float = 0.8,
    tissue_alpha: float = 1.0,
    cmap: str = "Spectral_r",
    spot_size: tuple = (3, 20),
    ob1: str = None,
    ob2: str = None,
    show_color_bar: bool = True,
    show_size_legend: bool = True,
    show_axis: bool = False,
    cropped: bool = True,
    margin: int = 100,
    dpi: int = 150,
    output: str = None,
) -> Optional[AnnData]:
    """\
        QC plot for sptial transcriptomics data.

        Parameters
        ----------
        adata
            Annotated data matrix.
        library_id
            Library id stored in AnnData.
        data_alpha
            Opacity of the spot.
        tissue_alpha
            Opacity of the tissue.
        cmap
            Color map to use.
        spot_size
            Size of the spot (min, max).
        show_color_bar
            Show color bar or not.
        show_axis
            Show axis or not.
        show_size_legend
            Show size legend or not.
        name
            Name of the output figure file.
        output
            Save the figure as file or not.
        copy
            Return a copy instead of writing to adata.
        Returns
        -------
        Nothing
        """

    imagecol = adata.obs["imagecol"]
    imagerow = adata.obs["imagerow"]
    from sklearn.preprocessing import MinMaxScaler
    
    
    ob_1 = pd.DataFrame(adata.obsm[ob1], columns=adata.var_names)[gene]
    scaler = MinMaxScaler(feature_range=spot_size)
    ob1_size = scaler.fit_transform(ob_1.to_numpy().reshape(-1, 1))
    ob_2 = pd.DataFrame(adata.obsm[ob2], columns=adata.var_names)[gene]
    
    # plt.rcParams['figure.dpi'] = dpi

    # Option for turning off showing figure
    plt.ioff()

    # Initialize matplotlib
    fig, a = plt.subplots()

    vmin = 0
    vmax = 8
    # Plot scatter plot based on pixel of spots
    plot = a.scatter(
        adata.obs["imagecol"],
        adata.obs["imagerow"],
        edgecolor="none",
        alpha=data_alpha,
        s=ob1_size,
        marker="o",
        vmin=vmin,
        vmax=vmax,
        cmap=plt.get_cmap(cmap),
        c=ob_2,
    )

    if show_color_bar:
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
        axins = inset_axes(a,
                   width="100%",
                   height="100%",
                   loc='upper left',
                   bbox_to_anchor=(1.0, 0.73, 0.05, 0.35),
                   bbox_transform=a.transAxes,
                   borderpad=4.3,
                   )
        cb = plt.colorbar(plot, cax=axins)
        cb.ax.set_xlabel(ob2, fontsize=10)
        cb.ax.xaxis.set_label_coords(0.98, 1.20)
        cb.outline.set_visible(False)

    if show_size_legend:
        size_min, size_max = spot_size
        markers = [size_min, size_min + 1 / 3 * (size_max - size_min),
                   size_min + 2 / 3 * (size_max - size_min), size_max]
        legend_markers = [plt.scatter([], [], s=i, c="grey") for i in markers]
        labels = [str(float((scaler.inverse_transform(np.array(i).reshape(1, 1)))))
                  for i in markers]
        a.legend(handles=legend_markers, labels=labels, loc='lower left', bbox_to_anchor=(1, 0.05),
                   scatterpoints=1, frameon=False, title=ob1)

    if not show_axis:
        a.axis("off")
    if library_id is None:
        library_id = list(adata.uns["spatial"].keys())[0]

    image = adata.uns["spatial"][library_id]["images"][
        adata.uns["spatial"][library_id]["use_quality"]
    ]
    # Overlay the tissue image
    a.imshow(
        image,
        alpha=tissue_alpha,
        zorder=-1,
    )

    if cropped:
        a.set_xlim(imagecol.min() - margin, imagecol.max() + margin)

        a.set_ylim(imagerow.min() - margin, imagerow.max() + margin)

        a.set_ylim(a.get_ylim()[::-1])
        # plt.gca().invert_yaxis()

    # fig.tight_layout()
    if output is not None:
        fig.savefig(output + "/" + name, dpi=dpi, bbox_inches="tight", pad_inches=0)

    plt.show()




In [None]:
from scipy import stats

def plot_correlation(df, attr_1, attr_2):
    r = stats.pearsonr(df[attr_1], 
                       df[attr_2])[0] **2

    g = sns.lmplot(data=df,
        x=attr_1, y=attr_2,
        height=5, legend=True
    )
    # g.set(ylim=(0, 360), xlim=(0,360))

    g.set_axis_labels(attr_1, attr_2)
    plt.annotate(r'$R^2:{0:.2f}$'.format(r),
                (max(df[attr_1])*0.9, max(df[attr_2])*0.9))
    return g

In [None]:
BASE_PATH = Path("/clusterdata/uqxtan9/Xiao/STimage/dataset/breast_cancer_10x_visium")
TILE_PATH = Path("/tmp") / "tiles"
TILE_PATH.mkdir(parents=True, exist_ok=True)

SAMPLE = "block1"
Sample1 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif", 0)
Sample1.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "block2"
Sample2 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_2_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_2_image.tif", 0)
Sample2.uns["spatial"][SAMPLE]['images']["fulres"] = img

In [None]:
Sample1

In [None]:
Sample2

In [None]:
gene_list=["SLITRK6", "PGM5", "LINC00645", 
           "TTLL12", "COX6C", "CPB1",
           "KRT5", "MALAT1"]
gene_list

In [None]:
for adata in [
    Sample1,
    Sample2,
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
    st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
    st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image
    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299)

In [None]:
n_genes = len(gene_list)
training_index = Sample1.obs.sample(frac=0.7, random_state=1).index
# training_dataset = Sample1[training_index,].copy()

training_dataset = Sample1.copy()

valid_index = Sample1.obs.index.isin(training_index)
valid_dataset = Sample1[~valid_index,].copy()

test_dataset = Sample2.copy()

train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
train_gen_ = train_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
valid_gen_ = valid_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
test_gen_ = test_gen.batch(1)

In [None]:
model = CNN_NB_multiple_genes((299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
train_history = model.fit(train_gen_,
                          epochs=20,
                          validation_data=valid_gen_,
                          callbacks=[callback]
                          )

In [None]:
model.load_weights("./final.h5")

In [None]:
test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
y_preds_std = []
for i in range(n_genes):
    n = test_predictions[i][:, 0]
    p = test_predictions[i][:, 1]
    y_pred = nbinom.mean(n, p)
    y_pred_std = nbinom.std(n, p)
    y_preds.append(y_pred)
    y_preds_std.append(y_pred_std)
test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()
test_dataset.obsm["predicted_gene_std"] = np.array(y_preds_std).transpose()

In [None]:
test_dataset_ = test_dataset[:,gene_list].copy()
test_dataset_.X = test_dataset_.obsm["predicted_gene"]

In [None]:
for i in gene_list:
    print(i)
    BB_plot(test_dataset_, gene=i, ob2="predicted_gene", ob1="predicted_gene_std")

In [None]:
BB_plot(test_dataset_, gene="COX6C", ob2="predicted_gene", ob1="predicted_gene_std")

In [None]:
BB_plot(test_dataset_, gene="KRT5", ob2="predicted_gene", ob1="predicted_gene_std")

In [None]:
test_dataset_.uns["spatial"]['block2'].keys()

In [None]:
pd.DataFrame(test_dataset_.obsm["predicted_gene"], columns=test_dataset_.var_names)

In [None]:
for i in gene_list:
    print(i)
    gene_plot(test_dataset_, genes=i, spot_size=8)

In [None]:
gene_plot(test_dataset, genes="COX6C", spot_size=8)

In [None]:
gene_plot(test_dataset_, genes="KRT5", spot_size=8, vmax=1, vmin=0)

In [None]:
for i in gene_list:
    print(i)
    gene_plot(test_dataset, genes=i, spot_size=8)

In [None]:
gene_list[6]

In [None]:
def model_predict_gene(gene):
    i = gene_list.index(gene)
    from scipy.stats import nbinom
    def model_predict(x):
        test_predictions = model.predict(x)
        n = test_predictions[i][:, 0]
        p = test_predictions[i][:, 1]
        y_pred = nbinom.mean(n, p)
        return y_pred.reshape(-1,1)
    return model_predict

In [None]:
import matplotlib.pyplot as plt
from libpysal.weights.contiguity import Queen
from libpysal import examples
import numpy as np
import pandas as pd
import geopandas as gpd
import os
import splot
from splot.esda import moran_scatterplot, lisa_cluster
from esda.moran import Moran, Moran_Local
from esda.moran import Moran_BV, Moran_Local_BV
from splot.esda import plot_moran_bv_simulation, plot_moran_bv, plot_local_autocorrelation

In [None]:
test_dataset = anndata.read_h5ad("")

In [None]:
df = pd.DataFrame()
moran_list=[]
cor_list=[]
for gene in gene_list:
    test_dataset_.obsm["gpd"] = gpd.GeoDataFrame(test_dataset_.obs,
                                             geometry=gpd.points_from_xy(
                                                 test_dataset_.obs.imagecol, 
                                                 test_dataset_.obs.imagerow))
    x = test_dataset.to_df()[gene].values
    y = test_dataset_.to_df()[gene].values
    w = Queen.from_dataframe(test_dataset_.obsm["gpd"])
    
    test_dataset_.obsm["gpd"]["gc_{}".format(gene)] = x
    test_dataset_.obsm["gpd"]["pred_{}".format(gene)] = y
    tissue_image = test_dataset_.uns["spatial"]["block2"]["images"]["fulres"]
    
    moran = Moran(y,w)
    moran_bv = Moran_BV(y, x, w)
    moran_loc = Moran_Local(y, w)
    moran_loc_bv = Moran_Local_BV(y, x, w)
    moran_list.append(moran_bv.I)
    
    r = stats.pearsonr(x, 
                       y)[0] **2
    cor_list.append(r)
    

In [None]:
cor_list

In [None]:
moran_list

In [None]:
df = pd.DataFrame([cor_list, moran_list, gene_list]).transpose()

In [None]:
df.columns = ["Pearson_corr", "Moran_I", "genes"]

In [None]:
df_ = pd.melt(df, id_vars=['genes'], value_vars=['Pearson_corr', 'Moran_I'])

In [None]:
df_["tile_size"] = 900

In [None]:
# df_all = df_

In [None]:
df_all = df_all.append(df_)

In [None]:
df_all

In [None]:
import seaborn as sns

In [None]:
sns.boxplot(x="tile_size", y="value",
            hue="variable", #palette=["m", "g"],
            data=df_all)
sns.despine(offset=10, trim=True)
plt.show()

In [None]:
gene = gene_list[4]

In [None]:
plt.imread()

In [None]:
plt.imshow(plt.imread(test_dataset.obs["tile_path"][30]))
plt.show()

In [None]:
gene

In [None]:
test_dataset_.obsm["gpd"] = gpd.GeoDataFrame(test_dataset_.obs,
                                             geometry=gpd.points_from_xy(
                                                 test_dataset_.obs.imagecol, 
                                                 test_dataset_.obs.imagerow))


In [None]:
test_dataset_.obsm["gpd"]

In [None]:
x = test_dataset.to_df()[gene].values
y = test_dataset_.to_df()[gene].values
w = Queen.from_dataframe(test_dataset_.obsm["gpd"])

In [None]:
test_dataset_.obsm["gpd"]["gc_{}".format(gene)] = x
test_dataset_.obsm["gpd"]["pred_{}".format(gene)] = y
tissue_image = test_dataset_.uns["spatial"]["block2"]["images"]["fulres"]

In [None]:
moran = Moran(y,w)
moran_bv = Moran_BV(y, x, w)
moran_loc = Moran_Local(y, w)
moran_loc_bv = Moran_Local_BV(y, x, w)

In [None]:
plot_correlation(test_dataset_.obsm["gpd"],
                 "pred_{}".format(gene),
                 "gc_{}".format(gene))
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_bv, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_loc_bv, p=0.05, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
def plot_choropleth(gdf, 
                    attribute_1,
                    attribute_2,
                    bg_img,
                    alpha=0.5,
                    scheme='Quantiles', 
                    cmap='YlGnBu', 
                    legend=True):
    
    fig, axs = plt.subplots(2,1, figsize=(5, 8),
                            subplot_kw={'adjustable':'datalim'})
    
    # Choropleth for attribute_1
    gdf.plot(column=attribute_1, scheme=scheme, cmap=cmap,
             legend=legend, legend_kwds={'loc': 'upper left',
                                         'bbox_to_anchor': (0.92, 0.8)},
             ax=axs[0], alpha=alpha, markersize=2)
    
    axs[0].imshow(bg_img)
    axs[0].set_title('choropleth plot for {}'.format(attribute_1), y=0.8)
    axs[0].set_axis_off()
    
    # Choropleth for attribute_2
    gdf.plot(column=attribute_2, scheme=scheme, cmap=cmap,
             legend=legend, legend_kwds={'loc': 'upper left',
                                         'bbox_to_anchor': (0.92, 0.8)},
             ax=axs[1], alpha=alpha, markersize=2)
    
    axs[1].imshow(bg_img)
    axs[1].set_title('choropleth plot for {}'.format(attribute_2), y=0.8)
    axs[1].set_axis_off()
    
    plt.tight_layout()
    
    return fig, ax 

In [None]:
plot_choropleth(test_dataset_.obsm["gpd"], 
                "gc_{}".format(gene),
                "pred_{}".format(gene),
                tissue_image)
plt.show()

In [None]:
lisa_cluster(moran_loc_bv, test_dataset_.obsm["gpd"], p=0.05, 
             figsize = (9,9), markersize=12, **{"alpha":0.8})
plt.imshow(test_dataset_.uns["spatial"]["block2"]["images"]["fulres"])
plt.show()

In [None]:
moran_bv.I

In [None]:
import skimage
from skimage.color import rgb2hed
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
from skimage.measure import label
import scipy as sp
from scipy import ndimage as ndi
from skimage.morphology import area_opening
import math
from sklearn.linear_model import LinearRegression, SGDRegressor
from sklearn.preprocessing import Binarizer

In [None]:
from lime import lime_image



In [None]:
def watershed_segment(image):
    annotation_hed = rgb2hed(image)
    annotation_h = annotation_hed[:,:,0]
    annotation_h *= 255.0 / np.percentile(annotation_h, q=80)
#     annotation_h = np.clip(annotation_h, a_min=0, a_max=255)
    thresh = skimage.filters.threshold_otsu(annotation_h)
    im_fgnd_mask = sp.ndimage.morphology.binary_fill_holes(
        annotation_h < thresh
    )
    distance = ndi.distance_transform_edt(im_fgnd_mask)
    coords = peak_local_max(distance, footprint=np.ones((5, 5)), labels=im_fgnd_mask)
    mask = np.zeros(distance.shape, dtype=bool)
    mask[tuple(coords.T)] = True
    markers, _ = ndi.label(mask)
    labels = watershed(annotation_h, markers, mask=im_fgnd_mask)
    im_nuclei_seg_mask = area_opening(labels, area_threshold=64).astype(np.int)
    map_dic = dict(zip(np.unique(im_nuclei_seg_mask), np.arange(len(np.unique(im_nuclei_seg_mask)))))
    im_nuclei_seg_mask = np.vectorize(map_dic.get)(im_nuclei_seg_mask)
    return im_nuclei_seg_mask




In [None]:
def LIME_plot(image, label_, gene1, model_predict_gene, gene_list):
    gene_i = gene_list.index(gene1)
    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(image.numpy().astype("double"), 
                                             model_predict_gene(gene1), 
                                             top_labels=1, num_samples=1000,
                                             hide_color=0,
#                                              num_features=10000,
#                                              model_regressor=SGDRegressor(),
                                             segmentation_fn=watershed_segment)
    dict_heatmap1 = dict(explanation.local_exp[explanation.top_labels[0]])
    heatmap1 = np.vectorize(dict_heatmap1.get)(explanation.segments)
#     heatmap1 = (heatmap1 - heatmap1.min()) / heatmap1.max()
    prediction = model_predict_gene(gene1)(np.expand_dims(image.numpy(), 0)).ravel()
    print("{}:".format(gene1))
    print("prediction: " + str(prediction))
    print("ground truth: " + str(label_[gene_i].numpy().ravel()))
    fig = plt.figure(figsize=(15,12))
    plt.subplot(1, 3, 1)
    plt.imshow(image.numpy().astype(int))
    plt.subplot(1, 3, 2)
    plt.imshow(heatmap1, alpha=1, cmap = 'RdYlBu_r', vmax=heatmap1.max(), vmin=-heatmap1.max())
    plt.colorbar(shrink=0.30)
    plt.subplot(1, 3, 3)
    plt.imshow(image.numpy().astype(int))
    plt.imshow(heatmap1, alpha=0.3, cmap = 'RdYlBu_r', vmax=heatmap1.max(), vmin=-heatmap1.max())
    plt.colorbar(shrink=0.30)
    plt.tight_layout()
    plt.show()

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)

In [None]:
for image_, label_ in test_gen.shuffle(2000).take(1):
    for gene in gene_list:
        LIME_plot(image_, label_, gene, model_predict_gene, gene_list)