In [1]:
import h5py

import numpy as np
import matplotlib.pyplot as plt
import time
import sys, os, pdb

sys.path.append('..')
from gtv import GraphTotalVariance

In [2]:
# Generate data to test the GTV. Use a block correlation structure and a small number of features
from utils import gen_covariance, gen_beta, gen_data
from sklearn.linear_model import Lasso

In [3]:
cov = gen_covariance('block', n_features = 30, block_size = 5, correlation=0.5)
beta = gen_beta(n_features = 30, block_size = 5, sparsity = 0.5)
X, X_test, y, y_test = gen_data(n_samples = 100, n_features = 30, kappa = 0.3, covariance = cov, beta = beta)

In [4]:
from sklearn.metrics import r2_score

In [5]:
def fit_gtv(lambda_S, lambda_TV, lambda_1, method):
    r2_scores = np.zeros((lambda_S.size, lambda_TV.size, lambda_1.size))
    for i1, l1 in enumerate(lambda_S):
        for i2, l2 in enumerate(lambda_TV):
            for i3, l3 in enumerate(lambda_1):
                gtv = GraphTotalVariance(l1, l2, l3, minimizer = method)
                gtv.fit(X, y, cov)
                gtv.coef_[gtv.coef_ < 1e-6] = 0
                r2_scores[i1, i2, i3] = r2_score(y_test, X_test @ gtv.coef_)
    return r2_scores

In [6]:
# Test on a 3D grid of tuning parameters
lambda_S = np.linspace(0, 1, num = 4)
lambda_TV = np.linspace(0, 1, num = 4)
lambda_1 = np.linspace(0, 1, num = 4)

%time r2_scores1 = fit_gtv(lambda_S, lambda_TV, lambda_1, 'quadprog')
%time r2_scores2 = fit_gtv(lambda_S, lambda_TV, lambda_1, 'lbfgs')

CPU times: user 55.1 s, sys: 35.4 ms, total: 55.2 s
Wall time: 9.67 s
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
torch.float64
CPU times: user 1.31 s, sys: 12 ms

In [9]:
r2_scores1

array([[[0.79381533, 0.79370027, 0.78157121, 0.76056549],
        [0.79381541, 0.7761026 , 0.70928522, 0.64732196],
        [0.79381565, 0.72562416, 0.63857565, 0.59785815],
        [0.79381605, 0.68615869, 0.61827477, 0.59785711]],

       [[0.70728587, 0.69439712, 0.67772728, 0.65730298],
        [0.70728586, 0.68034856, 0.64829772, 0.61270049],
        [0.70728585, 0.6653569 , 0.62367663, 0.59786102],
        [0.70728581, 0.64956973, 0.61827461, 0.59785711]],

       [[0.68395234, 0.67080189, 0.65399766, 0.63353965],
        [0.68395233, 0.66206177, 0.63607124, 0.60659534],
        [0.68395233, 0.65308378, 0.62138362, 0.59786249],
        [0.68395231, 0.64363264, 0.61827536, 0.59785711]],

       [[0.67391034, 0.66065151, 0.64380089, 0.62335847],
        [0.67391034, 0.65432825, 0.63091883, 0.60404876],
        [0.67391034, 0.647918  , 0.62045763, 0.59785957],
        [0.67391033, 0.64118278, 0.61827467, 0.59785711]]])

In [10]:
r2_scores2

array([[[0.79384212, 0.79809197, 0.79373128, 0.78879472],
        [0.79384212, 0.79300986, 0.77542208, 0.74646824],
        [0.79384212, 0.78604919, 0.74024282, 0.70674877],
        [0.79384212, 0.77000659, 0.71475918, 0.68162725]],

       [[0.70728328, 0.70130491, 0.69438331, 0.68651888],
        [0.70728328, 0.69672317, 0.6855888 , 0.67390289],
        [0.70728328, 0.69255181, 0.67843733, 0.66480867],
        [0.70728328, 0.68879026, 0.67274237, 0.65838122]],

       [[0.68396037, 0.67784305, 0.67081036, 0.66285796],
        [0.68396037, 0.67585013, 0.66698887, 0.65736743],
        [0.68396037, 0.67403564, 0.66380411, 0.6531758 ],
        [0.68396037, 0.67239396, 0.66114631, 0.64994261]],

       [[0.67392697, 0.66773516, 0.66065745, 0.65268689],
        [0.67392697, 0.66665082, 0.65859004, 0.64964667],
        [0.67392697, 0.66567776, 0.65674798, 0.64721201],
        [0.67392697, 0.66471858, 0.65520658, 0.64520745]]])