In [None]:
import numpy as np
import pandas as pd
from pathlib import Path

import anndata as ad
import scanpy as sc
import stlearn as st
import matplotlib.pyplot as plt
import pickle

## Import Data

In [None]:
OUT_DIR = "/QRISdata/Q1851/Andrew_C/Pfizer/Visium/"

In [None]:
def read_pfizer_visium(sample):
    PATH = "/QRISdata/Q2051/Pfizer/Visium/RAW_DATA/Pfizer/Python/" + sample + "/outs/"
    adata = sc.read_visium(PATH)
    adata.obsm['spatial'] = adata.obsm['spatial'].astype(np.int)
    return(adata)

In [None]:
#Samples to be used
sample_list = ["VLP78_A",  "VLP78_D",  "VLP79_A",  "VLP79_D",  "VLP80_A",  "VLP80_D",  "VLP81_A",  "VLP82_A",  "VLP82_D",  "VLP83_A",  "VLP83_D"]

data_dic = {}

for sample in sample_list:
    data_dic[sample] = read_pfizer_visium(sample)
    

In [None]:
#add label transfer
def add_spot_annotations(adata, sample):
    df = pd.read_csv("/QRISdata/Q1851/Andrew_C/Pfizer/Visium/"+sample+"/label_transfer/"+sample+"_label_transfer.csv")
    df.set_index("Unnamed: 0", inplace=True)
    adata.obs["Cell Types"] = df["predicted.id"]
    adata = adata[adata.obs['Cell Types'].notna()]
    return(adata)

In [None]:
for sample, data in data_dic.items():
    new_data = add_spot_annotations(data, sample)
    data_dic[sample] = new_data
    

In [None]:
with open('/QRISdata/Q2051/Jacky/visium_adata_gt.pkl', 'rb') as f:
    data_dict = pickle.load(f)
    
data_dict.update(data_dic)

# Process Data

In [None]:
def process_visium(adata):
    #print("filtering data")
    sc.pp.filter_cells(adata, min_counts=10)
    sc.pp.filter_genes(adata, min_cells=3)
    #print("normalize total")
    sc.pp.normalize_total(adata)
    #print("log transform")
    sc.pp.log1p(adata)
    #print("scale")
    sc.pp.scale(adata, max_value=10)
    return(adata)

In [None]:
normalised_data_dic = {}
for sample, data in data_dict.items():
    normalised_data_dic[sample] = process_visium(data.copy())

# Cluster Data 

In [None]:
def cluster_visium(adata):
    resolution = 0.5
    #print("PCA")
    sc.tl.pca(adata, svd_solver="arpack")
    #print("neighbors")
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
    #print("UMAP")
    sc.tl.umap(adata)
    #print("Leiden")
    sc.tl.leiden(adata, resolution=resolution)
    return(adata)

In [None]:
for sample, data in normalised_data_dic.items():
    normalised_data_dic[sample] = cluster_visium(data)

In [None]:
def plot_data(adata, sample, data_to_plot, UMAP = False, dont_show = True):
    if dont_show == True:
        if UMAP == True:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))


            # Create the first plot
            plot1  = sc.pl.umap(adata, color=[data_to_plot], size=10,  ax=ax1, show=False, title= str(sample)+ ": " + str(data_to_plot) + " UMAP");
            plt.close()
            # Create the second plot
            plot2 = sc.pl.spatial(adata, color=[data_to_plot], size=1.3,  ax=ax2, title= str(sample)+ ":  "+str(data_to_plot)+ " Spatial Plot");
            plt.close()
            
        else:
            sc.pl.spatial(adata, color=[data_to_plot], size=1.3, title= str(sample)+ ":  "+str(data_to_plot)+ " Spatial Plot");
            plt.close()
    else:
        if UMAP == True:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))


            # Create the first plot
            plot1  = sc.pl.umap(adata, color=[data_to_plot], size=10,  ax=ax1, show=False, title= str(sample)+ ": " + str(data_to_plot) + " UMAP");
            plt.close()
            # Create the second plot
            plot2 = sc.pl.spatial(adata, color=[data_to_plot], size=1.3,  ax=ax2, title= str(sample)+ ":  "+str(data_to_plot)+ " Spatial Plot");
            plt.close()
            
        else:
            sc.pl.spatial(adata, color=[data_to_plot], size=1.3, title= str(sample)+ ":  "+str(data_to_plot)+ " Spatial Plot");
            plt.close()
    
    

In [None]:
for sample, data in normalised_data_dic.items():
    plot_data(data, sample, "leiden", UMAP=True)

# Look at Target Genes

In [None]:
GENE_LIST = ["OTR", "V1AR", "V1BR", "V2R",
               "AVPR1A", "AVPR1B","AVPR3", "AVPR2",
               "NK1",
               "NTSR1", "SORT1",
               "MCHR1", "MCHR2","GPR145",
               "PAC1", "PAC1R", "VPAC1","VPAC1R","VPAC2","VPAC2R",
               "NPY1R","NPY2R","PPYR1","NPY5R",
               "MC3", "MC4", "MC5","MC3R", "MC4R", "MC5R",
               "MC1", "MC1R",
               "SST","SSTR1","SSTR2",
               "OX1","OX2","HCRT","HCRTR1","HCRTR2"]


In [None]:
def plot_gene_list(data, sample):
    gene_list_to_plot = []
    for idx, gene in enumerate(GENE_LIST):
        if gene in data.var_names:
            gene_list_to_plot.append(gene)
    with plt.rc_context():
        sc.pl.spatial(data, color=gene_list_to_plot, show = False, alpha_img=0.5, use_raw= False, vmin = 0, vmax = 2);
        #plt.savefig("/QRISdata/Q1851/Andrew_C/Breast_Cancer_Drug/genes_of_interest_"+ sample+'.pdf')
        plt.close()


In [None]:
for sample, data in normalised_data_dic.items():
    plot_gene_list(data, sample)