## Tutorial: Cell identity inference.
This tutorial demonstrates how to identify cell type on scRNA-seq data using Cell Graph-Net. Please refer to the homepage for software environment configuration and installation instructions.The tutorial employs a demonstration dataset derived from single-cell transcriptomic data of the human bone.

### Preparation

In [None]:
import cellgraph
import warnings 
import scanpy as sc 
warnings.filterwarnings ("ignore")

In [None]:
import torch
print(torch.__version__)
print(torch.cuda.get_device_capability(device=None),  torch.cuda.get_device_name(device=None))

In [None]:
ref_adata = sc.read_h5ad('../data/hBone/hBone_ref_adata.h5ad')

### Constructing the model

In [None]:
from cellgraph.data import interactions,hierarchy
import json
import pandas as pd
#hierarchy
n_layers = 3 # n layers of the model
reactome = hierarchy.hierarchy_layer(species='HSA') 
layers = reactome.get_layers(n_levels=n_layers)
ref_adata.uns['hierarchy'] = json.dumps(layers)
#ppi
human_ppi = pd.read_csv('../data/ppi/human_string_higconf.csv')
ref_adata = interactions.data_mapping_ppi(ref_adata,human_ppi,top_genes=3000)

In [None]:
ref_adata

## Train

In [None]:
dataset = "../data/hBone/hBone_ref_adata.h5ad"
device_id = 1
log_dir = f"../log/hBone"
model_args = {"add_one_hot":1, 
              "skip_raw": 1, 
              "pool":"mean", 
              "nhid":64, 
              "lr": 0.001, 
              "bootstrap_num": -1, 
              "encoder": "gat", 
              "heads": 4
              }

In [None]:
cellgraph.train(dataset = dataset, device_id = device_id, log_dir = log_dir, **model_args)

## Test

In [None]:
device_id = 1
log_dir = f"../log/hBone"
dataset = "../data/hBone/hBone_query_adata.h5ad"
fn_process = "processed-test"

In [None]:
cellgraph.test(dataset = dataset, device_id = device_id, log_dir = log_dir, fn_process = fn_process)

# Embed

In [None]:
device_id = 1
log_dir = f"../log/hBone"
dataset = "../data/hBone/hBone_query_adata.h5ad"
fn_process = "processed-test"

In [None]:
cellgraph.embed(dataset = dataset, device_id = device_id ,log_dir = log_dir, out_embed = "output", fn_process = fn_process)

# Predict

In [None]:
device_id = 1
log_dir = f"../log/hBone"
dataset = "../data/hBone/hBone_query_adata.h5ad"
fn_process = "processed-test"
predict_type = 'cell'

In [None]:
cells = cellgraph.predict(dataset = dataset, device_id = device_id ,log_dir = log_dir, fn_process = fn_process, predict_type = predict_type)

# Explain Feature

In [None]:
device_id = 1
log_dir = f"../log/hBone"
dataset = "../data/hBone/hBone_query_adata.h5ad"
fn_process = "processed-test"

In [None]:
cellgraph.explain_feature(dataset = dataset, device_id = device_id ,log_dir = log_dir, explain_method = "grad", fn_process = fn_process)
cellgraph.explain_feature(dataset = dataset, device_id = device_id ,log_dir = log_dir, explain_method = "grad_cam", fn_process = fn_process)
cellgraph.explain_feature(dataset = dataset, device_id = device_id ,log_dir = log_dir, explain_method = "attention", return_sample = 0, prod_value = 0, fn_process = fn_process)

# Explain PPI

In [None]:
device_id = 1
log_dir = f"../log/hBone"
dataset = "../data/hBone/hBone_query_adata.h5ad"
fn_process = "processed-test"
exp_dict ={
    "correlation": 0,
    "multi_atten": 1,
    "train_sample_gt": 0,
    "ce_loss_gt": 0,
    "exp_train_epochs": 100,
    "exp_lr": 0.01,
}

In [None]:
cellgraph.explain_ppi(dataset = dataset, device_id = device_id ,log_dir = log_dir, fn_process = fn_process, **exp_dict)