_Alex Malz (NYU)_

# Experiments with plotting the PZ DC1 results

In [1]:
import sys,os
import numpy as np

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.use('Agg')
mpl.rcParams['text.usetex'] = False
mpl.rcParams['mathtext.rm'] = 'serif'
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = 'Times New Roman'
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['savefig.dpi'] = 250
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['savefig.bbox'] = 'tight'

# %matplotlib inline

This call to matplotlib.use() has no effect because the backend has already
been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



In [3]:
metric_names = ['KS', 'CvM', 'AD']
colors = ['b', 'k', 'r']
#colors = ['royalblue', 'k', 'r']
n_symb = 5
shapes = [(n_symb, 3, 0), (n_symb-1, 0, 0), (n_symb, 1, 0), (n_symb+1, 2, 0)]

In [4]:
def readTable(tablefile):
    """
    function to read in the table data for KS,CvM, and AD stats
    assumes that the data is in alphabetical order, same as the labels defined in main
    input: 
        tablefile: string listing table filename
    output:
        data: datafile containing the KS, CvM, and two AD stats for the codes
    """
    data = np.loadtxt(tablefile,skiprows=1,usecols=(1,2,3,4))
    return data
    

In [5]:
mpl.__version__

'2.0.2'

In [6]:
def make_patch_spines_invisible(ax):
    ax.set_frame_on(True)
    ax.patch.set_visible(False)
    for sp in ax.spines.values():
        sp.set_visible(False)

def per_metric_helper(ax, n, data, codes):
    plot_n = n+1
    in_x = np.arange(len(codes))
    ax_n = ax
    ax_n.semilogy()


#     ax.scatter(in_x, [-1]*len(codes), color='k', alpha=0, marker=shapes[n], s=50, label=metric_names[n])
    n_factor = 0.15 * (plot_n - 2)
    if plot_n>1:
        ax_n = ax.twinx()
        #rot_ang = 270
        rot_ang=90
        label_space = 15.
    else:
        #rot_ang = 90
        rot_ang=90
        label_space = 0.
    if plot_n>2:
        ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (plot_n-1)))
        make_patch_spines_invisible(ax_n)
        ax_n.spines["right"].set_visible(True)
    
    #if n != 0:
    #    data[n] *= -1.
    ax_n.semilogy()
        
    handle = ax_n.scatter(in_x+n_factor, data[n], marker=shapes[plot_n], s=30, color=colors[n], label=metric_names[n])
    ax_n.set_ylabel(metric_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
    
    print('plotted '+metric_names[n]+': '+str(data[n]))
    return(ax, ax_n, handle)

def metric_plot(codes, data):
    xs = np.arange(len(codes))
    
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    plt.xticks(xs, codes, rotation=45)
    plt.xlabel('Photo-z code', fontsize=14)



    handles = []
#     for name in metric_names:
#         handles.append(ax.scatter([0], [0], alpha=0.))
    for n in range(len(metric_names)):
        (ax, ax_n, handle) = per_metric_helper(ax, n, data, codes)
        handles.append(handle)
#     plt.xticklabels(codes)
    #plt.xticks(xs, codes, rotation='vertical')
    #plt.xticks(xs, codes, rotation=45)
    #plt.subplots_adjust(bottom=-0.15)
    plt.legend(handles, metric_names)
#     plt.savefig('KSvsCvmvsAD.png')
    return(fig)

In [7]:
def pitmain():
    codes = ("ANNZ2","BPZ","DELIGHT","EAZY","FLEXZ","GPZ","LEPHARE","METAPHOR","CMNN","SKYNET","TPZ","TRAINZ")
    labels = ("ANNz2","BPZ","Delight","EAZY","FlexZBoost","GPz","LePhare","METAPhoR","NN","SkyNet","TPZ","TrainZ")
    labeldict = dict(zip(codes,labels))

    statdata = readTable("PITtabledata_withnull.dat")
    
    all_ks = statdata[:,0]
    all_cvm = statdata[:,1]
    all_ad = statdata[:,3]
    print (all_ks)
    print (all_cvm)
    print (all_ad)
    
    metric_data = np.array([all_ks, all_cvm, all_ad])
    fig = metric_plot(codes, metric_data)
    #plt.xticks(np.arange(len(codes)),codes,rotation=45)
    #plt.title("PIT Stats")
    plt.savefig("KSvsCvMvsAD_PIT_withnull.pdf", format='pdf')

In [8]:
pitmain()

[ 0.02    0.0388  0.0876  0.0723  0.024   0.0449  0.0663  0.0438  0.0795
  0.0747  0.1138  0.0047]
[  5.22500000e+01   2.80790000e+02   1.07517000e+03   1.10558000e+03
   6.88300000e+01   2.58560000e+02   4.73050000e+02   2.98560000e+02
   1.01111000e+03   7.63000000e+02   1.80174000e+03   1.16000000e+00]
[  7.59200000e+02   1.55750000e+03   6.16750000e+03   4.41860000e+03
   4.78800000e+02   1.41480000e+03   3.83800000e+02   7.15500000e+02
   6.30750000e+03   4.21640000e+03   1.05657000e+04   7.70000000e+00]
plotted KS: [ 0.02    0.0388  0.0876  0.0723  0.024   0.0449  0.0663  0.0438  0.0795
  0.0747  0.1138  0.0047]
plotted CvM: [  5.22500000e+01   2.80790000e+02   1.07517000e+03   1.10558000e+03
   6.88300000e+01   2.58560000e+02   4.73050000e+02   2.98560000e+02
   1.01111000e+03   7.63000000e+02   1.80174000e+03   1.16000000e+00]
plotted AD: [  7.59200000e+02   1.55750000e+03   6.16750000e+03   4.41860000e+03
   4.78800000e+02   1.41480000e+03   3.83800000e+02   7.15500000e+02
   

  (prop.get_family(), self.defaultFamily[fontext]))


In [9]:
def nzmain():
    codes = ("ANNZ2","BPZ","DELIGHT","EAZY","FLEXZ","GPZ","LEPHARE","METAPHOR","CMNN","SKYNET","TPZ","TRAINZ")
    labels = ("ANNz2","BPZ","Delight","EAZY","FlexZBoost","GPz","LePhare","METAPhoR","NN","SkyNet","TPZ","TrainZ")
    labeldict = dict(zip(codes,labels))

    statdata = readTable("NZtabledata_withnull.dat")
    
    all_ks = statdata[:,0]
    all_cvm = statdata[:,1]
    all_ad = statdata[:,3]
    print (all_ks)
    print (all_cvm)
    print (all_ad)
    
    metric_data = np.array([all_ks, all_cvm, all_ad])
    fig = metric_plot(codes, metric_data)
    #plt.title("N(z) Stats")
    #plt.xticks(np.arange(len(codes)),codes,rotation=45)
   
    plt.savefig("KSvsCvMvsAD_NZ_withnull.pdf", format='pdf')

In [10]:
nzmain()

[ 0.0237  0.0131  0.0256  0.0441  0.0128  0.0168  0.0254  0.035   0.0043
  0.0483  0.0126  0.0048]
[  25.676   17.587   42.341  169.977    7.007   25.605   58.9     60.79
    0.538  385.83     9.483    1.161]
[  276.    170.9   253.2   833.3   126.9   307.1   472.    672.3     5.8
  2110.3   110.6     8.3]
plotted KS: [ 0.0237  0.0131  0.0256  0.0441  0.0128  0.0168  0.0254  0.035   0.0043
  0.0483  0.0126  0.0048]
plotted CvM: [  25.676   17.587   42.341  169.977    7.007   25.605   58.9     60.79
    0.538  385.83     9.483    1.161]
plotted AD: [  276.    170.9   253.2   833.3   126.9   307.1   472.    672.3     5.8
  2110.3   110.6     8.3]
