In [None]:
import numpy as np
import pandas as pd
import geneselection.solvers.enet_pca as epca
from geneselection.datasets.correlated_random_variables import hub_spoke_data

In [None]:
params = dict(lambda_path = np.geomspace(1, 0.01, num=100),    # lambda path
              alpha = 0.9,                                     # fraction of regularization devoted to L1 prenalty
              n_pcs = 2,                                       # number of pcs to predit with multitask elastic net
              pc_weights = "scaled",                           # relative importance in predicting pcs (scaled = all selected pcs are equally important)
              n_bootstraps = 100,                              # number of bootstrap replicates
              n_processes = 25,                                # number of parallel processes to use
              thresholds = np.linspace(0.01, 1, num=100))      # selection thresholds for including genes

In [None]:
adata = hub_spoke_data(n_samples=20000,
                       n_groups=50,
                       group_size=20,
                       n_singeltons=5000,
                       diagonal_weight=1/np.e,
                       off_diagonal_weight=1)
adata.X = adata.X.astype(np.float64)
adata.var.index = adata.var.index.astype(str)

In [None]:
boot_results = epca.parallel_runs(adata,
                                  n_processes=params["n_processes"],
                                  n_bootstraps=params["n_bootstraps"],
                                  n_pcs=params["n_pcs"],
                                  alpha=params["alpha"],
                                  lambda_path=params["lambda_path"],
                                  pc_weights=params["pc_weights"])

In [None]:
epca.thresh_lambda_plot(boot_results,
                        adata,
                        thresholds=params["thresholds"],
                        lambdas=params["lambda_path"])

In [None]:
epca.hub_persistence_plot(adata, boot_results)

In [None]:
residual_variance_genes = epca.get_selected_genes(boot_results,
                                                  adata,
                                                  lambda_index=60,
                                                  selection_threshold_index=60,
                                                  thresholds=params["thresholds"])

In [None]:
adata.var.loc[residual_variance_genes]