In [1]:
%matplotlib inline
import logging
import Euclid
import matplotlib.pyplot as plt
import numpy as np

from collections import namedtuple
from glob import glob
from matplotlib import colors
from astropy.table import Table
from tqdm.notebook import tqdm
from scipy.stats import gaussian_kde

%elogin
%erun PHZ_NNPZ 0.9

%load_ext autoreload
%autoreload 2
from nnpz.reference_sample.ReferenceSample import ReferenceSample
from nnpz.photometry.ListFileFilterProvider import ListFileFilterProvider
from nnpz.photometry.FnuuJyPrePostProcessor import FnuuJyPrePostProcessor
from nnpz.photometry.ReferenceSampleParallelPhotometryBuilder import ReferenceSamplePhotometryParallelBuilder as ParallelPhotometryBuilder
from nnpz.photometry.ReferenceSamplePhotometryBuilder import ReferenceSamplePhotometryBuilder as PhotometryBuilder

plt.rcParams['figure.figsize'] = (20, 10)

class TqdmWrapper(object):
    def __init__(self, total):
        self.tqdm = tqdm(total=total)
        self.progress = 0
        
    def __call__(self, n):
        self.tqdm.update(n - self.progress)
        self.progress = n

Iter = namedtuple('Iter', ['sed'])



In [2]:
logging.getLogger().setLevel(logging.WARNING)

In [3]:
ref_sample = ReferenceSample('/home/aalvarez/Work/Data/SC8/PHZ_Calibration/MergeReferenceSampleDirs/reference_sample_dir.dir/')

In [4]:
# Filter Provider
subaru_trans_files = glob('/home/aalvarez/Phosphoros/AuxiliaryData/Filters/Subaru/IB6*.dat')
with open('/tmp/filter_list.txt', 'wt') as fd:
    fd.write('\n'.join(subaru_trans_files))
filter_provider = ListFileFilterProvider('/tmp/filter_list.txt')

# Photometry
fnuu = FnuuJyPrePostProcessor()

In [6]:
# Builder
photo_builder = ParallelPhotometryBuilder(filter_provider, fnuu, ncores=2)

In [7]:
class SedGenerator(object):
    def __init__(self):
        self.__seds = []
        self.__zs = []
        
    def add(self, sed, z):
        self.__seds.append(sed)
        self.__zs.append(z)
        
    def __iter__(self):
        for sed, zs in zip(self.__seds, self.__zs):
            for z in zs:
                yield Iter(sed=np.stack([sed[:, 0] * (z + 1), sed[:,1] / (1 + z)**2], axis=1))
                
    def __len__(self):
        return len(self.__seds) * len(self.__zs[0])

In [8]:
NSAMPLES = 500

transmissions = dict([(n, []) for n in filter_provider.getFilterNames()])
all_seds = SedGenerator()

all_ids = ref_sample.getIds()
for obj_id in tqdm(all_ids):
    sed_z = ref_sample.getSedData(obj_id)
    pdz = ref_sample.getPdzData(obj_id)
    # SED corresponds to max PDZ (?)
    ref_z = pdz[:,0][pdz[:,1].argmax()]
    # Un-shift SED
    sed = np.copy(sed_z)
    sed[:, 0] /= 1 + ref_z
    sed[:, 1] *= (1 + ref_z)**2
    
    # Generate MC Z samples
    # Note: we use ABS becase due to some floating point errors we may get tiny negatives (basically 0)
    normed_pdz = np.abs(pdz[:,1]) / np.sum(np.abs(pdz[:,1]))
    zpicks = np.random.choice(pdz[:,0], NSAMPLES, p=normed_pdz)
    all_seds.add(sed, zpicks)

HBox(children=(FloatProgress(value=0.0, max=99923.0), HTML(value='')))




In [9]:
ret = photo_builder.buildPhotometry(all_seds, progress_listener=TqdmWrapper(len(all_seds)))

HBox(children=(FloatProgress(value=0.0, max=49961500.0), HTML(value='')))

In [58]:
basepath = '/home/aalvarez/Work/Data/SC8/PHZ_Calibration/MergeReferenceSampleDirs/'
all_v = []
names = []
for k, v in ret.items():
    filename = os.path.join(basepath, k + '_data_1.npy')
    np.save(filename, v.reshape(-1, NSAMPLES))
    idx_filename = os.path.join(basepath, k + '_index.npy')
    idx = np.zeros((len(all_ids), 3), dtype=np.int64)
    idx[:, 0] = all_ids
    idx[:, 1] = 1
    idx[:, 2] = np.arange(0, len(all_ids))
    np.save(idx_filename, idx)
    all_v.append(v.reshape(-1, NSAMPLES, 1))
    names.append(k)

Merge all

In [59]:
z=np.concatenate(all_v, axis=-1)

In [60]:
z_idx = np.array(idx, copy=True)

In [61]:
np.save(os.path.join(basepath, 'MultiBand_data_1.npy'), z)
np.save(os.path.join(basepath, 'Multiband_index.npy'), z_idx)

In [62]:
np.save(os.path.join(basepath, 'MultiBand_names.npy'), names)

In [73]:
x=np.zeros(500, dtype={'names':['ID', 'Ring'], 'formats':[np.int, np.int]})
x.shape

(500,)