In [None]:
import sys
import stlearn as st
st.settings.set_figure_params(dpi=300)
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import sys
file = Path("../stimage").resolve()
parent= file.parent
sys.path.append(str(parent))
from PIL import Image
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling
from stimage._model import CNN_NB_multiple_genes
from stimage._data_generator import DataGenerator
import tensorflow as tf
import seaborn as sns
sns.set_style("white")
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np

In [None]:
from scipy import stats

def plot_correlation(df, attr_1, attr_2):
    r = stats.pearsonr(df[attr_1], 
                       df[attr_2])[0] **2

    g = sns.lmplot(data=df,
        x=attr_1, y=attr_2,
        height=5, legend=True
    )
    # g.set(ylim=(0, 360), xlim=(0,360))

    g.set_axis_labels(attr_1, attr_2)
    plt.annotate(r'$R^2:{0:.2f}$'.format(r),
                (max(df[attr_1])*0.9, max(df[attr_2])*0.9))
    return g

In [None]:
BASE_PATH = Path("/clusterdata/uqxtan9/Xiao/STimage/dataset/breast_cancer_10x_visium")
TILE_PATH = Path("/tmp") / "tiles"
TILE_PATH.mkdir(parents=True, exist_ok=True)

SAMPLE = "block1"
Sample1 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif", 0)
Sample1.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "block2"
Sample2 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_2_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_2_image.tif", 0)
Sample2.uns["spatial"][SAMPLE]['images']["fulres"] = img

In [None]:
Sample1

In [None]:
Sample2

In [None]:
gene_list=["SLITRK6", "PGM5", "LINC00645", 
           "TTLL12", "COX6C", "CPB1",
           "KRT5", "MALAT1"]
gene_list

In [None]:
for adata in [
    Sample1,
    Sample2,
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
    st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
    st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image
    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299)

In [None]:
n_genes = len(gene_list)
training_index = Sample1.obs.sample(frac=0.7, random_state=1).index
training_dataset = Sample1[training_index,].copy()

valid_index = Sample1.obs.index.isin(training_index)
valid_dataset = Sample1[~valid_index,].copy()

test_dataset = Sample2.copy()

train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
train_gen_ = train_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
valid_gen_ = valid_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
test_gen_ = test_gen.batch(1)

In [None]:
model = CNN_NB_multiple_genes((299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
train_history = model.fit(train_gen_,
                          epochs=50,
                          validation_data=valid_gen_,
                          callbacks=[callback]
                          )

In [None]:
test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
for i in range(n_genes):
    n = test_predictions[i][:, 0]
    p = test_predictions[i][:, 1]
    y_pred = nbinom.mean(n, p)
    y_preds.append(y_pred)
test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()

In [None]:
test_dataset_ = test_dataset[:,gene_list].copy()
test_dataset_.X = test_dataset_.obsm["predicted_gene"]

In [None]:
for i in gene_list:
    print(i)
    gene_plot(test_dataset_, genes=i, spot_size=8)

In [None]:
gene_plot(test_dataset, genes="COX6C", spot_size=8)

In [None]:
gene_plot(test_dataset, genes="KRT5", spot_size=8)

In [None]:
gene_list[6]

In [None]:
def model_predict_gene(gene):
    i = gene_list.index(gene)
    from scipy.stats import nbinom
    def model_predict(x):
        test_predictions = model.predict(x)
        n = test_predictions[i][:, 0]
        p = test_predictions[i][:, 1]
        y_pred = nbinom.mean(n, p)
        return y_pred.reshape(-1,1)
    return model_predict

In [None]:
import matplotlib.pyplot as plt
from libpysal.weights.contiguity import Queen
from libpysal import examples
import numpy as np
import pandas as pd
import geopandas as gpd
import os
import splot
from splot.esda import moran_scatterplot, lisa_cluster
from esda.moran import Moran, Moran_Local
from esda.moran import Moran_BV, Moran_Local_BV
from splot.esda import plot_moran_bv_simulation, plot_moran_bv, plot_local_autocorrelation

In [None]:
gene = gene_list[4]

In [None]:
test_dataset_.obsm["gpd"] = gpd.GeoDataFrame(test_dataset_.obs,
                                             geometry=gpd.points_from_xy(
                                                 test_dataset_.obs.imagecol, 
                                                 test_dataset_.obs.imagerow))


In [None]:
test_dataset_.obsm["gpd"]

In [None]:
x = test_dataset.to_df()[gene].values
y = test_dataset_.to_df()[gene].values
w = Queen.from_dataframe(test_dataset_.obsm["gpd"])

In [None]:
test_dataset_.obsm["gpd"]["gc_{}".format(gene)] = x
test_dataset_.obsm["gpd"]["pred_{}".format(gene)] = y
tissue_image = test_dataset_.uns["spatial"]["block2"]["images"]["fulres"]

In [None]:
moran = Moran(y,w)
moran_bv = Moran_BV(y, x, w)
moran_loc = Moran_Local(y, w)
moran_loc_bv = Moran_Local_BV(y, x, w)

In [None]:
plot_correlation(test_dataset_.obsm["gpd"],
                 "pred_{}".format(gene),
                 "gc_{}".format(gene))
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_bv, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_loc_bv, p=0.05, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
def plot_choropleth(gdf, 
                    attribute_1,
                    attribute_2,
                    bg_img,
                    alpha=0.5,
                    scheme='Quantiles', 
                    cmap='YlGnBu', 
                    legend=True):
    
    fig, axs = plt.subplots(2,1, figsize=(5, 8),
                            subplot_kw={'adjustable':'datalim'})
    
    # Choropleth for attribute_1
    gdf.plot(column=attribute_1, scheme=scheme, cmap=cmap,
             legend=legend, legend_kwds={'loc': 'upper left',
                                         'bbox_to_anchor': (0.92, 0.8)},
             ax=axs[0], alpha=alpha, markersize=2)
    
    axs[0].imshow(bg_img)
    axs[0].set_title('choropleth plot for {}'.format(attribute_1), y=0.8)
    axs[0].set_axis_off()
    
    # Choropleth for attribute_2
    gdf.plot(column=attribute_2, scheme=scheme, cmap=cmap,
             legend=legend, legend_kwds={'loc': 'upper left',
                                         'bbox_to_anchor': (0.92, 0.8)},
             ax=axs[1], alpha=alpha, markersize=2)
    
    axs[1].imshow(bg_img)
    axs[1].set_title('choropleth plot for {}'.format(attribute_2), y=0.8)
    axs[1].set_axis_off()
    
    plt.tight_layout()
    
    return fig, ax 

In [None]:
plot_choropleth(test_dataset_.obsm["gpd"], 
                "gc_{}".format(gene),
                "pred_{}".format(gene),
                tissue_image)
plt.show()

In [None]:
lisa_cluster(moran_loc_bv, test_dataset_.obsm["gpd"], p=0.05, 
             figsize = (9,9), markersize=12, **{"alpha":0.8})
plt.imshow(test_dataset_.uns["spatial"]["block2"]["images"]["fulres"])
plt.show()