In [4]:
import torch
import scanpy as sc
import anndata as ad

from scDiffusion.utils.utility_fn import check_isolation
from scDiffusion.sc_graph.build_graph import build_adj_graph, build_diffusion_graph, build_graph, build_gnd_steps_graph
from scDiffusion.sc_graph.call_attention import call_attention, call_gnd_attention
from scDiffusion.sc_graph.call_modularity import call_modularity, call_gnd_modularity, view_modularity
from scDiffusion.sc_analysis.clustering import clustering, evaluate_clustering

from scDiffusion.grand.feature_encoder import encode_features
from scDiffusion.grand.graph_DIF import graph_diffusion


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
if device == 'cuda':
    print(torch.cuda.get_device_name())

cuda
Tesla P100-PCIE-12GB


In [3]:
DATA_PATH = 'data/Klein/'
DATA_TYPE = '10X'

OUTPUT_PATH = 'outputs/Klein/'

# Load dataset

In [None]:
adata = sc.read_h5ad(DATA_PATH+'klein.h5ad')

In [None]:
adata

# Preprocess

We use the normalized and log transformed data (NOT scaled data).

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

sc.pp.highly_variable_genes(adata, n_top_genes =2000, min_mean=0.0125, max_mean=9, min_disp=0.5)
sc.pl.highly_variable_genes(adata)


In [None]:
adata.raw = adata

adata = adata[:, adata.var.highly_variable]
#sc.pp.scale(adata, max_value=10)

In [None]:
adata

# Feature encoder

In [None]:
adata = encode_features(adata, D_encode_list=[2000, 300, 50], D_decode_list=[50, 300, 2000], 
                      max_epoch=2000, lr=1e-3, 
                    device=device, 
                    activation=torch.nn.ELU(), 
                 encode_last_activation=False, decode_last_activation=False)

adata.write(OUTPUT_PATH+'encoded_adata.h5ad')

In [None]:
adata = ad.read_h5ad(OUTPUT_PATH+'encoded_adata.h5ad')

### View umap of the encoded data

In [None]:
sc.pp.neighbors(adata, use_rep='X_fae', n_neighbors=10, n_pcs=50)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=['labels'])

### Check isolation nodes

In [None]:
adata = check_isolation(adata, use_rep='X_fae', predict_pct=0.05)

In [None]:
sc.pl.umap(adata, color=['isolation'])

The isolation inofrmation can be used to prune edges for single cell graph.

# Diffusion

In [None]:
adata = build_adj_graph(adata, use_rep=None, k=50, data_dtype = torch.float32, device=device)

This graph provides the adjacency that can be used in loss function.

## Build diffusion based graph

In [None]:
adata = build_diffusion_graph(adata, use_rep=None, k_min=0, k_max=10, self_edge = False, remov_edge_prob=None, 
                              prune=False, device=device)


This graph is used in the graph neural diffusion process.

If prune=True, edges will be pruned according to the isolation labels for graph-nodes.

In [None]:
adata = graph_diffusion(adata, use_rep='X_fae', max_epoch=2000, lr=1e-4, device=device,
                           num_features_diffusion=50,
                           num_heads_diffusion=6,
                           num_steps_diffusion=8, 
                           time_increment_diffusion=0.1,
                           attention_type = 'sum', 
                           activation=torch.nn.ELU(),
                           dropout=0.0,  
                           encoder=None, 
                           decoder=[300],
                           save_model = True,
                           log_diffusion=True,
                           load_model_state = False,
                           loss_adj=0.0,
                           use_adj='adj_edge_index',
                           loss_reduction = "sum",
                           rebuild_graph=False,
                           rebuild_graph_args={
                               'k_min': 0,
                               'k_max': 10,
                               'remov_edge_prob': None,
                           }
                           )

adata.write(OUTPUT_PATH+'diffused_adata.h5ad')



In [None]:
adata = ad.read_h5ad(OUTPUT_PATH+'diffused_adata.h5ad')


# View UMAP for diffused data

In [None]:
sc.pp.neighbors(adata, use_rep='X_dif', n_neighbors=10, n_pcs=50)
sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color=['labels'])

# Clustering

In [None]:
adata = build_graph(adata, use_rep="X_dif", k=10, self_edge = False, prune=False, 
                    data_dtype = torch.float32, device=device)

If prune=True, edges will be pruned according to the isolation labels for graph-nodes.

In [None]:
adata = call_attention(adata, attention_type=None, num_heads_diffusion=None, dropout=None, device=device)

In [None]:
adata = clustering(adata, resolution=0.05, initial_membership=None)

In [None]:
sc.pl.umap(adata, color=['att_leiden', 'labels',])

In [None]:
evaluate_clustering(adata.obs['att_leiden'], adata.obs['celltype.l1'])

# View modularity

In [None]:
modularity = call_modularity(adata, use_label='att_leiden', edge_weight=True)
modularity

### Modularity in each diffusion step

In [None]:
adata = build_gnd_steps_graph(adata, k=10, self_edge = False, prune=False, 
                              data_dtype = torch.float32, device=device)

In [None]:
adata = call_gnd_attention(adata, attention_type=None, num_heads_diffusion=None, dropout=None, device=device)

In [None]:
gnd_modularity = call_gnd_modularity(adata, use_label='att_leiden', edge_weight=True)
gnd_modularity

In [None]:
gnd_modularity = call_gnd_modularity(adata, use_label='att_leiden', edge_weight=False)
gnd_modularity

In [None]:
view_modularity(gnd_modularity, save_fig=None)