In [None]:
#=================================================================================================================
# redMaPPer catalogue utilities:
#=================================================================================================================

#-----------------------------------------------------------------------------------------------------------------
def import_redMaPPer(filedir):
    """
    Function to read and import redmapper catalogue fits.
    Returns dict with keys: 
    id
    ra[0:5]
    dec[0:5]
    spec_z
    photo_z
    photo_z_err
    rich
    rich_err
    prob[0:5]
    s
    n_gal
    """
    
    print('\nReading catalogue found in: %s ...' % filedir)
        
    raw_catalog = Table.read(filedir, format='fits')
    
    print('\nNumber of raw objects found: %i' % len(raw_catalog))
    print('\nRaw columns in fits file:\n\n%s' % raw_catalog.colnames)
    
    catalog = {}
    
    catalog['id'] = raw_catalog['ID'].data
    catalog['ra'] = raw_catalog['RA_CEN'].data
    catalog['dec'] = raw_catalog['DEC_CEN'].data
    catalog['spec_z'] = raw_catalog['Z_SDSS'].data
    catalog['photo_z'] = raw_catalog['Z_LAMBDA'].data
    catalog['photo_z_err'] = raw_catalog['Z_LAMBDA_ERR'].data
    catalog['rich'] = raw_catalog['LAMBDA'].data
    catalog['rich_err'] = raw_catalog['LAMBDA_ERR'].data
    catalog['prob'] = raw_catalog['P_CEN'].data
    catalog['s'] = raw_catalog['S'].data
    catalog['n_gal'] = catalog['rich']/catalog['s']

    print('\nCreated catalog with keys:\n\n/ ', end='')    
    for key in catalog.keys():
        print(key , '/ ', end='')
    print('\n')
        
    return catalog


#-----------------------------------------------------------------------------------------------------------------
def weighted_redshift(sample):
    """
    Returns the probability weighted redshift of cluster.
    """
    
    z_mask = sample['spec_z'] > 0.
    z_mask_any = z_mask.any(axis=1)

    prob = ma.masked_array(sample['prob'], mask=~z_mask)

    dot_z = prob*sample['spec_z']
    sum_z = dot_z.sum(axis=1)

    sum_prob = prob.sum(axis=1)

    z_avg = (sum_z/sum_prob).data[z_mask_any]

    sample_z_avg = np.zeros(len(sample['id']))

    sample_z_avg[~z_mask_any] = -1.
    sample_z_avg[z_mask_any]  = z_avg

    return sample_z_avg


#-----------------------------------------------------------------------------------------------------------------
def create_raw_sample(redm_filedir):
    """
    Returns the redmapper sample to work with.
    """
    
    #redm_filedir = '%s/redMaPPer_DR12_v6.3.fits' % cat_dir
    raw_redm = import_redMaPPer(redm_filedir)

    # Compute weighted spectroscopic redshift
    raw_redm['w_spec_z'] = weighted_redshift(raw_redm)

    # Reduce centers to most probable CG position
    raw_redm['ra']  = raw_redm['ra'][:,0]
    raw_redm['dec'] = raw_redm['dec'][:,0]
    
    return raw_redm


#-----------------------------------------------------------------------------------------------------------------
def create_redm_subsample(raw_redm, rich_cutoff, z_min, z_max):
    """
    Returns the redmapper sample with selection cutoffs:
    Richness > rich_cutoff
    Redshift: [z_min, z_max]
    """
    
    rich_mask = raw_redm['rich'] >= rich_cutoff
    photo_z_mask = (raw_redm['photo_z'] >= z_min) & (raw_redm['photo_z'] <= z_max)
    spec_z_mask = (raw_redm['w_spec_z'] == -1.) | (raw_redm['w_spec_z'] >= z_min)
    redshift_z_mask = photo_z_mask & spec_z_mask

    ## Global mask:
    global_mask = rich_mask & redshift_z_mask

    print('\nNumber of redMaPPer clusters: %i\n' % len(raw_redm['id'][global_mask]))
    print('*'*20)
    
    ## Construct sample:
    sample = {}
    mask = global_mask

    for key in raw_redm.keys():
        sample[key] = raw_redm[key][mask]
        
    return sample


#-----------------------------------------------------------------------------------------------------------------
class Cluster(object):
    """
    Defines Cluster object.
    """
    
    def __init__(self, sample, id):
        self.id = id
        self.sample_mask = sample['id']==id
        self.ra = sample['ra'][self.sample_mask][0]
        self.dec = sample['dec'][self.sample_mask][0]
        if 'photo_z' in sample:
            self.photo_z = sample['photo_z'][self.sample_mask][0]
        if 'spec_z' in sample:
            self.spec_z = sample['spec_z'][self.sample_mask][0]
        if 'z' in sample:
            self.z = sample['z'][self.sample_mask][0]
        if 'rich' in sample:
            self.rich = sample['rich'][self.sample_mask][0]
                
    def set_id(self, id):
        return self.id
    
    def get_ra(self):
        return self.ra
    
    def get_dec(self):
        return self.dec
    
    def get_photo_z(self):
        return self.photo_z
    
    def get_spec_z(self):
        return self.spec_z
    