In [43]:
import numpy as np
import math
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

%matplotlib qt



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
from sklearn.linear_model import LinearRegression
from util import *

%matplotlib qt
T = 1000

def doppler(x, epsilon):
    return np.sin(2 * np.pi * (1 + epsilon) / (x + epsilon))

X = np.linspace(0, T, T)/T


std_levels = [0, 0.1, 0.2, 0.3, 0.4, 0.5]
var_levels = [s**2 for s in std_levels]

func_name = "Doppler"
G = doppler(abs(1 - X), 0.38)

Ys = []
num_runs = 20
for j in range(num_runs):
    Ysj = []
    for s in std_levels:
        noise = np.random.normal(0, s, T)
        Ysj.append(G + noise)
    Ys.append(Ysj)
Ys = np.array(Ys)

In [52]:
from akorn import AKORN
Zs = []
for i in range(len(std_levels)): 
    Z = []
    if std_levels[i] != 0:
        for j in range(num_runs):
            ak = AKORN(X, Ys[j][i], var_levels[i])
            ak.train()
            fv = ak.preds
            Z.append(fv)
    else:
        ak = AKORN(X, Ys[j][i], var_levels[i])
        ak.train()
        fv = ak.preds 
        for j in range(num_runs):
            Z.append(fv)
    Zs.append(Z)
    


print("Done!")

lr =  0.12500048312257067
run


KeyboardInterrupt: 

In [46]:
mses_akorn = [np.mean([((Zs[i][j] - G)**2).mean() for j in range(num_runs)]) for i in range(len(std_levels))]

In [48]:
#Compare to trendfiltering

import rpy2
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr, data



utils = importr('utils')
base = importr('base')
glmgen = importr('glmgen')

Xr = robjects.FloatVector(X)

from trendfilter import trend_filter
Z_tf = []
Z_tf_dof = []

for j in range(num_runs):
    Zj_tf = []
    Zj_tf_dof = []
    for i in range(len(std_levels)):
        Y = Ys[j][i]
        var = var_levels[i]

        Yr = robjects.FloatVector(Y)
        tf = glmgen.trendfilter(Xr, Yr, k = 1)
        rcode1 = "predict(%s)" %(tf.r_repr())
        tf_fit = np.array(robjects.r(rcode1)).T

        best_mse = float('inf')
        best_mse_est = float('inf')
        best_mse_dof = float('inf')
        best_tf_fit = None
        best_tf_fit_dof = None
        for t in tf_fit:
            tf_mse = ((t - G)**2).mean() #Oracle risk
            tf_mse_est = (np.sum(((t - Y)**2)) + 2*var*count_linear_pieces(t)) #Stein risk estimate
            if tf_mse < best_mse:
                best_mse = tf_mse
                best_tf_fit = t
            if tf_mse_est < best_mse_est:
                best_mse_est = tf_mse_est
                best_mse_dof = tf_mse
                best_tf_fit_dof = t

        Zj_tf.append(best_tf_fit)
        Zj_tf_dof.append(best_tf_fit_dof)
    Z_tf.append(Zj_tf)
    Z_tf_dof.append(Zj_tf_dof)
mses_tf = [np.mean([((Z_tf[j][i] - G)**2).mean() for j in range(num_runs)]) for i in range(len(std_levels))]
mses_tf_dof = [np.mean([((Z_tf_dof[j][i] - G)**2).mean() for j in range(num_runs)]) for i in range(len(std_levels))]
    
print("done")

done


In [75]:
import pywt
def oracle_wavelets(Y,G, std):
    def wavelet_denoise(Y, a, std):
        coeffs = pywt.wavedec(Y, "db2")
        thresholded = []
        for c in coeffs:
            thresholded.append(pywt.threshold(c, a*std*math.sqrt(2*math.log(T)), mode = "soft"))
        signal = pywt.waverec(thresholded, "db2")
        return signal
    best_fit = None
    best_mse = float('inf')
    for a in np.linspace(1, 100, 300):
        sig = wavelet_denoise(Y, a, std)
        sig_mse = np.mean(np.square(sig - G))
        if sig_mse < best_mse:
            best_fit = sig
            best_mse = sig_mse
    return (best_fit, best_mse)

Z_wav = []
for j in range((num_runs)): 
    Z = []
    for i in range(len(std_levels)):
        std = std_levels[i]
        fit = oracle_wavelets(Ys[i], G, std)[0]
        Z.append(fit)
    Z_wav.append(Z)
mses_wav = [np.mean([((Z_wav[j][i] - G)**2).mean() for j in range(num_runs)]) for i in range(len(std_levels))]
print(mses_wav)

[0.08779655463597784, 0.017183279486670745, 0.018125684419646192, 0.024961864966664405, 0.03437264722933579, 0.0437856143711426]


In [80]:
np.array(Z_wav[0][0])[0] 

array([-2.42427606e-16,  4.56088292e-03,  9.12829691e-03,  1.37021605e-02,
        1.82823915e-02,  2.28689068e-02,  2.74616226e-02,  3.20604542e-02,
        3.66653163e-02,  4.12761225e-02,  4.58927857e-02,  5.05152181e-02,
        5.51433310e-02,  5.97770349e-02,  6.44162393e-02,  6.90608531e-02,
        7.37107843e-02,  7.83659400e-02,  8.30262264e-02,  8.76915491e-02,
        9.23618125e-02,  9.70369205e-02,  1.01716776e-01,  1.06401281e-01,
        1.11090336e-01,  1.15783842e-01,  1.20481699e-01,  1.25183804e-01,
        1.29890056e-01,  1.34600350e-01,  1.39314584e-01,  1.44032652e-01,
        1.48754448e-01,  1.53479865e-01,  1.58208796e-01,  1.62941131e-01,
        1.67676762e-01,  1.72415578e-01,  1.77157467e-01,  1.81902317e-01,
        1.86650015e-01,  1.91400448e-01,  1.96153498e-01,  2.00909052e-01,
        2.05666991e-01,  2.10427198e-01,  2.15189554e-01,  2.19953939e-01,
        2.24720233e-01,  2.29488313e-01,  2.34258057e-01,  2.39029341e-01,
        2.43802041e-01,  

In [82]:
plt.plot(X, Z_wav[0][0][])

[<matplotlib.lines.Line2D at 0x28f995e10>]