# Example showing how to train the CNN
## Training without a GPU takes a very long time

In [None]:
%matplotlib inline

import numpy as np
import time
import os
import sys
import random
import gc

import matplotlib.pyplot as plt

from deepmass import map_functions as mf
from deepmass import lens_data as ld
from deepmass import wiener
from deepmass import cnn_keras as cnn

# This demonstration uses the validation data as training data 
### (the separate full training data cannot fit on the git repository)

In [None]:
map_size = 256
n_test = int(1000)
n_epoch = 20
batch_size = 32
learning_rate = 1-5

In [None]:
# make SV mask
mask = np.float32(np.real(np.where(np.load('../picola_training/Ncov.npy') > 1.0, 0.0, 1.0)))
_ = plt.imshow(mask, origin='lower', clim=(0,1)), plt.colorbar()

In [None]:
wiener_array = np.load('../picola_training/validation_data/test_array_wiener.npy')
gc.collect()

In [None]:
clean_array = np.load('../picola_training/validation_data/test_array_clean.npy')
gc.collect()

In [None]:
train_array_noisy = wiener_array[n_test:]
train_array_clean = clean_array[n_test:]

test_array_noisy = wiener_array[:n_test]
test_array_clean = clean_array[:n_test:]
gc.collect()

In [None]:
train_gen = cnn.BatchGenerator(train_array_noisy, train_array_clean, gen_batch_size=batch_size)
test_gen = cnn.BatchGenerator(test_array_noisy, test_array_clean, gen_batch_size=batch_size)

# Load and train model

In [None]:
cnn_instance = cnn.UnetlikeBaseline(map_size=map_size, learning_rate=learning_rate)
cnn_model = cnn_instance.model()


In [None]:
history = cnn_model.fit_generator(generator=train_gen,
                         epochs=n_epoch,
                         steps_per_epoch=np.ceil(train_array_noisy.shape[0] / int(batch_size)),
                         validation_data=test_gen,
                         validation_steps=np.ceil(test_array_noisy.shape[0] / int(batch_size)))

gc.collect()

In [None]:
_ = plt.plot(np.arange(n_epoch)+1., history.history['loss'], label = 'loss', marker = 'o')
_ = plt.plot(np.arange(n_epoch)+1., history.history['val_loss'], label = 'val loss', marker = 'x')
_ = plt.legend()

# Apply model

In [None]:
test_output = cnn_model.predict(test_array_noisy)

In [None]:
print('Result MSE =' + str(mf.mean_square_error(test_array_clean.flatten(),
                                                test_output.flatten())))

In [None]:
xticks=[None,'65°','75°','85°']
yticks=[]

_ = plt.figure(figsize =(15,4.5))
_ = plt.subplot(1,3,1), plt.title(r'${\rm Truth\ (Target)}$', fontsize=16)
_ = plt.imshow(np.where(mask!=0., (test_array_clean[0,:,:,0] -0.5)/3, np.nan),
               origin='lower', cmap='inferno', clim = (-0.025,0.025))


plt.xlabel(r'${\rm RA}$')
plt.ylabel(r'${\rm DEC}$', labelpad = 20.)

_ = plt.subplot(1,3,2), plt.title(r'${\rm Wiener\ filter}$', fontsize=16)
_ = plt.imshow(np.where(mask!=0., (test_array_noisy[0,:,:,0] -0.5)/3, np.nan),
               origin='lower', cmap='inferno', clim = (-0.025,0.025))

plt.xlabel(r'${\rm RA}$')

_ = plt.subplot(1,3,3), plt.title(r'${\rm DeepMass}$', fontsize=16)
_ = plt.imshow(np.where(mask!=0., (test_output[0,:,:,0] -0.5)/3, np.nan),
               origin='lower', cmap='inferno', clim = (-0.025,0.025))
plt.xlabel(r'${\rm RA}$')

plt.subplots_adjust(wspace=-0.3)
