In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import anndata as ad
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt

In [None]:
data_dic = {}

## Import and Load Data 

In [None]:
OUT_DIR = "/QRISdata/Q1851/Andrew_C/Pfizer/Visium/"

In [None]:
### Read Visium Data
def read_visium(sample):
    PATH = "/QRISdata/Q2051/Pfizer/Visium/RAW_DATA/Pfizer/Python/" + sample + "/outs/"
    adata = sc.read_visium(PATH)
    adata.obsm['spatial'] = adata.obsm['spatial'].astype(np.int)
    return(adata)


In [None]:
#Samples to be used
sample_list = ["VLP78_A",  "VLP78_D",  "VLP79_A",  "VLP79_D",  "VLP80_A",  "VLP80_D",  "VLP81_A",  "VLP82_A",  "VLP82_D",  "VLP83_A",  "VLP83_D"]

data_dic = {}

for sample in sample_list:
    data_dic[sample] = read_visium(sample)
    

## Run Preprocessing

In [None]:
def process_visium(adata):
    #print("filtering data")
    sc.pp.filter_cells(adata, min_counts=10)
    sc.pp.filter_genes(adata, min_cells=3)
    #print("normalize total")
    sc.pp.normalize_total(adata)
    #print("log transform")
    sc.pp.log1p(adata)
    #print("scale")
    sc.pp.scale(adata, max_value=10)
    return(adata)

In [None]:
normalised_data_dic = {}
for sample, data in data_dic.items():
    normalised_data_dic[sample] = process_visium(data.copy())

## Run and Plot Clustering

In [None]:
def cluster_visium(adata):
    resolution = 0.5
    #print("PCA")
    sc.tl.pca(adata, svd_solver="arpack")
    #print("neighbors")
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
    #print("UMAP")
    sc.tl.umap(adata)
    #print("Leiden")
    sc.tl.leiden(adata, resolution=resolution)
    return(adata)

In [None]:
for sample, data in data_dic.items():
    data_dic[sample] = cluster_visium(data)

In [None]:
def plot_data(adata, sample, data_to_plot, UMAP = False):

    if UMAP == True:
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))


        # Create the first plot
        plot1  = sc.pl.umap(adata, color=[data_to_plot], size=10,  ax=ax1, show=False, title= str(sample)+ ": " + str(data_to_plot) + " UMAP")

        # Create the second plot
        plot2 = sq.pl.spatial_scatter(adata, color=data_to_plot, size=1.3, figsize=(10, 10), ax=ax2, title= str(sample)+ ": " + str(data_to_plot) + " Spatial Plot", img_res_key="lowres")
        
    else:
        sq.pl.spatial_scatter(adata, color=data_to_plot, size=1.3, figsize=(10, 10), title= str(sample)+ ": " + str(data_to_plot) + " Spatial Plot")




In [None]:
for sample, data in data_dic.items():
    plot_data(data, sample, "leiden", UMAP=True)

## Add Label Transfer Data 

In [None]:
#add label transfer
def add_spot_annotations(adata, sample):
    df = pd.read_csv("/QRISdata/Q1851/Andrew_C/Pfizer/Visium/"+sample+"/label_transfer/"+sample+"_label_transfer.csv")
    df.set_index("Unnamed: 0", inplace=True)
    adata.obs["label_transfer"] = df["predicted.id"]
    adata = adata[adata.obs['label_transfer'].notna()]
    return(adata)

In [None]:
for sample, data in data_dic.items():
    new_data = add_spot_annotations(data, sample)
    data_dic[sample] = new_data
    #plot_data(new_data, sample, "label_transfer", UMAP=False)

## Run Neighborhood Analsyis between Clusters and Cell Type Labels

In [None]:
def run_neighborhood(adata, data_to_cluster):
    sq.gr.spatial_neighbors(adata, coord_type="generic", spatial_key="spatial")
    sq.gr.nhood_enrichment(adata, cluster_key=data_to_cluster)
    return(adata)

def plot_neighborhood(adata, data_to_plot):
    sq.pl.nhood_enrichment(adata,cluster_key=data_to_plot, method="average",cmap="inferno",vmin=-50,vmax=100,figsize=(10, 10))


In [None]:
for sample, data in data_dic.items():
    new_data = run_neighborhood(data, "leiden")
    new_data_2 = run_neighborhood(new_data, "label_transfer")
    data_dic[sample] = new_data_2

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

    # Create the first plot
    plot1  = sq.pl.nhood_enrichment(new_data_2,cluster_key="leiden", method="average",cmap="inferno",vmin=-50,vmax=100,figsize=(10, 10),  ax=ax1, title= str(sample)+ ": 'leiden' NbrHood Enrichment")

    # Create the second plot
    plot2 = sq.pl.nhood_enrichment(new_data_2,cluster_key="label_transfer", method="average",cmap="inferno",vmin=-50,vmax=100,figsize=(10, 10),  ax=ax2, title= str(sample)+ ": 'label_transfer' NbrHood Enrichment")
        
    

## Run Spatial Autocorrelation

In [None]:
def run_spatial_autocorrelation(adata):
    num_genes = len(adata.var_names)
    renames_genes = []
    for gene in adata.var_names:
        if gene not in renames_genes:
            renames_genes.append(gene)
        else:
            renames_genes.append(gene+".1")

    adata.var_names = renames_genes

    #run autocorrelation
    sq.gr.spatial_autocorr(adata, mode="moran")
    return(adata)

In [None]:
for sample, data in data_dic.items():
    data_dic[sample] = run_spatial_autocorrelation(data)

In [None]:
## Plot auto correlation
def plot_autocorrelation(adata, sample, genes_to_plot, stat_summary = False):

    if stat_summary == True:
        ###print statistics ###
        print("sample: "+ sample)
        for gene in genes_to_plot:
            print("Spatial Autocorrelation Rank of "+ gene + ": " + str(adata.uns["moranI"]["I"].sort_values(ascending=False).index.tolist().index(gene)))
        print("")
        print("Out of a total of " + str(len(adata.uns["moranI"]["I"].sort_values(ascending=False).index.tolist())) + " genes")
        ########################

    #plot
    sq.pl.spatial_scatter(adata, color=genes_to_plot, size=1.2, img=False, figsize=(5, 5))


In [None]:
### KAT6A/KAT6B genes
KAT6_gene_list =["KAT6A", "KAT6B", "ESR1","BRPF1","MEAF6","ING5"]

for sample, data in data_dic.items():
    plot_autocorrelation(data,sample,KAT6_gene_list)

## Plot spatial correlation between ["KAT6A","KAT6B"] and ["ESR1","BRPF1","MEAF6","ING5"]

In [None]:
import geopandas as gpd
from geopandas import GeoDataFrame
from libpysal.weights import Queen
from esda.moran import (Moran, Moran_BV, Moran_Local, Moran_Local_BV)
from splot.esda import lisa_cluster 


In [None]:
################ Supresses warnings #################
import warnings
warnings.filterwarnings('ignore')
#####################################################


gene_list1 = ["KAT6A","KAT6B"]
gene_list2 = ["ESR1","BRPF1","MEAF6","ING5", "KAT6B"]

def gene_correlation(adata, sample, OUT_DIR, plot_save_name,image_res="lowres"):

    fig, ax = plt.subplots(len(gene_list1), len(gene_list2), figsize=(60, 20))

    # add "imagerow" and "imagecol" to adata object
    scale = adata.uns["spatial"][sample]["scalefactors"]["tissue_" + image_res + "_scalef"]
    image_coor = adata.obsm["spatial"] * scale
    adata.obs["imagecol"] = image_coor[:, 0]
    adata.obs["imagerow"] = image_coor[:, 1]

    adata.obsm["gpd"] = gpd.GeoDataFrame(adata.obs, geometry=gpd.points_from_xy(
        adata.obs.imagecol,
        adata.obs.imagerow))

    for i, gene1 in enumerate(gene_list1):
        for j, gene2 in enumerate(gene_list2):
            #print("Generating " + gene1 + " vs " + gene2 + " Plot...")
            x = adata.to_df()[gene1].values.astype(np.float64)
            y = adata.to_df()[gene2].values.astype(np.float64)
            w = Queen.from_dataframe(adata.obsm["gpd"])

            tissue_image = adata.uns["spatial"][list(adata.uns["spatial"].keys())[0]]["images"][image_res]

            moran = Moran(y, w)
            moran_bv = Moran_BV(y, x, w)
            moran_loc = Moran_Local(y, w)
            moran_loc_bv = Moran_Local_BV(y, x, w)

            # Plot the LISA cluster on the specified axis
            lisa_cluster(moran_loc_bv, adata.obsm["gpd"], p=0.05, markersize=10, ax=ax[i, j], figsize = (10,10))
            ax[i, j].imshow(tissue_image)
            ax[i, j].set_title(gene1+" vs "+ gene2, fontsize = 20)

    # Adjust subplot layout
    #plt.tight_layout()

    # Display the final plot
    #plt.show()
    plt.savefig(OUT_DIR+sample+"/"+sample+"_"+plot_save_name+"_spatial_correlation.pdf")
    plt.close()

    
    return(adata)


In [None]:
auto_cor = {}
for sample, data in data_dic.items():
    auto_cor[sample] = gene_correlation(data,sample, OUT_DIR, plot_save_name = "KAT6A")

## Find Top Genes Spatially Correlated with KAT6A and KAT6B

In [None]:
ESR1_positive_spots = normalised_data_dic["VLP82_D"][normalised_data_dic["VLP82_D"][: , 'ESR1'].X > 0.4, :] 

## Plot ESR1 gene expression and positive spots
sq.pl.spatial_scatter(normalised_data_dic["VLP82_D"], color="ESR1", size=1.2, img=False, figsize=(5, 5), cmap='Reds')
plot_data(ESR1_positive_spots, "VLP82_D", 'label_transfer', UMAP = False)

## Look at Genes Spatially Autocorrelated to ESR1

In [None]:
#### Run Spatial Autocorrelation
sq.gr.spatial_autocorr(ESR1_positive_spots, mode="moran")
num_view = 12
top_autocorr = (
    ESR1_positive_spots.uns["moranI"]["I"].sort_values(ascending=False).head(num_view).index.tolist()
)


### Plot Top 12 Genes

In [None]:
sq.pl.spatial_scatter(
    ESR1_positive_spots, color=top_autocorr, size=1, cmap="Reds", img=False, figsize=(5, 5)
)

### Get Rank for KAT6A/KAT6B Associated Genes

In [None]:
rank = ESR1_positive_spots.uns["moranI"]["I"].sort_values(ascending=False)

print("The Autocorrelation rank of genes to ERS1 (in ERS1+ spots)...")
print("")
for idx, gene in enumerate(rank.index.tolist()):
    if gene in KAT6_gene_list:
        print(gene +" is ranked at: "+str(idx))
print("")
print("Out of a total of: "+str(len(rank.index.tolist()))+ " genes...")


In [None]:
rank.index

## Run Spatial Correlation for KAT6B and ESR1

In [None]:
import sys
import seaborn as sns

import spatialcorr

In [None]:
test_data = ad.AnnData(data_dic['VLP82_D'].X)

In [None]:
test_data.obs_names = data_dic['VLP82_D'].obs_names
test_data.var_names = data_dic['VLP82_D'].var_names
test_data.obs['row'] = data_dic['VLP82_D'].obs['array_row'].astype(int)
test_data.obs['col'] = data_dic['VLP82_D'].obs['array_col'].astype(int)
test_data.obs['cluster'] = data_dic['VLP82_D'].obs['label_transfer']

In [None]:
test_data

#### Test Spatial Data

In [None]:
spatialcorr.wrappers.kernel_diagnostics(
    test_data,
    'cluster',
    bandwidth=5,
    contrib_thresh=10,
    dsize=2,
    dpi=150
)

#### Run SpatialCorrelation on KAT6A/B Gene List

In [None]:
spatialcorr.analysis_pipeline_set(
    test_data,                    # The dataset
    KAT6_gene_list,  # The gene set                  
    'cluster',                # The key in `adata.obs` storing the cluster info
    max_perms=100,            # Maximum number of permutations to run
    dsize=5,                 # Control the size of the dots
    verbose=0,                 # Don't output any debugging information
)

In [None]:
spatialcorr.plot.mult_genes_plot_correlation(
    test_data,
    KAT6_gene_list,
    'cluster',
    estimate='local',
    dsize=1
)


## Plot region-wide estimates of correlation

In [None]:
spatialcorr.plot.mult_genes_plot_correlation(
    test_data,                     # The dataset
    KAT6_gene_list,   # The set of genes
    'cluster',                 # The key in `adata.obs` storing the cluster info
    estimate='regional',        # Set to 'regional' to show cluster-wide estimates
    dsize = 1
)

## Correlation between KAT6B and ERS1

In [None]:
spatialcorr.analysis_pipeline_pair(
    test_data,              # The dataset
    'KAT6B',            # Gene 1
    'ESR1',            # Gene 2
    cond_key='cluster', # The key in `adata.obs` storing the cluster info
    bandwidth=5,        # Kernel bandwidth
    max_perms=100,      # Maximum number of permutations to run
    dsize=5,           # Control the size of the dots
    verbose=0,          # Don't output any debuggin information      # Output format
)

In [None]:
# Calculate the correlation at each spot
corrs, kept_inds = spatialcorr.compute_local_correlation(
    test_data, 
    'KAT6B', 
    'ESR1', 
    row_key='row', 
    col_key='col', 
    condition='cluster'
)

# Plot the scatterplot of expression values within a neighborhood centered at row 16, column 40.
spatialcorr.plot.plot_local_scatter(
    test_data[kept_inds,:], 
    'KAT6B', 
    'ESR1', 
    16, 
    40, 
    corrs,
    row_key='row', 
    col_key='col',
    cmap='RdBu_r',
    dsize=5,
    vmin=-1,
    vmax=1
)

## Plot scatterplot of gene correlation for each cell type

In [None]:
spatialcorr.plot.region_scatterplots(
    test_data,
    'KAT6B', 
    'ESR1',
    'cluster',
    row_key='row',
    col_key='col'
)

In [None]:
spatialcorr.plot.region_scatterplots(
    test_data,
    'GATA3', 
    'ESR1',
    'cluster',
    row_key='row',
    col_key='col'
)