In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
!pip install torchvision 
!pip install geomloss

In [None]:
import triku as tk
import scanpy as sc
import pandas as pd
import numpy as np
import scipy.sparse as spr
import scipy.stats as sts
import os
import gc
from itertools import product
import pickle
import ray
import itertools

from tqdm.notebook import tqdm

from bokeh.io import show, output_notebook, reset_output
from bokeh.plotting import figure
from bokeh.models import LinearColorMapper

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D

from sklearn.metrics import adjusted_rand_score as ARS
from sklearn.metrics import adjusted_mutual_info_score as NMI
from sklearn.metrics import silhouette_score, davies_bouldin_score

from geomloss import SamplesLoss
import torch

import time 

reset_output()
output_notebook()

In [None]:
from cellassign import assign_cats

In [None]:
mpl.rcParams['figure.dpi'] = 150

In [None]:
!python setup.py install

In [None]:
save_dir = "../exports/fancy_gif/"

os.makedirs(save_dir, exist_ok=True)

In [None]:
adata = sc.read_10x_h5('data/10x/5k_pbmc_protein_v3_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()

In [None]:
names_out = [False if ('RPS' in i) or ('RPL' in i) or ('MT-' in i) else True for i in adata.var_names]
np.array(names_out).sum()

adata = adata[:, names_out]

In [None]:
sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_cells(adata, min_counts=500)

In [None]:
sc.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata, n_neighbors=35)

# First analysis without running FS and to mark the populations

In [None]:
dict_cats = {
    'NK cell': ['NKG7', 'GNLY', 'GZMA', 'CST7', 'PRF1', 'KLRD1', 'FGFBP2', 'HOPX', 'KLRF1', 'SPON2'], 
    'T naive': ['CD3D', 'CD3E', 'TRAC', 'TCF7', 'RGS10', 'MAL', 'FHIT', 'LRRN3'], 
    'T CD4': ['CD4', 'AQP3', 'TRADD', 'PLP2', 'ITGB1'], 
    'T CD8': ['CD8A', 'LYAR', 'GZMK', 'CD8B', 'TRGC2'], 
    'T reg': ['CDC14A', 'TTN', 'PCSK7', 'ANKRD36C', 'GOLGA8A', 'DGKA', 'EML4', 'SYNE2'], 
    'Macrophage': ['S100A8', 'VCAN', 'S100A9', 'NCF2', 'CD14', 'MNDA', 'CSTA'], 
    'Plasm. DC': ['FCER1A', 'ALDH2', 'HLA-DMA', 'HLA-DQB1', 'CLEC10A', 'GSN'], 
    'APC': ['AIF1', 'PSAP', 'FCGR3A', 'SERPINA1', 'LYN', 'MS4A7'], 
    'Monocyte': ['VCAN', 'CSF3R', 'DMXL2', 'ITGAX', 'RASSF4', 'STAB1', 'IRAK3'], 
    'B cell': ['CD79A', 'RALGPS2', 'IGHM', 'IGKC', 'TNFRSF13C', 'VPREB3', 'SWAP70', 'IGHD', 'PDLIM1', 'CD22'], 
    'Plasma cell': ['ITM2C', 'CCDC50', 'JCHAIN', 'IL3RA', 'IGKC', 'TPM2', 'LILRA4'],     
}

In [None]:
sc.tl.leiden(adata, resolution=3)
sc.tl.umap(adata)

assign_cats(adata, dict_cats, column_groupby='leiden', key_added='cell_type')

In [None]:
sc.pl.umap(adata, color=['cell_type'])

# Second, run the FS algorithms and get the images

In [None]:
def dist_func(umap_orig, umap_new, dict_class_mask):
    Loss =  SamplesLoss("sinkhorn", blur=0.01,)
    loss_list = []
    for classs in dict_class_mask.keys():
        
        coords_orig = umap_orig[dict_class_mask[classs], :]
        coords_new = umap_new[dict_class_mask[classs], :]
        
        loss = Loss( torch.from_numpy(coords_orig), torch.from_numpy(coords_new) ).item()
        loss_list.append(loss)
    
    pond_mean = (np.array([np.sqrt(i.sum()) for i in dict_class_mask.values()]) 
                 * np.array(loss_list)).sum() / (np.array([np.sqrt(i.sum()) for i in dict_class_mask.values()])).sum() 
    
    return pond_mean

In [None]:
tk.tl.triku(adata)
sc.pp.highly_variable_genes(adata)

In [None]:
for N in tqdm(list(range(1, 40)) + list(range(40, 120, 2)) + list(range(120, 1000, 5))):
    dict_class_mask = {cat: (adata.obs['cell_type'] == cat).values for cat in adata.obs['cell_type'].cat.categories}  # This is used to calculate the distances later on
     
    fig, axs = plt.subplots(1, 1, figsize=(5, 4))
    
    # Plot triku
    ind_triku = np.argpartition(adata.var['triku_distance'].values, -N)[-N:]
    mask_bool_triku = np.array([False] * len(adata.var))
    mask_bool_triku[ind_triku] = True
    adata.var['highly_variable'] = mask_bool_triku
    adata.obsm['X_triku'] = adata.X[:, mask_bool_triku]
    
    
    
    for idx, use_rep in enumerate(['X_triku']):
        sc.pp.neighbors(adata, n_neighbors=17, use_rep=use_rep)    
        
        if N == 1:
            umap_coords = sc.tl.umap(adata, copy=True, random_state=seed).obsm['X_umap']
            umap_coords[:, 0] = (umap_coords[:, 0] - min(umap_coords[:, 0])) / (max(umap_coords[:, 0]) - min(umap_coords[:, 0]))
            umap_coords[:, 1] = (umap_coords[:, 1] - min(umap_coords[:, 1])) / (max(umap_coords[:, 1]) - min(umap_coords[:, 1]))
            adata.obsm[use_rep.replace('_', '_umap_')] = umap_coords
            
        else:
            list_umaps = []
            for seed in range(8):
                umap_coords = sc.tl.umap(adata, copy=True, random_state=seed).obsm['X_umap']
                umap_coords[:, 0] = (umap_coords[:, 0] - min(umap_coords[:, 0])) / (max(umap_coords[:, 0]) - min(umap_coords[:, 0]))
                umap_coords[:, 1] = (umap_coords[:, 1] - min(umap_coords[:, 1])) / (max(umap_coords[:, 1]) - min(umap_coords[:, 1]))
                list_umaps.append(umap_coords)

            list_dists = [dist_func(adata.obsm[use_rep.replace('_', '_umap_')], umap_new, dict_class_mask) for umap_new in list_umaps]
            best_umap = list_umaps[np.argmin(list_dists)]

            adata.obsm[use_rep.replace('_', '_umap_')] = best_umap
        
        
        sc.pl.embedding(adata, color=['cell_type'], ax=axs, show=False, basis=use_rep.replace('_', '_umap_'), title='', frameon=False)
        
        xl, yl = axs.get_xlim(), axs.get_ylim()
        xt, yt = (xl[0] + xl[1]) / 2, yl[0] - 0.1 * (yl[1] - yl[0])
        axs.text(xt, yt, use_rep.replace('X_', ''), ha='center')
        axs.text(xl[0], yl[1], N, ha='left', c='#bcbcbc')
    
    plt.savefig(f'figures/gif/triku_{str(N).zfill(4)}.png', bbox_inches='tight', dpi=175 )
    plt.show()

In [None]:
for N in tqdm(list(range(1, 40)) + list(range(40, 120, 2)) + list(range(120, 1000, 5))):
    fig, axs = plt.subplots(1, 3, figsize=(10, 3))
    plt.plot([0.375, 0.375], [0.1, 0.9], color='#bcbcbc', lw=1,transform=fig.transFigure, clip_on=False)
    plt.plot([0.65, 0.65], [0.1, 0.9], color='#bcbcbc', lw=1,transform=fig.transFigure, clip_on=False)
    
    print(N)
    
    # Plot triku
    ind_triku = np.argpartition(adata.var['triku_distance'].values, -N)[-N:]
    mask_bool_triku = np.array([False] * len(adata.var))
    mask_bool_triku[ind_triku] = True
    adata.var['highly_variable'] = mask_bool_triku
    adata.obsm['X_triku'] = adata.X[:, mask_bool_triku]
    
    ind_scanpy = np.argpartition(adata.var['dispersions_norm'].values, -N)[-N:]
    mask_bool_scanpy = np.array([False] * len(adata.var))
    mask_bool_scanpy[ind_scanpy] = True
    adata.var['highly_variable'] = mask_bool_scanpy
    adata.obsm['X_scanpy'] = adata.X[:, mask_bool_scanpy]
    
    ind_random = np.random.choice(range(len(adata.var)), size=N, replace=False)
    mask_bool_random = np.array([False] * len(adata.var))
    mask_bool_random[ind_random] = True
    adata.var['highly_variable'] = mask_bool_random
    adata.obsm['X_random'] = adata.X[:, mask_bool_random]

    for idx, use_rep in enumerate(['X_triku', 'X_scanpy', 'X_random']):
        sc.pp.neighbors(adata, n_neighbors=17, use_rep=use_rep)    
        
        if N == 1:
            umap_coords = sc.tl.umap(adata, copy=True, random_state=seed).obsm['X_umap']
            umap_coords[:, 0] = (umap_coords[:, 0] - min(umap_coords[:, 0])) / (max(umap_coords[:, 0]) - min(umap_coords[:, 0]))
            umap_coords[:, 1] = (umap_coords[:, 1] - min(umap_coords[:, 1])) / (max(umap_coords[:, 1]) - min(umap_coords[:, 1]))
            adata.obsm[use_rep.replace('_', '_umap_')] = umap_coords
            
        else:
            list_umaps = []
            for seed in range(8):
                umap_coords = sc.tl.umap(adata, copy=True, random_state=seed).obsm['X_umap']
                umap_coords[:, 0] = (umap_coords[:, 0] - min(umap_coords[:, 0])) / (max(umap_coords[:, 0]) - min(umap_coords[:, 0]))
                umap_coords[:, 1] = (umap_coords[:, 1] - min(umap_coords[:, 1])) / (max(umap_coords[:, 1]) - min(umap_coords[:, 1]))
                list_umaps.append(umap_coords)
            
        list_dists = [dist_func(adata.obsm[use_rep.replace('_', '_umap_')], umap_new, dict_class_mask) for umap_new in list_umaps]
        best_umap = list_umaps[np.argmin(list_dists)]
    
        adata.obsm[use_rep.replace('_', '_umap_')] = best_umap
        
        if idx == 2:
            sc.pl.embedding(adata, color=['cell_type'], ax=axs[idx], show=False, basis=use_rep.replace('_', '_umap_'), title='', frameon=False)
        else:
            sc.pl.embedding(adata, color=['cell_type'], ax=axs[idx], show=False, legend_loc=False, basis=use_rep.replace('_', '_umap_'), title='', frameon=False)

        xl = axs[idx].get_xlim()
        yl = axs[idx].get_ylim()
        xt = (xl[0] + xl[1]) / 2
        yt = yl[0] - 0.1 * (yl[1] - yl[0])

        axs[idx].text(xt, yt, use_rep.replace('X_', ''), ha='center')
        
        if idx == 0:
            axs[idx].text(xl[0], yl[1], N, ha='left', c='#bcbcbc')
    
    plt.savefig(f'figures/gif/{str(N).zfill(4)}.png', bbox_inches='tight', dpi=175)
    plt.show()