In [1]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from epistatic_net.wht_sampling import SPRIGHTSample
from epistatic_net.spright_utils import SPRIGHT, make_system_simple
from datasets import FourierDataset

# Test SPRIGHT algorithm

In [2]:
data_n = 200
data_k = 200
data_d = 5
data_n_samples = 100

dataset = FourierDataset(n=data_n, k=data_k, d=data_d, n_samples=data_n_samples)

In [3]:
spright_m = 8
spright_d = 3

spright_sample = SPRIGHTSample(data_n, spright_m, spright_d)
print(spright_sample.sampling_matrices[0].shape)
print(len(spright_sample.delay_matrices[0]), spright_sample.delay_matrices[0][0].shape)
print(len(spright_sample.sampling_locations[0]), spright_sample.sampling_locations[0][0].shape)

total_samples = (2**spright_m)*data_n*(spright_d*2+1)
print('total samples: {}'.format(total_samples))

print('samples/ambient dimension: {}'.format(total_samples/(2**data_n)))

(8, 200)
1201 (200,)
1201 (256, 200)
total samples: 358400
samples/ambient dimension: 2.2303286755854332e-55


In [4]:
spright = SPRIGHT('frame', [1,2,3],spright_sample)

In [5]:
X_all = np.concatenate((np.vstack(spright_sample.sampling_locations[0]),np.vstack(spright_sample.sampling_locations[1]),np.vstack(spright_sample.sampling_locations[2])))
X_all, X_all_inverse_ind = np.unique(X_all, axis=0, return_inverse='True')
y_hat_all = dataset.compute_y(torch.from_numpy(X_all).float()).numpy()
print(X_all.shape, X_all_inverse_ind.shape, y_hat_all.shape)

spright.set_train_data(X_all, y_hat_all, X_all_inverse_ind)
spright.model_to_remove = dataset.compute_y

(922366, 200) (922368,) (922366,)


In [6]:
flag = spright.initial_run()
flag

True

In [7]:
if flag:
    spright.peel_rest()

In [8]:
spright.model.support.shape

(190, 200)

In [9]:
found_amps = np.zeros(dataset.freq_f.shape[0])
found_freq_indices = [np.argmax(np.sum(dataset.freq_f.numpy() == found_freq, axis=1)) for found_freq in spright.model.support]
found_amps[found_freq_indices] = spright.model.coef_*data_k
found_amps

array([-0.80767821, -0.24989122,  0.03991397,  0.17922974,  0.45048793,
        0.37700584,  0.56281895,  0.1232947 ,  0.        ,  0.26659153,
        0.65697359, -0.30595934,  0.73987205,  0.12565336,  0.17574888,
        0.30098752, -0.47693757,  0.21849275,  0.9044698 ,  0.        ,
       -0.2861986 ,  0.58301283,  0.14891093,  0.73297334,  0.07759817,
       -0.2436006 , -0.4879963 , -0.11644165,  0.        , -0.27953109,
       -0.04768237,  0.25208081, -0.21553426, -0.45848361, -0.10546515,
        0.        ,  0.62902249,  0.89911872, -0.6763244 ,  0.15905138,
       -0.82272208, -0.56339664,  0.1504797 ,  0.47960584, -0.62257347,
       -0.01968647, -0.30866317,  0.06312548, -0.47553044, -0.55452199,
        0.38197763,  0.80498859,  0.63011212, -0.07034613,  0.43367547,
       -0.99610964,  0.55303008,  0.53904387, -0.71114334,  0.98152852,
       -0.21638109,  0.782087  , -0.58180268,  0.15220317, -0.64151367,
       -0.58545986,  0.50842534,  0.07910379,  0.94658358,  0.23

In [10]:
print("Percentage of found freqs:", len(found_freq_indices) / dataset.amp_f.shape[0])
print("R2 of found freqs and real freqs:", r2_score(dataset.amp_f[found_freq_indices], found_amps[found_freq_indices]))
print("R2 of considering all freqs:", r2_score(dataset.amp_f, found_amps))

Percentage of found freqs: 0.95
R2 of found freqs and real freqs: 0.9999880524537472
R2 of considering all freqs: 0.9827468207168037
