# FITS Data Loading

In order to load fits data, let's use this custom FITSDataLoader. This code is taken from https://github.com/aritraghsh09/GaMPEN/blob/master/ggt/data/dataset.py and https://github.com/amritrau/fitsdataset/blob/master/fitsdataset/dataset.py. Please cite these two repos if you use any part of the following code.

The custom FITSDataset can transform a folder of `.fits` files and an associated `.csv` into a PyTorch DataLoader. This makes it very easy for these images to be then used for training/testing any model in PyTorch.

Before running this, in order to install most library dependencies that this notebook needs run the following commands in a terminal 
```bash
pip install fitsdataset
pip install tqdm
```

The FITSDataSet code creates a dataset similar to the MNIST dataset we had created in Problem #2 of PSet 10. Using this dataset, you can easily create a PyTorch dataloaders for your problem.

The various options for `FITSDataSet` are summarized below:-

* `data_dir` -- You need to create a directory to store all the data for your project. This argument should be the full path to that directory.

    This data directory should also have a file called `info.csv` which at the very least has a column called `file_name` with the names of the various fits files, and the target column which you are trying to predict. 

    This data directory should also have a folder called `cutouts` containing all the images for your project.

* `channels` -- The number of channels/layers in your input images. 

* `cutout_size` -- The height/width of your input images. Your images are assumed to be square. So, only a single integer is allowed as an input. 

* `label_col` -- The name of the column in `info.csv` that you are trying to predict.

* `normalize` -- This can be set to True or False. Setting this to True applies an arsinh transformation on your input images. This is often helpful for training CNNs.


* `expand_factor` -- Artificially expand the size of your dataset. If you want to use this option, you should also set pass on a set of random transformations to the `transform` argument. 

* `transform` -- If you want to apply additional transformations to your input image, pass them here. For example to randomly transform your image (while using `expand factor`),

    ```python
    import kornia.augmentation as K

    T = nn.Sequential(
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),
            K.RandomRotation(360),
        )
    ```

    then you can pass the above `T` variable to the transform argument.

    If you want to exploit the `transform` variable only for cropping you can also set this to

    ```python
    import kornia.augmentation as K

    T = nn.Sequential(
           K.CenterCrop(143), #this will crop the images to 143X143 pixels
        )
    ```

    You can also combine the cropping and random transformations above if you want to.


* `repeat_dims` -- If you want to artificially make your images have more than one channel (i.e., copy the same image to as many channels you mentioned in `channels`); set this to `True`.

* `load_labels` -- This should be set to `True` unless you don't happen to have the labels for a test set. 


In [8]:
from astropy.io import fits
import numpy as np
from functools import partial
from pathlib import Path
from tqdm import tqdm
import pandas as pd

import torch
from torch.utils.data import Dataset
import torch.multiprocessing as mp


mp.set_sharing_strategy("file_system")

def arsinh_normalize(X):
    """Normalize a Torch tensor with arsinh."""
    return torch.log(X + (X ** 2 + 1) ** 0.5)


def load_tensor(filename, tensors_path, as_numpy=True):
    """Load a Torch tensor from disk."""
    return torch.load(tensors_path / (filename + ".pt")).numpy()


class FITSDataset(Dataset):
    """Dataset from FITS files. Pre-caches FITS files as PyTorch tensors to
    improve data load speed."""

    def __init__(
        self,
        data_dir,
        channels=1,
        cutout_size=167,
        label_col="bt_g",
        normalize=True,
        transform=None,
        expand_factor=1,
        repeat_dims=False,
        load_labels=True,
    ):

        # Set data directories
        self.data_dir = Path(data_dir)

        # Set cutouts shape
        self.cutout_shape = (channels, cutout_size, cutout_size)

        # Set requested transforms
        self.normalize = normalize
        self.transform = transform
        self.repeat_dims = repeat_dims

        # Set data expansion factor (must be an int and >= 1)
        self.expand_factor = expand_factor

        ## Read the catalog CSV file
        catalog = self.data_dir / "info.csv"

        # Define paths
        self.data_info = pd.read_csv(catalog)
        self.cutouts_path = self.data_dir / "cutouts"
        self.tensors_path = self.data_dir / "tensors"
        self.tensors_path.mkdir(parents=True, exist_ok=True)

        # Retrieve labels & filenames
        if load_labels:
            self.labels = np.asarray(self.data_info[label_col])
        else:
            # generate fake labels of appropriate shape
            self.labels = np.ones((len(self.data_info), len(label_col)))

        self.filenames = np.asarray(self.data_info["file_name"])


        # If we haven't already generated PyTorch tensor files, generate them
        print("Generating PyTorch tensors from FITS files...")
        for filename in tqdm(self.filenames):
            filepath = self.tensors_path / (filename + ".pt")
            if not filepath.is_file():
                load_path = self.cutouts_path / filename
                t = FITSDataset.load_fits_as_tensor(load_path)
                torch.save(t, filepath)

        # Preload the tensors
        n = len(self.filenames)
        print(f"Preloading {n} tensors...")
        load_fn = partial(load_tensor, tensors_path=self.tensors_path)
        with mp.Pool(mp.cpu_count()) as p:
            # Load to NumPy, then convert to PyTorch (hack to solve system
            # issue with multiprocessing + PyTorch tensors)
            self.observations = list(
                tqdm(p.imap(load_fn, self.filenames), total=n)
            )
        self.observations = [torch.from_numpy(x) for x in self.observations]

    def __getitem__(self, index):
        """Magic method to index into the dataset."""
        if isinstance(index, slice):
            start, stop, step = index.indices(len(self))
            return [self[i] for i in range(start, stop, step)]
        elif isinstance(index, int):
            # Load image as tensor ("wrap around")
            X = self.observations[index % len(self.observations)]

            # Get image label ("wrap around"; make sure to cast to float!)
            y = torch.tensor(self.labels[index % len(self.labels)])
            y = y.float()

            # Normalize if necessary
            if self.normalize:
                X = arsinh_normalize(X)  # arsinh

            # Transform and reshape X
            if self.transform:
                X = self.transform(X)

            # Repeat dimensions along the channels axis
            if self.repeat_dims:
                if not self.transform:
                    X = X.unsqueeze(0)
                    X = X.repeat(self.cutout_shape[0], 1, 1)
                else:
                    X = X.repeat(1, self.cutout_shape[0], 1, 1)

            X = X.view(self.cutout_shape).float()

            # Return X, y
            return X, y
        elif isinstance(index, tuple):
            raise NotImplementedError("Tuple as index")
        else:
            raise TypeError("Invalid argument type: {}".format(type(index)))

    def __len__(self):
        """Return the effective length of the dataset."""
        return len(self.labels) * self.expand_factor

    @staticmethod
    def load_fits_as_tensor(filename):
        """Open a FITS file and convert it to a Torch tensor."""
        fits_np = fits.getdata(filename, memmap=False)
        return torch.from_numpy(fits_np.astype(np.float32))

Example usage below:-

In [5]:
dataset = FITSDataset(data_dir="/home/ag2422/project/git_repos/class-materials/yale-phys378/final-projects/experiments",
                     cutout_size=239,
                     label_col = "R_e",
                     normalize=True)

Generating PyTorch tensors from FITS files...


100%|██████████| 30000/30000 [00:00<00:00, 61057.67it/s]


Preloading 30000 tensors...


100%|██████████| 30000/30000 [00:16<00:00, 1848.07it/s]


Now within this `dataset` variable, note that 
* `dataset[0]` will be a tuple containing the image and the label for the 0th image
    * `dataset[0][0]` --> the image
    * `dataset[0][1]` --> the label 

And so on.......

To pass this dataset onto a PyTorch DataLoader you should be able to use (for example)

```python
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset, batch_size=128, shuffle=True)
```


## Warning
With the line `y = y.float()` in `FITSDataSet`, the label is being made into a float. Note that you will need to alter this line of you are trying to do a classification problem, where for e.g., your classes could be labelled as `0`,`1`,`2`, etc.