In [None]:
'''
    This code generates a train-test-validation split based on the ZTF light curves found at 
    
    /global/cfs/projectdirs/desi/users/akim/Projects/QSO_Distance/data/lc/*"
    
    and the DESI data found at 
    
    /global/cfs/projectdirs/desi/users/akim/Projects/QSO_Distance/data/dates.hdf5
    
    which were the initial datasets given to us by Dr. Kim at the start of the project. We note 
    that each dataset has the same 1,271,420 targetIDs, which we split into training/validation/test
    subsets in a 60-20-20 ratio, which are saved in train_split.json.
    
    Additionally, we consider the DESI dataset located at 
    
    /pscratch/sd/r/raichoor/desi-ztf-qso/
    
    which has 1,237,825 targetIDs, of which 1,040,800 overlap with the targetIDs in the original dataset.
'''

In [1]:
import h5py
import json
import random
import numpy as np
import glob
from tqdm import tqdm
import pandas as pd

In [1]:
# get LC data from ZTF

In [31]:
lc_path = "/global/cfs/projectdirs/desi/users/akim/Projects/QSO_Distance/data/lc/*"
paths = glob.glob(lc_path)

keys = []
for path in tqdm(paths):
    desi_data = h5py.File(path, 'r')
    for key in desi_data.keys():
        keys.append(int(key))

100%|██████████| 1174/1174 [00:40<00:00, 29.04it/s]


In [32]:
key_set = set(keys)
print(len(key_set))

1271420


In [None]:
#get DESI info from dates file

In [3]:
dates = pd.read_hdf("/global/cfs/projectdirs/desi/users/akim/Projects/QSO_Distance/data/dates.hdf5")

In [33]:
DESI_date_keys = list(dates["targetid"])
DESI_date_keys_set = set(DESI_date_keys)
len(DESI_date_keys_set)

1271420

## New DESI dataset

In [15]:
#load all of the fits files
import fitsio
import os
from time import time
ardir = os.path.join(os.getenv("PSCRATCH", "/pscratch/sd/r/raichoor/desi-ztf-qso/"))
sumfn = os.path.join(ardir, "desi-ztf-qso-iron-pernight-summary.fits")

def get_ws(ardir):
    fn = sorted(
        glob.glob(os.path.join(ardir, "pernight-spectra","desi-ztf-qso-iron-*-*.fits"))
    )[0]
    ws = fitsio.read(fn, "BRZ_WAVE")
    return ws

ws = get_ws(ardir)
nwave = len(ws)

start = time()
d = fitsio.read(sumfn, "FIBERMAP")

print("reading {} done (took {:.1f}s)".format(sumfn, time() - start))

reading /pscratch/sd/t/thomaslu/desi-ztf-qso-iron-pernight-summary.fits done (took 2.7s)


In [35]:
new_desi_keys = list(d["TARGETID"])
new_desi_keys_set = set(d["TARGETID"])
len(new_desi_keys_set)

1237825

In [34]:
len(DESI_date_keys_set.union(key_set))
## original DESI and ZTF have perfect overlap, size 1271420

1271420

In [36]:
len(new_desi_keys_set.union(key_set))
#original dataset and fits dataset are disjoint, combined size 1468445

1468445

In [38]:
print(f"overlap: {1237825+1271420-1468445}")

overlap: 1040800


## Make the split

In [38]:
from sklearn.model_selection import train_test_split

numpy.random.seed(42)
train_targets, test = train_test_split(keys, test_size = 0.4)
test_targets, val_targets = train_test_split(test, test_size = 0.5)

In [39]:
print(len(train_targets)/len(keys))
print(len(test_targets)/len(keys))
print(len(val_targets)/len(keys))

0.5999995539157001
0.20000022304214995
0.20000022304214995


In [41]:
train_test_split = {"train": train_targets, "test": test_targets, "validation": val_targets}
json_d = json.dumps(train_test_split)
with open("train_split.json", "w") as f:
    f.write(json_d)