In [1]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from epistatic_net.wht_sampling import SPRIGHTSample
from epistatic_net.spright_utils import SPRIGHT, make_system_simple
from datasets import FourierDataset

# Test SPRIGHT algorithm

In [2]:
data_n = 200
data_k = 200
data_d = 5
data_n_samples = 100

dataset = FourierDataset(n=data_n, k=data_k, d=data_d, n_samples=data_n_samples)

In [3]:
spright_m = 8
spright_d = 3

spright_sample = SPRIGHTSample(data_n, spright_m, spright_d)
print(spright_sample.sampling_matrices[0].shape)
print(len(spright_sample.delay_matrices[0]), spright_sample.delay_matrices[0][0].shape)
print(len(spright_sample.sampling_locations[0]), spright_sample.sampling_locations[0][0].shape)

total_samples = (2**spright_m)*data_n*(spright_d*2+1)
print('total samples: {}'.format(total_samples))

print('samples/ambient dimension: {}'.format(total_samples/(2**data_n)))

Loading SPRIGHT samples from cache ...
(8, 200)
1201 (200,)
1201 (256, 200)
total samples: 358400
samples/ambient dimension: 2.2303286755854332e-55


In [4]:
spright = SPRIGHT('frame', [1,2,3],spright_sample)

In [5]:
X_all = np.concatenate((np.vstack(spright_sample.sampling_locations[0]),np.vstack(spright_sample.sampling_locations[1]),np.vstack(spright_sample.sampling_locations[2])))
X_all, X_all_inverse_ind = np.unique(X_all, axis=0, return_inverse='True')
y_hat_all = dataset.compute_y(torch.from_numpy(X_all).float()).numpy()
print(X_all.shape, X_all_inverse_ind.shape, y_hat_all.shape)

spright.set_train_data(X_all, y_hat_all, X_all_inverse_ind)
spright.model_to_remove = dataset.compute_y

(922366, 200) (922368,) (922366,)


In [6]:
flag = spright.initial_run()
flag

True

In [7]:
if flag:
    spright.peel_rest()

In [8]:
spright.model.support.shape

(190, 200)

In [9]:
found_amps = np.zeros(dataset.freq_f.shape[0])
found_freq_indices = [np.argmax(np.sum(dataset.freq_f.numpy() == found_freq, axis=1)) for found_freq in spright.model.support]
found_amps[found_freq_indices] = spright.model.coef_*data_k
found_amps

array([-0.80648135, -0.249206  ,  0.03825224,  0.17780295,  0.4494334 ,
        0.37708214,  0.56539142,  0.12405929,  0.        ,  0.2658774 ,
        0.6566655 , -0.30016956,  0.74267446,  0.12622902,  0.17125008,
        0.29983136, -0.47769909,  0.21828967,  0.90598585,  0.        ,
       -0.285008  ,  0.58495877,  0.15245997,  0.73243782,  0.07820443,
       -0.24369992, -0.48432907, -0.11483389,  0.        , -0.28386448,
       -0.04534004,  0.25199746, -0.21578812, -0.45663654, -0.10470216,
        0.        ,  0.62818223,  0.89647289, -0.67458268,  0.1586621 ,
       -0.82370262, -0.56226251,  0.14986253,  0.47955813, -0.62375613,
       -0.02146116, -0.31199437,  0.06393513, -0.47464132, -0.55210664,
        0.37918096,  0.80271826,  0.62862604, -0.07081426,  0.43335832,
       -0.99616234,  0.55477262,  0.54046153, -0.71171709,  0.98413329,
       -0.2211212 ,  0.7849796 , -0.58141274,  0.15581471, -0.64116682,
       -0.58535729,  0.50974535,  0.08403472,  0.9448627 ,  0.23

In [10]:
print("Percentage of found freqs:", len(found_freq_indices) / dataset.amp_f.shape[0])
print("R2 of found freqs and real freqs:", r2_score(dataset.amp_f[found_freq_indices], found_amps[found_freq_indices]))
print("R2 of considering all freqs:", r2_score(dataset.amp_f, found_amps))

Percentage of found freqs: 0.95
R2 of found freqs and real freqs: 0.9999874843629595
R2 of considering all freqs: 0.9827462624049894


In [19]:
spright.model.coef_

array([-4.39585879e-03, -3.60553771e-03, -2.92678646e-03, -2.85706718e-03,
       -2.26700189e-04,  4.15479628e-03, -1.19830613e-03,  3.00852974e-03,
        3.76573050e-03, -1.09526733e-03, -3.20583412e-03,  3.95050490e-03,
       -4.83570430e-03, -2.37430153e-03,  7.62299828e-04,  3.14313021e-03,
       -3.37291340e-03, -2.70638773e-03,  3.49386946e-03,  3.71337232e-03,
        1.89590480e-03,  4.70489450e-03, -5.74169434e-04, -7.03069207e-04,
        6.09998500e-04,  3.15276283e-03,  7.15809094e-04, -3.54945636e-03,
       -1.41932238e-03, -4.96060953e-03, -1.07305820e-04,  1.05997020e-03,
       -1.38188997e-03, -2.67028185e-03, -4.11851310e-03, -4.98081169e-03,
       -2.49942848e-03,  1.88541071e-03,  3.83295203e-03,  1.62383968e-04,
        4.36354630e-03,  7.79073547e-04,  3.78804674e-03,  3.92489802e-03,
        2.54872677e-03,  3.91022173e-04,  8.89014755e-04, -4.66572628e-03,
       -2.36697343e-03, -3.81708976e-03, -1.76694305e-03,  8.56250423e-04,
        4.92066646e-03,  