In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scvelo as scv
import scanpy

In [None]:
datdir = "../data"

# Full gastrulation dataset

Load the data including the entire 116,312 cells, [downloaded through r](https://bioconductor.org/packages/release/data/experiment/vignettes/MouseGastrulationData/inst/doc/MouseGastrulationData.html) and converted to anndata.

In [None]:
adata = scv.read(f"{datdir}/gastrulation_atlas_full.h5ad")

# Remove cells with stage "mixed_gastrulation"
adata = adata[adata.obs['stage'] != 'mixed_gastrulation']
print(adata)

In [None]:
# Note that gene names are stored under var SYMBOL
adata.var['SYMBOL']

Prepare the data, adopting the approach taken by S. Farrell [here](https://github.com/Spencerfar/LatentVelo/blob/main/paper_notebooks/Gastrulation.ipynb).

In [None]:
# subsample data
np.random.seed(1)
# adata = adata[np.random.choice(adata.shape[0], size=30000, replace=False)]

# set up experimental time
adata.obs['exp_time'] = np.array([float(t[1:]) for t in adata.obs['stage']])
adata.obs['exp_time'] = adata.obs['exp_time']/adata.obs['exp_time'].max()

adata.obs['celltype_names'] = adata.obs['celltype'].copy().values
# scv.pp.filter_genes(adata, min_shared_counts=10)
# gc.collect()

# ltv.utils.anvi_clean_recipe(adata, batch_key='sequencing.batch', celltype_key='celltype', n_top_genes=None)

# gc.collect()

Take a look at the cell types present.

In [None]:
list(np.unique(adata.obs['celltype_names']))

We can select a subset of these cells that we wish to keep. For now, keep all of them.

In [None]:
# types_to_keep = [
#     'Epiblast', 
#     'Caudal epiblast', 
#     'Caudal neurectoderm', 
#     'Rostral neurectoderm', 
#     'Paraxial mesoderm',
#     'Caudal Mesoderm',
#     'Primitive Streak',
#     'NMP',
# ]

types_to_keep = np.unique(adata.obs['celltype_names'])

subset = adata[np.isin(adata.obs['celltype_names'], types_to_keep)]

We are interested in only a subset of the genes.

In [None]:
gene_list = [
    'T',  # Bra
    'Cdx2',
    'Sox1',
    'Sox2',
    'Sox3',
    'Tbx6',
    'Otx2'
]

present_list = []
print(subset.var['SYMBOL'].str.lower().values)
for gname in gene_list:
#     ispresent = gname in subset.var.index
    ispresent = gname.lower() in subset.var['SYMBOL'].str.lower().values
    print(f"Gene {gname} in subset?: {ispresent}")
    present_list.append(ispresent)

In [None]:
list(np.unique(subset.obs['celltype_names']))

### Plot all cells in the subset, coloring according to cell type.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ctypes = np.unique(subset.obs['celltype_names'])
for ctype in ctypes:
    bset = subset[subset.obs['celltype_names'] == ctype]
    cols = ['#'+s for s in bset.obs['colour']]
    ax.scatter(bset.obsm['umap']['x'], bset.obsm['umap']['y'], s=3, c=cols, label=ctype)
ax.set_xlabel('UMAP $x$')
ax.set_ylabel('UMAP $y$')

lgnd = ax.legend()  
for i in range(len(ctypes)):
    lgnd.legendHandles[i]._sizes = [30]

### Plot gene expression for each gene of interest

In [None]:
for i, gname in enumerate(gene_list):
    if not present_list[i]:
        print(f"Gene {gname} not present in subset. Skipping")
    else:
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        
        for ctype in ctypes:
            bset = subset[subset.obs['celltype_names'] == ctype]
            cols = ['#'+s for s in bset.obs['colour']]
            ax.scatter(bset.obsm['umap']['x'], bset.obsm['umap']['y'], s=3, c=cols, label=ctype, alpha=0.02)
        ax.set_xlabel('UMAP $x$')
        ax.set_ylabel('UMAP $y$')
        
        expr = subset[:,subset.var['SYMBOL']==gname].X.toarray()
        screen = (expr > 0).flatten()
        expr = expr[screen]
        
        ax.scatter(subset.obsm['umap']['x'][screen], subset.obsm['umap']['y'][screen], s=1, c=expr, label=gname, alpha=0.2)
        ax.set_xlabel('UMAP $x$')
        ax.set_ylabel('UMAP $y$')
        ax.set_title(gname)

### Some silly quantitative analysis of the genes of interest

In [None]:
for gname in gene_list:
    print(gname)
    tot_expr = np.sum(adata[:,adata.var['SYMBOL']==gname].X, 1)
    screen = tot_expr > 0
    print("\tTotal Expr:\n\t", tot_expr[screen])
    print("\tTotal Sum Expr:\n\t", np.sum(tot_expr[screen]))

#     print(np.sum(adata[:,gname].layers['spliced'], 1).max())
#     print(np.sum(adata[:,gname].layers['unspliced'], 1).max())

## Explore Sox1 expression

In [None]:
# Get all cells with nonzero counts of Sox1
sox1cells = adata[adata[:, adata.var['SYMBOL'] == 'Sox1'].X > 0,:]
print("Number of cells with nonzero Sox1 expression:", len(sox1cells))
sox1cells = sox1cells[~np.isin(sox1cells.obs['celltype'], 
                               ['ExE exctoderm', 'ExE endoderm', 'ExE mesoderm']),:]
print("Number of non-ExE cells with nonzero Sox1 expression:", len(sox1cells))

In [None]:
list(np.unique(sox1cells.obs['celltype']))

In [None]:
sox1cells.obs['celltype'].value_counts()