In [None]:
import sys
import stlearn as st
st.settings.set_figure_params(dpi=300)
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import sys
file = Path("../stimage").resolve()
parent= file.parent
sys.path.append(str(parent))
from PIL import Image
from stimage._utils import gene_plot, Read10X, ReadOldST, tiling
from stimage._model import CNN_NB_multiple_genes, negative_binomial_layer, negative_binomial_loss
from stimage._data_generator import DataGenerator, DataGenerator_LSTM_one_output
import tensorflow as tf
import seaborn as sns
sns.set_style("white")
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np
import geopandas as gpd
from sklearn.neighbors import KDTree

In [None]:
from anndata import AnnData
from typing import Iterable, Union, Optional
import pandas as pd
def enrich_group(adata: AnnData,
                 gene_list: Iterable,
                 enrich_name: Union[pd.Index, list],
                 

                
) -> Optional[AnnData]:
    adata_ = adata[:,adata.var_names.isin(gene_list)].copy()
    adata_enrich = AnnData(X=adata_.X.sum(axis=1),
                       obs=adata_.obs,
                       uns=adata_.uns,
                       obsm=adata_.obsm)
    adata_enrich.var_names = enrich_name
    return adata_enrich

In [None]:
from scipy import stats

def plot_correlation(df, attr_1, attr_2):
    r = stats.pearsonr(df[attr_1], 
                       df[attr_2])[0] **2

    g = sns.lmplot(data=df,
        x=attr_1, y=attr_2,
        height=5, legend=True
    )
    # g.set(ylim=(0, 360), xlim=(0,360))

    g.set_axis_labels(attr_1, attr_2)
    plt.annotate(r'$R^2:{0:.2f}$'.format(r),
                (max(df[attr_1])*0.9, max(df[attr_2])*0.9))
    return g

In [None]:
BASE_PATH = Path("/clusterdata/uqxtan9/Xiao/STimage/dataset/breast_cancer_10x_visium")
TILE_PATH = Path("/tmp") / "tiles"
TILE_PATH.mkdir(parents=True, exist_ok=True)

SAMPLE = "block1"
Sample1 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_1_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif", 0)
Sample1.uns["spatial"][SAMPLE]['images']["fulres"] = img

SAMPLE = "block2"
Sample2 = st.Read10X(BASE_PATH / SAMPLE, 
                  library_id=SAMPLE, 
                  count_file="V1_Breast_Cancer_Block_A_Section_2_filtered_feature_bc_matrix.h5",
                  quality="fulres",)
                  #source_image_path=BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_1_image.tif")
img = plt.imread(BASE_PATH / SAMPLE /"V1_Breast_Cancer_Block_A_Section_2_image.tif", 0)
Sample2.uns["spatial"][SAMPLE]['images']["fulres"] = img

In [None]:
gene_list=["SLITRK6", "PGM5", "LINC00645", 
           "TTLL12", "COX6C", "CPB1",
           "KRT5", "MALAT1"]
gene_list

In [None]:
for adata in [
    Sample1,
    Sample2,
]:
#     count_df = adata.to_df()
#     count_df[count_df <=1] = 0
#     count_df[count_df >1] = 1
#     adata.X = count_df
#     adata[:,gene_list]
    st.pp.filter_genes(adata,min_cells=3)
#     st.pp.normalize_total(adata)
    st.pp.log1p(adata)
#     st.pp.scale(adata)

    # pre-processing for spot image
    TILE_PATH_ = TILE_PATH / list(adata.uns["spatial"].keys())[0]
    TILE_PATH_.mkdir(parents=True, exist_ok=True)
    tiling(adata, TILE_PATH_, crop_size=299)

In [None]:
gene_list_1 = Sample1.to_df().filter(regex=("KRT.*")).columns
gene_list_2 = Sample2.to_df().filter(regex=("KRT.*")).columns
gene_list_share = gene_list_1.intersection(gene_list_2)

In [None]:
gene_list=pd.Index(["KRT_enrich"])

In [None]:
Sample1 = enrich_group(Sample1, 
                       gene_list_share, 
                       gene_list)
Sample2 = enrich_group(Sample2, 
                       gene_list_share, 
                       gene_list)

In [None]:
n_genes = len(gene_list)
# training_index = Sample1.obs.sample(frac=0.9, random_state=1).index
# training_dataset = Sample1[training_index,].copy()

# valid_index = Sample1.obs.index.isin(training_index)
# valid_dataset = Sample1[~valid_index,].copy()

training_dataset = Sample1.copy()

valid_index = Sample1.obs.sample(frac=0.3, random_state=1).index
valid_dataset = Sample1[valid_index,].copy()

test_dataset = Sample2.copy()

In [None]:
train_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=training_dataset, 
                          genes=gene_list, aug=False),
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
train_gen_ = train_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
valid_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=valid_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
valid_gen_ = valid_gen.shuffle(buffer_size=100).batch(8).repeat(1).cache().prefetch(tf.data.experimental.AUTOTUNE)
test_gen = tf.data.Dataset.from_generator(
            lambda:DataGenerator_LSTM_one_output(adata=test_dataset, 
                          genes=gene_list), 
            output_types=(tf.float32, tuple([tf.float32]*n_genes)), 
            output_shapes=([7,299,299,3], tuple([1]*n_genes))
)
test_gen_ = test_gen.batch(1)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Dropout, Lambda, TimeDistributed, LSTM
from tensorflow.keras.models import Model


def CNN_LSTM_NB_multiple_genes(steps_tile_shape, n_genes):
    tile_input = Input(shape=steps_tile_shape, name="steps_tile_input")
    resnet_base = ResNet50(input_shape=steps_tile_shape[1:], weights='imagenet', include_top=False)
    #     stage_5_start = resnet_base.get_layer("conv5_block1_1_conv")
    #     for i in range(resnet_base.layers.index(stage_5_start)):
    #         resnet_base.layers[i].trainable = False

#     for i in resnet_base.layers:
#         i.trainable = False
    cnn_out = resnet_base.output
    cnn_out = GlobalAveragePooling2D()(cnn_out)
    cnn = Model(inputs=resnet_base.input, outputs=cnn_out)
    #     cnn = Dropout(0.5)(cnn)
    #     cnn = Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l1(0.01),
    #                 activity_regularizer=tf.keras.regularizers.l2(0.01))(cnn)
    # cnn = Dense(256, activation='relu')(cnn)
    encoded_frames = TimeDistributed(cnn)(tile_input)
    encoded_sequence = LSTM(256)(encoded_frames)
    output_layers = []
    for i in range(n_genes):
        output = Dense(2)(encoded_sequence)
        output_layers.append(Lambda(negative_binomial_layer, name="gene_{}".format(i))(output))

    model = Model(inputs=[tile_input], outputs=output_layers)
    #     losses={}
    #     for i in range(8):
    #         losses["gene_{}".format(i)] = negative_binomial_loss(i)
    #     optimizer = tf.keras.optimizers.RMSprop(0.001)
    optimizer = tf.keras.optimizers.Adam(0.0001)
    model.compile(loss=negative_binomial_loss,
                  optimizer=optimizer)
    return model

In [None]:
model = CNN_LSTM_NB_multiple_genes((7, 299, 299, 3), n_genes)
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20,
                                            restore_best_weights=False)

In [None]:
model.summary()

In [None]:
train_history = model.fit(train_gen_,
                          epochs=50,
                          validation_data=valid_gen_,
                          callbacks=[callback]
                          )

In [None]:
# model.save("./LSTM_CNN.h5")

In [None]:
model.save("./LSTM_CNN_ft.h5")

In [None]:
model.load_weights("./LSTM_CNN.h5")

In [None]:
test_predictions = model.predict(test_gen_)
from scipy.stats import nbinom
y_preds = []
if n_genes >1:
    for i in range(n_genes):
        n = test_predictions[i][:, 0]
        p = test_predictions[i][:, 1]
        y_pred = nbinom.mean(n, p)
        y_preds.append(y_pred)
    test_dataset.obsm["predicted_gene"] = np.array(y_preds).transpose()
else:
    n = test_predictions[:, 0]
    p = test_predictions[:, 1]
    y_pred = nbinom.mean(n, p)
    test_dataset.obsm["predicted_gene"] = y_pred

In [None]:
test_dataset_ = test_dataset[:,gene_list].copy()
test_dataset_.X = test_dataset_.obsm["predicted_gene"]

In [None]:
# gene_plot(test_dataset_, genes="KRT5", spot_size=8, vmin=0,vmax=3)

In [None]:
# gene_plot(test_dataset_, genes="TTLL12", spot_size=8, vmin=0,vmax=4)

In [None]:
# for i in gene_list:
#     print(i)
#     gene_plot(test_dataset_, genes=i, spot_size=8)

In [None]:
# for i in gene_list:
#     print(i)
#     gene_plot(test_dataset, genes=i, spot_size=8)