In [None]:
import os
import math
import numpy as np
import torch
import torch.nn.functional as F
import Pk_library as PKL
import scipy.ndimage
import imageio
import matplotlib.pyplot as plt

import utilities
import flow_architecture
import losses

In [None]:
device = 'cuda'
float_dtype = np.float32
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device_id = 2
torch.cuda.set_device(device_id)

In [None]:
save_dir = "nbody_384px_mask_1p0_flow/"
if not os.path.exists(save_dir): os.makedirs(save_dir)

In [None]:
cl_theo_ell = np.load('sample_test_data/384px_cl_theo_ell.npy')
cl_theo = np.load('sample_test_data/384px_cl_theo.npy')

In [None]:
class Parameters():
    def __init__(self):
        #Data parameters
        self.nx = 128
        self.dx = 0.00018425707547169813
        
        #Fitting parameters
        self.nlev_t = 1.0
        self.noise_fac = self.nlev_t
        self.noise_pix = 2*(self.nlev_t)**2
        self.use_ql = False #The nbody power spectrum is matched with trainingdata
        self.wf_batch_size = 100 #The number of maps to fit
        
        mask512 = (imageio.imread("masks/mask2_512.png")[19:485, 19:485, 0]/255).astype(float)
        self.mask = scipy.ndimage.zoom(mask512, 384/(485-19), order=0)
        mask_patch_0 = utilities.make_small_maps_from_big_map(torch.tensor(self.mask,  dtype=torch.float32), 128)
        mask_patch_1 = utilities.make_small_maps_from_big_map(torch.tensor(self.mask,  dtype=torch.float32), 128, displace=1)
        mask_patch_2 = utilities.make_small_maps_from_big_map(torch.tensor(self.mask,  dtype=torch.float32), 128, displace=2)
        mask_patch_3 = utilities.make_small_maps_from_big_map(torch.tensor(self.mask,  dtype=torch.float32), 128, displace=3)
        self.mask_patches = torch.cat((mask_patch_0, mask_patch_1, mask_patch_2, mask_patch_3), axis=0)
        
        #Pre-trained flow parameters
        self.flow_n_layers = 16
        self.flow_hidden = [12, 12]
        self.trained_flow_dir = 'pretrained_flows/'
        
params = Parameters()

n_maps = 36

In [None]:
plt.imshow(params.mask)
plt.colorbar()
print(params.mask.shape)

In [None]:
prior = flow_architecture.SimpleNormal(torch.zeros((params.nx, params.nx)), torch.ones((params.nx, params.nx)))

layers = flow_architecture.make_flow1_affine_layers(lattice_shape=(params.nx, params.nx),
                                                    n_layers=params.flow_n_layers, hidden_sizes=params.flow_hidden,
                                                    kernel_size=[3, 3, 3], torch_device=device, padding_mode='zeros')
model = {'layers': layers, 'prior': prior}

checkpoint = torch.load(params.trained_flow_dir+'dict_nonperiodic')
model['layers'].load_state_dict(checkpoint['model_state_dict'])

In [None]:
y_true_np = np.load('sample_test_data/384px_true_map.npy')
y_true_np = (y_true_np - np.mean(y_true_np)) / np.std(y_true_np)
np.save(save_dir + 'true_maps', y_true_np)

y_pred_np = utilities.add_noise(y_true_np, std=params.noise_fac) * params.mask
np.save(save_dir + 'masked_maps', y_pred_np)

In [None]:
vmin = np.min(y_true_np)
vmax = 11
figsize = (6, 6)

In [None]:
utilities.imshow(y_true_np[0], vmin=vmin, vmax=vmax, title='Truth', figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'truth')

In [None]:
utilities.imshow(y_pred_np[0], vmin=vmin, vmax=vmax, title='Masked', figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'masked')

In [None]:
y_pred_0 = utilities.make_small_maps_from_big_map(torch.tensor(y_pred_np[0],  dtype=torch.float32), 128)
y_pred_1 = utilities.make_small_maps_from_big_map(torch.tensor(y_pred_np[0],  dtype=torch.float32), 128, displace=1)
y_pred_2 = utilities.make_small_maps_from_big_map(torch.tensor(y_pred_np[0],  dtype=torch.float32), 128, displace=2)
y_pred_3 = utilities.make_small_maps_from_big_map(torch.tensor(y_pred_np[0],  dtype=torch.float32), 128, displace=3)

In [None]:
y_pred_all = torch.cat((y_pred_0, y_pred_1, y_pred_2, y_pred_3), axis=0)

In [None]:
y_true        = torch.tensor(y_true_np, requires_grad=True,  dtype=torch.float32).to(device)
y_pred_nograd = torch.tensor(y_pred_all, requires_grad=False, dtype=torch.float32).to(device)
y_pred_flow = [None] * n_maps
for n in range(n_maps):
    y_pred_flow[n] = torch.tensor(torch.unsqueeze(y_pred_all[n], 0), requires_grad=True).to(device)
y_pred_wf = [None] * params.wf_batch_size
for n in range(1):
    y_pred_wf[n] = torch.tensor(np.expand_dims(y_pred_np[n], 0), requires_grad=True,  dtype=torch.float32).to(device)

In [None]:
lossfunctions = losses.Lossfunctions(params, cl_theo_ell=cl_theo_ell, cl_theo=cl_theo)

In [None]:
loss_list_flow = []
J2_ave_list_flow = []
J2_map_list_flow = [None] * params.wf_batch_size
loss_list_wf = []
J2_ave_list_wf = []
J2_map_list_wf = [None] * params.wf_batch_size

In [None]:
optimizer_flow = []
for n in range(36):
    optimizer_flow.append(torch.optim.Adam([y_pred_flow[n]], lr=0.05))

In [None]:
def optimize(y_pred_nograd, y_pred, optimizer, steps, loss_list, J2_ave_list, J2_map_list, use_flow, print_freq=100):
    for i in range(steps):
        loss_ave = 0
        J2_ave = 0
        
        for n in range(n_maps):
            optimizer[n].zero_grad()
            if use_flow:
                loss_1, loss_2 = lossfunctions.loss_wiener_J3_flow_patching(y_pred_nograd[n], y_pred[n], prior, model['layers'], patch_id=n)
            else:
                loss_1, loss_2 = lossfunctions.loss_wiener_J3(y_pred_nograd[n], y_pred[n])
            loss = loss_1 + loss_2
            loss.backward()
            optimizer[n].step()
            loss_ave += loss.cpu().detach().numpy() / params.wf_batch_size
            
        loss_list.append(loss_ave)
        if i % print_freq == 0: print("step =", i, "loss =", loss_ave, "J2 =", J2_ave)

## Flow

In [None]:
optimize(y_pred_nograd, y_pred_flow, optimizer_flow, 3, loss_list_flow, J2_ave_list_flow, J2_map_list_flow, True)

In [None]:
for n in range(params.wf_batch_size):
    for g in optimizer_flow[0].param_groups:
        g['lr'] = 0.01

In [None]:
optimize(y_pred_nograd, y_pred_flow, optimizer_flow, 10, loss_list_flow, J2_ave_list_flow, J2_map_list_flow, True)

In [None]:
utilities.plot_lists(loss_list_flow[:], title='Flow loss', file_name=save_dir+'flow_loss')

In [None]:
y_pred_flow_np = y_pred_flow[0].cpu().detach().numpy()
utilities.imshow(y_pred_flow_np[0], title='Optimized map with flow prior',
                 vmin=vmin, vmax=vmax, figsize=figsize, axis=False, colorbar=False, file_name=save_dir+'flow_result')

In [None]:
np.save(save_dir+'y_pred_flow_np', y_pred_flow_np)

In [None]:
y_pred_flow_cc = torch.zeros((n_maps, params.nx, params.nx))
for n in range(n_maps):
    y_pred_flow_cc[n, :, :] = y_pred_flow[n]

In [None]:
y_pred_flow_cc_np = utilities.grab(y_pred_flow_cc)
np.save(save_dir + 'flow_maps', y_pred_flow_cc_np)