#### grnboost2 construction for all datasets 

In [None]:
import numpy as np
import pandas as pd

def get_logExpr3(expr_npz):
    data = np.load(expr_npz, allow_pickle=True)
    countExpr = data['count']  # count: row-cell, column-gene
    print("raw (cells, genes): ", countExpr.shape)

    row_sums = countExpr.sum(axis=1, keepdims=True)
    normalized_data = 1e6 * countExpr / row_sums
    normalized_data = normalized_data.astype(np.float32)

    logExpr0 = np.log1p(normalized_data)  
    logExpr1 = np.log1p(normalized_data + 1e-5) 
    return logExpr0, logExpr1

In [None]:
# get hvgs and split
import os
import numpy as np
from arboreto.algo import grnboost2

scRNA_datasets = ['Muraro', 'Baron_Mouse', 'Segerstolpe', 'Baron_Human', 'Zhang_T', 'Kang_ctrl', 'AMB', 'TM', 'Zheng68K']
pathjoin = os.path.join

for base_filename in scRNA_datasets:
    print(base_filename)
    seq_dict = np.load(f'dataset/5fold_data/{base_filename}/seq_dict.npz', allow_pickle=True) 
    gene_symbol = seq_dict['gene_symbol']
    print(gene_symbol)

    seq_folder = f"dataset/5fold_data/{base_filename}"
    all_filtered_genes_file = pathjoin(seq_folder, f'{base_filename}_filtered_hvgs2000.npy')
    all_filtered_genes_array = np.load(all_filtered_genes_file, allow_pickle=True)
    filtered_genes_index = all_filtered_genes_array[0].astype(int)

    expr_npz = f"../../data/pre_data/scRNAseq_datasets4/{base_filename}.npz"
    logExpr0, _ = get_logExpr3(expr_npz) 

    for k in range(5):
        k_fold = k + 1
        print("train k_fold: ", k_fold)
        train_index = seq_dict[f'train_index_{k_fold}'] 
        filtered_genes_index = all_filtered_genes_array[k]
        filtered_genes_index = filtered_genes_index.astype(int)
        logExpr0_train = logExpr0[np.ix_(train_index, filtered_genes_index)] 

        gene_names = gene_symbol[filtered_genes_index].tolist()
        # print(gene_names)
        network = grnboost2(expression_data=logExpr0_train,
                            gene_names=gene_names)
        
        network.head()   
        os.makedirs(f'dataset/5fold_data/{base_filename}/grnboost2', exist_ok=True)
        network.to_csv(f'dataset/5fold_data/{base_filename}/grnboost2/grnboost2_f{k_fold}.tsv', sep='\t', header=False, index=False)
        print(len(network))
        print(f"grnboost2_f{k_fold}.tsv is saved!")