# Predicting Wavebands with a cGAN
One of the most exciting applications of deep learning is image-to-image translation which includes tasks such as image colourisation and super resolution. This task needed a lot of human input and hardcoding several years ago but, with the development of deep learning over recent years, the whole process can be done end-to-end with the power of machine learning. 

Here, we are focussed on predicting long wavelength, hereafter $LW$, JWST data *2.4-5.0µm* of simulated strong gravitational lenses given the short wavelength, $SW$, data *0.6-2.4µm* as input to the network. The network we use is a *conditional Generative Adversarial Network*, hereafter cGAN, and the architecture and methods closely follows that of  
[_**Image-to-Image Translation with Conditional Adversarial Networks**_](https://arxiv.org/abs/1611.07004) which may also be known by *pix2pix* where a general solution to many image-to-image translation problems is proposed, one being image colourisation. Using the trained cGAN, we are then of interest in predicting $LW$ JWST data of the strong gravitational lenses given Euclid VIS and NISP instrument data as input. Euclid VIS and NISP have a pixel resolution of $0.1"/pix$ and $0.3"/pix$, respectively, compared to the pixel resolution of JWST NIRcam being $0.03"/pix$. Thus, we expect a more resolved image of the gravitational lens as observed by JWST. If we see a different image of the lens as observed by JWST, then this image is different in an interesting way- maybe we are seeing more obscure arcs and rings or maybe not a lens at all. So, we propose that the cGAN can act as an anomaly detector.

In this approach, two losses are used, namely an L1 loss, which makes this task a regression task, and an adversarial (GAN) loss, which helps to solve the problem in an unsupervised manner. 

### The World of GANs ###
The architecture used in this problem is a conditional GAN which uses an extra loss function, the L1 loss. It is useful to understand the setup of a GAN.
In a GAN, there is a Generator and a Discriminator network which work together to solve a problem. In this model, the Generator network takes a 3-channel input, composed of the stacked short waveband data, NIRcam *F115W, F150W and F200W* filters, and produces a 3-channel output, that of the long waveband filters, NIRcam *F277W, F356W* and *F444W*. The Discriminator network takes the generated long waveband data and decides whether it is real or fake. Naturally, the Discriminator needs to see real inputs - those that are not produced by the Generator, and should learn that they are real. 

The condition on this model is that both the Generator and Discriminator *see* the input.
Let's take a further look into what the cGAN is doing. Consider _**x**_ to be the input to the network, _**z**_ as the input noise for the Generator, and _**y**_ the output we expect from the Generator. Let G and D denote the Generator and Discriminator networks, respectively. The loss of the cGAN can be described via:
\
![cGAN-loss](images/cGAN-loss.png).
\
Note that _**x**_ is the condition that we have introduced and it is seen by both networks. Also note that we are *not* feeding an *n*-dimensional vector of random noise to the Generator, which is common in machine learning networks, since the noise is introduced in the form of dropout layers in the Generator network.

#### Loss Function ####
The goal is optimisation and, more specifically, to minimise the loss. The above loss function helps to produce an output that seems real, however, to further steer the model in the right direction and to introduce some supervision into this task, we combine the above loss with the L1 loss (which can be also known as the mean absolute error):
\
![L1 loss](images/L1-loss.png).
\
The model will learn features from the data using the L1 loss alone, but it will be conservative and take an average which will reduce the L1 loss as much as possible (this can be compared to the blurring effect of L1 or L2 loss in a super resolution task). Combining the adversarial loss with the L1 loss gives the overall loss function for the model:
\
![loss](images/Loss.png),
\
where *λ* is the coefficient to balance the contribution of the two losses to the final loss. Note that the discriminator loss does not involve the L1 loss. 

### Implementing the cGAN ###
The cGAN is implemented using Pytorch; a convenient package for machine learning models. This example uses ~ 1000 input data in the form of the simulated strong gravitational lenses of size $64\times 64$ simulated using lenstronomy. The notebook for simulated strong gravitational lenses using lenstronomy is given in this repository. You are welcome to train the network on your own data provided it is in the same format. This is discussed in the README page.

In [1]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import time
from astropy.io import fits
import warnings
import glob
from astropy.wcs import WCS
from astropy.visualization import make_lupton_rgb

warnings.filterwarnings("ignore", module="matplotlib\..*")

In [2]:
# path to YOUR working directory
home = "/Users/ruby/Documents/Python Scripts/cGAN/Euclid-JWST/Data/Lens_Source/HighRes/"

# list of filters which is then split into SW filters and LW filters
filters = ['F115W/', 'F150W/', 'F200W/', 'F277W/', 'F356W/', 'F444W/']
# now the Euclid filters
filters_Euclid = ['VIS', 'NISP-J', 'NISP-Y', 'NISP-H']
nbands = len(filters)
n_filters = len(filters_Euclid)

The data that we simulated in the previous notebooks must be processed before feeding to the network. Since large pixel values will cause long training times, we choose to normalise the data appropriately across all 6 wavebands. To do this, we plot a histogram of each data file we previously extracted for each waveband and take the lower and upper $10\%$ percentiles.
The Euclid data is normalised in the same way.

In [3]:
def Filter_Percentiles(path, filtername):
    file_list = glob.glob(f"{path}/{filtername}/*.fits")
    bins = np.linspace(1e-9, np.log10(5), 100)
    hists = []
    for i in file_list:
        with fits.open(i) as fitsfile:
            img = fitsfile[0].data #removed np.log10
            hist, bins = np.histogram(img, bins=bins)
            hists.append(hist)
    hists_sum = np.sum(hists, axis=0)
    total = hists_sum.sum()
    lower = 0.
    upper = 0.
    for i in range(len(hists_sum)):
        lower += hists_sum[i]
        if lower > .1*total:
            percentile_lower = bins[i]
            break
    for j in range(len(hists_sum)):
        upper += hists_sum[-j]
        if upper > 0.1*total:
            percentile_upper = bins[-(j+1)]
            break
    return percentile_lower, percentile_upper, hists_sum
          


Now, we create a list to append the percentiles of each waveband to.

In [4]:
waveband_percentiles = []
for filter_ in filters:
    lower, upper, sum1 = Filter_Percentiles(home, filter_)
    waveband_percentiles.append([lower, upper])

Since *waveband_percentiles* contains the information of the lower and upper extreme percentiles of the wavebands in the order $F115W...F444W$, we can set the lower and upper percentiles for each waveband individually to use for the normalisation. 

In [5]:
f115w_lower, f115w_upper = waveband_percentiles[0][0], waveband_percentiles[0][1] 
f150w_lower, f150w_upper = waveband_percentiles[1][0], waveband_percentiles[1][1] 
f200w_lower, f200w_upper = waveband_percentiles[2][0], waveband_percentiles[2][1] 
f277w_lower, f277w_upper = waveband_percentiles[3][0], waveband_percentiles[3][1] 
f356w_lower, f356w_upper = waveband_percentiles[4][0], waveband_percentiles[4][1] 
f444w_lower, f444w_upper = waveband_percentiles[5][0], waveband_percentiles[5][1]

Before we create our dataset and dataloaders, both the train and test data need to be normalised for the model to train. The data are normalised using the lower and upper percentiles that we calculated above and will use whilst creating each dataset. The below function defines a Min-Max normalisation with the lower and upper bounds being the lower and upper percentiles, respectively. 

In [6]:
def Normalise(data, lower, upper):
    return ((data - lower)/ (upper - lower))

To extract actual predictions from the network and not data that have been normalised, we will have to inverse this normalisation. The following function is the rearranged equation of the normalisation function we have used above. We will use this after training/testing the network.

In [7]:
def Inverse(data, lower, upper):
    return (data * (upper - lower) + lower)

#### Dataset and Dataloaders ####
Here we create the dataset for our model. First, the data must be transformed to the correct size for the model (256x256). We do this because the network will take relatively small kernels sizes (of size $4\times4$) meaning that it has a small receptive field. Taking too large an image size will result in a narrow network which is difficult to train Additionally, the layers within the U-Net are easier to design when the images are a factor of 2 in size, thus resizing the inputs to $256=2^8$ allows for easy construction of the convolution layers.

The data in each waveband file are read, appended as an input for short wavebands (as a label for long wavebands) and normalised using the above function. Both the inputs and labels are transformed to a tensor of shape [(channel, height, width)] before being resized to 256x256. We then split the dataset into training and testing, using 90% of the data to train and 10% for testing, before creating the dataloaders. 

To test the network for predicting JWST $LW$ from the Euclid bands, we change the JWST $SW$ inputs in the *__getitem__* method to the Euclid VIS and NISP data. For the NISP data, we stack NISP-H, NISP-Y, NISP-J (RGB) as opposed to the VIS data which we load and stack with NISP-J and NISP-Y. 

In [12]:
SIZE = 256
# create the dataset which will be split into train and test
class FilterDataset(Dataset):
    def __init__(self, path):
        ''' path = path to directory containing fits files '''
        self.path = path
        self.transforms_inputs = torch.nn.Sequential(transforms.Resize((SIZE, SIZE)))
                                                     
        self.transforms_labels = torch.nn.Sequential(transforms.Resize((SIZE, SIZE)))
                                                    
        self.f115w_path = path+'F115W/'
        self.f150w_path = path+'F150W/'
        self.f200w_path = path+'F200W/'
        self.f277w_path = path+'F277W/'
        self.f356w_path = path+'F356W/'
        self.f444w_path = path+'F444W/'
        self.VIS_path = path+'VIS/'
        self.NISP_J_path = path+'NISP-J/'
        self.NISP_Y_path = path+'NISP-Y/'
        self.NISP_H_path = path+'NISP-H/'
        self.l1 = len(os.listdir(self.f115w_path)) - 1 
    
    def __len__(self):
        # return total number of fits files for the galaxy cutouts consistent
        # with the 'idx' in the __getitem__ method
        return (self.l1)
    
    def __getitem__(self, idx):
        # get the name of the fits file
        name = str(idx)+'_lens.fits'
        # for each input filter, get each fits file, open and extract the 
        # first row (only row which is 'SCI' data) and normalise before
        # formatting into an array
        hdu1 = fits.open(self.f115w_path+name)[0]
        data1 = Normalise(hdu1.data, f115w_lower, f115w_upper)
        hdu2 = fits.open(self.f150w_path+name)[0]
        data2 = Normalise(hdu2.data, f150w_lower, f150w_upper)
        hdu3 = fits.open(self.f200w_path+name)[0]
        data3 = Normalise(hdu3.data, f200w_lower, f200w_upper)
        # now the same for the label filters as    
        hdu4 = fits.open(self.f277w_path+name)[0]
        data4 = Normalise(hdu4.data, f277w_lower, f277w_upper)
        hdu5 = fits.open(self.f356w_path+name)[0]
        data5 = Normalise(hdu5.data, f356w_lower, f356w_upper)
        hdu6 = fits.open(self.f444w_path+name)[0]
        data6 = Normalise(hdu6.data, f444w_lower, f444w_upper)
        # now for the Euclid data - load this is when you need it
        #hdu_vis = fits.open(self.VIS_path+name)[0]
        #data_vis = Normalise(hdu_vis.data, f115w_lower, f115w_upper)
        #hdu_J = fits.open(self.NISP_J_path+name)[0]
        #data_J = Normalise(hdu_J.data, f150w_lower, f150w_upper)
        #hdu_Y = fits.open(self.NISP_Y_path+name)[0]
        #data_Y = Normalise(hdu_Y.data, f115w_lower, f115w_upper)
        #hdu_H = fits.open(self.NISP_H_path+name)[0]
        #data_H = Normalise(hdu_H.data, f200w_lower, f200w_upper)
        
        # stack the input filters (f115w, f150w, f200w)
        # change this for VIS and NISP bands in the same manner as below
        inputs = np.dstack((data1, data2, data3)).astype("float32")
        # reformat the inputs as tensors
        inputs = transforms.ToTensor()(inputs)
        # reshape the tensor to [C, H, W] for the transform to work
        inputs = inputs.permute(0, 1, 2)
        # now resize the inputs to 256x256
        inputs = self.transforms_inputs(inputs)
        
        # do the same for the labels
        labels = np.dstack((data4, data5, data6)).astype("float32")
        labels = transforms.ToTensor()(labels)
        labels = labels.permute(0, 1, 2)
        labels = self.transforms_labels(labels)
        
        # return the inputs with corresponding labels in a dictionary
        return {'Inputs': inputs, 'Labels': labels}

In [13]:
dataset = FilterDataset(path=home)
print(dataset)

<__main__.FilterDataset object at 0x1297235e0>


Now, we create the dataloaders, namely *trainloader* and *testloader*.

In [14]:
# split the generated dataset into training and testing 
BATCH_SIZE = 16                                     # set the batch size
VALIDATION_SPLIT = 0.1                              # set the validation split of 10%
SHUFFLE_DATASET = True                              # shuffle the training data only
RANDOM_SEED = 42                                    # randomly shuffle through indexed dataset

# create indices for training and test split
DATASET_SIZE = len(dataset)
# list the dataset with an index for each entry
indices = list(range(DATASET_SIZE))
# define the split for the dataset
split = int(np.floor(DATASET_SIZE * VALIDATION_SPLIT))
if SHUFFLE_DATASET:
    np.random.seed(RANDOM_SEED)
    np.random.shuffle(indices)
# split the dataset into training and testing 
train_indices, test_indices = indices[split:], indices[:split]

# create data samplers and dataloaders
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
# create dataloaders
trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
testloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
#print(len(trainloader), len(testloader)) 

Let's check the shape of the inputs and labels. They must be 3 channels and $256\times256$.

In [15]:
data = next(iter(trainloader))
inputs_, labels_ = data['Inputs'], data['Labels']
print(inputs_.shape, labels_.shape)

torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])


#### Generator Architecture ####
The following code implements a U-Net to be used as the Generator for the cGAN. It produces the U-Net from the middle part, down in the U shape, and adds down-sampling and up-sampling modules to the left and the right of the middle module, respectively, at every iteration until it reaches the input module and output module:
\
![U-Net](images/U-Net.png).
\
The blue boxes show the order in which the related modules are built. The U-Net shown in the following code has more layers than depicted above. In the code, we go 8 layers down, so, starting with a 256x256 input, we will get a 1x1 (256/2⁸) image in the middle of the U-Net which then gets up-sampled to produce a 256x256 image. 

In [16]:
# U-Net module
class UnetBlock(nn.Module):
    ''' U-Net is used as the generator of the GAN.
        Creates the U-Net from the middle part down and adds down-sampling and
        up-sampling modules to the left and right of the middle module.
        8 layers down so start with a 256x256 tensor with 3 channels, down-sample 
        to a 1x1 tensor, then up-sample to a 256x256 tensor with 3 channels. '''
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        ''' ni = number of filters in the inner convolution layer
            nf = number of filters in the outer convolution layer
            input_c = number of input channels (= 3)
            submodule = previously defined submodules
            dropout = not using dropout layers '''
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(in_channels=input_c, out_channels=ni, kernel_size=4, stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost: # if this module is the outermost module i.e downsampling
            upconv = nn.ConvTranspose2d(in_channels=ni*2, out_channels=nf, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost: # if this module is the innermost module, i.e upsampling
            upconv = nn.ConvTranspose2d(in_channels=ni, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(in_channels=ni*2, out_channels=nf, kernel_size=4, stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else: # add skip connections
            return torch.cat([x, self.model(x)], dim=1)

class Unet(nn.Module):
    ''' U-Net based generator.'''
    def __init__(self, input_c=3, output_c=3, n_down=8, num_filters=64):
        ''' input_c = number of input channels (= 3)
            output_c = number of output channels (= 3)
            n_down = number of downsamples: we start with 256x256 and after 
                                            8 layers, we have a 1x1 tensor at the bottleneck.
            num_filters = number of filters in the last convolution layer. '''
        super().__init__()
        unet_block = UnetBlock(num_filters*8, num_filters*8, innermost=True)
        for _ in range(n_down - 5):
            # adds intermediate layers with num_filters * 8 filters
            unet_block = UnetBlock(num_filters*8, num_filters*8, submodule=unet_block, dropout=True)
        out_filters = num_filters*8
        for _ in range(3):
            # gradually reduce the number of filters to num_filters
            unet_block = UnetBlock(out_filters//2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)

#### Discriminator Architecture ####
The following code describes the architecture of the Discriminator which implements a model by stacking blocks of Convolution - Batch Normalisation - Leaky ReLU to decide whether the input is real or fake. The first and last blocks do *not* use batch normalisation and the last block has *no* activation function (because the activation function will be embedded in the loss function we will use later).

In [17]:
# Using a Patch-Discriminator
class PatchDiscriminator(nn.Module):
    ''' Patch discriminator stacks blocks of convolution-batchnorm-leakyrelu 
        to decide whether the input tensor is real or fake. 
        Patch discriminator outputs one number for every NxN pixels of the input
        and decides whether each "patch" is real/fake. 
        Patches will be 70 by 70. '''
    def __init__(self, input_c, num_filters=64, n_down=3):
        ''' input_c = number of input channels (= 3)
            num_filters = number of filters in last convolution layer
            n_down = number of layers '''
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        # use if statement to take care of not using a stride of 2 in the last block of the loop
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i+1), s=1 if i == (n_down-1) else 2) for i in range(n_down)]
        # do not use normalisation or activation for the last layer of the model
        model += [self.get_layers(num_filters * 2 ** n_down, 3, s=1, norm=False, act=False)] # ouput 3 channel prediction
        self.model = nn.Sequential(*model)
    
    # make a separate method for the repetitive layers
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        ''' norm = batch norm layer
        act = apply activation '''
        layers = [nn.Conv2d(in_channels=ni, out_channels=nf, kernel_size=k, stride=s, padding=p, bias=not norm)]
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

Let's take a look at its blocks.

In [18]:
PatchDiscriminator(3)

PatchDiscriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (

And the shape of its output.

In [19]:
discriminator = PatchDiscriminator(3)
input_ = torch.randn(16, 3, 256, 256) # [Batch, Channels, Height, Width]
output = discriminator(input_)
output.shape

torch.Size([16, 3, 30, 30])

Notice that we are using a Patch Discriminator. What is a Patch Discriminator?
\
In a vanilla discriminator, the model outputs one number (a scalar), which represents how much the model thinks the input is real (or fake). In a patch discriminator, the model outputs one number for every patch of ~70x70 pixels of the input and for each of them, decides whether it is real (or fake), separately. Using such a model for this task is reasonable because the local changes that the model needs to make are important. Making a decision on the whole input regarding whether it is real or fake, as in a vanilla discriminator, cannot take care of the subtleties of this task. 

Here, the model's output shape is 30x30 but that does not mean that the patches are of size 30x30. The actual patch size is obtained when we compute the [_**receptive field**_](https://www.researchgate.net/figure/The-PatchGAN-discriminator-where-the-receptive-field-of-the-discriminator-is-N-N-Gz_fig5_336431839) of each of these 900 (30x30=900) output numbers, which will be 70 by 70 in this case. 

#### GAN Loss ####
We need to initiate the adversarial GAN loss for the final model. Below, in the **init** function, we decide what type of loss we will use ("vanilla") and we register some constant tensors as the "real" and "fake" labels, representing a tensor of all 1's or all 0's, respectively. It fills these tensors when we call the module and computes the loss.

In [20]:
# Unique loss function for the GAN 
class GANLoss(nn.Module):
    ''' Calculates the GAN loss of the final model.
        Uses a "vanilla" loss and registers constant tensors for the real
        and fake labels. Returns tensors full of zeros or ones to compute the loss'''
        
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer(name='real_label', tensor=torch.tensor(real_label))
        self.register_buffer(name='fake_label', tensor=torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss() # binary cross entropy loss
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss() 
        
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds) # expand to the same size as predictions
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

#### Model Initialisation ####
Here, we initialise the weights of the model with a mean, $µ=0$, and a standard deviation, $σ=0.02$. We also initialise the entire model by sending to model to the device (I'm using "cuda" but you can choose which device you wish to use) and initialising its weights.

In [21]:
# Initilise the weights of the model here
def Init_Weights(net, init='norm', gain=.02):
    ''' Image-to-image translation paper state that the model is initialised 
        with a mean of 0.0 and std 0.02'''
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                # fills tensor with values drawn from normal distribution N(mean,std^2)
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier': 
                # fills input tensor with avlues sampled from N(0,std^2)
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming': 
                # resulting tensor has values sampled from N(0,std^2)
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(tensor=m.bias.data, val=0.0) # tensor filled with zeros
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1.0, gain)
            nn.init.constant_(tensor=m.bias.data, val=0.0)
    
    net.apply(init_func)
    print(f"model initialised with {init} initialisation")
    return net

def Init_Model(model, device):
    model = model.to(device)
    model = Init_Weights(model)
    return model

#### The Model ####
The following code is a class that brings all the previous sections together and implements the methods required to train the model.

Firstly, in the **init** function, we define the Generator and Discriminator networks using the above classes that we defined and we initialise them using the **Init_Model** function above. We define the two loss functions that we have discussed and the optimisers of both the Generator and Discriminator (we use the [_**Adam**_](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) optimiser which is similar to applying the Gradient Descent Method). It is also worth noting the learning rate of both the Generator and Discriminator, which have a learning rate of *lr_G=lr_D=0.0002*. This learning rate is carefully selected; a learning rate that is too large might lead to a divergent solution, and a learning rate that is too small will take unneccessarily long and eventually end up in a local minimum. A learning rate of *0.0002* is sufficient.

The majority of the computations are done in the **optimise** method of this class. First, and only once per iteration (batch of the training set), we call the module's forward method and store the outputs in the *fake_fits* variable of the class. 

The Discriminator is trained first using the **backward_D** method, where we feed the *fake* data produced by the Generator to the Discriminator (we detach them from the Generator's graph to make sure they act as a constant to the Discriminator) and label the data as *fake*. A batch of *real* data from the training set is then fed to the Discriminator and labelled as *real*. The losses for the *fake* and *real* data is calculated and added togther, the average between the two taken, and the backward method called on the final loss. The Generator is then trained. In the **backward_G** method, the Discriminator is fed the *fake* data and we try to fool the Discriminator by assigning *real* labels and calculating the adversarial (GAN) loss. As previously mentioned, the L1 loss is also used to compute the distance between the predicted output and the target output, which is then multipled by the coefficient *λ* (where we have set λ=100) to balance the two losses before adding this loss to the adversarial loss. 

The backward method of the loss is finally called. 

In [22]:
# now to initialise the main GAN network
class GANModel(nn.Module):
    ''' Initialises the model defining the generator and discriminator in the
        __init__ function using the functions given and initialises the loss
        functions '''
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=.5, beta2=.999, lambda_L1=100.): 
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = Init_Model(Unet(input_c=3, output_c=3, n_down=8, num_filters=64), self.device)
        else:
            self.net_G.to(self.device)
        
        self.net_D = Init_Model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GAN_loss = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1_loss = nn.L1Loss()
        # initialise optimisers for generator and discriminator  
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1,beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1,beta2))
        # initialise empty lists to append the generator and discriminator losses to
        self.generator_losses, self.discriminator_losses = [], []
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        # Get the input data and labels
        self.inputs = data['Inputs'].to(self.device)
        self.labels = data['Labels'].to(self.device)
    
    def forward(self):
        # For each batch in the training set, forward method is called and
        # outputs stored in fake_fits variable
        self.fake_fits = self.net_G(self.inputs)
        
    def backward_D(self):
        ''' Discriminator loss takes both target and input images.
            loss_D_real is sigmoid cross-entropy loss of the target tensors and an array
            of ones. 
            loss_D_fake is sigmoid cross-entropy loss of the input tensors and an
            array of zeros.
            Discriminator loss is loss_D = loss_D_real + loss_D_fake. '''
        # Train the discriminator by feeding the fake images produced by the 
        # generator 
        fake_fits = self.fake_fits
        fake_preds = self.net_D(fake_fits.detach()) # detach from generator's graph so they act like constants
        # label the fake images as fake 
        self.loss_D_fake = self.GAN_loss(preds=fake_preds, target_is_real=False)
        # Now feed a batch of real images from the training set and label them as real
        real_fits = self.labels
        real_preds = self.net_D(real_fits)
        self.loss_D_real = self.GAN_loss(preds=real_preds, target_is_real=True)
        # Add the two losses for fake and real, take the average and call backward()
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * .5
        self.loss_D.backward()
        self.discriminator_losses += [self.loss_D.item()]
    
    def backward_G(self):
        ''' Generator loss is a sigmoid cross-entropy of input tensors and an 
            array of ones. Using the L1 loss, input tensors are structurally
            similar to the target tensors.
            Generator loss is defined as loss_G = loss_G_GAN + loss_G_L1*lambda_L1. '''
        # Train the generator by feeding the discriminator the fake fits data and 
        # fool it by assigning real labels and calculating adversarial loss.
        fake_fits = self.fake_fits
        fake_preds = self.net_D(fake_fits.detach())
        self.loss_G_GAN = self.GAN_loss(preds=fake_preds, target_is_real=True)
        # Use L1 loss so tensors are not averaged over and compute the 
        # difference between the prediction and real and multiply 
        # by constant lambda 
        self.loss_G_L1 = self.L1_loss(self.fake_fits, self.labels) * self.lambda_L1
        # Add L1 loss to the adversarial loss then call backward()
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
        self.generator_losses += [self.loss_G_GAN.item()]
        
    def optimise(self):
        # Now optimise by the usual method of zeroing the gradients and calling
        # step() on the optimiser
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

#### Helper Functions ####
Below are the functions used to help visualise how the losses of the model are updated and to illustrate the performance of the model. The losses are logged and the results are shown with the help of these useful functions.

In [23]:
class AverageMeter:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count
        
def Create_Loss_Meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

# Update losses after each epoch
def Update_Losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)
    return loss_meter 

# Plot the losses for both the Generator and Discriminator
def Loss_Plot(model, save=False):
    gen_loss = model.generator_losses
    dis_loss = model.discriminator_losses
    fig = plt.figure(figsize=(12,6))
    plt.plot(gen_loss, label='Generator Loss', color='red')
    plt.plot(dis_loss, label='Discriminator Loss', color='blue', linestyle='--')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    if save:
        fig.savefig(f"loss_{time.time()}.png")

Also plots for plotting the output of the cGAN as opposed to the true label and the input to the network.

In [24]:
def Visualise_Train(model, data, save=False):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_fits = model.fake_fits.detach()
    fake_fits /= torch.max(fake_fits)
    real_fits = model.labels
    real_fits /= torch.max(real_fits)
    inputs = model.inputs
    inputs /= torch.max(inputs)
    inputs = inputs.permute(0, 3, 2, 1)
    fake_fits = fake_fits.permute(0, 3, 2, 1)
    real_fits = real_fits.permute(0, 3, 2, 1)
    fig = plt.figure(figsize=(9,11))
    for i in range(2):
        ax = plt.subplot(3, 2, i+1)
        try:
            ax.imshow(inputs[i])
            ax.set_title(r"SW Channel [$0.6-2.3\mu$m]")
            ax.axis("off")
        except IndexError:
            continue
        ax = plt.subplot(3, 2, i+1+2)
        ax.imshow(fake_fits[i])
        ax.set_title("Generated LW")
        ax.axis("off")
        ax = plt.subplot(3, 2, i+1+4)
        ax.imshow(real_fits[i])
        ax.set_title(r"Actual LW Channel [$2.4-5.0\mu$m]")
        ax.axis("off")
    plt.tight_layout()
    plt.show()
    if save:
        fig.savefig("train.png")

def Visualise_Test(model, data, save=False):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.eval()
    fake_fits = model.fake_fits.detach()
    fake_fits /= torch.max(fake_fits)
    real_fits = model.labels
    inputs = model.inputs
    inputs /= torch.max(inputs)
    inputs = inputs.permute(0, 3, 2, 1)
    fake_fits = fake_fits.permute(0, 3, 2, 1)
    real_fits = real_fits.permute(0, 3, 2, 1)
    real_fits /= torch.max(real_fits)
    fig = plt.figure(figsize=(9,11))
    for i in range(2):
        ax = plt.subplot(3, 2, i+1)
        try:
            ax.imshow(inputs[i])
            ax.set_title(r"SW Channel [$0.6-2.3\mu$m]")
            ax.axis("off")
        except IndexError:
            continue
        ax = plt.subplot(3, 2, i+1+2)
        ax.imshow(fake_fits[i])
        ax.set_title("Generated LW")
        ax.axis("off")
        ax = plt.subplot(3, 2, i+1+4)
        ax.imshow(real_fits[i])
        ax.set_title(r"Actual LW Channel [$2.4-5.0\mu$m]")
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig("test2.png")

#### Training Function ####
The functions below train and test the network, respectively, feeding data from the training set (test set) for the network to learn features from. Here, we set the number of epochs for training. Training with 100 epochs is a recommendation, although, we already see results after 40 epochs. After each epoch, the weights are tuned further to the optimal weight for better model performance.


We've also defined a function to log the loss results.

In [25]:
def Log_Results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")
        
# now train the network, display epochs and losses
def Train_Model(model, trainloader, epochs, display_every=30):
    print("Starting training....")
    start = time.time()
    data = next(iter(trainloader)) # batch for visualising the model output after fixed intervals after training
    for e in range(epochs):
        # function returning a dictionary of objects to log the losses of the complete network
        loss_meter_dict = Create_Loss_Meters() 
        i = 0
        for data in tqdm(trainloader):
            model.setup_input(data)
            model.optimise()
            Update_Losses(model, loss_meter_dict, count=data['Inputs'].size(0)) # updates the log objects
            i += 1
        print(f"\nEpoch {e+1}/{epochs}")
        if i % display_every == 0: 
            print(f"Iteration {i}/{len(trainloader)}")
        total_loss = Log_Results(loss_meter_dict) # function prints out the losses
        print(total_loss)
    Loss_Plot(model, save=False)
    Accuracy_Plot(model, save=False)
    Visualise_Train(model, data)
    endtime = time.time()
    end = endtime - start
    print("Time to train network: {:.2f}s".format(end))

def Test_Model(model, testloader, epochs, display_every=30):
    print("Testing...")
    data = next(iter(testloader))
    for epoch in range(epochs):
        loss_meter_dict = Create_Loss_Meters()
        counter = 0
        for data in tqdm(testloader):
            model.setup_input(data)
            model.optimise()
            Update_Losses(model, loss_meter_dict, count=data['Inputs'].size(0))
            counter += 1
        print(f"\nEpoch {epoch+1}/{epochs}")
        if counter % display_every == 0:
            print(f"Iteration {counter}/{len(testloader)}")
        total_loss = Log_Results(loss_meter_dict)
        print(total_loss)
    Loss_Plot(model, save=False)
    Accuracy_Plot(model, save=False)
    Visualise_Test(model, data)
    print("Finished testing")

Each epoch takes between 3 to 4 minutes on a powerful GPU. The above code can be altered to test the model on the test set by setting the model to evaluation mode. An example of training the model is shown below:

In [26]:
model = GANModel()
#Train_Model(model, trainloader, epochs=100)

model initialised with norm initialisation
model initialised with norm initialisation


### Results ###
After training the cGAN for 100 epochs, the test set is used (the *Visualise_Test()* function) with 40 epochs to measure network performance. Below we see an output on the test set in the first plot. The plot below shows the input, generated data and true label for each individual NIRcam filter. 

![output3](images/Output3.png)


![bands output](images/Bands-Output.png)

As we can see, the model has understanding of the features in the data and some colourisation. It can clearly predict the lenses as seen in the $LW$ filters to a high accuracy, noting almost zero difference between the output and the true label. We are seeing the clear, visible arcs/rings/mulitple images and the lensing galaxy. Although, we must see that the model is learning features such as a mostly black background.
Nevertheless, the model is a good baseline model for predicting longer wavebands and can be further implemented to predict lenses in unseen data. 

### Euclid VIS to JWST NIRcam LW
Predicting $LW$ JWST NIRcam data given Euclid-VIS or Euclid-NISP data or a mixture of both would be a beneficial application of the cGAN. Now, the pixel resolution between the two Euclid instruments and JWST NIRcam is different, so we must expect a different output from before. It could be the case that predicting strong gravitational lenses as observed by JWST given the observation by Euclid is a method of anomaly detection- observations made by Euclid that are potential gravitational lenses could be proved as non-lenses as observed by JWST. Although this does not solve the problem of finding strong gravitational lenses, it is useful for improving the purity and recal of our strong lens finding methods. 

Below shows the output from the cGAN by testing on the test set when recreating the dataloaders in *FilterDataset* with the stacked Euclid VIS data:

![Results VIS](images/VIS_Output1.png)

Additionally, we can further explore the result by plotting the individual results from each filter with the difference image between the $LW$ prediction of a particular lens in the figure above and the true $LW$ of that lens shown below:

![bands output vis](images/Wavebands_VIS1.png)

In both plots, we see that the JWST NIRcam $LW$ output is much more resolved than the Euclid VIS input- unexpected since we are getting more information from the lenses as observed by JWST. Thus the cGAN is capable or interpreting the Euclid information and transversing that to a JWST pixel resolution. For reference, we provide the residual between the cGAN's prediction and the ground truth JWST $LW$:

![residual VIS](images/Residual_VIS.png)


Interestingly, again, we see similar results by predicting JWST NIRcam $LW$ information of the strong gravitational lenses with Euclid NISP information as input to the network. This is done in the same process as above by changing the input information to that of Euclid NISP in the *FilterDataset* class.

![results nisp](images/NISP_Output1.png)

Again, we can visualise the prediction of each individual JWST NIRcam $LW$ filter.

![nisp bands](images/Wavebands_NISP111.png)

with the residual between the prediction by the cGAN and the ground truth JWST data shown below:

![residual nisp](images/Residual_NISP1.png)

For more results as those shown above with each individual filter prediction, see the *images/* folder in this repository.

### Improved Purity and Recall
The below example shows a clear, red ring observed in the Euclid NISP band. This has an Einstein radius of $\approx 1.8"$. Both the prediction by the cGAN and the true JWST NIRcam $LW$ data show no such ring at all:

![anomaly](images/Anomaly.png)

To further visualise each individual band, we can plot both the prediction by the cGAN and the true information in each individual filter:

![anomaly bands](images/Anomaly_Bands.png)

Again, we see no such ring in any JWST NIRcam $LW$ filter, in either prediction or true label.
Although this is of no result of the cGAN, we can clearly see that a lens observed by Euclid is not observed by JWST. Again this not not aid in solving the problem of lens finding, but there is clear proof that this particular example is a potential non-lens. This could improve purity of our sample lenses.