In [8]:
import random
import torch
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.table import Table
class DECAMDataset(Dataset):
    """
    DECAM dataset class used for pre-training the encoders for IQA.

    Args:
        root (string): root directory of the dataset
        patch_size (int): size of the patches to extract from the images
        max_distortions (int): maximum number of distortions to apply to the images
        num_levels (int): number of levels of distortion to apply to the images
        pristine_prob (float): probability of not distorting the images

    Returns:
        dictionary with keys:
            img_A_orig (Tensor): first view of the image pair
            img_A_ds (Tensor): downsampled version of the first view of the image pair (scale factor 2)
            img_B_orig (Tensor): second view of the image pair
            img_B_ds (Tensor): downsampled version of the second view of the image pair (scale factor 2)
            img_A_name (string): name of the image of the first view of the image pair
            img_B_name (string): name of the image of the second view of the image pair
            distortion_functions (list): list of the names of the distortion functions applied to the images
            distortion_values (list): list of the values of the distortion functions applied to the images
    """
    def __init__(self,
                 root: str,
                 patch_size: int = 224,
                 max_distortions: int = 4,
                 num_levels: int = 5,
                 pristine_prob: float = 0.05):

        root = Path(root)

        
        filenames_csv_path = "../../data/decam_dr10_good_exp.csv"
        exp_df = pd.read_csv(filenames_csv_path, header=None, names=["expnum"])

        # Path to FITS file containing image data
        file_path = "/global/cfs/cdirs/cosmo/work/legacysurvey/dr10/survey-ccds-decam-dr10.fits.gz"
        image_table = Table.read(file_path)

        # List to store paths of selected images
        self.ref_images = []
        self.hdu_numbers = []

        test = image_table[:400]
        idx = np.isin(test["expnum"], exp_df["expnum"])
        matched_exp = test[idx]
        self.ref_images =  matched_exp["image_filename"]
        self.hdu_numbers = matched_exp['image_hdu']


        # Convert paths to Path objects
        self.ref_images = [Path(root,path) for path in self.ref_images]

        self.patch_size = patch_size
        self.max_distortions = max_distortions
        self.num_levels = num_levels
        self.pristine_prob = pristine_prob

        assert 0 <= self.max_distortions <= 7, "The parameter max_distortions must be in the range [0, 7]"
        assert 1 <= self.num_levels <= 5, "The parameter num_levels must be in the range [1, 5]"
        
    def __getitem__(self, index: int) -> dict:
        print(index)
        
        img_A_path = self.ref_images[index]
        print(img_A_path)
        hdu_number = self.hdu_numbers[index]

        hdul_A = fits.open(img_A_path)
        img_A = hdul_A[hdu_number].data
        print(img_A)
        # Select another exposure randomly
        other_exp_index = np.random.choice(np.setdiff1d(range(len(self.ref_images)), [index]))
        img_B_path = self.ref_images[other_exp_index]
        hdul_B = fits.open(img_B_path)
        img_B = hdul_B[hdu_number].data
    

        img_A_orig = transforms.ToTensor()(img_A_orig)
        img_B_orig = transforms.ToTensor()(img_B_orig)

        distort_functions_A = []
        distort_values_A = []
        distort_functions_B = []
        distort_values_B = []


        return {
            "img_A_orig": img_A_orig,"img_B_orig": img_B_orig,
        
        }
    def __len__(self) -> int:
        return len(self.ref_images)


In [6]:
import argparse
import torch
from torch.utils.data import DataLoader
from pathlib import Path
import random
import os
import numpy as np



    # Initialize the training dataset and dataloader
train_dataset = DECAMDataset(root="/global/cfs/cdirs/cosmo/work/legacysurvey/dr10/images",
                                    patch_size=224,
                                    max_distortions=4,
                                    num_levels=5,
                                    pristine_prob=0.05)


In [9]:
train_dataloader = DECAMDataset.__getitem__(train_dataset, index=0)


0
/global/cfs/cdirs/cosmo/work/legacysurvey/dr10/images/decam/CP/V3.1.2/CP20140825/c4d_140825_015328_ooi_i_v1.fits.fz
[[1653.6238 1784.6875 1745.4961 ... 1323.6549 1579.623  1378.1066]
 [1214.6892 1520.8129 1483.4412 ... 1310.9122 1528.0323 1304.8257]
 [1123.5559 1336.0892 1332.3662 ... 1315.0166 1441.5266 1293.9194]
 ...
 [1091.2863 1503.143  1289.112  ... 1543.8789 1838.7279 1293.9272]
 [1133.6361 1455.5337 1198.4419 ... 1445.0701 1701.9904 1332.0646]
 [1690.9584 1507.9193 1295.4703 ... 1573.3124 1840.6239 1608.7803]]


UnboundLocalError: cannot access local variable 'img_A_orig' where it is not associated with a value