In [None]:
import pandas as pd
import numpy as np
import random
import time
import scipy.sparse
from arboreto.algo import grnboost2
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

tfs_path = "/home/amorin/Data/Metadata/TFs_human.tsv"
mat_dense_path = "/space/scratch/amorin/R_objects/GSE180928_mcg_filt.tsv"
mat_sparse_path = "/space/scratch/amorin/R_objects/GSE180928_mcg_filt.mtx"

cluster = SLURMCluster(cores=8, 
                       processes=1,
                       memory="16GB",
                       account="amorin",
                       walltime="1:00:00",
                       queue="normal")

custom_client = Client(cluster)

In [None]:
tfs = pd.read_table(tfs_path)["Symbol"].tolist()
mat_dense = pd.read_table(mat_dense_path, index_col = 0)
mat_sparse = scipy.io.mmread(mat_sparse_path).tocsc()  # CSC needed for arboreto

In [None]:
# Random subset for speed
random.seed(5)
samp_ix = random.sample(range(mat_dense.shape[1]), 1000)
mat_dense_sub = mat_dense.iloc[:, samp_ix]
mat_sparse_sub = mat_sparse[:, samp_ix]
tfs_sub = set(tfs).intersection(mat_dense_sub.columns)
genes_sub = mat_dense_sub.columns.tolist()

In [None]:
start = time.time()

network_dense = grnboost2(expression_data=mat_dense_sub, 
                          tf_names=tfs_sub,
                          seed=4,
                          client_or_address=custom_client)

end = time.time()

print(end - start)

In [None]:
start = time.time()

network_sparse = grnboost2(expression_data=mat_sparse_sub, 
                           tf_names=tfs_sub,
                           gene_names=genes_sub,
                           seed=4,
                           client_or_address=custom_client)

end = time.time()

print(end - start)

In [None]:
custom_client.close()
cluster.close()