# Compare with `cr.sparse`

In [2]:
import numpy as np
import pandas as pd
import time
import jax.numpy as jnp
from skscope.solver import *
import cr.sparse.dict as crdict
from cr.sparse.pursuit import iht, omp, htp, cosamp
from abess.datasets import make_glm_data

In [23]:
def test_time(n=500, p=1000, s=5, random_state=None):
    print('='*20 + f'  n={n}, p={p}, s={s}  ' + '='*20 )
    rng = np.random.default_rng(random_state)
    true_support_set = rng.choice(np.arange(p), size=s, replace=False)
    real_coef = np.zeros(p)
    real_coef[true_support_set] = rng.choice(np.arange(1, 4), size=s) * rng.choice([1, -1], size=s)
    data = make_glm_data(n=n, p=p, k=s, family='gaussian', coef_=real_coef)
    X, y = data.x, data.y

    iterables = [['OMP', 'IHT', 'HTP', 'Grasp'], ['cr-sparse', 'skscope']]
    index = pd.MultiIndex.from_product(iterables, names=['Algorithm', 'Package'])
    res = pd.DataFrame(columns=['Accuracy', 'Time'], index = index)

    def objective(params):
        loss = jnp.mean((y - X @ params) ** 2)
        return loss

    for algo in iterables[0]:
        if algo == 'OMP':
            solver = OMPSolver(p, sparsity=s)
            model = omp
        elif algo == 'IHT':
            solver = IHTSolver(p, sparsity=s)
            model = iht
        elif algo == 'HTP':
            solver = HTPSolver(p, sparsity=s)
            model = htp
        elif algo == 'Grasp':
            solver = GraspSolver(p, sparsity=s)
            model = cosamp

        # cr-sparse
        t_begin = time.time()
        solution = model.matrix_solve(jnp.array(X), y, s)
        t_cr = time.time() - t_begin
        acc_cr = len(set(solution.I.tolist()) & set(true_support_set)) / s
        res.loc[(algo, 'cr-sparse')] = [acc_cr, np.round(t_cr, 4)]
        
        # skscope
        t_begin = time.time()
        params = solver.solve(objective, jit=True)
        t_skscope = time.time() - t_begin
        acc_skscope = len(set(np.nonzero(params)[0]) & set(np.nonzero(data.coef_)[0])) / s
        res.loc[(algo, 'skscope')] = [acc_skscope, np.round(t_skscope, 4)]

    display(res)

In [24]:
settings = [
    (500, 1000, 5),
    (2000, 5000, 10),
    (5000, 10000, 10),
]
for setting in settings:
    n, p, s = setting
    test_time(n=n, p=p, s=s)



Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,Time
Algorithm,Package,Unnamed: 2_level_1,Unnamed: 3_level_1
OMP,cr-sparse,0.2,0.0113
OMP,skscope,1.0,0.3887
IHT,cr-sparse,0.4,1.3391
IHT,skscope,1.0,0.3695
HTP,cr-sparse,0.8,1.1967
HTP,skscope,1.0,0.4277
Grasp,cr-sparse,1.0,1.4531
Grasp,skscope,1.0,0.3773




Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,Time
Algorithm,Package,Unnamed: 2_level_1,Unnamed: 3_level_1
OMP,cr-sparse,0.1,17.8637
OMP,skscope,1.0,1.9605
IHT,cr-sparse,0.6,3.517
IHT,skscope,1.0,1.5742
HTP,cr-sparse,0.7,53.1873
HTP,skscope,1.0,1.5623
Grasp,cr-sparse,1.0,58.4635
Grasp,skscope,1.0,1.6789




Unnamed: 0_level_0,Unnamed: 1_level_0,Accuracy,Time
Algorithm,Package,Unnamed: 2_level_1,Unnamed: 3_level_1
OMP,cr-sparse,0.1,9.7462
OMP,skscope,1.0,6.7015
IHT,cr-sparse,0.6,11.4237
IHT,skscope,1.0,6.1202
HTP,cr-sparse,0.7,346.9801
HTP,skscope,1.0,8.7026
Grasp,cr-sparse,1.0,635.7011
Grasp,skscope,1.0,8.6967
