# 1. Import the necessary packages

In [None]:
import ComicGTN
import os
import re
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from ComicGTN.utils import *
from ComicGTN.settings import *
from ComicGTN.GCNConv import *
from ComicGTN.FastGTNConv import *
from ComicGTN.ComicGTN_model import *

# 2. Read in data

In [None]:
workdir = "./ComicGTN/data/BMMC-bench-1"
RNA_seq = sc.read(os.path.join(workdir, "Gene_Cell.mtx"))
ATAC_seq = sc.read(os.path.join(workdir, "Peak_Cell.mtx"))
Cell_names = pd.read_csv(os.path.join(workdir, "Cell_names.tsv"), sep = "\t", header = None)
Cell_types = pd.read_csv(os.path.join(workdir, "Cell_types.tsv"), sep = "\t", header = None)
Gene_names = pd.read_csv(os.path.join(workdir, "Gene_names.tsv"), sep = "\t", header = None)
Peak_names = pd.read_csv(os.path.join(workdir, "Peak_names.tsv"), sep = "\t", header = None)
RNA_seq.obs_names = Gene_names[0]
RNA_seq.var_names = Cell_names[0].astype(str)
ATAC_seq.obs_names = Peak_names[0]
ATAC_seq.var_names = Cell_names[0].astype(str)
RNA_count = RNA_seq.X
ATAC_count = ATAC_seq.X


Kmer_adata = ad.io.read_hdf(os.path.join(workdir, "output_kmers_motifs/freq_kmer.h5"), "mat")        
Motif_adata = ad.io.read_hdf(os.path.join(workdir, "output_kmers_motifs/freq_motif.h5"), "mat")      
Kmer_adata.obs_names = Peak_names.iloc[0:len(Kmer_adata.obs.index), 0]
Kmer_adata.var_names = [x.decode("utf-8") for x in Kmer_adata.var.index]
Motif_adata.obs_names = Peak_names.iloc[0:len(Motif_adata.obs.index), 0]
Motif_adata.var_names = [x.decode("utf-8") for x in Motif_adata.var.index]
Kmer_count = Kmer_adata.X 
Motif_count = Motif_adata.X


ATAC_count_re, Peak_names_re = Remove_Scaffold(ATAC_count, Peak_names)

# 3. Initialize parameters

In [None]:
import argparse


parser = argparse.ArgumentParser(description = "Training GTN on heterogeneous graph.")
parser.add_argument("--num_FastGTN_layers", type = int, default = 1)
parser.add_argument("--num_channels", type = int, default = 8)
parser.add_argument("--num_layers", type = int, default = 4)
parser.add_argument("--node_dim", type = int, default = 128)
parser.add_argument("--non_local", type = bool, default = False)
parser.add_argument("--non_local_weight", type = int, default = 0)
parser.add_argument("--K", type = int, default = 1)
parser.add_argument("--beta", type = float, default = 0)
parser.add_argument("--channel_agg", type = str, default = "mean")
parser.add_argument("--remove_self_loops", type = bool, default = True)
parser.add_argument("--smoothing", type = float, default = 0.1)
parser.add_argument("--rare_weight", type = list, default = [2.0, 3.0])
parser.add_argument("--temperature", type = float, default = 0.1)
parser.add_argument("--hard_neg_k", type = int, default = 3)
parser.add_argument("--rand_neg_ratio", type = float, default = 0.3)
parser.add_argument("--lr", type = float, default = 0.0005)
parser.add_argument("--weight_decay", type = float, default = 0.005)


args = parser.parse_args([])
num_FastGTN_layers = args.num_FastGTN_layers
num_channels = args.num_channels
num_layers = args.num_layers
node_dim = args.node_dim
non_local = args.non_local
non_local_weight = args.non_local_weight
K = args.K
beta = args.beta
channel_agg = args.channel_agg
remove_self_loops = args.remove_self_loops
smoothing = args.smoothing
rare_weight = args.rare_weight
temperature = args.temperature
hard_neg_k = args.hard_neg_k
rand_neg_ratio = args.rand_neg_ratio
lr = args.lr
weight_decay = args.weight_decay

# 4. Run the main program to detect rare cell populations

In [None]:
if __name__ == "__main__":
    device = f"cuda" if torch.cuda.is_available() else "cpu"
    print("You will use : ",device)
    
    
    initial_pre = Initial_Clustering(RNA_count) 
    cluster_ini_num = len(set(initial_pre)) 
    ini_clu = [int(i) for i in initial_pre]
    total_node_idx, cell_node_idx, dic_cell, dic_peak = Subgraph_Extraction(RNA_count, ATAC_count_re, Kmer_count, 
                                                                                                                       Motif_count, neighbor_node_num = [20, 20, 5, 2], 
                                                                                                                       cell_node_num = 30)
    
    rare_labels = Calculate_Frequency(ini_clu)
    num_edge_type = 5
    node_model = NodeFeatureEmbedding(RNA_count, ATAC_count, Kmer_count, Motif_count, total_node_idx, ini_clu,
                                                                      rare_labels, args, device, num_edge_type, epochs = 1)
    GTN, cell_emb, gene_emb, peak_emb, kmer_emb, motif_emb = node_model.train_process(batch_num = len(total_node_idx))
    
    
    Comic_model = Comic(GTN = GTN,  batch_num = len(total_node_idx), rare_labels = rare_labels, 
                                          args = args, device = device, epochs = 1)
    ComicGTN = Comic_model.train_process(total_node_idx, RNA_count, ATAC_count, Kmer_count, Motif_count, ini_clu)
    
    
    cell_node_num = 30
    Comic_result = ComicGTN_test(RNA_count, ATAC_count, Kmer_count, Motif_count, total_node_idx, cell_node_idx, 
                                                       cell_node_num, ComicGTN, device, co_emb = True)

# 5. Calculate evaluation metrics

In [None]:
y_true = np.array(Cell_types[0])[cell_node_idx]
y_pred = Comic_result["predicted_cluster_label"]
rare_types = ["B1 B", "CD16+ Mono", "Proerythroblast", "CD8+ T naive", "cDC2", "pDC", "Transitional B", "MK/E prog", "Lymph prog",
                      "G/M prog", "ID2-hi myeloid prog", "HSC", "Plasma cell"]
y_true_binary = [1 if label in rare_types else 0 for label in y_true]
y_pred_binary = [1 if pred in rare_items else 0 for pred in y_pred]


F1, G_mean, MCC, Kappa = Calculate_Metrics(y_true_binary, y_pred_binary)
print("F1 score is " + "{:.4f}".format(F1))
print("G mean is " + "{:.4f}".format(G_mean))
print("MCC is " + "{:.4f}".format(MCC))
print("Kappa is " + "{:.4f}".format(Kappa))

# 6. Save inferred results

In [None]:
outputdir = "./ComicGTN/tutorial/output"


rare_items = Get_Rare_Items(Counter(y_pred), 0.03)
mapping = {num: f"R{i + 1}" for i, num in enumerate(rare_items)}
custom_labels = [mapping.get(item, "Abundant") for item in y_pred]
sorted_cell_node_idx, sorted_custom_labels = zip(*sorted(zip(cell_node_idx, custom_labels)))
sorted_custom_labels = pd.DataFrame(list(sorted_custom_labels))
sorted_custom_labels.to_csv(os.path.join(outputdir, "ComicGTN_pred.tsv"), sep = "\t", header = False, index = False)

# 7. UMAP visualization of raw data

In [None]:
import matplotlib.pyplot as plt


sc.set_figure_params(scanpy = False)
plt.rcParams["font.family"] = "Arial"
workdir = "/home/jsl/YBR/Benchmarking/BMMC/BMMC-bench-1"
RNA_seq = sc.read(os.path.join(workdir, "Gene_Cell.mtx"))
Cell_names = pd.read_csv(os.path.join(workdir, "Cell_names.tsv"), sep = "\t", header = None)
Cell_types = pd.read_csv(os.path.join(workdir, "Cell_types.tsv"), sep = "\t", header = None)
Gene_names = pd.read_csv(os.path.join(workdir, "Gene_names.tsv"), sep = "\t", header = None)
RNA_count = RNA_seq.X


adata =  ad.AnnData(RNA_count.transpose(), dtype = "int32")
adata.obs_names = Cell_names[0]
adata.var_names = Gene_names[0]
adata.var_names_make_unique()
sc.pp.normalize_total(adata, target_sum = 1e4)
sc.pp.log1p(adata)
adata.obs["cell_types"] = Cell_types[0].tolist()
adata.obs["cell_types"] = adata.obs["cell_types"].astype("category")
sc.pp.neighbors(adata)
sc.tl.umap(adata)


fig, ax = plt.subplots(figsize = (10, 8))
sc.pl.umap(adata, size = 20, color = ["cell_types"], title = "", frameon = False, ax = ax, show = False)
handles, labels = ax.get_legend_handles_labels()
desired_order = ["B1 B", "CD14+ Mono", "CD16+ Mono", "CD4+ T activated", "CD4+ T naive", "CD8+ T", "cDC2",
                            "Erythroblast", "G/M prog", "HSC", "ILC", "Lymph prog", "MK/E prog", "NK", "Naive CD20+ B",
                            "Normoblast", "pDC", "Plasma cell", "Proerythroblast", "Transitional B"]
ordered_handles = []
ordered_labels = []


for label in desired_order:
    if label in labels:
        idx = labels.index(label)
        ordered_handles.append(handles[idx])
        ordered_labels.append(labels[idx])

        
ax.get_legend().remove()
ax.legend(ordered_handles, ordered_labels, loc = "center left", bbox_to_anchor = (0.95, 0.5), ncol = 2, fontsize = 18, frameon = False)

# 8. UMAP visualization of ComicGTN inferences

In [None]:
ComicGTN_pred = pd.read_csv(os.path.join(workdir, "ComicGTN_pred.tsv"), sep = "\t", header = None)
adata = ad.AnnData(RNA_count.transpose(), dtype = "int32")
adata.obs_names = Cell_names[0]
adata.var_names = Gene_names[0]
adata.var_names_make_unique()
sc.pp.normalize_total(adata, target_sum = 1e4)
sc.pp.log1p(adata)
adata.obs["labels"] = ComicGTN_pred[0].tolist()
adata.obs["labels"] = adata.obs["labels"].astype("category")
sc.pp.neighbors(adata)
sc.tl.umap(adata)


custom_palette = {}
default_colors = sc.pl.palettes.default_20
custom_palette["Abundant"] = "#CCCCCC"

for i in range(10):
    custom_palette[f"R{i + 1}"] = default_colors[i]
    
fig, ax = plt.subplots(figsize = (10, 8))
sc.pl.umap(adata, size = 20, color = ["labels"], palette = custom_palette, title = "", frameon = False, ax = ax, show = False)
handles, labels = ax.get_legend_handles_labels()
desired_order = ["Abundant", "R1", "R2", "R3", "R4", "R5", "R6", "R7", "R8", "R9", "R10"]
ordered_handles = []
ordered_labels = []


for label in desired_order:
    if label in labels:
        idx = labels.index(label)
        ordered_handles.append(handles[idx])
        ordered_labels.append(labels[idx])

        
ax.get_legend().remove()
ax.legend(ordered_handles, ordered_labels, loc = "center left", bbox_to_anchor = (0.95, 0.5), ncol = 1, fontsize = 18, frameon = False)