In [93]:
import sys
import os
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from epistatic_net.wht_sampling import SPRIGHTSample
from epistatic_net.spright_utils import SPRIGHT, make_system_simple
from utils import FourierDataset

In [77]:
spright_d = 3
spright_m = 4

data_n = 13
data_k = 2**spright_m
data_d = 5
data_n_samples = 100

dataset = FourierDataset(n=data_n, k=data_k, d=data_d, n_samples=data_n_samples)

In [78]:
spright_sample = SPRIGHTSample(data_n, spright_m, spright_d)

- Loded cache/N13-m4-d3-seed0/sampling-matrix-2.p from cache (sampling matrix)
- Loded cache/N13-m4-d3-seed0/sampling-matrix-1.p from cache (sampling matrix)
- Loded cache/N13-m4-d3-seed0/sampling-matrix-0.p from cache (sampling matrix)
- Loded cache/N13-m4-d3-seed0/delays-2.p from cache (delay matrix)
- Loded cache/N13-m4-d3-seed0/delays-0.p from cache (delay matrix)
- Loded cache/N13-m4-d3-seed0/delays-1.p from cache (delay matrix)
- Loded cache/N13-m4-d3-seed0/sampling-locations-0.p from cache (samling location)
- Loded cache/N13-m4-d3-seed0/sampling-locations-1.p from cache (samling location)
- Loded cache/N13-m4-d3-seed0/sampling-locations-2.p from cache (samling location)


In [79]:
spright = SPRIGHT('frame', [1,2,3],spright_sample)

In [80]:
X_all = np.concatenate((np.vstack(spright_sample.sampling_locations[0]),np.vstack(spright_sample.sampling_locations[1]),np.vstack(spright_sample.sampling_locations[2])))
X_all, X_all_inverse_ind = np.unique(X_all, axis=0, return_inverse='True')
y_hat_all = dataset.compute_y(X_all).numpy()
print(X_all.shape, X_all_inverse_ind.shape, y_hat_all.shape)

spright.set_train_data(X_all, y_hat_all, X_all_inverse_ind)
spright.model_to_remove = dataset.compute_y

(3062, 13) (3792,) (3062,)


In [81]:
flag = spright.initial_run()
flag

True

In [82]:
if flag:
    spright.peel_rest()

In [83]:
spright.model.support.shape

(11, 13)

In [94]:
found_amps = np.zeros(dataset.freq_f.shape[0])
found_freq_indices = [np.argmax(np.sum(dataset.freq_f.numpy() == found_freq, axis=1)) for found_freq in spright.model.support]
found_amps[found_freq_indices] = spright.model.coef_*data_k
found_amps

array([-0.74584536, -0.75940823,  0.        ,  0.90431971,  0.        ,
       -0.50140492, -0.15087225,  0.        ,  0.4282222 ,  0.        ,
        0.75691562, -0.83738433,  0.56431174,  0.28851532,  0.30091978,
        0.        ])

In [96]:
print("Percentage of found freqs:", len(found_freq_indices) / dataset.amp_f.shape[0])
print("R2 of found freqs and real freqs:", r2_score(dataset.amp_f, found_amps))

Percentage of found freqs: 0.6875
R2 of found freqs and real freqs: 0.8330591871765951
