In [16]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.mlab as mlab
import math

from astropy.table import Table

% matplotlib inline
import make_dictionaries

In [17]:
question = 't04_spiral'
answer = 'a08_spiral'

questions = make_dictionaries.questions

In [3]:
debiased_values = Table.read('output_files/' + question + '/' + answer + '/debiased.fits')
print('Loaded debiased values...')
full_data = Table.read('../../fits/full_sample_debiased.fits')
print('Loaded full data...')

Loaded debiased values...
Loaded full data...


In [37]:
from scipy.stats import binned_statistic

def bin_by_column(column, nbins, fixedcount=True):
    sorted_indices = np.argsort(column)
    if fixedcount:
        bin_edges = np.linspace(0, 1, nbins + 1)
        bin_edges[-1] += 1
        values = np.empty(len(column))
        values[sorted_indices] = np.linspace(0, 1, len(column))
        bins = np.digitize(values, bins=bin_edges)
    else:
        bin_edges = np.linspace(np.min(column),np.max(column), nbins + 1)
        bin_edges[-1] += 1
        values = column
        bins = np.digitize(values, bins=bin_edges)
    x, b, n = binned_statistic(values, column, bins=bin_edges)
    return x, bins


def reduce_sample(full_data,questions,question,p_cut=0.5,N_cut=5,normalised_values=True,dataset='w13'):
    '''For a given question, get thresholded data'''
    # Get the reference sample from the previous data:
    
    previous_q = questions[question]['pre_questions']
    previous_a = questions[question]['pre_answers']
    
    if dataset == 'w13':
        suffix = '_debiased'
    elif dataset == 'raw':
        suffix = '_weighted_fraction'
    elif normalised_values == True:
        suffix = '_debiased_rh'
    else:
        suffix = '_debiased_rh_normalised'
    
    if previous_q != None:
        
        p_col = np.ones(len(full_data))
        
        for m in range(len(previous_q)):
            p_col = p_col*(full_data[previous_q[m] + '_' + previous_a[m] + suffix])
        N_col = (full_data[previous_q[-1] + '_' + previous_a[-1] + '_count'])
        
        select = (p_col > p_cut) & (N_col >= N_cut)
        data_reduced = full_data[select]
        print('{}/{} galaxies with p>{} and N>={}.'.format(len(data_reduced),
                                                          len(full_data),p_cut,N_cut))
    
    else:
        data_reduced = full_data.copy()
        print('Primary question, so all {} galaxies used.'.format(len(data_reduced)))
        
        select = np.ones(len(data_reduced)) == 1
    
    return data_reduced,select


def plot_thresholds(questions,question_dictionary,full_data,p_th=0.5):
    
    fig = plt.figure()
    plt.subplots_adjust(hspace=0,wspace=0)
    
    lines = ['solid','dotted','dashed']
    
    colors = ['red','blue','green']
    
    for y,Q in enumerate(questions):
        
        answers = question_dictionary[Q]['answers']
        q_label = question_dictionary[Q]['questionlabel']
        a_labels = question_dictionary[Q]['answerlabels']
        
        data,select = reduce_sample(full_data,question_dictionary,Q)

        C = colors
        
        vl_data = data[data['in_volume_limit']]
        
        for x,A in enumerate(answers):
        
            answer_debiased = vl_data[Q + '_' + A + '_debiased_rh']
            answer_w13 = vl_data[Q + '_' + A + '_debiased']
            answer_raw = vl_data[Q + '_' + A + '_weighted_fraction']
            z_values = vl_data['REDSHIFT_1']
            
            z_vals,bins = bin_by_column(z_values,20)
            
            for x2,column in enumerate([answer_raw,answer_w13,answer_debiased]):
                
                a_label = [a_labels[x],None,None]
                
                fracs = get_fractions(column,bins,p_th)
                plt.plot(z_vals,fracs,linewidth=2
                              ,linestyle=lines[x2],label=a_label[x2],color=C[x])
            
            plt.legend(prop={'size':10})
            
    return None


def make_axes(n_morph,xlabel='$\log(f_v)$',ylabel='cumulative fraction',sharex=True,sharey=True,stack=False,width=20):

    x_dimension = math.ceil(math.sqrt(n_morph))
    y_dimension = math.ceil(n_morph/x_dimension)
    n_plots = x_dimension*y_dimension
    n_spare = n_plots-n_morph
    remove_axes = np.arange(-n_spare,0)
    
    height = (y_dimension/x_dimension)*width
    
    if stack == True:
        fig,axes = plt.subplots(n_morph,1,sharex=sharex,sharey=sharey,figsize=(10,3*n_morph))
    else:
        fig,axes = plt.subplots(y_dimension,x_dimension,sharex=sharex,sharey=sharey,figsize=(width,
                                                                                             height))
        if n_morph >= 3:
            for ax in axes[-1,:]:
                ax.set_xlabel(xlabel)
            for ax in axes[:,0]:
                ax.set_ylabel(ylabel) 
        else:
            axes[0].set_xlabel(xlabel)
            axes[0].set_ylabel(ylabel)
            axes[1].set_xlabel(xlabel)
        axes=axes.ravel()
        for m in remove_axes:
            fig.delaxes(axes[m])
    
    plt.subplots_adjust(hspace=0,wspace=0)

    return fig,axes


def get_fractions(column,bins,th=0.5):
    
    fracs = np.zeros(len(np.unique(bins)))
    
    for m,b in enumerate(np.unique(bins)):
        
        bin_data = column[bins == b]
        fracs[m] = np.sum(bin_data > th)/len(bin_data)
        
    return fracs