In [None]:
import os
import math
import numpy as np
import torch
import torch.nn.functional as F
import Pk_library as PKL
import imageio
import matplotlib.pyplot as plt

import utilities
import flow_architecture
import losses

In [None]:
device = 'cuda'
float_dtype = np.float32
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device_id = 1
torch.cuda.set_device(device_id)

In [None]:
save_dir = "nbody_128px_mask_1p0_flow/"
if not os.path.exists(save_dir): os.makedirs(save_dir)

In [None]:
cl_theo_ell = np.load('sample_test_data/128px_cl_theo_ell.npy')
cl_theo = np.load('sample_test_data/128px_cl_theo.npy')

In [None]:
class Parameters():
    def __init__(self):
        #Data parameters
        self.nx = 128
        self.dx = 0.00018425707547169813
        
        #Fitting parameters
        self.nlev_t = 1.0
        self.noise_fac = self.nlev_t
        self.noise_pix = 2*(self.nlev_t)**2
        self.use_ql = False #The nbody power spectrum is matched with trainingdata
        self.wf_batch_size = 100 #The number of maps to fit
        mask128 = (imageio.imread("masks/mask1_128.png")[:, :, 0]/255).astype(float)
        self.mask = mask128 #np.ones((self.nx, self.nx))
        
        #Pre-trained flow parameters
        self.flow_n_layers = 16
        self.flow_hidden = [12, 12]
        self.trained_flow_dir = 'pretrained_flows/'
        
params = Parameters()

In [None]:
prior = flow_architecture.SimpleNormal(torch.zeros((params.nx, params.nx)), torch.ones((params.nx, params.nx)))

layers = flow_architecture.make_flow1_affine_layers(lattice_shape=(params.nx, params.nx),
                                                    n_layers=params.flow_n_layers, hidden_sizes=params.flow_hidden,
                                                    kernel_size=[3, 3, 3], torch_device=device, padding_mode='circular')
model = {'layers': layers, 'prior': prior}

checkpoint = torch.load(params.trained_flow_dir+'dict_periodic')
model['layers'].load_state_dict(checkpoint['model_state_dict'])

In [None]:
y_true_np = np.load('sample_test_data/128px_true_maps_periodic.npy')
np.save(save_dir + 'true_maps', y_true_np)

y_pred_np = utilities.add_noise(y_true_np, std=params.noise_fac) * params.mask
np.save(save_dir + 'masked_maps', y_pred_np)

In [None]:
vmin = np.min(y_true_np)
vmax = 11
figsize = (4, 4)

In [None]:
utilities.imshow(y_true_np[0], vmin=vmin, vmax=vmax, title='Truth', figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'truth')

In [None]:
utilities.imshow(y_pred_np[0], vmin=vmin, vmax=vmax, title='Masked', figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'masked')

In [None]:
y_true        = torch.tensor(y_true_np, requires_grad=True,  dtype=torch.float32).to(device)
y_pred_nograd = torch.tensor(y_pred_np, requires_grad=False, dtype=torch.float32).to(device)
y_pred_flow = [None] * params.wf_batch_size
for n in range(params.wf_batch_size):
    y_pred_flow[n]   = torch.tensor(np.expand_dims(y_pred_np[n], 0), requires_grad=True,  dtype=torch.float32).to(device)
y_pred_wf = [None] * params.wf_batch_size
for n in range(params.wf_batch_size):
    y_pred_wf[n] = torch.tensor(np.expand_dims(y_pred_np[n], 0), requires_grad=True,  dtype=torch.float32).to(device)

In [None]:
lossfunctions = losses.Lossfunctions(params, cl_theo_ell=cl_theo_ell, cl_theo=cl_theo)

In [None]:
loss_list_flow = []
J2_ave_list_flow = []
J2_map_list_flow = [None] * params.wf_batch_size
loss_list_wf = []
J2_ave_list_wf = []
J2_map_list_wf = [None] * params.wf_batch_size

In [None]:
optimizer_flow = []
for n in range(params.wf_batch_size):
    optimizer_flow.append(torch.optim.Adam([y_pred_flow[n]], lr=0.05))

In [None]:
def optimize(y_pred_nograd, y_pred, optimizer, steps, loss_list, J2_ave_list, J2_map_list, use_flow, print_freq=100):
    for i in range(steps):
        loss_ave = 0
        J2_ave = 0
        
        for n in range(params.wf_batch_size):
            optimizer[n].zero_grad()
            if use_flow:
                loss_1, loss_2 = lossfunctions.loss_wiener_J3_flow(y_pred_nograd[n], y_pred[n], prior, model['layers'])
            else:
                loss_1, loss_2 = lossfunctions.loss_wiener_J3(y_pred_nograd[n], y_pred[n])
            loss = loss_1 + loss_2
            loss.backward()
            optimizer[n].step()
            loss_ave += loss.cpu().detach().numpy() / params.wf_batch_size
            J2_map_list[n] = lossfunctions.loss_J2(y_true[n], y_pred[n]).cpu().detach().numpy()
            J2_ave += J2_map_list[n] / params.wf_batch_size
            
            
        loss_list.append(loss_ave)
        J2_ave_list.append(J2_ave)
        if i % print_freq == 0: print("step =", i, "loss =", loss_ave, "J2 =", J2_ave)

## Flow

In [None]:
optimize(y_pred_nograd, y_pred_flow, optimizer_flow, 300, loss_list_flow, J2_ave_list_flow, J2_map_list_flow, True)

In [None]:
for n in range(params.wf_batch_size):
    for g in optimizer_flow[0].param_groups:
        g['lr'] = 0.01

In [None]:
optimize(y_pred_nograd, y_pred_flow, optimizer_flow, 1000, loss_list_flow, J2_ave_list_flow, J2_map_list_flow, True)

In [None]:
utilities.plot_lists(loss_list_flow[:], title='Flow loss', file_name=save_dir+'flow_loss')

In [None]:
utilities.plot_lists(J2_ave_list_flow, title='Flow J2', file_name=save_dir+'flow_J2')

In [None]:
J2_ave_list_flow[-1]/(128*128)

In [None]:
y_pred_flow_np = y_pred_flow[0].cpu().detach().numpy()
utilities.imshow(y_pred_flow_np[0], title='Optimized map with flow prior',
                 vmin=vmin, vmax=vmax, figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'flow_result')

In [None]:
np.save(save_dir+'/y_pred_flow_np', y_pred_flow_np)

## Weiner filtering

In [None]:
optimizer_wf = []
for n in range(params.wf_batch_size):
    optimizer_wf.append(torch.optim.Adam([y_pred_wf[n]], lr=0.01))

In [None]:
optimize(y_pred_nograd, y_pred_wf, optimizer_wf, 3000, loss_list_wf, J2_ave_list_wf, J2_map_list_wf, False)

In [None]:
utilities.plot_lists(loss_list_wf[:], title='WF loss', file_name=save_dir+'wf_loss')

In [None]:
utilities.plot_lists(J2_ave_list_wf, title='WF J2', file_name=save_dir+'wf_J2')

In [None]:
y_pred_wf_np = y_pred_wf[0].cpu().detach().numpy()
utilities.imshow(y_pred_wf_np[0], title='Optimized map with Wiener filtering',
                 vmin=vmin, vmax=vmax, figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'/wf_result')

In [None]:
np.save(save_dir+'y_pred_wf_np', y_pred_wf_np)

In [None]:
y_pred_flow_cc = torch.zeros((params.wf_batch_size, params.nx, params.nx))
for n in range(params.wf_batch_size):
    y_pred_flow_cc[n, :, :] = y_pred_flow[n]
    
y_pred_wf_cc = torch.zeros((params.wf_batch_size, params.nx, params.nx))
for n in range(params.wf_batch_size):
    y_pred_wf_cc[n, :, :] = y_pred_wf[n]

In [None]:
y_pred_flow_cc_np = utilities.grab(y_pred_flow_cc)
np.save(save_dir + 'flow_maps', y_pred_flow_cc_np)

y_pred_wf_cc_np = utilities.grab(y_pred_wf_cc)
np.save(save_dir + 'wf_maps', y_pred_wf_cc_np)