## Class for creating the data set and function to form PyTorch DataLoaders with the given data

In [2]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import optuna
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1 import make_axes_locatable


# This routine takes a set of maps and remove their monopole (i.e. average value)
def remove_monopole(maps, verbose=True):

    if verbose:  print('removing monopoles')

    # compute the mean of each map
    maps_mean = np.mean(maps, axis=(1,2), dtype=np.float64)

    # do a loop over all maps and remove mean value
    for i in range(maps.shape[0]):
          maps[i] = maps[i] - maps_mean[i]

    return maps


# This class creates the dataset. It will read the maps and store them in memory
# the rotations and flipings are done when calling the data 

class make_dataset2(Dataset):
    def __init__(self, mode, seed, fmaps, fparams, splits, fmaps_norm, 
                 monopole, monopole_norm, verbose):
        super().__init__()

        # getting the total number of simulations and maps
        # there are 1000 simulations and each simulation has 15 maps
        # we have selected some maps per simulations using 'splits'

        # loading SIMULATION parameters:
        params_sims = np.loadtxt(fparams)
        total_sims, total_maps, num_params = params_sims.shape[0], params_sims.shape[0]*splits, params_sims.shape[1]

        # initialising array for MAP parameters:
        params_maps = np.zeros((total_maps, num_params), dtype=np.float32)

        # loading the map parameters into the array:
        for i in range(total_sims):
            for j in range(splits):
                params_maps[i*splits + j] = params_sims[i]

        # normalizing the the cosmological & astrophysical parameters for each map (min-max)
        # total of 6 parameters (2 cosmological and 4 astrophysical)

        minimum     = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
        maximum     = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])
        params_maps = (params_maps - minimum)/(maximum - minimum)

        # get the size and offset depending on the type of dataset
        if   mode=='train':  offset, size_sims = int(0.00*total_sims), int(0.90*total_sims)
        elif mode=='valid':  offset, size_sims = int(0.90*total_sims), int(0.05*total_sims)
        elif mode=='test':   offset, size_sims = int(0.95*total_sims), int(0.05*total_sims)
        elif mode=='all':    offset, size_sims = int(0.00*total_sims), int(1.00*total_sims)
        else:                raise Exception('Wrong name!')

        # total size of maps is total size of simulations in the dataset mode (train/valid/test) multiplied by splits
        size_maps = size_sims*splits


        # randomly shuffle the simulations (not maps). Instead of 0 1 2 3...999 have a 
        # random permutation. E.g. 5 9 0 29...342
        np.random.seed(seed)
        sim_numbers = np.arange(total_sims) #shuffle sims not maps
        np.random.shuffle(sim_numbers)
        sim_numbers = sim_numbers[offset:offset+size_sims] #select indexes of mode
    
        # after shuffling the SIMULATIONS,getting the corresponding indexes of the MAPS associated to the simulations:
        indexes = np.zeros(size_maps, dtype=np.int32)
        count = 0
        for i in sim_numbers:
            for j in range(splits):
                indexes[count] = i*splits + j
                count += 1

        # using the parameters of the maps with the selected indices
        params_maps = params_maps[indexes]

        # loading the map data

        # length of the list is the number of channels
        # e.g., if there are say T and Mtot for IllustrisTNG, then it is a multifield map with 2 channels
        channels = len(fmaps)

        # loading the first map in fmaps list
        dumb = np.load(fmaps[0])    

        # height and width of the first map in fmaps list
        height, width = dumb.shape[1], dumb.shape[2]
        del dumb

        # initialising the data array which accounts for the number of channels as a separate axis
        data = np.zeros((size_maps, channels, height, width), dtype=np.float32)

        # actually reading the data
    
        print('Found %d channels\nReading data...'%channels)

        for channel, (fim, fnorm) in enumerate(zip(fmaps, fmaps_norm)):

            # read maps in the channel in the current channel
            data_c = np.load(fim)
            if data_c.shape[0]!=total_maps:
                raise Exception('sizes do not match')
            if verbose:
                print('%.3e < F(all|original) < %.3e'%(np.min(data_c), np.max(data_c)))


            # rescale maps (log scale)
            # replacing only the 0 value pixels with 1
            data_c = np.where(data_c !=0, data_c, 1)
            # scaling logarithmically whilst preserving the sign
            data_c = np.sign(data_c)*np.log10(np.abs(data_c))

            if verbose:
                print('%.3f < F(all|rescaled)  < %.3f'%(np.min(data_c), np.max(data_c)))

            # remove monopole of the images
            if monopole is False:
                data_c = remove_monopole(data_c, verbose)

            # normalize maps (mean,std)
            # fnorm contains information about normalising the data with respect to another dataset. E.g., training on TNG and testing on SIMBA
            if fnorm is None:  
                mean,std = np.mean(data_c), np.std(data_c)
                #minimum, maximum = np.min(data_c),  np.max(data_c)
            else:
                # read data
                data_norm = np.load(fnorm)

                # rescale
                data_norm = np.where(data_norm !=0, data_norm, 1)
                data_norm = np.sign(data_norm)*np.log10(np.abs(data_norm))

                # remove monopole
                if monopole_norm is False:
                    data_norm = remove_monopole(data_norm, verbose)

                # compute mean and std
                mean, std = np.mean(data_norm), np.std(data_norm)
                minimum, maximum = np.min(data_norm),  np.max(data_norm)

                #deleting data_norm from memory
                del data_norm

            data_c = (data_c - mean)/std
            if verbose:
                print('%.3f < F(all|normalized) < %.3f'%(np.min(data_c), np.max(data_c))) 

            # keep only the data of the chosen indices as the params
            # loading the data for each channel into the data array per channel axis

            data[:,channel,:,:] = data_c[indexes]

            if verbose:
                print('Channel %d contains %d maps'%(channel,size_maps))
                print('%.3f < F < %.3f'%(np.min(data_c), np.max(data_c)))

        self.size = data.shape[0]
        self.x    = torch.from_numpy(data)
        self.y    = torch.from_numpy(params_maps)
        del data, data_c, params_maps, params_sims

        print('{} dataset created!\n'.format(mode))
        
    def __len__(self):
        return self.size
        del self.size

    def __getitem__(self, idx):

        # choosing a rotation angle (0 = 0°, 1 = 90°, 2 = 180°, 3 = 270°)
        # and whether flipping is done or not
        rot  = np.random.randint(0,4)
        flip = np.random.randint(0,1)

        # rotate and flip the maps
        maps = torch.rot90(self.x[idx], k=rot, dims=[1,2])
        if flip==1:  maps = torch.flip(maps, dims=[1])

        return maps, self.y[idx]
        
        del maps, self.x[idx], self.y[idx], rot, flip

# This class creates the dataset. Rotations and flippings are done and stored
class make_dataset(Dataset):
    def __init__(self, mode, seed, fmaps, fparams, splits, fmaps_norm, 
                 monopole, monopole_norm, just_monopole, verbose):
        super().__init__()

        # getting the total number of simulations and maps
        # there are 1000 simulations and each simulation has 15 maps
        # we have selected some maps per simulations using 'splits'

        # loading SIMULATION parameters:
        params_sims = np.loadtxt(fparams)
        total_sims, total_maps, num_params = params_sims.shape[0], params_sims.shape[0]*splits, params_sims.shape[1]

        # initialising array for MAP parameters:
        params_maps = np.zeros((total_maps, num_params), dtype=np.float32)

        # loading the map parameters into the array:
        for i in range(total_sims):
            for j in range(splits):
                params_maps[i*splits + j] = params_sims[i]

        # normalizing the the cosmological & astrophysical parameters for each map (min-max)
        # total of 6 parameters (2 cosmological and 4 astrophysical)

        minimum     = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
        maximum     = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])
        params_maps = (params_maps - minimum)/(maximum - minimum)

        # get the size and offset depending on the type of dataset
        if   mode=='train':  offset, size_sims = int(0.00*total_sims), int(0.90*total_sims)
        elif mode=='valid':  offset, size_sims = int(0.90*total_sims), int(0.05*total_sims)
        elif mode=='test':   offset, size_sims = int(0.95*total_sims), int(0.05*total_sims)
        elif mode=='all':    offset, size_sims = int(0.00*total_sims), int(1.00*total_sims)
        else:                raise Exception('Wrong name!')

        # total size of maps is total size of simulations in the dataset mode (train/valid/test) multiplied by splits
        size_maps = size_sims*splits


        # randomly shuffle the simulations (not maps). Instead of 0 1 2 3...999 have a 
        # random permutation. E.g. 5 9 0 29...342
        np.random.seed(seed)
        sim_numbers = np.arange(total_sims) #shuffle sims not maps
        np.random.shuffle(sim_numbers)
        sim_numbers = sim_numbers[offset:offset+size_sims] #select indexes of mode
    
        # after shuffling the SIMULATIONS,getting the corresponding indexes of the MAPS associated to the simulations:
        indexes = np.zeros(size_maps, dtype=np.int32)
        count = 0
        for i in sim_numbers:
            for j in range(splits):
                indexes[count] = i*splits + j
                count += 1

        # using the parameters of the maps with the selected indices
        params_maps = params_maps[indexes]

        # loading the map data

        # length of the list is the number of channels
        # e.g., if there are say T and Mtot for IllustrisTNG, then it is a multifield map with 2 channels
        channels = len(fmaps)

        # loading the first map in fmaps list
        dumb = np.load(fmaps[0])    

        # height and width of the first map in fmaps list
        height, width = dumb.shape[1], dumb.shape[2]
        del dumb

        # initialising the data array which accounts for the number of channels as a separate axis
        data     = np.zeros((size_maps*8, channels, height, width), dtype=np.float32)
        params   = np.zeros((size_maps*8, num_params),              dtype=np.float32)

        # actually reading the data
    
        print('Found %d channels\nReading data...'%channels)

        for channel, (fim, fnorm) in enumerate(zip(fmaps, fmaps_norm)):

            # read maps in the channel in the current channel
            data_c = np.load(fim)
            if data_c.shape[0]!=total_maps:
                raise Exception('sizes do not match')
            if verbose:
                print('%.3e < F(all|original) < %.3e'%(np.min(data_c), np.max(data_c)))


            # rescale maps (log scale)
            # replacing only the 0 value pixels with 1
            data_c = np.where(data_c !=0, data_c, 1)
            # scaling logarithmically whilst preserving the sign
            data_c = np.sign(data_c)*np.log10(np.abs(data_c))

            if verbose:
                print('%.3f < F(all|rescaled)  < %.3f'%(np.min(data_c), np.max(data_c)))

            # remove monopole of the images
            if monopole is False:
                data_c = remove_monopole(data_c, verbose)

            # normalize maps (mean,std)
            # fnorm contains information about normalising the data with respect to another dataset. E.g., training on TNG and testing on SIMBA
            if fnorm is None:  
                mean,std = np.mean(data_c), np.std(data_c)
                minimum, maximum = np.min(data_c),  np.max(data_c)
            else:
                # read data
                data_norm = np.load(fnorm)

                # rescale
                data_norm = np.where(data_norm !=0, data_norm, 1)
                data_norm = np.sign(data_norm)*np.log10(np.abs(data_norm))

                # remove monopole
                if monopole_norm is False:
                    data_norm = remove_monopole(data_norm, verbose)

                # compute mean and std
                mean, std = np.mean(data_norm), np.std(data_norm)
                minimum, maximum = np.min(data_norm),  np.max(data_norm)

                #deleting data_norm from memory
                del data_norm

                # whether to make maps with the mean value in all pixels
                if just_monopole:
                    data_c = 10**(data_c)
                    mean_each_map = np.mean(data_c, axis=(1,2))
                    for i in range(data_c.shape[0]):
                        data_c[i] = mean_each_map[i]
                    data_c = np.log10(data_c)

            data_c = (data_c - mean)/std
            if verbose:
                print('%.3f < F(all|normalized) < %.3f'%(np.min(data_c), np.max(data_c))) 

            # keep only the data of the chosen indices as the params
            # loading the data for each channel into the data array per channel axis

            data_c = data_c[indexes]

            # do a loop over all rotations (each is 90 deg)
            counted_maps = 0
            for rot in [0,1,2,3]:
                data_rot = np.rot90(data_c, k=rot, axes=(1,2))

                data[counted_maps:counted_maps+size_maps,channel,:,:] = data_rot
                params[counted_maps:counted_maps+size_maps]           = params_maps
                counted_maps += size_maps

                data[counted_maps:counted_maps+size_maps,channel,:,:] = \
                                                    np.flip(data_rot, axis=1)
                params[counted_maps:counted_maps+size_maps]           = params_maps
                counted_maps += size_maps

                del data_rot

            if verbose:
                print('Channel %d contains %d maps'%(channel,counted_maps))
                print('%.3f < F < %.3f'%(np.min(data_c), np.max(data_c)))

        del data_c, params_maps, params_sims

        self.size = data.shape[0]
        self.x    = torch.from_numpy(data)
        self.y    = torch.from_numpy(params)

        del data, params

        print('{} dataset created!\n'.format(mode))
        
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        
        return self.x[idx], self.y[idx]
        del self.x[idx], self.y[idx]




def create_dataloader(mode, seed, fmaps, fparams, batch_size, splits, 
                      fmaps_norm, monopole=True, monopole_norm=True,
                      rot_flip_in_mem=True, shuffle=True, 
                      just_monopole=False, verbose=False):

    # whether rotations and flippings are kept in memory
    if rot_flip_in_mem:
        data_set = make_dataset(mode, seed, fmaps, fparams, splits, 
                                fmaps_norm, monopole, monopole_norm, 
                                just_monopole, verbose)
    else:
        data_set = make_dataset2(mode, seed, fmaps, fparams, splits, 
                                fmaps_norm, monopole, monopole_norm, 
                                verbose)

    data_loader = DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)

    return data_loader
