In [None]:
import sys
import stlearn as st
st.settings.set_figure_params(dpi=300)
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import sys
file = Path("../stimage").resolve()
parent= file.parent
sys.path.append(str(parent))
from PIL import Image
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling
from stimage._model import CNN_NB_multiple_genes, negative_binomial_layer, negative_binomial_loss
from stimage._data_generator import DataGenerator, DataGenerator_LSTM_one_output, DataGenerator_LSTM_multi_output
import tensorflow as tf
import seaborn as sns
sns.set_style("white")
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
import geopandas as gpd
from sklearn.neighbors import KDTree

In [None]:
import matplotlib.pyplot as plt
from libpysal.weights.contiguity import Queen
from libpysal import examples
import numpy as np
import pandas as pd
import geopandas as gpd
import os
import splot
from splot.esda import moran_scatterplot, lisa_cluster
from esda.moran import Moran, Moran_Local
from esda.moran import Moran_BV, Moran_Local_BV
from splot.esda import plot_moran_bv_simulation, plot_moran_bv, plot_local_autocorrelation

In [None]:
from anndata import AnnData
from typing import Iterable, Union, Optional
import pandas as pd
def spatial_autocorr(adata_true: AnnData,
                     adata_pred: AnnData,
                     model_name: str,
                     p: float = 0.05,
                     save_plots: str = None,
) -> Optional[AnnData]:
    
    assert len(adata_true.var_names) == len(adata_pred.var_names)
    
    gpd_name = "gpd_{}".format(model_name)
    library_id = list(adata_true.uns["spatial"].keys())[0]
    tissue_image = adata_true.uns["spatial"][library_id]["images"]["fulres"]
    
    adata_true.obsm[gpd_name] = gpd.GeoDataFrame(adata_true.obs,
                                                 geometry=gpd.points_from_xy(
                                                          adata_true.obs.imagecol, 
                                                          adata_true.obs.imagerow))
    w = Queen.from_dataframe(adata_true.obsm[gpd_name])
    
    spatial_autocorr_label = []
    moran_i = []
    n_sigs = []
    for gene in adata_true.var_names:
        x = np.array(adata_true.to_df()[gene].values, dtype='float')
        y = np.array(adata_pred.to_df()[gene].values, dtype='float')
        
        adata_true.obsm[gpd_name]["gc_{}".format(gene)] = x
        adata_true.obsm[gpd_name]["pred_{}".format(gene)] = y
        
#         moran = Moran(x,w)
        moran_bv = Moran_BV(y, x, w)
#         moran_loc = Moran_Local(x, w)
        moran_loc_bv = Moran_Local_BV(y, x, w)
        
        labels, n_sig = hot_cold_label(moran_loc_bv, p)
        spatial_autocorr_label.append(labels)
        moran_i.append(moran_bv.I)
        n_sigs.append(n_sig)
        
        if save_plots:
            lisa_cluster(moran_loc_bv, adata_true.obsm[gpd_name], p=0.05, 
                         figsize = (9,9), markersize=12,)# **{"alpha":1})
            plt.imshow(tissue_image)
            plt.savefig(save_plots / "lisa_cluster_{}".format(gene))
            
        adata_true.obsm[gpd_name].drop(["gc_{}".format(gene),
                                        "pred_{}".format(gene)], inplace=True, axis=1)
    
    adata_true.obsm["spatial_autocorr_{}".format(model_name)] = np.array(spatial_autocorr_label).transpose()
    adata_true.var["moran_i_{}".format(model_name)] = moran_i
    adata_true.var["n_sigs_{}".format(model_name)] = n_sigs
    
    return adata_true.copy()


def hot_cold_label(moran_loc, p):
    cluster = moran_hot_cold_spots(moran_loc, p)
    cluster_labels = ['ns', 'HH', 'LH', 'LL', 'HL']
    labels = [cluster_labels[i] for i in cluster]
    n_sig = labels.count('HH') + labels.count('LL')
    return labels, n_sig

In [None]:
from matplotlib import patches, colors
def lisa_cluster(moran_loc, gdf, p=0.05, ax=None,
                 legend=True, legend_kwds=None, **kwargs):
    """
    Create a LISA Cluster map
    Parameters
    ----------
    moran_loc : esda.moran.Moran_Local or Moran_Local_BV instance
        Values of Moran's Local Autocorrelation Statistic
    gdf : geopandas dataframe instance
        The Dataframe containing information to plot. Note that `gdf` will be
        modified, so calling functions should use a copy of the user
        provided `gdf`. (either using gdf.assign() or gdf.copy())
    p : float, optional
        The p-value threshold for significance. Points will
        be colored by significance.
    ax : matplotlib Axes instance, optional
        Axes in which to plot the figure in multiple Axes layout.
        Default = None
    legend : boolean, optional
        If True, legend for maps will be depicted. Default = True
    legend_kwds : dict, optional
        Dictionary to control legend formatting options. Example:
        ``legend_kwds={'loc': 'upper left', 'bbox_to_anchor': (0.92, 1.05)}``
        Default = None
    **kwargs : keyword arguments, optional
        Keywords designing and passed to geopandas.GeoDataFrame.plot().
    Returns
    -------
    fig : matplotlip Figure instance
        Figure of LISA cluster map
    ax : matplotlib Axes instance
        Axes in which the figure is plotted
    Examples
    --------
    Imports
    
    >>> import matplotlib.pyplot as plt
    >>> from libpysal.weights.contiguity import Queen
    >>> from libpysal import examples
    >>> import geopandas as gpd
    >>> from esda.moran import Moran_Local
    >>> from splot.esda import lisa_cluster
    Data preparation and statistical analysis
    
    >>> guerry = examples.load_example('Guerry')
    >>> link_to_data = guerry.get_path('guerry.shp')
    >>> gdf = gpd.read_file(link_to_data)
    >>> y = gdf['Donatns'].values
    >>> w = Queen.from_dataframe(gdf)
    >>> w.transform = 'r'
    >>> moran_loc = Moran_Local(y, w)
    Plotting
    
    >>> fig = lisa_cluster(moran_loc, gdf)
    >>> plt.show()
    
    """
    # retrieve colors5 and labels from mask_local_auto
    _, colors5, _, labels = mask_local_auto(moran_loc, p=p)

    # define ListedColormap
    hmap = colors.ListedColormap(colors5)

    if ax is None:
        figsize = kwargs.pop('figsize', None)
        fig, ax = plt.subplots(1, figsize=figsize)
    else:
        fig = ax.get_figure()

    gdf.assign(cl=labels).plot(column='cl', categorical=True,
                               k=2, cmap=hmap, linewidth=0.1, ax=ax,
                               edgecolor='white', legend=legend,
                               legend_kwds=legend_kwds, **kwargs)
    ax.set_axis_off()
    ax.set_aspect('equal')
    return fig, ax


def mask_local_auto(moran_loc, p=0.5):
    '''
    Create Mask for coloration and labeling of local spatial autocorrelation
    Parameters
    ----------
    moran_loc : esda.moran.Moran_Local instance
        values of Moran's I Global Autocorrelation Statistic
    p : float
        The p-value threshold for significance. Points will
        be colored by significance.
    Returns
    -------
    cluster_labels : list of str
        List of labels - ['ns', 'HH', 'LH', 'LL', 'HL']
    colors5 : list of str
        List of colours - ['#d7191c', '#fdae61', '#abd9e9',
        '#2c7bb6', 'lightgrey']
    colors : array of str
        Array containing coloration for each input value/ shape.
    labels : list of str
        List of label for each attribute value/ polygon.
    '''
    # create a mask for local spatial autocorrelation
    cluster = moran_hot_cold_spots(moran_loc, p)

    cluster_labels = ['ns', 'HH', 'LH', 'LL', 'HL']
    labels = [cluster_labels[i] for i in cluster]

    colors5 = {0: '#ffffff00',
               1: '#d7191cff',
               2: '#abd9e9ff',
               3: '#2c7bb6ff',
               4: '#fdae61ff'}
    colors = [colors5[i] for i in cluster]  # for Bokeh
    # for MPL, keeps colors even if clusters are missing:
    x = np.array(labels)
    y = np.unique(x)
    colors5_mpl = {'HH': '#d7191cff',
                   'LH': '#abd9e9ff',
                   'LL': '#2c7bb6ff',
                   'HL': '#fdae61ff',
                   'ns': '#ffffff00'}
    colors5 = [colors5_mpl[i] for i in y]  # for mpl

    # HACK need this, because MPL sorts these labels while Bokeh does not
    cluster_labels.sort()
    return cluster_labels, colors5, colors, labels


def moran_hot_cold_spots(moran_loc, p=0.05):
    sig = 1 * (moran_loc.p_sim < p)
    HH = 1 * (sig * moran_loc.q == 1)
    LL = 3 * (sig * moran_loc.q == 3)
    LH = 2 * (sig * moran_loc.q == 2)
    HL = 4 * (sig * moran_loc.q == 4)
    cluster = HH + LL + LH + HL
    return cluster

In [None]:
from anndata import AnnData
from typing import Iterable, Union, Optional
import pandas as pd
def enrich_group(adata: AnnData,
                 gene_list: Iterable,
                 enrich_name: Union[pd.Index, list],               
) -> Optional[AnnData]:
    
    adata_ = adata[:,adata.var_names.isin(gene_list)].copy()
    adata_enrich = AnnData(X=adata_.X.sum(axis=1),
                       obs=adata_.obs,
                       uns=adata_.uns,
                       obsm=adata_.obsm)
    adata_enrich.var_names = enrich_name
    return adata_enrich

In [None]:
from scipy import stats

def plot_correlation(df, attr_1, attr_2):
    r = stats.pearsonr(df[attr_1], 
                       df[attr_2])[0] **2

    g = sns.lmplot(data=df,
        x=attr_1, y=attr_2,
        height=5, legend=True
    )
    # g.set(ylim=(0, 360), xlim=(0,360))

    g.set_axis_labels(attr_1, attr_2)
    plt.annotate(r'$R^2:{0:.2f}$'.format(r),
                (max(df[attr_1])*0.9, max(df[attr_2])*0.9))
    return g

In [None]:
BASE_PATH = Path("/clusterdata/uqxtan9/Xiao/STimage/dataset/breast_cancer_10x_visium")
TILE_PATH = Path("/tmp") / "tiles"
TILE_PATH.mkdir(parents=True, exist_ok=True)

SAMPLE = "block1"
Sample1 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif", 0)
Sample1.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "block2"
Sample2 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_2_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_2_image.tif", 0)
Sample2.uns["spatial"][SAMPLE]['images']["fulres"] = img

In [None]:
gene_list=["SLITRK6", "PGM5", "LINC00645", 
           "TTLL12", "COX6C", "CPB1",
           "KRT5", "MALAT1"]
gene_list

In [None]:
for adata in [
    Sample1,
    Sample2,
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
    st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
    st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image
    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299)

In [None]:
n_genes = len(gene_list)
# training_index = Sample1.obs.sample(frac=0.9, random_state=1).index
# training_dataset = Sample1[training_index,].copy()

# valid_index = Sample1.obs.index.isin(training_index)
# valid_dataset = Sample1[~valid_index,].copy()

training_dataset = Sample1.copy()

valid_index = Sample1.obs.sample(frac=0.3, random_state=1).index
valid_dataset = Sample1[valid_index,].copy()

test_dataset = Sample2.copy()

# LSTM

In [None]:
train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
train_gen_ = train_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
valid_gen_ = valid_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
test_gen_ = test_gen.batch(1)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Dropout, Lambda, TimeDistributed, LSTM
from tensorflow.keras.models import Model


def CNN_LSTM_NB_multiple_genes(steps_tile_shape, n_genes):
    tile_input = Input(shape=steps_tile_shape, name="steps_tile_input")
    resnet_base = ResNet50(input_shape=steps_tile_shape[1:], weights='imagenet', include_top=False)
    #     stage_5_start = resnet_base.get_layer("conv5_block1_1_conv")
    #     for i in range(resnet_base.layers.index(stage_5_start)):
    #         resnet_base.layers[i].trainable = False

#     for i in resnet_base.layers:
#         i.trainable = False
    cnn_out = resnet_base.output
    cnn_out = GlobalAveragePooling2D()(cnn_out)
    cnn = Model(inputs=resnet_base.input, outputs=cnn_out)
    #     cnn = Dropout(0.5)(cnn)
    #     cnn = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l1(0.01),
    #                 activity_regularizer=tf.keras.regularizers.l2(0.01))(cnn)
    # cnn = Dense(256, activation='relu')(cnn)
    encoded_frames = TimeDistributed(cnn)(tile_input)
    encoded_sequence = LSTM(256)(encoded_frames)
    output_layers = []
    for i in range(n_genes):
        output = Dense(2)(encoded_sequence)
        output_layers.append(Lambda(negative_binomial_layer, name="gene_{}".format(i))(output))

    model = Model(inputs=[tile_input], outputs=output_layers)
    #     losses={}
    #     for i in range(8):
    #         losses["gene_{}".format(i)] = negative_binomial_loss(i)
    #     optimizer = tf.keras.optimizers.RMSprop(0.001)
    optimizer = tf.keras.optimizers.Adam(0.0001)
    model.compile(loss=negative_binomial_loss,
                  optimizer=optimizer)
    return model

In [None]:
model = CNN_LSTM_NB_multiple_genes((7, 299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
model.load_weights("./LSTM_CNN_ft.h5")

In [None]:
test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
if n_genes >1:
    for i in range(n_genes):
        n = test_predictions[i][:, 0]
        p = test_predictions[i][:, 1]
        y_pred = nbinom.mean(n, p)
        y_preds.append(y_pred)
    test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()
else:
    n = test_predictions[:, 0]
    p = test_predictions[:, 1]
    y_pred = nbinom.mean(n, p)
    test_dataset.obsm["predicted_gene"] = y_pred

In [None]:
test_final = test_dataset[:,gene_list].copy()

In [None]:
test_dataset_lstm = test_final.copy()
test_dataset_lstm.X = test_dataset.obsm["predicted_gene"]

In [None]:
# for i in gene_list:
#     print(i)
#     gene_plot(test_dataset_lstm, genes=i, spot_size=8)

In [None]:
# for i in gene_list:
#     print(i)
#     gene_plot(test_final, genes=i, spot_size=8)

In [None]:
# OUTPUT_PATH = Path("./lstm_one_plot")
# OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
# final_adata = spatial_autocorr(test_final, test_dataset_lstm, "lstm_one", save_plots=OUTPUT_PATH)

In [None]:
final_adata.var

In [None]:
import numpy as np
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import preprocess_input as preprocess_resnet
from tensorflow.keras.preprocessing import image
from stimage._imgaug import seq_aug
from sklearn.neighbors import KDTree

class DataGenerator_LSTM_multi_output(keras.utils.Sequence):
    'Generates data for Keras'

    def __init__(self, adata, dim=(299, 299), n_channels=3, genes=None, aug=False, tile_path="tile_path"):
        'Initialization'
        self.dim = dim
        self.adata = adata
        self.n_channels = n_channels
        self.genes = genes
        self.num_genes = len(genes)
        self.aug = aug
        self.tile_path = tile_path
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(self.adata.n_obs)

    def __getitem__(self, index):
        'Generate one batch of data'
        # Find list of IDs
        obs_temp = self.adata.obs_names[index]
        
        # Generate neighbour data
        candidates = self.adata.obs[['imagecol', 'imagerow']]
        quary_spots = np.array(candidates.loc[obs_temp,:]).reshape(1, -1)
        nearest_index = self.get_nearest_index(quary_spots, candidates, k_nn=7, leaf_size=self.adata.n_obs//2)
        np.random.shuffle(nearest_index.flat)
        obs_temp_neighbour = self.adata.obs_names[nearest_index.ravel()]
        
        # Generate data
        X_img_final = self._load_neighbour(self._load_img, obs_temp, obs_temp_neighbour)
        y_final = self._load_neighbour(self._load_label, obs_temp, obs_temp_neighbour)
        

        return X_img_final, y_final
    
    def _load_neighbour(self, fn, obs_temp, obs_temp_neighbour):
        obs_data = fn(obs_temp)
        obs_neighbour = np.stack([fn(x) for x in obs_temp_neighbour], axis=0)
        obs_data = np.expand_dims(obs_data, axis=0)
        obs_data_final = np.concatenate([obs_neighbour, obs_data], axis=0)
        return obs_data_final
    
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(self.adata.n_obs)

    def _load_img(self, obs):
        img_path = self.adata.obs.loc[obs, self.tile_path]
        X_img = image.load_img(img_path, target_size=self.dim)
        X_img = image.img_to_array(X_img).astype('uint8')
        #         X_img = np.expand_dims(X_img, axis=0)
        #         n_rotate = np.random.randint(0, 4)
        #         X_img = np.rot90(X_img, k=n_rotate, axes=(1, 2))
        if self.aug:
            X_img = seq_aug(image=X_img)
#         X_img = preprocess_resnet(X_img)
#         X_img = np.expand_dims(X_img, axis=0)
        return X_img

    def _load_label(self, obs):
        batch_adata = self.adata[obs, self.genes].copy()

        return np.array([batch_adata.to_df()[i].values for i in self.genes])

    def get_classes(self):
        return self.adata.to_df().loc[:, self.genes]
    
    def get_nearest_index(self, quary_spots, candidates, k_nn, leaf_size):
        tree = KDTree(candidates, leaf_size=leaf_size)
        _, indices = tree.query(quary_spots, k=k_nn)
        return indices.ravel()[1:]

# LSTM multiple output

In [None]:
train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_multi_output(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tf.float32), 
            output_shapes=([7,299,299,3], [7,8,1])
)
train_gen_ = train_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_multi_output(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tf.float32), 
            output_shapes=([7,299,299,3], [7,8,1])
)
valid_gen_ = valid_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_multi_output(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tf.float32), 
            output_shapes=([7,299,299,3], [7,8,1])
)
test_gen_ = test_gen.batch(1)

In [None]:
def map_data(inputs, outputs):
    return inputs, [tuple(i for i in outputs)]

In [None]:
abc = train_gen.map(tf.py_function(map_data, (tf.float32, tf.float32)))

In [None]:
for a, b in train_gen.shuffle(10).take(1):
    print(a.dtype)
    print(b.shape)

In [None]:
c= DataGenerator_LSTM_multi_output(adata=training_dataset, 
                          genes=gene_list, aug=False)

In [None]:
for a, b in c:
    print(a.dtype)
    print(b.shape)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Dropout, Lambda, TimeDistributed, LSTM
from tensorflow.keras.models import Model


def CNN_LSTM_NB_multiple_output_genes(steps_tile_shape, n_genes):
    tile_input = Input(shape=steps_tile_shape, name="steps_tile_input")
    resnet_base = ResNet50(input_shape=steps_tile_shape[1:], weights='imagenet', include_top=False)
    #     stage_5_start = resnet_base.get_layer("conv5_block1_1_conv")
    #     for i in range(resnet_base.layers.index(stage_5_start)):
    #         resnet_base.layers[i].trainable = False

#     for i in resnet_base.layers:
#         i.trainable = False
    cnn_out = resnet_base.output
    cnn_out = GlobalAveragePooling2D()(cnn_out)
    cnn = Model(inputs=resnet_base.input, outputs=cnn_out)
    #     cnn = Dropout(0.5)(cnn)
    #     cnn = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l1(0.01),
    #                 activity_regularizer=tf.keras.regularizers.l2(0.01))(cnn)
    # cnn = Dense(256, activation='relu')(cnn)
    encoded_frames = TimeDistributed(cnn)(tile_input)
    encoded_sequence = LSTM(256, return_sequences=True)(encoded_frames)
    output_layers = []
    for i in range(n_genes):
        output = Dense(2)(encoded_sequence)
        output_layers.append(Lambda(negative_binomial_layer, name="gene_{}".format(i))(output))

    model = Model(inputs=[tile_input], outputs=output_layers)
    #     losses={}
    #     for i in range(8):
    #         losses["gene_{}".format(i)] = negative_binomial_loss(i)
    #     optimizer = tf.keras.optimizers.RMSprop(0.001)
    optimizer = tf.keras.optimizers.Adam(0.0001)
    model.compile(loss=negative_binomial_loss,
                  optimizer=optimizer)
    return model

In [None]:
model = CNN_LSTM_NB_multiple_output_genes((7, 299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
train_history = model.fit(train_gen_,
                          epochs=5,
                          validation_data=valid_gen_,
                          callbacks=[callback]
                          )

In [None]:
test_predictions[0].shape

In [None]:
# test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
if n_genes >1:
    for i in range(n_genes):
        n = test_predictions[i][:, -1, 0]
        p = test_predictions[i][:, -1, 1]
        y_pred = nbinom.mean(n, p)
        y_preds.append(y_pred)
    test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()
else:
    n = test_predictions[:, 0]
    p = test_predictions[:, 1]
    y_pred = nbinom.mean(n, p)
    test_dataset.obsm["predicted_gene"] = y_pred

In [None]:
test_dataset_lstm_m = test_final.copy()
test_dataset_lstm_m.X = test_dataset.obsm["predicted_gene"]

In [None]:
for i in gene_list:
    print(i)
    gene_plot(test_dataset_lstm_m, genes=i, spot_size=8)

# CNN_NB

In [None]:
train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
train_gen_ = train_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
valid_gen_ = valid_gen.shuffle(buffer_size=500).batch(128).repeat(3).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([299,299,3], tuple([1]*n_genes))
)
test_gen_ = test_gen.batch(1)

In [None]:
model = CNN_NB_multiple_genes((299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
model.load_weights("./CNN_NB_8genes_model.h5")

In [None]:
test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
if n_genes >1:
    for i in range(n_genes):
        n = test_predictions[i][:, 0]
        p = test_predictions[i][:, 1]
        y_pred = nbinom.mean(n, p)
        y_preds.append(y_pred)
    test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()
else:
    n = test_predictions[:, 0]
    p = test_predictions[:, 1]
    y_pred = nbinom.mean(n, p)
    test_dataset.obsm["predicted_gene"] = y_pred

In [None]:
test_dataset_NB = test_final.copy()
test_dataset_NB.X = test_dataset.obsm["predicted_gene"]

In [None]:
# for i in gene_list:
#     print(i)
#     gene_plot(test_dataset_NB, genes=i, spot_size=8)

In [None]:
# OUTPUT_PATH = Path("./cnn_nb_plot")
# OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
# final_adata = spatial_autocorr(test_final, test_dataset_NB, "cnn_nb", save_plots=OUTPUT_PATH)

In [None]:
final_adata.var

# CNN NB two level tiling

In [None]:
for adata in [
    Sample1,
    Sample2,
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
#     st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
#     st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image
    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299, target_size=299,save_name="small_tile")
    tiling(adata, TILE_PATH_, crop_size=900, target_size=900,save_name="large_tile")

In [None]:
def CNN_NB_multiple_genes(tile_shape, n_genes):
    tile_input = Input(shape=tile_shape, name="tile_input")
    resnet_base = ResNet50(input_tensor=tile_input, weights='imagenet', include_top=False)
    #     stage_5_start = resnet_base.get_layer("conv5_block1_1_conv")
    #     for i in range(resnet_base.layers.index(stage_5_start)):
    #         resnet_base.layers[i].trainable = False

    for i in resnet_base.layers:
        i.trainable = False
    cnn = resnet_base.output
    cnn = GlobalAveragePooling2D()(cnn)
    #     cnn = Dropout(0.5)(cnn)
    #     cnn = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l1(0.01),
    #                 activity_regularizer=tf.keras.regularizers.l2(0.01))(cnn)
    # cnn = Dense(256, activation='relu')(cnn)
    output_layers = []
    for i in range(n_genes):
        output = Dense(2)(cnn)
        output_layers.append(Lambda(negative_binomial_layer, name="gene_{}".format(i))(output))

    model = Model(inputs=tile_input, outputs=output_layers)
    #     losses={}
    #     for i in range(8):
    #         losses["gene_{}".format(i)] = negative_binomial_loss(i)
    #     optimizer = tf.keras.optimizers.RMSprop(0.001)
    optimizer = tf.keras.optimizers.Adam(0.0001)
    model.compile(loss=negative_binomial_loss,
                  optimizer=optimizer)
    return model

In [None]:
for gene in gene_list:
    

In [None]:
test_final.obsm["gpd"] = gpd.GeoDataFrame(test_final.obs,
                                            geometry=gpd.points_from_xy(
                                            test_final.obs.imagecol, 
                                            test_final.obs.imagerow))

In [None]:
test_final.var

In [None]:
gene = gene_list[4]
x = np.array(test_final.to_df()[gene].values, dtype='float')
y = test_dataset_lstm.to_df()[gene].values
w = Queen.from_dataframe(test_final.obsm["gpd"])

In [None]:
gene_list

In [None]:
test_final.obsm["gpd"]["gc_{}".format(gene)] = np.array(x, dtype='float')
test_final.obsm["gpd"]["pred_{}".format(gene)] = y
tissue_image = test_final.uns["spatial"]["block2"]["images"]["fulres"]

In [None]:
moran = Moran(x,w)
moran_bv = Moran_BV(y, x, w)
moran_loc = Moran_Local(x, w)
moran_loc_bv = Moran_Local_BV(y, x, w)

In [None]:
plot_correlation(test_final.obsm["gpd"],
                 "pred_{}".format(gene),
                 "gc_{}".format(gene))
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_bv, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
moran_scatterplot(moran_loc_bv, p=0.05, ax=ax)
ax.set_xlabel('prediction of gene {}'.format(gene))
ax.set_ylabel('Spatial lag of ground truth of gene {}'.format(gene))
plt.tight_layout()
plt.show()

In [None]:
lisa_cluster(moran_loc_bv, test_final.obsm["gpd"], p=0.05, 
             figsize = (9,9), markersize=12,)# **{"alpha":1})
plt.imshow(tissue_image)
plt.show()

In [None]:
lisa_cluster(moran_loc, test_final.obsm["gpd"], p=0.05, 
             figsize = (9,9), markersize=12,)# **{"alpha":1})
plt.imshow(tissue_image)
plt.show()

In [None]:
df = pd.read_csv("./8genes_spatial_autocorr.csv", index_col=0)

In [None]:
df_lstm = df[["moran_i_lstm", "n_sigs_lstm"]].copy()

In [None]:
df_nb =  df[["moran_i_cnn_nb", "n_sigs_cnn_nb"]].copy()

In [None]:
df_lstm["genes"] = df_lstm.index

In [None]:
df_nb["genes"] = df_nb.index

In [None]:
df_lstm = df_lstm.reset_index()
df_nb = df_nb.reset_index()

In [None]:
df_lstm["model"] = "lstm"
df_nb["model"] = "cnn_nb"

In [None]:
df_lstm.columns = ['index', 'moran_i', 'n_sigs', 'genes', 'model']
df_nb.columns = ['index', 'moran_i', 'n_sigs', 'genes', 'model']

In [None]:
df_final = pd.concat([df_lstm, df_nb], axis=0,ignore_index=True)

In [None]:
df_final

In [None]:
sns.lineplot(
    data=df_final,
    x="genes", y="moran_i", hue="model",
    markers=True, dashes=False
)

In [None]:
sns.lineplot(
    data=df_final,
    x="genes", y="n_sigs", hue="model",
    markers=True, dashes=False
)