In [12]:
from astropy.table import Table

In [2]:
data = Table.read('/arc/home/aydanmckay/lamost8_gaia3.fits')

In [3]:
# data

In [4]:
from gaiaxpy import calibrate

In [5]:
f = '/arc/home/aydanmckay/XpContinuousMeanSpectrum_407725-409897.csv'

In [6]:
calibrated_spectra, sampling = calibrate(f)
# calibrated_spectra

                              

In [13]:
# importing the required libraries
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
  
# defining the Dataset class
class BasicSet(Dataset):
    def __init__(self):
        numbers = list(range(0, 100, 1))
        self.data = numbers
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]
basicset = BasicSet()
  
# implementing dataloader on the dataset and printing per batch
basicdataloader = DataLoader(basicset, batch_size=10, shuffle=True)
for i, batch in enumerate(basicdataloader):
    print(i, batch)

0 tensor([48, 52,  1, 76, 90, 98, 79, 80, 54, 85])
1 tensor([57, 86, 69, 44, 91, 60, 71, 75, 34, 14])
2 tensor([55, 26, 78, 36, 68, 51, 92, 59, 24, 12])
3 tensor([39, 19, 11, 99, 82, 72, 25, 87, 21, 27])
4 tensor([66, 89, 70, 23, 95, 32, 46, 29, 35, 93])
5 tensor([47, 13, 94, 77, 65, 20, 16,  5, 43, 37])
6 tensor([56, 73,  3,  8, 38, 97, 58, 18,  6, 17])
7 tensor([40, 84, 53, 50, 15, 28, 61, 49, 67,  7])
8 tensor([63, 88,  9, 64, 22, 30, 83, 42,  4, 74])
9 tensor([45,  2, 33, 81,  0, 96, 10, 41, 31, 62])


In [14]:
# for i in dataloader:
#     print(i)

In [15]:
# dataloader.batch_size

In [30]:
import h5py
import numpy as np
from pathlib import Path
from torch.utils import data

class HDF5Dataset(data.Dataset):
    """Represents an abstract HDF5 dataset.
    
    Input params:
        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
        recursive: If True, searches for h5 files in subdirectories.
        load_data: If True, loads all the data immediately into RAM. Use this if
            the dataset is fits into memory. Otherwise, leave this at false and 
            the data will load lazily.
        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
        transform: PyTorch transform to apply to every data instance (default=None).
    """
    def __init__(self, file_path, recursive, load_data, data_cache_size=3, transform=None):
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size
        self.transform = transform

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())
        if recursive:
            files = sorted(p.glob('**/*.h5'))
        else:
            files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
            
    def __getitem__(self, index):
        # get data
        x = self.get_data("data", index)
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)

        # get label
        y = self.get_data("label", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('data'))
    
    def _add_data_infos(self, file_path, load_data):
        with h5py.File(file_path) as h5_file:
            print(h5_file)
            # Walk through all groups, extracting datasets
            for gname, group in h5_file.items():
                print(gname,group)
                for dname, ds in group.items():
                    print(dname)
                    # if data is not loaded its cache index is -1
                    idx = -1
                    if load_data:
                        # add data to the data cache
                        idx = self._add_to_cache(ds.value, file_path)
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    print(ds)
                    # self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.shape, 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    idx = self._add_to_cache(ds.value, file_path)

                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [33]:
num_epochs = 50
loader_params = {'batch_size': 100, 'shuffle': True, 'num_workers': 6}

dataset = HDF5Dataset('/arc/home/aydanmckay', recursive=False, load_data=False, 
   data_cache_size=4, transform=None)

data_loader = data.DataLoader(dataset, **loader_params)

for i in range(num_epochs):
    for x,y in data_loader:
        print(x,y)
        break
    pass

<HDF5 file "mydata.h5" (mode r)>
group_1 <HDF5 group "/group_1" (2 members)>
data
<HDF5 dataset "data": shape (3, 578275), type "|V12">
label
<HDF5 dataset "label": shape (110, 578275), type "<f8">
group_2 <HDF5 group "/group_2" (2 members)>
data
<HDF5 dataset "data": shape (3, 64253), type "|V12">
label
<HDF5 dataset "label": shape (110, 64253), type "<f8">


AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_124/1725768936.py", line 40, in __getitem__
    x = self.get_data("data", index)
  File "/tmp/ipykernel_124/1725768936.py", line 125, in get_data
    self._load_data(fp)
  File "/tmp/ipykernel_124/1725768936.py", line 85, in _load_data
    idx = self._add_to_cache(ds.value, file_path)
AttributeError: 'Dataset' object has no attribute 'value'


In [18]:
import pandas as pd
    
# defining the Dataset class
class data_set(Dataset):
    def __init__(self,file):
        datum = Table.read(file)
        self.data = datum
  
    def __len__(self):
        return len(self.data)
  
    def __getitem__(self, index):
        return self.data[index]

In [21]:
dataset = data_set('/arc/home/aydanmckay/gaiahike/bp_rp_lamost.fits')

len(dataset)
# implementing dataloader on the dataset and printing per batch
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for i, batch in enumerate(dataloader):
    print(i, batch)
    break
# So this works the way that it should be working! at least for fits files.

0 [tensor([-0.3640, -0.0110]), tensor([0.0690, 0.0694]), tensor([12.4963, 12.7018]), tensor([6330331376789732224, 6330335530022502784]), tensor([5302.0200, 6069.3301]), tensor([4.4990, 4.2070]), tensor([219.1310, 219.0966], dtype=torch.float64), tensor([-7.6277, -7.5659], dtype=torch.float64), tensor([12029.7314, 10450.9819], dtype=torch.float64), tensor([-1411.8249, -1930.3344], dtype=torch.float64), tensor([-219.1894,   35.6158], dtype=torch.float64), tensor([129.3349,  52.6922], dtype=torch.float64), tensor([-68.3525, -62.8730], dtype=torch.float64), tensor([-86.3349, -19.8526], dtype=torch.float64), tensor([-56.4289, -28.1879], dtype=torch.float64), tensor([37.5401, 21.8474], dtype=torch.float64), tensor([-42.2532, -17.0194], dtype=torch.float64), tensor([-5.6693,  8.7219], dtype=torch.float64), tensor([4.0050, 6.2666], dtype=torch.float64), tensor([1.2507, 9.6351], dtype=torch.float64), tensor([-13.1181,   8.2519], dtype=torch.float64), tensor([-5.8534, -3.2109], dtype=torch.float

In [20]:
dataset[:2]

feh,ebv,phot_g_mean_mag,source_id,teff,logg,ra,dec,bp_1,bp_2,bp_3,bp_4,bp_5,bp_6,bp_7,bp_8,bp_9,bp_10,bp_11,bp_12,bp_13,bp_14,bp_15,bp_16,bp_17,bp_18,bp_19,bp_20,bp_21,bp_22,bp_23,bp_24,bp_25,bp_26,bp_27,bp_28,bp_29,bp_30,bp_31,bp_32,bp_33,bp_34,bp_35,bp_36,bp_37,bp_38,bp_39,bp_40,bp_41,bp_42,bp_43,bp_44,bp_45,bp_46,bp_47,bp_48,bp_49,bp_50,bp_51,bp_52,bp_53,bp_54,bp_55,rp_1,rp_2,rp_3,rp_4,rp_5,rp_6,rp_7,rp_8,rp_9,rp_10,rp_11,rp_12,rp_13,rp_14,rp_15,rp_16,rp_17,rp_18,rp_19,rp_20,rp_21,rp_22,rp_23,rp_24,rp_25,rp_26,rp_27,rp_28,rp_29,rp_30,rp_31,rp_32,rp_33,rp_34,rp_35,rp_36,rp_37,rp_38,rp_39,rp_40,rp_41,rp_42,rp_43,rp_44,rp_45,rp_46,rp_47,rp_48,rp_49,rp_50,rp_51,rp_52,rp_53,rp_54,rp_55,bpe_1,bpe_2,bpe_3,bpe_4,bpe_5,bpe_6,bpe_7,bpe_8,bpe_9,bpe_10,bpe_11,bpe_12,bpe_13,bpe_14,bpe_15,bpe_16,bpe_17,bpe_18,bpe_19,bpe_20,bpe_21,bpe_22,bpe_23,bpe_24,bpe_25,bpe_26,bpe_27,bpe_28,bpe_29,bpe_30,bpe_31,bpe_32,bpe_33,bpe_34,bpe_35,bpe_36,bpe_37,bpe_38,bpe_39,bpe_40,bpe_41,bpe_42,bpe_43,bpe_44,bpe_45,bpe_46,bpe_47,bpe_48,bpe_49,bpe_50,bpe_51,bpe_52,bpe_53,bpe_54,bpe_55,rpe_1,rpe_2,rpe_3,rpe_4,rpe_5,rpe_6,rpe_7,rpe_8,rpe_9,rpe_10,rpe_11,rpe_12,rpe_13,rpe_14,rpe_15,rpe_16,rpe_17,rpe_18,rpe_19,rpe_20,rpe_21,rpe_22,rpe_23,rpe_24,rpe_25,rpe_26,rpe_27,rpe_28,rpe_29,rpe_30,rpe_31,rpe_32,rpe_33,rpe_34,rpe_35,rpe_36,rpe_37,rpe_38,rpe_39,rpe_40,rpe_41,rpe_42,rpe_43,rpe_44,rpe_45,rpe_46,rpe_47,rpe_48,rpe_49,rpe_50,rpe_51,rpe_52,rpe_53,rpe_54,rpe_55
float32,float32,float32,int64,float32,float32,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32,float32
-0.364,0.069047,12.496257,6330331376789732224,5302.02,4.499,219.131024159976,-7.627747806774632,12029.73135517054,-1411.8248991274577,-219.18935203560088,129.3348848741412,-68.352525797212,-86.33485312462123,-56.4288840072494,37.54006039545181,-42.253209717230725,-5.669295948772287,4.004988672031573,1.250717833021587,-13.118147685652408,-5.85335558891185,2.9191199559863232,-5.965406814068504,5.177502148362913,-1.5113064835079113,-18.69389942111117,-4.379951794261879,-2.5716124742393425,30.66784215809901,2.876339656676244,6.454842594753531,-0.61580096904663,-1.7317637586582897,-2.588875180594774,-4.735372993232711,2.723404584233639,-1.1091093260693008,-3.756056828804223,2.890264768375606,2.5824034496278974,-3.497321617526314,5.589937985732052,-0.7089629513504392,-0.0718246189054392,-1.4329553488097782,-0.1144560616810339,-1.4060874340832676,-1.5372627056743755,0.8771477849850234,-1.7386455707367487,1.4631076869393875,-1.328924177816735,-3.274643706779776,0.4116611148617004,-1.0174916596287713,-3.738424779524977,-0.6897691817834506,-1.4106974343514804,-0.7303907037970297,1.306858843175598,0.0786983391005103,-0.1582966919415726,13046.33144200603,-1641.1659606254582,-72.79720741636,-2.677841493543232,10.055829317762027,-20.143786078792143,-31.05569621899466,13.881915332812644,-12.743859650291846,-0.3267925620378819,-3.2582235331732914,2.558959605270238,1.5825889322962132,-0.0240732849839234,-7.55396995370354,-4.283636898855539,-3.767099542506233,-4.618424796668488,-0.3826847544621197,-1.577114815999789,2.8324680315599533,2.280628581999916,-0.3554105126805091,-0.5694349607244443,1.8658913680483216,1.644453718653922,-5.809488718174509,-11.63108364139222,3.0483466373117345,5.633366560349695,4.401141964353226,2.559657930312138,0.6903658363187258,-3.0978335446019725,5.73198818206075,-2.0008666025375224,0.2465886085468562,-1.56444414310446,1.6276078835139491,4.030101009558468,2.187565694786453,-0.4969029627920216,-0.1314582528368399,2.9290008561937504,-0.3511891745659584,0.042714171153479,-0.4795971956055125,-0.0047541227767506,-0.4630688994679037,-0.0370378120982881,1.1114843293653611,-0.1675854429076297,0.6724553206073789,-0.0576367933700185,-0.2589425153329447,8.561699,7.358293,7.6541004,7.633931,7.9895453,6.338178,6.002439,11.029905,6.072579,6.942428,7.5021367,6.873598,7.094597,5.0050745,5.3846564,6.5940742,6.068771,5.204529,5.515418,3.6735656,6.363709,4.426275,5.4215794,4.7866,5.1068697,4.6616454,3.692223,2.315076,3.8289065,2.5999148,2.931904,2.8181386,1.950434,2.7553375,1.9909607,2.217723,1.9788728,1.9711441,2.0606232,2.1406999,1.7199641,2.1398177,2.668717,1.6198235,2.0461445,2.6404948,2.8194392,1.577793,1.5943475,1.3191528,1.2595327,0.65944815,0.93880457,0.46021208,0.18836597,4.0155005,3.766243,3.701974,4.0720973,3.7484212,3.9839602,3.883636,3.6255498,3.544636,3.8870735,3.2846787,3.8213787,2.6216357,3.7501376,3.6374953,3.8894465,3.6576526,3.3956966,3.4787664,3.791071,2.994246,3.632406,3.2658532,3.1895466,3.5184424,3.3933523,3.4710796,3.163262,2.8702621,2.6644864,2.8838375,2.3235414,2.7359493,2.2942936,2.2069607,2.0608687,1.9893657,1.8448502,1.6733536,1.4514717,1.3412256,1.424264,1.6948938,1.3802388,1.3380924,1.3827865,1.1668892,1.0986447,1.2063049,1.1319106,0.95395327,0.8091774,0.7472917,0.31202096,0.15312435
-0.011,0.069432,12.701829,6330335530022502784,6069.33,4.207,219.0966001120319,-7.565927097484413,10450.981910560537,-1930.3344105529316,35.615812818200475,52.69218574703245,-62.87298551956982,-19.852564120676377,-28.187943536664285,21.847374723681234,-17.019447673512392,8.721890502290824,6.2665831801422405,9.635065798618005,8.251927581655977,-3.2108990470866963,0.8044745243967977,-0.9209852324015668,6.246545953406075,-1.917701623417716,2.1844547136362635,-12.634269744319026,-4.075712710813704,8.456264680259682,9.76001021680542,1.595993475407398,4.061371526016699,0.9573741379844802,-0.6469127390911892,2.8862223856838938,-3.181575704343859,1.6490460456530285,1.701258841142636,1.7298380979490655,-2.768455832477895,-3.4989954072216585,1.2041100360184671,4.02940177850777,-2.976214229960119,1.2931910873327723,-2.3860109425045555,-1.5157774168240885,2.034453540951048,-1.852978006990836,-1.7151643668441203,1.3810026711899197,-1.151165570409796,-0.207366870760382,1.6770265070899069,-1.7181508147575446,-0.5781798442961438,-0.2005066729617539,-0.2783862389047136,-0.3424224378605865,1.0000986341497642,0.1399350782839868,-0.3752983926383662,9752.04893944838,-1551.7652900258624,-112.796791292405,-48.47676802655127,0.9451947898401424,1.8597990316997817,-21.87189424262403,-4.527711885014163,-10.902243439742335,-1.4595328861673846,7.382360766125392,1.3011996091072986,-1.8510208958874892,1.9654346332252493,-6.066745024190991,-3.935317431344419,-0.5009986255532003,3.5787569856699264,-0.9511405520364408,-0.9789355691436454,-1.0947417711533278,-0.4792429176340852,-1.1984403390121705,1.7491379838799168,-0.1876362003998214,-3.907016148195573,1.3100022783810066,1.3540556254385032,-2.8978221213551847,-2.503115416540165,-4.048710745442849,1.5433298224994128,0.9900319371514414,-3.3453990080322416,1.8936067263130856,1.269586483125818,0.0285527495216979,0.6667782707861785,-0.797214142170213,-0.0882670164914845,-0.5365249641330267,1.361111312173033,0.4001007791042277,-0.4611911634788381,1.5020961761438931,-2.0104611607766154,0.2945297332682139,1.0826764480460125,0.3321117093049603,0.3157423658021551,-0.1732299211406029,0.5461690846164245,-0.0422086804084448,-0.0164574428426962,-0.244732106145499,8.447972,7.0834284,7.5440526,7.565726,7.7916765,6.104835,5.758369,11.053668,6.1651845,6.7502637,7.416419,7.052223,7.148545,5.0879917,5.8355975,6.441967,5.923933,4.9942284,5.475122,3.8486683,6.1385727,4.6468277,5.2434373,4.791053,4.9283032,4.572143,3.7644968,2.4040082,3.7863429,2.517023,3.019513,3.1905012,2.1015317,2.8211997,1.970466,2.2226217,2.0188181,2.0939827,2.0704157,2.290988,1.7165297,2.1951559,2.8740735,1.5214671,2.0076058,3.427612,3.3711236,1.927568,2.3196056,1.7015835,1.268457,0.58226115,1.1152003,0.40442988,0.17098732,4.228244,3.7142632,3.6580877,4.2376585,3.7133286,3.9989412,4.0074525,3.5329695,3.5416896,3.9283783,3.2082567,3.8726315,2.442821,3.7463374,3.507902,3.990694,3.488716,3.0971622,3.0333781,3.788047,2.586261,3.617715,3.1034687,3.0581264,3.4357011,3.2590563,3.2022147,2.865373,2.4829512,2.3784854,2.5526812,2.00205,2.446931,1.9701358,1.8695737,1.7675664,1.6795605,1.5981243,1.392807,1.2006375,1.0248477,1.1394763,1.3978342,1.0637425,1.0956715,1.1603894,0.917158,1.0857401,0.85240316,0.9193995,0.9409097,0.6815596,0.5583514,0.24447003,0.13148484


In [None]:
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

In [None]:
dataloader