In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib 
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
from sklearn import metrics
from sklearn.neighbors import NearestNeighbors
import tabulate
import itertools
from mpl_toolkits import mplot3d
import matplotlib.patches as mpatches
import umap
from importlib import reload
from scipy import stats
import matplotlib.image as mpimg
import gget

# locals
import utils as ut
reload(ut)

sc.settings.verbosity = 3

In [None]:
clusterPath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/data/scanpy/clusters.csv"
cf = pd.read_parquet(clusterPath)
print(cf.shape)
print(cf['cellType'].unique())
cf.head()

In [None]:
dirPath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/data/SPT/"

adata = {}

for f in os.listdir(dirPath):
    
    if not f == "all":
        key = f
        mtxPath = f"{dirPath}{f}/outs/"
        data = sc.read_visium(mtxPath)
        data.var_names_make_unique()
        data.var_names = [x.upper() for x in data.var_names]
        adata[key] = data
    
adata.keys()

# Filtering

In [None]:
keys = ['ND', 'HFD8', 'HFD14']

min_counts = 10

for key in keys:
    print(key)
    data = adata[key]
    sc.pp.filter_cells(data, min_counts=min_counts)
    sc.pp.normalize_total(data, target_sum=1e4) # Normalize each cell by total counts over all genes
    sc.pp.log1p(data) # Logarithmize data via `X = \log(X + 1)`,
    sc.pp.scale(data) # unit variance and zero mean


In [None]:
def getImage(data):
    """A function to get image from the spatial data """
    from scanpy.pl._tools import scatterplots as spt    
    sptData = data.uns['spatial']
    
    libraryId = list(sptData.keys())[0]
    sptData = sptData[libraryId]
    
    # get params
    img = sptData['images']['hires']
    scale_factor = sptData['scalefactors']['tissue_hires_scalef']
    return img, scale_factor


trim = {
    'ND' : {'xlim' :  [250, 1500], 'ylim' : [1600, 650]},
    'HFD8' : {'xlim' :  [190, 1350], 'ylim' : [1680, 400]},
    'HFD14' : {'xlim' :  [250, 1400], 'ylim' : [1650, 450]},
}

print(trim)

In [None]:
ctypes = sorted(cf['cellType'].unique())
nTypes = len(ctypes)
colorList = ut.ncolor(nTypes, cmap='Spectral')
colorDict = dict(zip(ctypes, colorList))

cf['color'] = cf['cellType'].map(colorDict)
cf.head()

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 10, 10
plt.rcParams['figure.facecolor'] = 'w'

lft = 1
nGene = 30
alpha = 0.05
q = 20
layer = 6 # reverse ordered from 6 to 0 with 6 == 1 first layer

for key in keys:
    data = adata[key]
    # get the image and the scale factors
    img, scale_factor = getImage(data)
    df = data.to_df()
    
    # get the image coordinates and scale them
    coords = data.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    x = x * scale_factor 
    y = y * scale_factor 
    
    # get marker genes for cell types
    rf = cf[cf['key'] == key]
    rf = rf[rf['pvals'] <= alpha].reset_index(drop=True)
    
    rf = rf.groupby(['gene', 'cellType']).agg(
        clusterCount = ('clusterId', 'count'),
        meanLFC = ('logfoldchanges', 'mean')
    ).reset_index(drop=False)
    
    rf = rf[rf['meanLFC'] > lft].reset_index(drop=True)        
    rf = rf.set_index('gene')
    rf = rf.groupby(['cellType'])['clusterCount'].nlargest(nGene).reset_index(drop=False)
    rf['flag'] = 1.0
    
    # print(rf.head())
    # print(rf['cellType'].value_counts())
    
    rf = pd.pivot_table(rf, values='flag',
                        index='gene', 
                        columns='cellType')
    
    rf = rf.fillna(0.0)

    dfT = df.T
    # binarize the spot data
    spots = dfT.columns
    
    spotDecon = pd.merge(dfT, rf,
                         how='left',  
                         left_on=dfT.index,
                         right_on=rf.index,)
    
    spotDecon = spotDecon.fillna(0)
    """ Deconvolute the spots """
    
    spotColor = []
    spotTypes = []
    for spot in spots:
        spotThresh = np.percentile(spotDecon[spot], q)
        spotVec = np.where(spotDecon[spot] > spotThresh, 1.0, 0.0)
        scores = []
        for c in ctypes:
            
            score = metrics.jaccard_score(spotDecon[c], spotVec)
            scores.append(score)
        
        indx = np.argsort(scores)
        i = np.nonzero(indx == layer)[0][0]
        spotType = ctypes[i]
        spotColor.append(colorDict[spotType])
        spotTypes.append(spotType)
        
    print(pd.Series(spotTypes).value_counts())
        
    handles = []
    for ctype in ctypes:
        handle = mpatches.Patch(color=colorDict[ctype], 
                                ec='k', 
                                label=ctype)
        handles.append(handle)
        
    plt.scatter(x, y,
                c=spotColor,
                marker="o",
                edgecolor="None",
                # lw=0.1,
                zorder=2,
                alpha=0.7,
                s=60)
    
    # plt.imshow(img.astype(float), 
    #        interpolation='none',
    #        cmap='binary',
    #        alpha=0.7,
    #        zorder=1)
    
    xlim = trim[key]['xlim']
    ylim = trim[key]['ylim']
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.yticks([])
    plt.xticks([])
    
    plt.legend(handles=handles, bbox_to_anchor=(1.04, 1.02))
    plt.title(key)
    plt.show()
    
    break

In [None]:
ctypes

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 10, 10
plt.rcParams['figure.facecolor'] = 'w'

expT = 0 # gene expression threshold


for key in keys:
    data = adata[key]
    # get the image and the scale factors
    img, scale_factor = getImage(data)
    df = data.to_df()
    
    # get the image coordinates and scale them
    coords = data.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    x = x * scale_factor 
    y = y * scale_factor 
    
    spotType = []
    
    for idx, row in df.head(10).iterrows():
        mask = row > expT
        expressedGenes = df.columns[mask]
        
        ef = gget.enrichr(expressedGenes, database='PanglaoDB_Augmented_2021')
        predType = ef.head(1)['path_name'].values[0]
        spotType.append(predType)
        
    
    print(pd.Series(spotType).value_counts())

   
        
#     handles = []
#     for ctype in ctypes:
#         handle = mpatches.Patch(color=colorDict[ctype], 
#                                 ec='k', 
#                                 label=ctype)
#         handles.append(handle)
        
#     plt.scatter(x, y,
#                 c=spotColor,
#                 marker="o",
#                 edgecolor="None",
#                 # lw=0.1,
#                 zorder=2,
#                 alpha=0.7,
#                 s=60)
    
#     # plt.imshow(img.astype(float), 
#     #        interpolation='none',
#     #        cmap='binary',
#     #        alpha=0.7,
#     #        zorder=1)
    
#     xlim = trim[key]['xlim']
#     ylim = trim[key]['ylim']
#     plt.xlim(xlim)
#     plt.ylim(ylim)
#     plt.yticks([])
#     plt.xticks([])
    
#     plt.legend(handles=handles, bbox_to_anchor=(1.04, 1.02))
#     plt.title(key)
#     plt.show()
    
    break