### cGANs in Astronomy
Normalising the data properly!

In [None]:
# import packages
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import time
from astropy.io import fits
import warnings

warnings.filterwarnings("ignore", module="matplotlib\..*")

In [None]:
# Path to working directory (use directory where your data is saved)
home = "/Users/ruby/Documents/Python Scripts/Filters/"

Before we create our dataset and dataloaders, both the train and test data need to be normalised for the model to train. The data are normalised on the interval [0,1] using the below function. You may wish to change this using more efficient methods such as MinMaxScaler() from the sklearn.preprocessing package.

In [1]:
# normalise the values between 0 and 1
def Normalise(data, lower=0, upper=1):
    return ((data - data.min())/ (data.max() - data.min()))

Dataset and Dataloaders

Here we create the dataset for our model. First, the data must be transformed to the correct size for the model (256x256). The data in each waveband file are read, appended as an input for short wavebands (as a label for long wavebands) and normalised using the above function. Both the inputs and labels are transformed to a tensor of shape [(channel, height, width)] before being resized to 256x256. We then split the dataset into training and testing, using 90% of the data to train and 10% for testing, before creating the dataloaders.

In [None]:
SIZE = 256    

# create the dataset for the fits files
class FilterDataset(Dataset):
    def __init__(self, path):
        ''' path = path to directory containing fits files '''
        self.path = path
        self.transforms_inputs = torch.nn.Sequential(transforms.Resize((SIZE, SIZE), antialias=True))
                                                     
        self.transforms_labels = torch.nn.Sequential(transforms.Resize((SIZE, SIZE), antialias=True))
        self.data = []
        self.f115w_path = path+'F115W/'
        self.l1 = len(os.listdir(self.f115w_path)) - 1 # might want to remove -1
    
    def __len__(self):
        # return total number of fits files for the galaxy cutouts consistent
        # with the 'idx' in the __getitem__ method
        return (self.l1)
    
    def __getitem__(self, idx):
        # get the name of the fits file
        name = str(idx)+'.fits'
        # for each input filter, get each fits file, open and extract the 
        # first row (only row which is 'SCI' data) and normalise before
        # formatting into an array
        for i in range(nbands):
            root = self.path+filters[i]
            hdu = fits.open(root+name)[0]
            data_img = hdu.data
            self.data.append(data_img)
        data_array = np.array(self.data)
        data_norm = Normalise(data_array)
        data_norm_Nbands2 = np.split(data_norm, nbands)
        f115w = data_norm_Nbands2[0]
        f150w = data_norm_Nbands2[1]
        f200w = data_norm_Nbands2[2]
        f277w = data_norm_Nbands2[3]
        f356w = data_norm_Nbands2[4]
        f444w = data_norm_Nbands2[5]
        
        # stack the input filters (f115w, f150w, f200w)
        inputs = np.dstack((f115w, f150w, f200w)).astype("float32")
        # reformat the inputs as tensors
        inputs = transforms.ToTensor()(inputs)
        # reshape the tensor to [C, H, W] for the transform to work
        inputs = inputs.permute(0, 1, 2)
        # now resize the inputs to 256x256
        inputs = self.transforms_inputs(inputs)

        # do the same for the labels
        labels = np.dstack((f277w, f356w, f444w)).astype("float32")
        labels = transforms.ToTensor()(labels)
        labels = labels.permute(0, 1, 2)
        labels = self.transforms_labels(labels)

        # return the inputs with corresponding labels in a dictionary
        return {'Inputs': inputs, 'Labels': labels}

In [None]:
dataset = FilterDataset(path=home)
print(dataset)

In [None]:
# split the generated dataset into training and testing 
BATCH_SIZE = 16                                     # set the batch size
VALIDATION_SPLIT = 0.1                              # set the validation split of 10%
SHUFFLE_DATASET = True                              # shuffle the training data only
RANDOM_SEED = 42                                    # randomly shuffle through indexed dataset

# create indices for training and test split
DATASET_SIZE = len(dataset)
# list the dataset with an index for each entry
indices = list(range(DATASET_SIZE))
# define the split for the dataset
split = int(np.floor(DATASET_SIZE * VALIDATION_SPLIT))
if SHUFFLE_DATASET:
    np.random.seed(RANDOM_SEED)
    np.random.shuffle(indices)
# split the dataset into training and testing 
train_indices, test_indices = indices[split:], indices[:split]

# create data samplers and dataloaders
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
# create dataloaders
trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
testloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)
#print(len(trainloader), len(testloader)) 

In [None]:
data = next(iter(trainloader))
inputs_, labels_ = data['Inputs'], data['Labels']
# print(inputs_.shape, labels_.shape) = torch.Size([16, 270, 256, 256]) torch.Size([16, 270, 256, 256])
# crashes the kernel!!!

The rest of the model....