In [3]:
import pyglstudy as gl
import numpy as np

In [49]:
def generate_group_lasso_data(
    n,
    p,
    n_groups,
    rho=0.1,
):
    Sigma = np.full((p, p), rho)
    np.fill_diagonal(Sigma, 1)
    X = np.random.normal(size=(n, p)) @ Sigma
    beta = np.random.normal(size=(p,))
    y = X @ beta + np.random.normal(size=(n,))

    A = (X.T @ X) / n
    r = (X.T @ y) / n

    order = np.arange(1, p) 
    groups = np.sort(np.random.choice(
        order, (n_groups-1,), replace=False,
    ))
    groups = np.concatenate([[0], groups, [p]], dtype=np.int32)
    group_sizes = groups[1:(n_groups+1)] - groups[:n_groups]
    groups = groups[:n_groups]
    
    return {
        "X": X,
        "beta": beta,
        "y": y,
        "A": A,
        "r": r,
        "groups": groups,
        "group_sizes": group_sizes,
    }

In [53]:
n = 100
p = 10
n_groups = 5
alpha = 1.0
penalty = np.ones(p, dtype=float)

X, beta, y, A, r, groups, group_sizes = generate_group_lasso_data(n, p, n_groups).values()
A = np.reshape(A, A.shape, order='F')
strong_set = np.arange(0, n_groups, dtype=np.int32)
strong_g1 = np.array([i for i in range(len(strong_set)) if group_sizes[strong_set[i]] == 1], dtype=np.int32)
strong_g2 = np.array([i for i in range(len(strong_set)) if group_sizes[strong_set[i]] > 1], dtype=np.int32)
assert((len(strong_g1) + len(strong_g2)) == n_groups)
strong_begins = np.cumsum(np.concatenate(
    [[0], np.array([group_sizes[i] for i in range(len(strong_set))], dtype=np.int32)],
), dtype=np.int32)[:-1]
strong_A_diag = np.concatenate(
    [
        np.diag(A)[groups[i] : (groups[i] + group_sizes[i])]
        for i in strong_set
    ]
)
lmdas = np.array([1, 0.5, 0.1, 0.01, 0.001])
max_cds = np.int32(1e5)
thr = 1e-7
newton_tol = 1e-8
newton_max_iters = np.int32(100)
rsq = 0.0
strong_beta = np.zeros((p,), dtype=float)
strong_grad = np.copy(r)
active_set = np.empty((0,), dtype=np.int32)
active_g1 = np.empty((0,), dtype=np.int32)
active_g2 = np.empty((0,), dtype=np.int32)
active_begins = np.empty((0,), dtype=np.int32)
active_order = np.empty((0,), dtype=np.int32)
is_active = np.zeros((p,), dtype=bool)

gl.group_lasso(
    A, 
    groups,
    group_sizes,
    alpha,
    penalty,
    strong_set,
    strong_g1,
    strong_g2,
    strong_begins,
    strong_A_diag,
    lmdas,
    max_cds,
    thr,
    newton_tol,
    newton_max_iters,
    rsq,
    strong_beta,
    strong_grad,
    active_set,
    active_g1,
    active_g2,
    active_begins,
    active_order,
    is_active,
)

TypeError: group_lasso(): incompatible function arguments. The following argument types are supported:
    1. (arg0: numpy.ndarray[numpy.float64[m, n], flags.writeable, flags.f_contiguous], arg1: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg2: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg3: float, arg4: numpy.ndarray[numpy.float64[m, 1], flags.writeable], arg5: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg6: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg7: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg8: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg9: numpy.ndarray[numpy.float64[m, 1], flags.writeable], arg10: numpy.ndarray[numpy.float64[m, 1], flags.writeable], arg11: int, arg12: float, arg13: float, arg14: int, arg15: float, arg16: numpy.ndarray[numpy.float64[m, 1], flags.writeable], arg17: numpy.ndarray[numpy.float64[m, 1], flags.writeable], arg18: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg19: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg20: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg21: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg22: numpy.ndarray[numpy.int32[m, 1], flags.writeable], arg23: numpy.ndarray[bool[m, 1], flags.writeable]) -> dict

Invoked with: array([[1.19167284, 0.36299895, 0.23316748, 0.30828545, 0.43217507,
        0.50207576, 0.31465945, 0.35067446, 0.23765489, 0.24866225],
       [0.36299895, 1.10586134, 0.43981417, 0.4187271 , 0.14655645,
        0.24375045, 0.2855293 , 0.19285799, 0.2632533 , 0.29834778],
       [0.23316748, 0.43981417, 1.14111544, 0.39230663, 0.12996303,
        0.3601077 , 0.30191575, 0.30587335, 0.24102518, 0.27156045],
       [0.30828545, 0.4187271 , 0.39230663, 1.09273656, 0.25340065,
        0.40031797, 0.41961632, 0.25316936, 0.3080634 , 0.29516623],
       [0.43217507, 0.14655645, 0.12996303, 0.25340065, 1.04263233,
        0.35591924, 0.28106039, 0.44377382, 0.28832453, 0.11609087],
       [0.50207576, 0.24375045, 0.3601077 , 0.40031797, 0.35591924,
        1.16569068, 0.39009442, 0.27879836, 0.32735786, 0.4006966 ],
       [0.31465945, 0.2855293 , 0.30191575, 0.41961632, 0.28106039,
        0.39009442, 1.17660284, 0.27155064, 0.3157164 , 0.24783469],
       [0.35067446, 0.19285799, 0.30587335, 0.25316936, 0.44377382,
        0.27879836, 0.27155064, 1.0275587 , 0.21735846, 0.07699091],
       [0.23765489, 0.2632533 , 0.24102518, 0.3080634 , 0.28832453,
        0.32735786, 0.3157164 , 0.21735846, 0.85001565, 0.1212861 ],
       [0.24866225, 0.29834778, 0.27156045, 0.29516623, 0.11609087,
        0.4006966 , 0.24783469, 0.07699091, 0.1212861 , 0.88568011]]), array([0, 3, 4, 7, 8], dtype=int32), array([3, 1, 3, 1, 2], dtype=int32), 1.0, array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), array([0, 1, 2, 3, 4], dtype=int32), array([1, 3], dtype=int32), array([0, 2, 4], dtype=int32), array([0, 3, 4, 7, 8], dtype=int32), array([1.19167284, 1.10586134, 1.14111544, 1.09273656, 1.04263233,
       1.16569068, 1.17660284, 1.0275587 , 0.85001565, 0.88568011]), array([1.   , 0.5  , 0.1  , 0.01 , 0.001]), 100000, 1e-07, 1e-08, 100, 0.0, array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([ 1.69190184, -0.11410927, -0.69422144, -0.54206909,  0.81187762,
        0.07250747,  0.97257487,  0.47304594, -0.02786722, -0.42553432]), array([], dtype=int32), array([], dtype=int32), array([], dtype=int32), array([], dtype=int32), array([], dtype=int32), array([False, False, False, False, False, False, False, False, False,
       False])