<img style="float: middle;" src="../images/logo.png">

**This script has been generated by Feng Zhang and Onkar Mulay from the *Genomics and Machine Learning Lab*.** 

If you have any questions, please do not hesitate to contact me (feng.zhang@uq.edu.au or o.mulay@uq.edu.au).

- [1 NeighborhoodCoordination](#NeighborhoodCoordination)<br>
- [2 Cell community identification](#2-Cell-community-identification)<br>
    - [2.1 Get the neighborhood cells](#Get-the-neighborhood-cells)<br>
    - [2.2 Calculate the proportion](#Count-the-number-of-each-cell-type-in-those-neighbors)<br>
    - [2.3 Cluster as cell community](#Cluster-cells-into-different-cell-community)<br>
    - [2.4 Visualization](#plot)<br>

## Neighborhood Coordination
![neighbor](images/neighbor.jpg)

## 2 Cell community identification
### Import modules

In [None]:
from sklearn.cluster import MiniBatchKMeans
from sklearn.neighbors import NearestNeighbors
import os, time,sys
import numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns


In [None]:

def get_windows(job,n_neighbors):
    '''
    For each region and each individual cell in dataset, return the indices of the nearest neighbors.
    
    'job:  meta data containing the start time,index of region, region name, indices of region in original dataframe
    n_neighbors:  the number of neighbors to find for each cell
    '''
    start_time,idx,tissue_name,indices = job
    job_start = time.time()
    
    print ("Starting:", str(idx+1)+'/'+str(len(exps)),': ' + str(exps[idx]))

    tissue = tissue_group.get_group(tissue_name)
    to_fit = tissue.loc[indices][[X,Y]].values

#     fit = NearestNeighbors(n_neighbors=n_neighbors+1).fit(tissue[[X,Y]].values)
    fit = NearestNeighbors(n_neighbors=n_neighbors).fit(tissue[[X,Y]].values)
    m = fit.kneighbors(to_fit)

    #sort_neighbors
    args = m[0].argsort(axis = 1)
    add = np.arange(m[1].shape[0])*m[1].shape[1]
    sorted_indices = m[1].flatten()[args+add[:,None]]
    neighbors = tissue.index.values[sorted_indices]

    end_time = time.time()
   
    print ("Finishing:", str(idx+1)+"/"+str(len(exps)),": "+ str(exps[idx]),end_time-job_start,end_time-start_time)
    return neighbors.astype(np.int32)


### Set path and load data


In [None]:
#script_dir = '/scratch/project/stseq/Feng/package-vignette/scripts/topic/community'
script_dir = '/data/module2/data/'
sys.path.insert(0, script_dir)
from voronoi import draw_voronoi_scatter

In [None]:
# set default parameters
ks = [5,10,20] # k=5 means it collects 5 nearest neighbors for each center cell
X = 'sdimx'
Y = 'sdimy'
reg = 'fov'
cluster_col = 'custom_clust'
keep_cols = [X,Y,reg,cluster_col]
n_neighbors = max(ks)

In [None]:
if os.path.exists('spatial.csv'):
    cells = pd.read_csv('spatial.csv').reset_index()
else:
    cells = pd.read_csv('/data/module2/data/spatial.csv').reset_index()

cells.head()

### Get the neighborhood cells

In [None]:
print(cells.shape)
print(cells.columns)
print(cells[cluster_col].unique())

In [None]:
# find windows for each cell in each tissue region
cells = pd.concat([cells,pd.get_dummies(cells[cluster_col])],axis=1)
sum_cols = cells[cluster_col].unique()
values = cells[sum_cols].values
tissue_group = cells[[X,Y,reg]].groupby(reg)
exps = list(cells[reg].unique())
tissue_chunks = [(time.time(),exps.index(t),t,a) for t,indices in tissue_group.groups.items() for a in np.array_split(indices,1)] 
tissues = [get_windows(job,n_neighbors) for job in tissue_chunks]

In [None]:
print(len(tissues))
print(type(tissues[0]))
print(tissues[0].shape)
tissues[0][:5,]

### Count the number of each cell type in those neighbors for each cell

In [None]:
# for each cell and its nearest neighbors, reshape and count the number of each cell type in those neighbors.
out_dict = {}
for k in ks:
    for neighbors,job in zip(tissues,tissue_chunks):
        chunk = np.arange(len(neighbors))#indices
        tissue_name = job[2]
        indices = job[3]
        window = values[neighbors[chunk,:k].flatten()].reshape(len(chunk),k,len(sum_cols)).sum(axis = 1)
        out_dict[(tissue_name,k)] = (window.astype(np.float16),indices)


In [None]:
out_dict[(1, 10)][0][:5,]

In [None]:
#concatenate the summed windows and combine into one dataframe for each window size tested.
windows = {}
for k in ks:
    window = pd.concat([pd.DataFrame(out_dict[(exp,k)][0],index = out_dict[(exp,k)][1].astype(int),columns = sum_cols) for exp in exps],axis=0)
    window = window.loc[cells.index.values]
    window = pd.concat([cells[keep_cols],window], axis=1)
    windows[k] = window


In [None]:
print(type(windows))
print(windows.keys())
print(windows[10].columns)
print(len(windows[10].columns))
windows[10]

### Cluster cells into different cell community

In [None]:
k = 10
n_neighborhoods = 10
neighborhood_name = "neighborhood"+str(k)
k_centroids = {}

In [None]:
windows2 = windows[10]
# windows2[cluster_col] = cells[cluster_col]

km = MiniBatchKMeans(n_clusters = n_neighborhoods,random_state=0)

labelskm = km.fit_predict(windows2[sum_cols].values)
k_centroids[k] = km.cluster_centers_
cells['neighborhood10'] = labelskm
cells[neighborhood_name] = cells[neighborhood_name].astype('category')


### plot
#### clusters heatmap

In [None]:
# this plot shows the types of cells (ClusterIDs) in the different niches
k_to_plot = 10
niche_clusters = (k_centroids[k_to_plot])
tissue_avgs = values.mean(axis = 0)
fc = np.log2(((niche_clusters+tissue_avgs)/(niche_clusters+tissue_avgs).sum(axis = 1, keepdims = True))/tissue_avgs)
fc = pd.DataFrame(fc,columns = sum_cols)
s=sns.clustermap(fc, cmap = 'bwr',row_cluster = False)
plt.show()

#### The scatter and voronoi plot

In [None]:
fov=12
sub_cell = cells.loc[cells['fov']==fov]
sub_cell.loc[:,'sdimy'] = sub_cell['sdimy']*(-1) # to compare with hoodscanR results
sns.scatterplot(data = sub_cell,x = 'sdimx',y='sdimy',hue = 'custom_clust',palette='bright')
plt.legend( bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()

In [None]:
sns.scatterplot(data = sub_cell,x = 'sdimx',y='sdimy',hue = 'neighborhood10',palette='bright')
plt.legend( bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()


In [None]:
p = draw_voronoi_scatter(sub_cell,[],X = 'sdimx',Y='sdimy',voronoi_hue = 'neighborhood10',scatter_hue = 'neighborhood10',voronoi_kwargs={"invert_y":False})
plt.show()