In [0]:
#-------------------------------------------------------
# demo code to display the results of rv accuracy determinations
# at different noise levels. Data are generated by the
# radial-velocity-estimation-09.ipynb notebook.
#
# Alex Szalay, 2021-09-27
#-------------------------------------------------------
import numpy as np
import scipy as sp
import os
import collections
import copy
import time

import matplotlib.pyplot as plt
from scipy import stats
from numpy import linalg
%matplotlib inline

#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# set working directory (change it to your own) 
# to the location of this notebook
#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
os.chdir('/home/idies/workspace/Storage/szalay/persistent/PFS/RV')


In [0]:
#----------------------------------
# read all the parameter data
#----------------------------------
DATAPATH = '/home/idies/workspace/pfs/'
para = np.genfromtxt(DATAPATH+'para.csv', delimiter=',',skip_header=1)

MH    = para[:,5]
Teff  = para[:,6]
logG  = para[:,7]
CH    = para[:,8]
AH    = para[:,9]

uM = np.unique(MH)
uT = np.unique(Teff)
uG = np.unique(logG)
uC = np.unique(CH)
uA = np.unique(AH)

if (1==0):
    print('Teff:',uT)
    print('logG:',uG)
    print('M/H :',uM)
    print('ALP :',uA)
    print('C/H :',uC)

#--------------------------------------
# read the raw Mauna Kea sky spectrum
#--------------------------------------
sky = np.genfromtxt(DATAPATH+'skybg_50_10.csv', delimiter=',')
sky[:,0]= 10*sky[:,0]

#-----------------------------------
# create subdirectory for figures
#-----------------------------------
if (not os.path.isdir('./figs')):
    os.mkdir('./figs')
    

def isValidGrid(m,c,a,t,g):
    #----------------------------------------------------------
    # determine if the parameters are at a valid grid point
    #----------------------------------------------------------
    ix = (MH==m) & (CH==c) & (AH==a) & (Teff==t) & (logG==g)
    return ix.any()


In [0]:
#-------------------------------------------------------------
# string formatting to compose filenames from the parameters
#-------------------------------------------------------------

def fmt(x):
    y = np.round(np.abs(10*x)+0.2).astype(np.int32)
    z = '{:+03.0f}'.format(y).replace('+','p')
    if (np.sign(x)<0):
        z = z.replace('p','m')
    return z

def fmn(x):    
    return str(np.floor(x).astype(np.int32))

def ffm(x):
    return '{:+4.2f}'.format(x)   
    
def getfilename(m,c,a,t,g,R):
    #---------------------------------------
    # get the long filename for the website
    #---------------------------------------
    fname = 'am'+fmt(m)+'c'+fmt(c)+'o'+fmt(a)+'t'+fmn(t)+'g'\
        +fmn(10*g)+'v20modrt0b'+'{:d}'.format(R)+'rs.asc.bz2'
    return fname

def getname(m,c,a,t,g):
    #----------------------------------
    # get short name for the spectrum
    #----------------------------------
    fname = 'T'+fmn(t)+'G'+fmn(10*g)+'M'+fmt(m)+'A'+fmt(a)+'C'+fmt(c)
    return fname

def decodeS(s):
    T = float(s[1:5])
    G = float(s[6:8])/10
    M = float(s[ 9:12].replace('m','-').replace('p','+').replace('03','02.5').replace('08','07.5'))/10
    A = float(s[13:16].replace('m','-').replace('p','+').replace('03','02.5').replace('08','07.5'))/10
    C = float(s[17:  ].replace('m','-').replace('p','+').replace('03','02.5').replace('08','07.5'))/10
    return (T,G,M,A,C)

In [0]:
#----------------------------------------------
# build figure to show the rv error vs S/N
#----------------------------------------------
def showMSE(qr, sname, tname, prt=0):

    ymin = 100
    ymax =   0
    plt.figure(figsize=(8,8))    
    if (type(qr)==type(list((1,2)))):
        for n in range(len(qr)):
            q = np.array(qr[n])
            lw = 0.75
            if (sname[n]==tname):
                lw=2
            plt.loglog(q[:,2],np.sqrt(q[:,4]),'o-',lw=lw)
    else:
        q = qr
        plt.loglog(q[:,2],np.sqrt(q[:,4]),'o-',lw=lw)
        
    plt.xlabel('S/N')
    plt.ylabel('sqrt(MSE(rv)) [km/s]')
    plt.title('('+tname+')')
    plt.xticks([10,20,50,100,200],['10','20','50','100','200'])
    plt.grid(which='both')
    
    ylim = plt.ylim()
    r = np.sqrt(100*ylim[0]/ylim[1])
    plt.ylim((ylim[0]/r,100*ylim[0]/r))
    plt.xlim((5,500))

    plt.legend(sname)              
               
    plt.savefig('figs/mse-'+tname+'.png',dpi=300);
    plt.show()

def showRVHisto(a):
    #--------------------------------------
    # show the bias info for each template
    #--------------------------------------
    uNL = np.unique(a[:,1])
    bins = np.linspace(-1,1,51)
    bb = bins[:-1]
    nc = a.shape[0]
    SN = np.empty((0, nc), float)
    clr = ['b','g','r','m','c','orange','violet']
    for i, nl in enumerate(uNL):
        #--------------------------------
        # go through each noise level
        #--------------------------------
        aa = a[a[:,1]==nl]
        sn = np.mean(aa[:,2])
        SN = np.append(SN, sn)
        dv = aa[:,3]
        hh,_ = np.histogram(dv,bins)
        #print(sn,np.mean(np.abs(dv)))
        plt.plot(bb,hh+sn,color=clr[i], lw=0.5)
    return SN



def showBias(A, sname, tname, prt=0):
    
    plt.figure(figsize=(8,8))

    SN = []
    for n in range(len(A)):
        #-----------------------------------------------
        # for each template go through all realizations
        #-----------------------------------------------
        #print(' ')
        #print(sname[n])
        s = showRVHisto(A[n])
        SN.append(s)
    SN = np.array(SN)
    sn = np.mean(SN,axis=0)
    
    leg=[]
    for s in sn:
        leg.append(str(np.round(s)))
    plt.legend(leg,loc='upper right')
    
    plt.grid()
    plt.xlabel('delta_rv[km/s]')
    plt.ylabel('freq')    
    plt.title('('+tname+')')
    plt.xlim((-1,1))
                   
    plt.savefig('figs/bias-'+tname+'.png',dpi=300);
    
    plt.show()
    

In [0]:
def getRVStats(code, sname):
    #---------------------------------------------
    # read all the files and plot the result
    #---------------------------------------------
    tpath = './1D/'+code+sname+'/'
    #--------------------------------------------------
    # get the output from all the distinct simulations
    #--------------------------------------------------
    S = []
    R = []
    A = []
    P = np.genfromtxt(tpath+'parameters.csv', delimiter=',')
    #----------------------------
    # go through each template
    #----------------------------
    for pmt in P:
        g = getname(*pmt)
        try:
            a = np.genfromtxt(tpath+g+'-A.csv',delimiter=',')        
            r = np.genfromtxt(tpath+g+'-SN.csv',delimiter=',')
            S.append(g)
            R.append(r)
            A.append(a)
        except:
            continue
            
    showMSE(R,S,name)
    showBias(A,S,name)            
    
    return R,S,A


In [0]:
def getLLHStats(code, sname):
    #--------------------------------------------------
    # get all the data related to the 2D LLH contours
    #--------------------------------------------------
    # get the parameters
    #---------------------
    tpath = './2D/'+code+'-'+sname
    #tpath = './test/'+code+'-'
    #--------------------------------------------------
    # get the output from all the distinct simulations
    #--------------------------------------------------
    A = []
    try:
        P = np.genfromtxt(tpath+'-parameters.csv', delimiter=',')
        A = np.genfromtxt(tpath+'-map.csv',delimiter=',')
    except:
        print(tpath+'parameters.csv not found...')
        return None,None, None

    return A,P

def showOneContour(a,p,sn,ax):

    #---------------------------------------------------
    # get the columns which change
    # the temperature [3] is always the x-axis
    #---------------------------------------------------
    ps = ['M/H','Alpha','C/H','Teff','log g']
    vs = np.where(np.std(p,axis=0)>0)[0]
    flip=0
    if (vs[1]==3):
        vs = vs[::-1] 
    #----------------------
    # get the 2D arrays
    #---------------------
    n = int(np.sqrt(a.shape[0]))
    v = a.reshape((n,n))    
    y  = p[:,vs[1]].reshape((n,n))
    x  = p[:,vs[0]].reshape((n,n))
    #------------------------
    # get the center point
    #------------------------
    m = int((n-1)/2)
    xc = x[m,m]
    yc = y[m,m]    
    #------------------------
    # smooth the contours
    #------------------------
    X = sp.ndimage.zoom(x, 8)
    Y = sp.ndimage.zoom(y, 8)
    V = sp.ndimage.zoom(v, 8)
    #---------------------------
    # plot it all
    #---------------------------

    im = ax.contourf(X,Y,V,16)
    ax.plot(xc,yc,'k*')
    ax.set_title(sn)
    ax.set_xlabel(ps[vs[0]])
    ax.set_ylabel(ps[vs[1]])
    return im
    
def showContours(code, name):
    #-------------------------------------------------
    # display the countours on the Teff-logG plane
    #-------------------------------------------------
    A,P = getLLHStats(code,name)
    
    a = np.array(A)
    p = np.array(P)    
    n2 = np.int(a.shape[0]/2)
    
    #--------------
    # get the S/N
    #--------------
    sarr = [10,20,30,50,100]
    carr  = ['P','Q','R','S','Z']
    for n in range(len(carr)):
        if (code[0]==carr[n]):
            SN = sarr[n]
            
    #--------------------------------------
    # get the S/N and build the labels
    #--------------------------------------    
    sn = name+'  S/N= {:.0f}'.format(SN)
    fname = code+'-'+name+'-{:.0f}'.format(SN)   
    
    #-----------------
    # do the plots
    #-----------------
    fig, axs = plt.subplots(1, 2, figsize=(18,6.5))
    
    #--------------    
    # Teff - log g
    #--------------
    im = showOneContour(a[:n2],p[:n2],sn,axs[0])
    fig.colorbar(im, ax=axs[0])    
    
    #--------------
    #  Teff - M/H
    #--------------
    im = showOneContour(a[n2:],p[n2:],sn,axs[1])
    fig.colorbar(im, ax=axs[1])
    
    plt.savefig('figs/'+fname+'.png',dpi=300)
    
    plt.show()
    return a,p


In [0]:
######################################################
# Generate all the contourplots
# Needs data from the ./2D subdirectory
# THe different S/N cases are coded with 
# letters ['P','Q','R','S','Z']
# figures are written in the ./figs subdirectory
######################################################
name = 'T6000G25Mm05Ap03Cp00'
carr  = ['P','Q','R','S','Z']
for c in carr:
    A,P= showContours(c,name)


In [0]:
###########################################################################
# Generate all the MSE and BIAS plots.
# Needs data form the ./1D subdirectory.
# The example will show the results from the experiment with the code 'W'
###########################################################################
names = [\
    'T4250G10Mm15Ap03Cp00',\
    'T4750G40Mm10Ap03Cp00',\
    'T4750G25Mm20Ap03Cp00',\
    'T6000G25Mm05Ap03Cp00',\
    'T6500G40Mm10Ap03Cp00',\
    'T8000G25Mm20Ap03Cp00']

for name in names:
    R,S,A = getRVStats('W',name);


