In [12]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import scipy.stats
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [13]:
from nn_resampler import nn_resampler
from phase_est_smc import phase_est_smc

In [14]:
num_particles = 1000 # number of SMC particles (num of w points)
num_samples = 10000 # number of samples to draw from the particle distribution (to be binned)
num_bins = 100 # number of bins
n_iters = 1000 # number of different omega*
t0 = 0.1 # starting time
max_iters = 100 # maximum number of iterations before breaking

In [15]:
net = nn_resampler(100,100);
net.load_state_dict(torch.load("model/nn_no_normalize.model"));
net.eval();

## Generate arbitrary data before resampling (can be good or poor SMC runs)

In [16]:
n_data = 0
n_data_max = 10000
bins_data = []
edges_data = []

while n_data < n_data_max:
    
    omega_star = np.random.uniform(low=-1, high =1) * np.pi
    
    omega_star_list.append(omega_star)
    smc = phase_est_smc(omega_star, t0, max_iters)
    smc.init_particles(num_particles)

    # take data from first resample step only
    particle_pos, particle_wgts = smc.particles(threshold=num_particles/5, num_measurements=1)
    bins, edges = smc.get_bins(num_bins, num_samples)
    
    bins_data.append(bins)
    edges_data.append(edges)
            
    n_data += 1
        
    if n_data % 1000 == 0:
        print("Current progress: {:d}/{:d}".format(n_data,n_data_max))

Current progress: 100/10000
Current progress: 200/10000
Current progress: 300/10000
Current progress: 400/10000
Current progress: 500/10000
Current progress: 600/10000
Current progress: 700/10000
Current progress: 800/10000
Current progress: 900/10000
Current progress: 1000/10000
Current progress: 1100/10000
Current progress: 1200/10000
Current progress: 1300/10000
Current progress: 1400/10000
Current progress: 1500/10000
Current progress: 1600/10000
Current progress: 1700/10000
Current progress: 1800/10000
Current progress: 1900/10000
Current progress: 2000/10000
Current progress: 2100/10000
Current progress: 2200/10000
Current progress: 2300/10000
Current progress: 2400/10000
Current progress: 2500/10000
Current progress: 2600/10000
Current progress: 2700/10000
Current progress: 2800/10000
Current progress: 2900/10000
Current progress: 3000/10000
Current progress: 3100/10000
Current progress: 3200/10000
Current progress: 3300/10000
Current progress: 3400/10000
Current progress: 3500/

In [18]:
good_bins = np.array(bins_data)
good_edges = np.array(edges_data)
np.save("data/good_bins.npy", good_bins)
np.save("data/good_edges", good_edges)

## Generate resampling failed data

In [19]:
n_data = 0
n_data_max = 5000
bins_data = []
edges_data = []

while n_data < n_data_max:
    
    omega_star = np.random.uniform(low=-1, high =1) * np.pi
    
    omega_star_list.append(omega_star)
    smc = phase_est_smc(omega_star, t0, max_iters)
    smc.init_particles(num_particles)
    resample_counts = 0
    
    while True:
        
        particle_pos, particle_wgts = smc.particles(threshold=num_particles/5, num_measurements=1)
        bins, edges = smc.get_bins(num_bins, num_samples)
        
        if smc.break_flag:
            break
            
        nn_pred = net(torch.tensor(bins).float().unsqueeze(0)) ## convert to float tensor, then make dim [1, num_bins]
        smc.nn_bins_to_particles(nn_pred.detach().numpy(), edges)
        
        resample_counts += 1
    
    # if we detect a failed example, store that as training data
    if abs(smc.curr_omega_est - omega_star) > 1:

        
        # first resample data
        bins_data.append(smc.memory.bins_edges_bef_res[0][0])
        edges_data.append(smc.memory.bins_edges_bef_res[0][1])
        
        if len(smc.memory.bins_edges_bef_res) >= 2:
        
            # second resample data
            bins_data.append(smc.memory.bins_edges_bef_res[1][0])
            edges_data.append(smc.memory.bins_edges_bef_res[1][1])
                                                                        
            
        n_data += 1
        
        if n_data % 100 == 0:
            print("Current progress: {:d}/{:d}".format(n_data,n_data_max))

Current progress: 100/5000
Current progress: 200/5000
Current progress: 300/5000
Current progress: 400/5000
Current progress: 500/5000
Current progress: 600/5000
Current progress: 700/5000
Current progress: 800/5000
Current progress: 900/5000
Current progress: 1000/5000
Current progress: 1100/5000
Current progress: 1200/5000
Current progress: 1300/5000
Current progress: 1400/5000
Current progress: 1500/5000
Current progress: 1600/5000
Current progress: 1700/5000
Current progress: 1800/5000
Current progress: 1900/5000
Current progress: 2000/5000
Current progress: 2100/5000
Current progress: 2200/5000
Current progress: 2300/5000
Current progress: 2400/5000
Current progress: 2500/5000
Current progress: 2600/5000
Current progress: 2700/5000
Current progress: 2800/5000
Current progress: 2900/5000
Current progress: 3000/5000
Current progress: 3100/5000
Current progress: 3200/5000
Current progress: 3300/5000
Current progress: 3400/5000
Current progress: 3500/5000
Current progress: 3600/5000
C

In [21]:
poor_bins = np.array(bins_data)
poor_edges = np.array(edges_data)
np.save("data/poor_bins.npy", poor_bins)
np.save("data/poor_edges", poor_edges)