In [1]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import scipy.stats
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
from nn_resampler import nn_resampler
from phase_est_smc import phase_est_smc

In [None]:
num_particles = 1000 # number of SMC particles (num of w points)
num_samples = 10000 # number of samples to draw from the particle distribution (to be binned)
num_bins = 100 # number of bins
n_runs = 100 # number of different omega*
t0 = 10 # starting time
max_iters = 100 # maximum number of iterations before breaking

In [None]:
net = nn_resampler(num_bins,num_bins)
net.load_state_dict(torch.load("model/nn_resampler.model"))
net.eval();

## NN Resampler

In [None]:
true_omegas = []
nn_preds = []
nn_data = []

for i in range(n_runs): 
    
    true_omega = np.random.uniform(low=-1, high =1) * np.pi
    
    true_omegas.append(true_omega)
    smc = phase_est_smc(true_omega, t0, max_iters)
    smc.init_particles(num_particles)
    resample_counts = 0
    
    while True:
        
        particle_pos, particle_wgts = smc.particles(threshold=num_particles/10, num_measurements=1)
        bins, edges = smc.get_bins(num_bins, num_samples)
        
        if smc.break_flag:
            break
            
        nn_pred = net(torch.tensor(bins).float().unsqueeze(0)) ## convert to float tensor, then make dim [1, num_bins]
        smc.nn_bins_to_particles(nn_pred.detach().numpy(), edges)
        
        resample_counts += 1

    nn_data.append(smc.data)
    nn_preds.append(smc.curr_omega_est)
    
    if np.abs(true_omega - smc.curr_omega_est) > 1:
        print("True omega: {:f}, prediction by NN: {:f}, num of resample calls: {:d}. Failed".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))
    else:
        print("True omega: {:f}, prediction by NN: {:f}, num of resample calls: {:d}".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))

In [None]:
nn_mse = np.mean( (np.array(true_omegas) - np.array(nn_preds) ) ** 2 )
print(nn_mse)

In [None]:
nn_data_squared = ( np.array(nn_data) - np.array(true_omegas).reshape(-1,1)) ** 2 
nn_data_mean = np.mean(nn_data_squared, axis=0)
nn_data_median = np.median(nn_data_squared, axis =0)

num_data_points = nn_data_squared.shape[1]

plt.plot(np.arange(1,num_data_points+1, dtype=int), nn_data_mean, label='Mean')
plt.plot(np.arange(1,num_data_points+1, dtype=int), nn_data_median, label='Median')
plt.legend()
plt.title("NN")
plt.xlabel("Iters")
plt.ylabel("$(\omega - \omega*)^2$")
plt.yscale('log')
plt.show()

## Gaussian Bin Resampler

In [None]:
true_omegas = []
gb_preds = []
gb_data = []

for i in range(n_runs): 
    
    true_omega = np.random.uniform(low=-1, high =1) * np.pi
    
    true_omegas.append(true_omega)
    smc = phase_est_smc(true_omega, t0, max_iters)
    smc.init_particles(num_particles)
    resample_counts = 0
    
    while True:
        
        particle_pos, particle_wgts = smc.particles(threshold=num_particles/10, num_measurements=1)
        bins, edges = smc.get_bins(num_bins, num_samples)
        
        if smc.break_flag:
            break
        
        smc.nn_bins_to_particles(bins[np.newaxis,:],edges)
        
        resample_counts += 1   
    
    gb_data.append(smc.data)
    gb_preds.append(smc.curr_omega_est)
    
    if np.abs(true_omega - smc.curr_omega_est) > 1:
        print("True omega: {:f}, prediction by GB: {:f}, num of resample calls: {:d}. Failed".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))
    else:
        print("True omega: {:f}, prediction by GB: {:f}, num of resample calls: {:d}".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))

In [None]:
gb_mse = np.mean( (np.array(true_omegas) - np.array(gb_preds) ) ** 2 )
print(gb_mse)

In [None]:
gb_data_squared = ( np.array(gb_data) - np.array(true_omegas).reshape(-1,1)) ** 2 
gb_data_mean = np.mean(gb_data_squared, axis=0)
gb_data_median = np.median(gb_data_squared, axis =0)

num_data_points = gb_data_squared.shape[1]

plt.plot(np.arange(1,num_data_points+1, dtype=int), gb_data_mean, label='Mean')
plt.plot(np.arange(1,num_data_points+1, dtype=int), gb_data_median, label='Median')
plt.legend()
plt.title("GB")
plt.xlabel("Iters")
plt.ylabel("$(\omega - \omega*)^2$")
plt.yscale('log')
plt.show()

## Liu-West resampler

In [None]:
true_omegas = []
lw_preds = []
lw_data = []

for i in range(n_runs): 
    true_omega = np.random.uniform(low=-1, high=1) * np.pi
    
    true_omegas.append(true_omega)
    smc = phase_est_smc(true_omega, t0, max_iters)
    smc.init_particles(num_particles)
    
    resample_counts = 0
    
    while True:

        particle_pos, particle_wgts = smc.particles(threshold=num_particles/10, num_measurements=1)
        smc.liu_west_resample()
    
        resample_counts += 1 
        
        if smc.break_flag == True:
            break

    while len(smc.data) < max_iters:
        smc.data.append(smc.data[-1]) # append the last estimate

    lw_data.append(smc.data)
    lw_preds.append(smc.curr_omega_est)
    
    if np.abs(true_omega - smc.curr_omega_est) > 1:
        print("True omega: {:f}, prediction by NN: {:f}, num of resample calls: {:d}. Failed".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))
    
    else:
        print("True omega: {:f}, prediction by NN: {:f}, num of resample calls: {:d}".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))
    

In [None]:
lw_mse = np.mean( (np.array(true_omegas) - np.array(lw_preds)) ** 2 )
print(lw_mse)

In [None]:
lw_data_squared = ( np.array(lw_data) - np.array(true_omegas).reshape(-1,1)) ** 2 
lw_data_mean = np.mean(lw_data_squared, axis=0)
lw_data_median = np.median(lw_data_squared, axis =0)

num_data_points = lw_data_squared.shape[1]

plt.plot(np.arange(1,num_data_points+1, dtype=int), lw_data_mean, label='Mean')
plt.plot(np.arange(1,num_data_points+1, dtype=int), lw_data_median, label='Median')
plt.legend()
plt.title("LW")
plt.xlabel("Iters")
plt.ylabel("$(\omega - \omega*)^2$")
plt.yscale('log')
plt.show()

## KDE Resampler

In [None]:
true_omegas = []
kde_preds = []
kde_data = []

for i in range(n_runs): 
    
    true_omega = np.random.uniform(low=-1, high =1) * np.pi
    
    true_omegas.append(true_omega)
    smc = phase_est_smc(true_omega, t0, max_iters)
    smc.init_particles(num_particles)
    resample_counts = 0
    
    while True:
        
        particle_pos, particle_wgts = smc.particles(threshold=num_particles/5, num_measurements=1)
        if smc.break_flag:
            break
        smc.kde_resample(num_samples=num_samples, method=1)
        resample_counts += 1
    
    kde_data.append(smc.data)
    kde_preds.append(smc.curr_omega_est)
    
    if np.abs(true_omega - smc.curr_omega_est) > 1:
        print("True omega: {:f}, prediction by KDE: {:f}, num of resample calls: {:d}. Failed".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))
    else:
        print("True omega: {:f}, prediction by KDE: {:f}, num of resample calls: {:d}".format(true_omega,
                                                                                            smc.curr_omega_est,
                                                                                            resample_counts))

In [None]:
kde_mse = np.mean( (np.array(true_omegas) - np.array(kde_preds) ) ** 2 )
print(kde_mse)

In [None]:
kde_data_squared = ( np.array(kde_data) - np.array(true_omegas).reshape(-1,1)) ** 2 
kde_data_mean = np.mean(kde_data_squared, axis=0)
kde_data_median = np.median(kde_data_squared, axis =0)

num_data_points = kde_data_squared.shape[1]

plt.plot(np.arange(1,num_data_points+1, dtype=int), kde_data_mean, label='Mean')
plt.plot(np.arange(1,num_data_points+1, dtype=int), kde_data_median, label='Median')
plt.legend()
plt.title("KDE")
plt.xlabel("Iters")
plt.ylabel("$(\omega - \omega*)^2$")
plt.yscale('log')
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(12,6))
x_iters = np.arange(1,num_data_points+1, dtype=int)
ax1.plot(x_iters, nn_data_mean, label='NN Resampler')
ax1.plot(x_iters, gb_data_mean, label='GB Resampler')
ax1.plot(x_iters, lw_data_mean, label='LW Resampler')
ax1.plot(x_iters, kde_data_mean, label='KDE Resampler')
ax1.set_title("Mean vs n_iters")
ax1.set_ylabel("$(\omega - \omega*)^2$")
ax1.legend()

ax2.plot(x_iters, nn_data_median, label='NN Resampler')
ax2.plot(x_iters, gb_data_median, label='GB Resampler')
ax2.plot(x_iters, lw_data_median, label='LW Resampler')
ax2.plot(x_iters, kde_data_median, label='KDE Resampler')
ax2.set_title("Median vs n_iters")
ax2.set_yscale("log")
ax2.legend()

plt.show()