# 3D Plot

In [None]:
import anndata as ad
import scanpy as sc
import scvi
import seaborn as sns
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
%matplotlib inline

In [None]:
adata = sc.read_h5ad("cds_baseline.h5ad")
adata

In [None]:
len(adata.X)

In [None]:
sc.pp.filter_genes(adata, min_cells = 10) #only keep genes that are in at least 10 of the cells


In [None]:
# closer to 1 is more clusters
# closer to 0 is fewer. clusters
sc.tl.leiden(adata, resolution = 0.5)
sc.tl.umap(adata, n_components=3)


In [None]:
leiden_series = adata.obs.leiden.unique()
leiden_list = list(leiden_series)

colors = {}
cm = plt.get_cmap('gist_rainbow')
for index, group in enumerate(leiden_list):
    color= cm(1.*index/len(leiden_list))
    colors[group]=color

In [None]:
leiden_colors = adata.obs['leiden']
leiden_colors = leiden_colors.reset_index(drop=True)
leiden_colors.astype('str').map(colors

In [None]:
for i in range(0, 360, 2):
    fig = plt.figure(figsize = (15,15))
    ax = fig.add_subplot(projection = '3d')
    ax.scatter(umap[:,0], umap[:,1], umap[:,2], 
               c = leiden_colors.astype('str').map(colors))

    x_center = (umap[:,0].max() + umap[:,0].min())/2
    y_center = (umap[:,1].max() + umap[:,1].min())/2
    z_center = (umap[:,2].max() + umap[:,2].min())/2

    extend=10
    ax.plot([x_center,x_center],[y_center,y_center],[umap[:,2].min()-extend,umap[:,2].max()+extend],c='k',lw=5)
    ax.plot([x_center,x_center],[umap[:,1].min()-extend,umap[:,1].max()+extend],[z_center,z_center],c='k',lw=5)
    ax.plot([umap[:,1].min()-extend,umap[:,1].max()+extend],[y_center,y_center],[z_center,z_center],c='k',lw=5)

    ax.view_init(20,i)
    ax.axis('off')
    plt.savefig(f'figures/{i:003}.png', dpi=100, facecolor='white')
   

In [None]:
!convert -delay 5 figures/*.png umap.gif