In [1]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from epistatic_net.wht_sampling import SPRIGHTSample
from epistatic_net.spright_utils import SPRIGHT, make_system_simple
from utils import FourierDataset

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
spright_d = 3
spright_m = 5

data_n = 25
data_k = 2**spright_m
data_d = 5
data_n_samples = 100

dataset = FourierDataset(n=data_n, k=data_k, d=data_d, n_samples=data_n_samples)

In [13]:
spright_sample = SPRIGHTSample(data_n, spright_m, spright_d)

- Loded cache/N25-m5-d3-seed0/sampling-matrix-2.p from cache (sampling matrix)
- Loded cache/N25-m5-d3-seed0/sampling-matrix-0.p from cache (sampling matrix)
- Loded cache/N25-m5-d3-seed0/sampling-matrix-1.p from cache (sampling matrix)
- Loded cache/N25-m5-d3-seed0/delays-2.p from cache (delay matrix)
- Loded cache/N25-m5-d3-seed0/delays-1.p from cache (delay matrix)
- Loded cache/N25-m5-d3-seed0/delays-0.p from cache (delay matrix)
- Loded cache/N25-m5-d3-seed0/sampling-locations-2.p from cache (samling location)
- Loded cache/N25-m5-d3-seed0/sampling-locations-0.p from cache (samling location)
- Loded cache/N25-m5-d3-seed0/sampling-locations-1.p from cache (samling location)


In [14]:
spright = SPRIGHT('frame', [1,2,3],spright_sample)

In [15]:
X_all = np.concatenate((np.vstack(spright_sample.sampling_locations[0]),np.vstack(spright_sample.sampling_locations[1]),np.vstack(spright_sample.sampling_locations[2])))
X_all, X_all_inverse_ind = np.unique(X_all, axis=0, return_inverse='True')
y_hat_all = dataset.compute_y(X_all).numpy()
print(X_all.shape, X_all_inverse_ind.shape, y_hat_all.shape)

spright.set_train_data(X_all, y_hat_all, X_all_inverse_ind)
spright.model_to_remove = dataset.compute_y

(14492, 25) (14496,) (14492,)


In [16]:
flag = spright.initial_run()
flag

True

In [17]:
if flag:
    spright.peel_rest()

In [18]:
spright.model.support.shape

(29, 25)

In [19]:
found_amps = np.zeros(dataset.freq_f.shape[0])
found_freq_indices = [np.argmax(np.sum(dataset.freq_f.numpy() == found_freq, axis=1)) for found_freq in spright.model.support]
found_amps[found_freq_indices] = spright.model.coef_*data_k
found_amps

array([ 0.83100438,  0.5102114 , -0.48533083,  0.        , -0.06299139,
       -0.70023113, -0.09746977, -0.73933198,  0.03420725,  0.6426868 ,
       -0.97603921,  0.17991741, -0.38687808, -0.4554383 , -0.61094019,
       -0.87368608, -0.70157013, -0.33340202,  0.43164988,  0.4316573 ,
        0.        ,  0.        ,  0.38160576, -0.80137256,  0.83749306,
        0.59483401, -0.99849275, -0.9704108 ,  0.92420608,  0.55827206,
       -0.58688168,  0.56860295])

In [20]:
print("Percentage of found freqs:", len(found_freq_indices) / dataset.amp_f.shape[0])
print("R2 of found freqs and real freqs:", r2_score(dataset.amp_f[found_freq_indices], found_amps[found_freq_indices]))
print("R2 of considering all freqs:", r2_score(dataset.amp_f, found_amps))

Percentage of found freqs: 0.90625
R2 of found freqs and real freqs: 0.9999633820757393
R2 of considering all freqs: 0.9795496786936753
