In [2]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from epistatic_net.wht_sampling import SPRIGHTSample
from epistatic_net.spright_utils import SPRIGHT, make_system_simple
from datasets import FourierDataset

# Test SPRIGHT algorithm

In [3]:
spright_m = 6

data_n = 25
data_k = 2**spright_m
data_d = 5
data_n_samples = 100

dataset = FourierDataset(n=data_n, k=data_k, d=data_d, n_samples=data_n_samples)

In [4]:
spright_d = 7

spright_sample = SPRIGHTSample(data_n, spright_m, spright_d)
print(spright_sample.sampling_matrices[0].shape)
print(len(spright_sample.delay_matrices[0]), spright_sample.delay_matrices[0][0].shape)
print(len(spright_sample.sampling_locations[0]), spright_sample.sampling_locations[0][0].shape)

total_samples = (2**spright_m)*data_n*(spright_d*2+1)
print('total samples: {}'.format(total_samples))

print('samples/ambient dimension: {}'.format(total_samples/(2**data_n)))

- Loded cache/N25-m6-d7-seed0/sampling-matrix-2.p from cache (sampling matrix)
- Loded cache/N25-m6-d7-seed0/sampling-matrix-1.p from cache (sampling matrix)
- Loded cache/N25-m6-d7-seed0/sampling-matrix-0.p from cache (sampling matrix)
- Loded cache/N25-m6-d7-seed0/delays-0.p from cache (delay matrix)
- Loded cache/N25-m6-d7-seed0/delays-1.p from cache (delay matrix)
- Loded cache/N25-m6-d7-seed0/delays-2.p from cache (delay matrix)
- Loded cache/N25-m6-d7-seed0/sampling-locations-2.p from cache (samling location)
- Loded cache/N25-m6-d7-seed0/sampling-locations-1.p from cache (samling location)
- Loded cache/N25-m6-d7-seed0/sampling-locations-0.p from cache (samling location)
(6, 25)
351 (25,)
351 (64, 25)
total samples: 24000
samples/ambient dimension: 0.0007152557373046875


In [5]:
spright = SPRIGHT('frame', [1,2,3],spright_sample)

In [6]:
X_all = np.concatenate((np.vstack(spright_sample.sampling_locations[0]),np.vstack(spright_sample.sampling_locations[1]),np.vstack(spright_sample.sampling_locations[2])))
X_all, X_all_inverse_ind = np.unique(X_all, axis=0, return_inverse='True')
y_hat_all = dataset.compute_y(torch.from_numpy(X_all).float()).numpy()
print(X_all.shape, X_all_inverse_ind.shape, y_hat_all.shape)

spright.set_train_data(X_all, y_hat_all, X_all_inverse_ind)
spright.model_to_remove = dataset.compute_y

(67270, 25) (67392,) (67270,)


In [7]:
flag = spright.initial_run()
flag

True

In [8]:
if flag:
    spright.peel_rest()

In [9]:
spright.model.support.shape

(56, 25)

In [10]:
found_amps = np.zeros(dataset.freq_f.shape[0])
found_freq_indices = [np.argmax(np.sum(dataset.freq_f.numpy() == found_freq, axis=1)) for found_freq in spright.model.support]
found_amps[found_freq_indices] = spright.model.coef_*data_k
found_amps

array([ 0.92097609, -0.35252892,  0.8016816 ,  0.18910937, -0.93928587,
        0.32322738,  0.45761317, -0.02755043, -0.16505679, -0.55105776,
       -0.77345865, -0.68703681,  0.71299566,  0.88221974,  0.04493178,
        0.44608153, -0.50255522,  0.95583203,  0.        , -0.58544415,
       -0.54935235,  1.14704513, -0.11968114, -0.77372629,  0.        ,
        0.        ,  0.97405563, -0.71225116, -0.54562178,  0.        ,
       -0.02304042,  0.14881217,  0.27535573, -0.73424716,  0.23681441,
       -0.53149324,  0.98077892, -0.42322467, -0.65962634, -0.61327475,
       -0.01105803, -0.70831273, -0.36730334,  0.        ,  0.80031969,
        0.        ,  0.21025709, -0.05268596,  0.96001737, -0.08517527,
        0.86512474, -0.6444383 ,  0.28956651, -0.60469727, -0.81127916,
       -0.83717781,  0.72100954, -0.16763071,  0.        ,  0.        ,
        0.48552131,  0.22305477,  0.18692622, -0.06070981])

In [11]:
print("Percentage of found freqs:", len(found_freq_indices) / dataset.amp_f.shape[0])
print("R2 of found freqs and real freqs:", r2_score(dataset.amp_f[found_freq_indices], found_amps[found_freq_indices]))
print("R2 of considering all freqs:", r2_score(dataset.amp_f, found_amps))

Percentage of found freqs: 0.875
R2 of found freqs and real freqs: 0.9700787962435677
R2 of considering all freqs: 0.9172287651781538
