In [1]:
import h5py

import numpy as np
import matplotlib.pyplot as plt
import time
import sys, os, pdb

sys.path.append('..')
from gtv import GraphTotalVariance

In [18]:
from utils import selection_accuracy

In [2]:
# Generate data to test the GTV. Use a block correlation structure and a small number of features
from utils import gen_covariance, gen_beta, gen_data
from sklearn.linear_model import Lasso

In [3]:
cov = gen_covariance('block', n_features = 30, block_size = 5, correlation=0.5)
beta = gen_beta(n_features = 30, block_size = 5, sparsity = 0.5)
X, X_test, y, y_test = gen_data(n_samples = 100, n_features = 30, kappa = 0.3, covariance = cov, beta = beta)

In [4]:
from sklearn.metrics import r2_score

In [34]:
def fit_gtv(X, X_test, y, y_test, beta, lambda_S, lambda_TV, lambda_1, method):
    r2_scores = np.zeros((lambda_S.size, lambda_TV.size, lambda_1.size))
    sa = np.zeros(r2_scores.shape)
    for i1, l1 in enumerate(lambda_S):
        for i2, l2 in enumerate(lambda_TV):
            for i3, l3 in enumerate(lambda_1):
                gtv = GraphTotalVariance(l1, l2, l3, minimizer = method, use_skeleton = True)
                gtv.fit(X, y, cov)
                gtv.coef_[gtv.coef_ < 1e-6] = 0
                r2_scores[i1, i2, i3] = r2_score(y_test, X_test @ gtv.coef_)
                sa[i1, i2, i3] = selection_accuracy(beta.T, gtv.coef_[np.newaxis, :])[0]
    return r2_scores, sa

In [None]:
# How do things scale as we increase number of features?
n_features = [30, 50, 100, 200, 500]
block_size = [5, 10, 20, 40, 100, 200]
lambda_1 = np.linspace(0, 1, 10)
lambda_TV = np.linspace(0, 1, 10)
lambda_S = np.linspace(0, 1, 10)

for i, nf in enumerate(n_features):
    cov = gen_covariance('falloff', n_features = nf, L = nf/10)
    beta = gen_beta(n_features = nf, block_size = nf, sparsity = 0.5)
    X, X_test, y, y_test = gen_data(n_samples = 3 * nf, n_features = nf, kappa = 0.3, covariance = cov, beta = beta)

    %time r2_scores2, sa2 = fit_gtv(X, X_test, y, y_test, beta, lambda_S, lambda_TV, lambda_1, 'lbfgs')    

Wall time: 14.9 s
Wall time: 35.6 s
Wall time: 2min 18s
Wall time: 12min 31s


In [8]:
r2_scores

array([[[0.6119165 , 0.63560495, 0.65054672, 0.66132972, 0.66852789,
         0.6731042 , 0.67579184, 0.67638185, 0.67582676, 0.67533533],
        [0.61188252, 0.63610683, 0.65097509, 0.66164501, 0.66878169,
         0.67329823, 0.67582268, 0.67627267, 0.67568932, 0.67526051],
        [0.61180144, 0.63641011, 0.65197869, 0.6624272 , 0.66905829,
         0.67334169, 0.6756171 , 0.67571427, 0.67557446, 0.67513094],
        [0.61159988, 0.63749551, 0.65162684, 0.66206784, 0.66838966,
         0.67244544, 0.67471709, 0.67461856, 0.67397323, 0.67278514],
        [0.61171414, 0.63747747, 0.65184426, 0.66067427, 0.6672928 ,
         0.67145584, 0.6719144 , 0.67056843, 0.66873794, 0.66765159],
        [0.61182052, 0.63752578, 0.65139901, 0.65939157, 0.66557156,
         0.66709342, 0.6660093 , 0.6633515 , 0.66178439, 0.6608888 ],
        [0.61142748, 0.63914468, 0.65038764, 0.6583246 , 0.66093359,
         0.66119406, 0.65823604, 0.65673969, 0.65603476, 0.6560384 ],
        [0.61110023, 0.6405

In [11]:
r2_scores2

array([[[0.6488013 , 0.65694438, 0.660529  , 0.66250234, 0.66193242,
         0.66220785, 0.66160868, 0.66013287, 0.65791625, 0.65537217],
        [0.64889028, 0.6568326 , 0.66069393, 0.66271391, 0.6623629 ,
         0.66246059, 0.66204015, 0.66108369, 0.65856351, 0.65586059],
        [0.6489324 , 0.65658266, 0.66164847, 0.66320914, 0.66270046,
         0.66152098, 0.66072701, 0.66072251, 0.65829796, 0.65469819],
        [0.64895177, 0.65753195, 0.66205149, 0.6626993 , 0.65895811,
         0.65554254, 0.6551628 , 0.65588556, 0.654733  , 0.65074999],
        [0.64843606, 0.65581445, 0.66010088, 0.65855957, 0.65190859,
         0.64853395, 0.64912423, 0.64937803, 0.64773722, 0.64377792],
        [0.6487171 , 0.6530236 , 0.65594273, 0.65262637, 0.64361866,
         0.64120805, 0.64175033, 0.64078753, 0.63681539, 0.63307935],
        [0.64861756, 0.65002543, 0.65040568, 0.64459313, 0.63509882,
         0.63284286, 0.63193778, 0.62824952, 0.62327765, 0.61967781],
        [0.6484713 , 0.6478

In [15]:
r2_scores3

array([[[0.53229873, 0.58492935, 0.62162581, 0.64376762, 0.65737806,
         0.66342733, 0.66647425, 0.66959034, 0.67161827, 0.67248226],
        [0.53234366, 0.58920941, 0.62451108, 0.64612285, 0.65953942,
         0.66614727, 0.66835559, 0.67113324, 0.67283415, 0.6734984 ],
        [0.53245358, 0.59924051, 0.63158177, 0.64968946, 0.66282007,
         0.66602269, 0.66800156, 0.66928503, 0.6696824 , 0.66750066],
        [0.53191845, 0.60563083, 0.63441827, 0.65050771, 0.65642299,
         0.65608301, 0.65282235, 0.65051631, 0.64738612, 0.64510336],
        [0.53232085, 0.61258457, 0.63085241, 0.64115902, 0.64113715,
         0.63721104, 0.62904734, 0.62510729, 0.62369547, 0.62190078],
        [0.53188468, 0.62112945, 0.62671876, 0.63020666, 0.62529982,
         0.61232886, 0.59904832, 0.59842622, 0.59627534, 0.59375427],
        [0.53285136, 0.62701601, 0.61633654, 0.61506641, 0.59983183,
         0.57644347, 0.57305442, 0.57175471, 0.57041456, 0.56719749],
        [0.5360483 , 0.6238

In [16]:
np.amax(r2_scores2)

0.6632091445616742

In [17]:
np.amax(r2_scores3)

0.6734983992447747

In [29]:
np.amax(r2_scores3)

0.6761218399909579

In [36]:
np.argmax(r2_scores2)

142

In [37]:
sa2.ravel()[142]

0.7833333333333333