In [6]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import scipy.stats
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [7]:
from nn_resampler import nn_resampler
from phase_est_smc import phase_est_smc

In [8]:
net = nn_resampler(100,100)
#net = net.to(device)
net.load_state_dict(torch.load("net_bn_aft_relu.model"))
net.eval()

nn_resampler(
  (enc1): Linear(in_features=100, out_features=50, bias=True)
  (enc2): Linear(in_features=50, out_features=25, bias=True)
  (enc3): Linear(in_features=25, out_features=10, bias=True)
  (dec1): Linear(in_features=10, out_features=25, bias=True)
  (dec2): Linear(in_features=25, out_features=50, bias=True)
  (dec3): Linear(in_features=50, out_features=100, bias=True)
  (bn_enc1): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_enc2): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_enc3): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_dec1): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_dec2): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [9]:
num_particles = 100 # number of SMC particles (num of w points)
num_samples = 10000 # number of samples to draw from the particle distribution (to be binned)
num_bins = 100 # number of bins
n_iters = 1000 # number of iterations of resampling

In [10]:
for i in range(1000): # 1000 different phases to estimate
    t0 = 0.1
    omega_star = np.random.uniform() * 2 * np.pi
    smc = phase_est_smc(omega_star, t0)
    smc.init_particles(num_particles)
    running_est = []

    while True:

        particle_pos, particle_wgts = smc.particles(threshold=num_particles/5, num_measurements=1)
        data = np.random.choice(particle_pos, size = num_samples, p=particle_wgts)

        mean = np.mean(data)
        std = np.std(data)

        # if standard deviation of samples are too small it means the distribution is strongly peaked
        # and will not change anymore, so break out of loop
#         if std <= 1e-8:
#             break

        data = (data-mean)/std
        bins, edges = np.histogram(data, num_bins)
        bins = bins/num_samples

        nn_pred = net(torch.tensor(bins).float().unsqueeze(0)) ## convert to float tensor, then make dim [1, num_bins]

        smc.nn_resample(nn_pred.detach().numpy(), edges, mean, std)

    #     edges_ = edges[:-1]
    #     edge_width = edges_[1] - edges_[0]

    #     plt.bar(edges_*std + mean, bins, align='edge', width = edge_width)
    #     plt.bar(edges_*std + mean, nn_pred[0].detach().numpy(), align='edge', width = edge_width, alpha=0.7)
    #     plt.show() 
        curr_omega_est = smc.particle_pos[np.argmax(smc.particle_wgts)]
        running_est.append(curr_omega_est)
        
        if len(running_est) > 10:
            last_10 = running_est[-10:]
            # np.allclose(a,b) returns true if every element in a and b are equal to a tolerance
            # np.full_like makes an array with shake of last_10 with every element being last_10[0]
            # if all elements of last_10 are equal then they should all equal the first element
            if np.allclose(last_10, np.full_like(last_10, last_10[0])):
                break
                
    print("True omega: {:f}, prediction by SMC: {:f}".format(omega_star,curr_omega_est))


True omega: 5.220185, prediction by SMC: 3.039305
True omega: 4.683895, prediction by SMC: 4.683269
True omega: 1.874683, prediction by SMC: 1.874923
True omega: 3.613499, prediction by SMC: 3.617202
True omega: 4.099283, prediction by SMC: 4.099906
True omega: 1.616176, prediction by SMC: 1.539601
True omega: 5.953770, prediction by SMC: 5.956761
True omega: 3.915083, prediction by SMC: 3.557439
True omega: 4.946350, prediction by SMC: 4.590607
True omega: 3.165810, prediction by SMC: 3.240458
True omega: 0.170685, prediction by SMC: 0.525838
True omega: 6.241775, prediction by SMC: 6.114848
True omega: 3.013720, prediction by SMC: 3.013713
True omega: 3.681087, prediction by SMC: 3.433617
True omega: 3.205513, prediction by SMC: 3.202549
True omega: 1.553516, prediction by SMC: 1.655820
True omega: 1.856648, prediction by SMC: 1.839056
True omega: 0.909317, prediction by SMC: 0.845002
True omega: 2.674069, prediction by SMC: 1.761906
True omega: 1.891746, prediction by SMC: 1.853371


True omega: 2.251050, prediction by SMC: 1.438179
True omega: 4.201655, prediction by SMC: 3.846252
True omega: 3.397567, prediction by SMC: 3.406111
True omega: 1.671702, prediction by SMC: 1.603182
True omega: 0.563642, prediction by SMC: 0.562096
True omega: 0.455914, prediction by SMC: 0.453797
True omega: 4.818545, prediction by SMC: 4.814192
True omega: 6.159728, prediction by SMC: 5.090290
True omega: 2.909747, prediction by SMC: 2.281671
True omega: 1.128853, prediction by SMC: 1.071177
True omega: 1.540027, prediction by SMC: 2.071102
True omega: 5.853980, prediction by SMC: 5.854007
True omega: 4.878610, prediction by SMC: 4.827607
True omega: 3.502063, prediction by SMC: 3.503495
True omega: 2.679388, prediction by SMC: 2.761269
True omega: 1.409629, prediction by SMC: 1.446164
True omega: 3.539685, prediction by SMC: 3.542430
True omega: 3.039044, prediction by SMC: 3.176755
True omega: 0.356837, prediction by SMC: 0.369310
True omega: 2.257989, prediction by SMC: 2.220772


True omega: 4.793722, prediction by SMC: 4.649042
True omega: 3.874712, prediction by SMC: 3.874883
True omega: 0.115654, prediction by SMC: 0.467576
True omega: 3.631169, prediction by SMC: 3.580959
True omega: 2.358295, prediction by SMC: 2.358294
True omega: 4.462723, prediction by SMC: 4.465553
True omega: 5.218265, prediction by SMC: 5.225125
True omega: 4.099452, prediction by SMC: 4.395703
True omega: 0.262861, prediction by SMC: 0.795477
True omega: 5.397509, prediction by SMC: 5.339172
True omega: 1.107719, prediction by SMC: 1.056697
True omega: 0.660867, prediction by SMC: 0.661722
True omega: 0.206668, prediction by SMC: 0.212959
True omega: 2.686649, prediction by SMC: 1.387807
True omega: 1.520329, prediction by SMC: 1.493846
True omega: 0.494338, prediction by SMC: 0.491723
True omega: 1.736817, prediction by SMC: 1.693635
True omega: 1.256541, prediction by SMC: 1.205432
True omega: 5.004457, prediction by SMC: 4.867706
True omega: 4.908674, prediction by SMC: 4.916324


True omega: 3.235851, prediction by SMC: 3.236527
True omega: 0.158991, prediction by SMC: 0.575897
True omega: 6.204930, prediction by SMC: 3.369526
True omega: 4.382189, prediction by SMC: 4.346003
True omega: 4.008839, prediction by SMC: 3.784610
True omega: 1.495753, prediction by SMC: 1.447892
True omega: 1.748233, prediction by SMC: 1.733806
True omega: 0.880137, prediction by SMC: 0.854671
True omega: 2.507005, prediction by SMC: 2.467013
True omega: 2.379815, prediction by SMC: 2.424387
True omega: 1.963533, prediction by SMC: 1.751395
True omega: 2.354356, prediction by SMC: 2.365278
True omega: 1.571349, prediction by SMC: 1.573673
True omega: 1.005541, prediction by SMC: 1.005220
True omega: 5.510325, prediction by SMC: 5.328557
True omega: 2.515575, prediction by SMC: 2.535055
True omega: 1.961850, prediction by SMC: 1.727892
True omega: 3.859396, prediction by SMC: 3.865605
True omega: 3.733312, prediction by SMC: 3.716789
True omega: 0.101984, prediction by SMC: 0.230180


True omega: 0.531459, prediction by SMC: 0.428655
True omega: 5.312230, prediction by SMC: 4.223634
True omega: 6.202419, prediction by SMC: 6.018108
True omega: 1.709747, prediction by SMC: 1.640034
True omega: 1.777224, prediction by SMC: 1.738605
True omega: 6.028102, prediction by SMC: 5.949913
True omega: 0.255339, prediction by SMC: 0.255409
True omega: 3.750647, prediction by SMC: 3.750640
True omega: 1.097239, prediction by SMC: 1.105245
True omega: 2.357936, prediction by SMC: 2.335284
True omega: 0.912125, prediction by SMC: 1.003881
True omega: 5.629708, prediction by SMC: 5.629704
True omega: 2.079031, prediction by SMC: 1.659372
True omega: 2.562378, prediction by SMC: 2.525922
True omega: 0.274313, prediction by SMC: 0.317087
True omega: 2.892110, prediction by SMC: 2.892445
True omega: 0.549645, prediction by SMC: 0.570515
True omega: 2.814894, prediction by SMC: 2.780954
True omega: 0.137702, prediction by SMC: 0.137683
True omega: 1.466043, prediction by SMC: 1.466041


True omega: 0.248863, prediction by SMC: 0.248761
True omega: 5.971687, prediction by SMC: 5.970631
True omega: 1.788286, prediction by SMC: 1.784368
True omega: 1.154572, prediction by SMC: 1.120645
True omega: 1.767818, prediction by SMC: 2.048439
True omega: 5.720390, prediction by SMC: 5.671771
True omega: 1.118675, prediction by SMC: 1.021615
True omega: 4.840975, prediction by SMC: 3.906954
True omega: 4.551293, prediction by SMC: 4.411355
True omega: 5.314532, prediction by SMC: 5.308654
True omega: 1.320823, prediction by SMC: 1.345948
True omega: 5.711725, prediction by SMC: 5.600732
True omega: 5.270219, prediction by SMC: 3.684542
True omega: 5.111362, prediction by SMC: 5.094989
True omega: 3.068547, prediction by SMC: 2.540777
True omega: 2.716853, prediction by SMC: 2.219222
True omega: 3.439751, prediction by SMC: 3.457686
True omega: 3.019021, prediction by SMC: 3.012562
True omega: 3.508122, prediction by SMC: 3.535353
True omega: 3.385973, prediction by SMC: 3.425670


True omega: 3.317994, prediction by SMC: 2.884713
True omega: 1.407886, prediction by SMC: 1.500632
True omega: 1.208979, prediction by SMC: 1.225223
True omega: 4.105549, prediction by SMC: 4.133614
True omega: 2.127242, prediction by SMC: 2.147000
True omega: 1.883382, prediction by SMC: 1.883383
True omega: 3.350660, prediction by SMC: 3.200654
True omega: 1.113290, prediction by SMC: 1.113423
True omega: 5.449460, prediction by SMC: 5.453223
True omega: 2.400844, prediction by SMC: 2.400969
True omega: 0.132409, prediction by SMC: 0.159084
True omega: 2.055329, prediction by SMC: 2.045383
True omega: 0.463757, prediction by SMC: 0.370627
True omega: 0.156926, prediction by SMC: 0.904870
True omega: 3.027931, prediction by SMC: 3.011064
