In [None]:
import sys
import quicklens as ql

import utilities 
import flow_architecture
import training_data

import base64
import io
import time
import pickle
import numpy as np
import pylab as pl
import quicklens as ql
import scipy.ndimage
import torch
import torch.fft
print(f'TORCH VERSION: {torch.__version__}')
import importlib
%matplotlib inline
import matplotlib.pyplot as plt
torch.cuda.empty_cache()

In [None]:
#importlib.reload(training_data)
importlib.reload(utilities)
#importlib.reload(flow_architecture)

In [None]:
if torch.cuda.is_available():
  torch_device = 'cuda'
  float_dtype = np.float32 # single
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  torch.cuda.set_device(1) # Set to a GPU
else:
  torch_device = 'cpu'
  float_dtype = np.float64 # double
  torch.set_default_tensor_type(torch.DoubleTensor)
print(f"TORCH DEVICE: {torch_device} {torch.cuda.current_device()}")

startTime = time.time()

# training data

In [None]:
datamode = "nbodybased" #gaussbased or nbodybased

In [None]:
if datamode == "gaussbased":
    nx = 64 #16 # number of pixels.
    dx = 4.*utilities.d2r / float(nx)   #1 deg for 16px. 4 deg for 64 px
    fnl = 0.2 #0.05
    fnlmode = True
    trainingdata = training_data.TrainingDataGaussBased(nx,dx,fnl,fnlmode)
    #trainingdata = training_data.TrainingDataGaussBased_Cached(nx,dx,fnl,fnlmode,nmaps_train=10000,nmaps_valid=10000)
    lmax = trainingdata.lmax

In [None]:
if datamode == "nbodybased":
    trainingdata = training_data.TrainingDataNbodyBased()
    nx = trainingdata.nx
    dx = trainingdata.dx
    lmax = trainingdata.lmax
    fnl = trainingdata.fnl #arbitrary normalisation in nbody case

In [None]:
#test the norm
samples_test = trainingdata.draw_samples_of_px(1000)
print (np.std(samples_test))

# defining the flow

In [None]:
lattice_shape = (nx,nx)

priormode = "correlated" #"correlated" #correlated or whitenoise

In [None]:
if priormode == "whitenoise":
    prior = flow_architecture.SimpleNormal(torch.zeros(lattice_shape), torch.ones(lattice_shape))
    
if priormode == "correlated":
    rfourier_shape = (nx,int(nx/2+1),2)
    prior = flow_architecture.CorrelatedNormal(torch.zeros(rfourier_shape), torch.ones(rfourier_shape),nx,dx,trainingdata.cl_theo,torch_device)

In [None]:
nsamples = 100
torch_z = prior.sample_n(nsamples)
z = utilities.grab(torch_z)
torch_logp = prior.log_prob(torch_z)
logp = utilities.grab(torch_logp)
#print(f'z.shape = {z.shape}')

fig, ax = plt.subplots(2,4, figsize=(10,5))
for j in range(4):
    for i in range(2):
        ind = i*4 + j
        ax[i,j].imshow(z[ind], cmap='viridis')
        ax[i,j].axes.xaxis.set_visible(False)
        ax[i,j].axes.yaxis.set_visible(False)
        #print ("logp unnormed",logp[ind])
plt.show()

print (np.std(z))

In [None]:
n_layers = 16 #std: 16 
hidden_sizes = [16,16] #std: [16,16] 
kernel_size = 3
layers = flow_architecture.make_flow1_affine_layers(
    lattice_shape=lattice_shape, n_layers=n_layers, 
    hidden_sizes=hidden_sizes, kernel_size=kernel_size,torch_device=torch_device)
model = {'layers': layers, 'prior': prior}

# training

In [None]:
base_lr = 0.001 #standard: 0.001
optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr)

In [None]:
N_era = 10000
N_epoch = 100
batch_size = 128
print_freq = N_epoch
plot_freq = 1

history = {
    'loss' : [],
}

lossList = []
validationList = []

In [None]:
def train_step(model, optimizer, metrics, trainingdata):
    layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()

    x = trainingdata.draw_samples_of_px(batch_size)
    x = torch.from_numpy(x).float().to(torch_device)

    u, log_pu, log_J_Tinv = flow_architecture.apply_reverse_flow_to_sample(x, prior, layers)

    loss = -(log_pu + log_J_Tinv).mean()

    loss.backward()
    optimizer.step()

    lossval = utilities.grab(loss)
    lossList.append(lossval)
    
    metrics['loss'].append(lossval)

In [None]:
use_pretrained = False

print("  era  |     sample loss      |   validation loss   |  time")
print("--------------------------------------------------------------")
if not use_pretrained:
    for era in range(N_era):
        print ("  ",era,' ','|','                    'if era==0 else lossList[-1],'|','                    'if era==0 else validationList[-1],'|', round(time.time()-startTime), "s")
        for epoch in range(N_epoch):
            train_step(model, optimizer, history, trainingdata)
            
            v,_ = draw_samples_of_v(batch_size)
            v = torch.from_numpy(v).float().to(torch_device)
            u_v, log_pu_v, log_J_Tinv_v = flow_architecture.apply_reverse_flow_to_sample(v, prior, layers)
            validation = -(log_pu_v + log_J_Tinv_v).mean()
            validationval = utilities.grab(validation)
            validationList.append(validationval)
else:
    print ("restoring state dict.")
    save_model_dir = "models/gauss_cached15/"
    model['layers'].load_state_dict(torch.load(save_model_dir+"model"))
    with open(save_model_dir+'train_loss.pkl', 'rb') as f:
        lossList = pickle.load(f)
    with open(save_model_dir+'val_loss.pkl', 'rb') as f:
        validationList = pickle.load(f)


In [None]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

idmin = 0
idmax = len(lossList)

lossListmod = np.array(lossList.copy())
#lossListmod[lossListmod>-1800] = -2000

validationListmod = np.array(validationList.copy())
#validationListmod[validatonListmod>-1800] = -2000
#validationListmod = validationListmod[validationListmod<0]

fig=plt.figure(figsize=(5,3))
#plt.plot(np.arange(idmin,len(validationListmod),1),validationListmod[idmin:idmax],label='Validation')
plt.plot(np.arange(idmin,idmax,1),lossListmod[idmin:idmax],color='red',label='Training')
plt.legend(loc=1,frameon=False,fontsize=14)
plt.grid(True)
plt.xlabel('Batch',fontsize=14)
plt.ylabel('Loss',fontsize=14)
fig.tight_layout()
plt.savefig('loss_nbody.pdf') #loss_gauss.pdf

# visually inspect the results

In [None]:
u, log_pu, z, log_pz = flow_architecture.apply_flow_to_prior(prior, layers, batch_size=1)
u_rev, log_pu_rev, log_J_Tinv = flow_architecture.apply_reverse_flow_to_sample(z, prior, layers)
#print (z)
#print (u)
#print (u_rev)
print (log_pu)
print (log_pu_rev)
print (log_pz)
print (log_J_Tinv)
print (log_pu+log_J_Tinv)

In [None]:
u, log_pu, z, log_pz = flow_architecture.apply_flow_to_prior(prior, layers, batch_size=16)
z = utilities.grab(z)
u = utilities.grab(u)

fig, ax = plt.subplots(1,3, figsize=(9,3.5))
for i in range(3):
    ax[i].imshow(u[i], cmap='viridis')
    ax[i].axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
fig.suptitle('Prior samples', fontsize=14)
fig.tight_layout()
plt.savefig("samples_nbody_prior.pdf")
plt.show()

fig, ax = plt.subplots(1,3, figsize=(9,3.5))
for i in range(3):
    ax[i].imshow(z[i], cmap='viridis')
    ax[i].axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
fig.suptitle('Model samples', fontsize=14)
fig.tight_layout()
plt.savefig("samples_nbody_model.pdf")
plt.show()

x = trainingdata.draw_samples_of_px(16)
fig, ax = plt.subplots(1,3, figsize=(9,3.5))
for i in range(3):
    ax[i].imshow(x[i], cmap='viridis')
    ax[i].axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
fig.suptitle('Target samples', fontsize=14)
fig.tight_layout()
plt.savefig("samples_nbody_target.pdf")
plt.show()

# make power spectrum and density plot

In [None]:
ntest = 10000
batchsize_test = 100
samples_true = np.zeros( (ntest,nx,nx) )
samples_true_logp_true = np.zeros( (ntest) )
samples_flow = np.zeros( (ntest,nx,nx) )
samples_prior = np.zeros( (ntest,nx,nx) )

#draw truth samples
for batch_id in range(100):
    z,log_p = trainingdata.draw_samples_of_pv(batchsize_test)
    #print (z.shape,log_p.shape)
    samples_true[batch_id*batchsize_test:(batch_id+1)*batchsize_test] = z  
    samples_true_logp_true[batch_id*batchsize_test:(batch_id+1)*batchsize_test] = -0.5*log_p

#make flow samples
for batch_id in range(100):
    u, log_pu, z, log_pz = flow_architecture.apply_flow_to_prior(prior, layers, batch_size=batchsize_test)
    z = utilities.grab(z)
    u = utilities.grab(u)
    samples_flow[batch_id*batchsize_test:(batch_id+1)*batchsize_test] = z
    
#make gaussianized samples
for batch_id in range(100):
    torch_z = prior.sample_n(nsamples)
    z = utilities.grab(torch_z)
    samples_prior[batch_id*batchsize_test:(batch_id+1)*batchsize_test] = z

In [None]:
lbins      = np.linspace(100, lmax, 25) # multipole bins.

cl_avg_true,ell_binned = utilities.estimate_ps_ensemble(samples_true,nx,dx,lbins)
cl_avg_flow,ell_binned = utilities.estimate_ps_ensemble(samples_flow,nx,dx,lbins)
cl_avg_prior,ell_binned = utilities.estimate_ps_ensemble(samples_prior,nx,dx,lbins)

In [None]:
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

fig=plt.figure(figsize=(5,3))
ax = fig.add_subplot(111)
ax.plot(ell_binned,cl_avg_true*ell_binned**2.,color='red',label='$C_\ell^{true}$')
ax.plot(ell_binned,cl_avg_flow*ell_binned**2.,color='black',label='$C_\ell^{flow}$')
ax.plot(ell_binned,cl_avg_prior*ell_binned**2.,color='green',ls='dotted',label='$C_\ell^{prior}$')
#ax.plot(cl_theo_ell,cl_theo_tt*cl_theo_ell**2.,color='black',label='$C_\ell^{theo}$')
plt.legend(loc=1,frameon=False,fontsize=14)
ax.set_yscale('log')
ax.set_xlim(0,trainingdata.lmax-600)
plt.xlabel('$\ell$',fontsize=14)
plt.ylabel('$\ell^2 C_\ell$',fontsize=14)
ax.grid(True)
fig.tight_layout()
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
#plt.savefig('ps_nobdy_128.pdf')
plt.savefig('ps_nbody.pdf')
plt.show()

# non-gaussianity estimation

In [None]:
#draw some gaussian samples for variance calculation
samples_true_gauss = np.zeros( (ntest,nx,nx) )

#draw truth samples
for batch_id in range(100):
    z,log_p = trainingdata.draw_samples_of_pv_gauss(batchsize_test)
    samples_true_gauss[batch_id*batchsize_test:(batch_id+1)*batchsize_test] = z  

In [None]:
print ("FNL LOCAL")
cl_theo_normed = trainingdata.cl_theo_normed 
fnls_flow,_ = utilities.estimate_fnl_local_ensemble(samples_flow, samples_true_gauss, cl_theo_normed,nx,dx)
fnls_true,_ = utilities.estimate_fnl_local_ensemble(samples_true, samples_true_gauss, cl_theo_normed,nx,dx)

print("TRUTH: mean fnl local",np.mean(fnls_true),"std per map",np.std(fnls_true), "snr ratio", np.mean(fnls_true)/np.std(fnls_true))
print("FLOW: mean fnl local",np.mean(fnls_flow),"std",np.std(fnls_flow), "snr ratio", np.mean(fnls_flow)/np.std(fnls_flow))
print (" ")
renormfactor = fnl/np.mean(fnls_true) #takes into account ng variance
print("TRUTH renorm: mean fnl local",np.mean(fnls_true)*renormfactor,"std per map",np.std(fnls_true)*renormfactor, "snr ratio", np.mean(fnls_true)/np.std(fnls_true))
print("FLOW renorm: mean fnl local",np.mean(fnls_flow)*renormfactor,"std per map",np.std(fnls_flow)*renormfactor, "snr ratio", np.mean(fnls_flow)/np.std(fnls_flow))

#print("check zero val: mean no-fnl local",np.mean(fnl_normed_gauss),"std",np.std(fnl_normed_gauss))

In [None]:
print ("FNL EQUILATERAL")
fnls_flow,_ = utilities.estimate_fnl_equilateral_ensemble(samples_flow, samples_true_gauss, trainingdata.cl_theo_ell,nx,dx)
fnls_true,_ = utilities.estimate_fnl_equilateral_ensemble(samples_true, samples_true_gauss, trainingdata.cl_theo_ell,nx,dx)

print("TRUTH: mean fnl equilateral",np.mean(fnls_true),"std per map",np.std(fnls_true), "snr ratio", np.mean(fnls_true)/np.std(fnls_true))
print("FLOW: mean fnl equilateral",np.mean(fnls_flow),"std per map",np.std(fnls_flow), "snr ratio", np.mean(fnls_flow)/np.std(fnls_flow))
print (" ")
print("TRUTH renorm: mean fnl equilateral",np.mean(fnls_true)*renormfactor,"std per map",np.std(fnls_true)*renormfactor, "snr ratio", np.mean(fnls_true)/np.std(fnls_true))
print("FLOW renorm: mean fnl equilateral",np.mean(fnls_flow)*renormfactor,"std per map",np.std(fnls_flow)*renormfactor, "snr ratio", np.mean(fnls_flow)/np.std(fnls_flow))