# WD Database for Python

The goal is to download the following database: http://vizier.u-strasbg.fr/viz-bin/VizieR?-source=J%2FMNRAS%2F455%2F3413 for use in a machine-learning inspired scheme to 

In [1]:
#Preamble. Standard packages for to load
import astropy
from astropy.table import Table, Column, MaskedColumn, vstack 
import numpy as np
from astroquery.vizier import Vizier
import matplotlib.pyplot as plt
import urllib2
# special IPython command to prepare the notebook for matplotlib
%matplotlib inline 
from astroquery.sdss import SDSS
from astropy import units as u
from astropy import coordinates as coords
from astropy.io import fits
import astropy.io.ascii as ascii
import os
import random



In [2]:
catalog_list_1 = Vizier.find_catalogs('New white dwarf SDSS DR12')
catalog_list_2 = Vizier.find_catalogs('J/ApJS/204/5')
catalog_list_3 = Vizier.find_catalogs('J/MNRAS/446/4078')

In [None]:
print({k:v.description for k,v in catalog_list_1.items()})
print({k:v.description for k,v in catalog_list_2.items()})
print({k:v.description for k,v in catalog_list_3.items()})

{u'J/MNRAS/455/3413': u'New white dwarf and subdwarf stars in SDSS DR12 (Kepler+, 2016)'}
{u'J/ApJS/204/5': u'SDSS DR7 white dwarf catalog (Kleinman+, 2013)'}
{u'J/MNRAS/446/4078': u'New white dwarf stars in SDSS DR10 (Kepler+, 2015)'}


In [None]:
Vizier.ROW_LIMIT = -1
catalogs_1 = Vizier.get_catalogs(catalog_list_1.keys())
catalogs_2 = Vizier.get_catalogs(catalog_list_2.keys())
catalogs_3 = Vizier.get_catalogs(catalog_list_3.keys())

In [None]:
print(catalogs_1)
print(catalogs_2)
print(catalogs_3)

In [None]:
catalogs = vstack([catalogs_1[0], catalogs_2[0], catalogs_3[0]])

In [None]:
catalogs

In [None]:
#This is a way to add coordinates if we need to. I don't think we need to right now.
#catalogs['Coordinates'] = coords.SkyCoord(catalogs['_RAJ2000'], catalogs['_DEJ2000'], frame='icrs')

In [None]:
#Here we do clean-up trying to merge those columns which were not properly merged
#because they were named different things in different catalogs. These include
#SDSS identifiers, a weird underscore for a log(g) parameter, different ways of
#specifying spectral type, and different ways of calibrating signal to noise.

PMF = catalogs['PMF']

for ind,obj in enumerate(PMF):
    if type(obj) != np.ma.core.MaskedConstant:
        split_PMF = obj.split('-')
        catalogs['Plate'][ind] = split_PMF[0]
        catalogs['MJD'][ind] = split_PMF[1]
        catalogs['Fiber'][ind] = split_PMF[2]
        
PMJ = catalogs['PMJ']

for ind,obj in enumerate(PMJ):
    if type(obj) != np.ma.core.MaskedConstant:
        split_PMJ = obj.split('-')
        catalogs['Plate'][ind] = split_PMJ[0]
        catalogs['MJD'][ind] = split_PMJ[1]
        catalogs['Fiber'][ind] = split_PMJ[2]

log_g_ah = catalogs['log_g_']

for ind,obj in enumerate(log_g_ah):
    if type(catalogs['logg'][ind]) == np.ma.core.MaskedConstant:
        if type(obj) != np.ma.core.MaskedConstant:
            catalogs['logg'][ind] = obj
            catalogs['e_logg'][ind] = catalogs['e_log_g_'][ind] 

Types = catalogs['SpType']

for ind,obj in enumerate(Types):
    if type(catalogs['Type'][ind]) == np.ma.core.MaskedConstant:
        if type(obj) != np.ma.core.MaskedConstant:
            catalogs['Type'][ind] = obj
            
SN = catalogs['SNg']

for ind,obj in enumerate(SN):
    if type(catalogs['S_N'][ind]) == np.ma.core.MaskedConstant:
        if type(obj) != np.ma.core.MaskedConstant:
            catalogs['S_N'][ind] = obj

### Let's select a quality sample of WD spectra

In [None]:
WD = catalogs[catalogs['Type'] == 'DA']

print("We start with", len(WD), "WDs")

# First, we want to remove systems with NaN's - only found in log g
good_WD = WD[np.where(~np.isnan(WD['logg']))]
print("We removed", len(WD[np.isnan(WD['logg'])]), "systems with NaNs")

# Now, we want to remove systems in which the log g was assumed. These all have e_logg=0.0
good_WD = good_WD[good_WD['e_logg'] != 0.0]
print("Number with determined log g",len(good_WD))

# Next, we only want objects with a S/N above 10
good_WD = good_WD[good_WD['S_N']>10]
print("Number with S/N > 10",len(good_WD))

# Next, we want objects with log g uncertainties smaller than, say, 0.2
good_WD = good_WD[good_WD['e_logg']<0.2]
print("Number with log g error less than 0.2",len(good_WD))

# # Let's do the same with T_eff uncertainties - limit to 15% of T_eff
# good_WD = good_WD[good_WD['e_Teff']<0.15*good_WD['Teff']]
# print("Number with Teff uncertainties less than 15%",len(good_WD))

# Print the median Teff error
print("Median T_eff error:", np.median(good_WD['e_Teff']))

# Print the median log g error
print("Median log g error:", np.median(good_WD['e_logg']))

In [None]:
def download_data(cat):
    directory = 'data/'
    cat['file'] = MaskedColumn(length=len(cat),dtype='S32')
    for ind,plate in enumerate(cat['Plate']):
        try:
            spec = SDSS.get_spectra_async(plate=plate, mjd=cat['MJD'][ind], fiberID=WD['Fiber'][ind])
            url_of_interest = str(spec[0]).split()[4]
            filename = directory+url_of_interest.split('/')[-1]       
        except:
            print "No spectra found in database:", plate, cat['MJD'][ind], cat['Fiber'][ind]
            pass
        if os.path.exists(filename): 
            cat['file'][ind] = filename
            continue
        try:
            spec = SDSS.get_spectra(plate=plate, mjd=cat['MJD'][ind], fiberID=cat['Fiber'][ind])
            spec[0].writeto(filename)
            WD['file'][ind] = filename
        except:
            print "Could not download spectra:", plate, cat['MJD'][ind], cat['Fiber'][ind]
            pass  

In [None]:
spec = SDSS.get_spectra(plate=6679,mjd=56401,fiberID=756)

print spec[0]

In [None]:
print(1)

In [None]:
download_data(good_WD)

In [None]:
def get_filename(plate,mjd,fiber,wd):
    try:
        plwd = wd[wd['plate'] == plate]
        if len(plwd) == 0: raise Exception()
    except Exception:
        print 'No plate number'
        return ''
    try:
        mjwd = plwd[plwd['mjd'] == mjd]
        if len(mjwd) == 0: raise Exception()
    except Exception:
        print 'No mjd date'
        return ''
    try:
        fbwd = mjwd[mjwd['fiber'] == fiber]
        if len(fbwd) == 0: raise Exception()
    except Exception:
        print 'No fiber number'
        return ''
    name = fbwd['file']
    return str(name[0])

In [None]:
def plot_spec(plate,mjd,fiber,wd):
    fits_spec = fits.open(get_filename(plate,mjd,fiber,wd))
    wavelength = 10**fits_spec[1].data['loglam']
    flux = fits_spec[1].data['flux']
    fig, ax = plt.subplots(1, 2, figsize=(12,4))
    ax[0].plot(wavelength, flux)
    ax[1].plot(wavelength, flux)
    ax[1].set_xlim(3800, 4400)
    plt.show()

In [None]:
plot_spec(337,51997,195,WD)

In [None]:
plt.hist(WD['S_N'])
plt.xlabel('Signal to Noise')

### Create training, test, and validation sets

In [None]:
# Randomly shuffle indices
indices = np.arange(len(good_WD))
np.random.shuffle(indices)
good_shuffle_WD = good_WD[indices]

# Determine training, test, and validation sets
validation_WD = good_shuffle_WD[0:300]
test_WD = good_shuffle_WD[300:600]
training_WD = good_shuffle_WD[600:]

### Plot up systems in T_eff and log g space to see where they lie

In [None]:
plt.errorbar(training_WD['logg'], training_WD['Teff'], xerr=training_WD['e_logg'], 
             yerr=training_WD['e_Teff'], ls='none', fmt='', capsize=0, label='train')

plt.errorbar(test_WD['logg'], test_WD['Teff'], xerr=test_WD['e_logg'], 
             yerr=test_WD['e_Teff'], ls='none', fmt='', capsize=0, label='test')

plt.errorbar(validation_WD['logg'], validation_WD['Teff'], xerr=validation_WD['e_logg'], 
             yerr=validation_WD['e_Teff'], ls='none', fmt='', capsize=0, label='val')
plt.legend()


plt.ylabel(r'T$_{\rm eff}$')
plt.xlabel(r'Log $g$')

plt.yscale('log')
plt.ylim(5.0e3, 1.0e5)

plt.show()

Now we will look at just the DAs.

In [None]:
set(good_WD['SpType'])

In [None]:
DA_good = good_WD[good_WD['SpType']=='DA']

print("Number of DAs in sample",len(DA_good))

In [None]:
plt.errorbar(DA_good['Teff'], DA_good['logg'],yerr=DA_good['e_logg'], 
             xerr=DA_good['e_Teff'], ls='none', fmt='', capsize=0)


plt.xlabel(r'T$_{\rm eff}$')
plt.ylabel(r'Log $g$')

plt.xscale('log')
plt.xlim(5.0e3, 1.0e5)

plt.show()

In [None]:
ascii.write(DA_good, 'DA_good.csv', format='csv', include_names=['_RAJ2000','_DEJ2000','SDSS','S_N','umag','e_umag','gmag','e_gmag','rmag','e_rmag','imag','e_imag','zmag','e_zmag','E_B-V_','pm'])

In [None]:
ind = []
for num,thing in enumerate(good_WD['SpType']):
    if 'A' in thing:
        ind += [num]

All_A = good_WD[ind]
print("Number of As in sample",len(All_A))

In [None]:
All_A

In [None]:
plt.errorbar(All_A['Teff'], All_A['logg'],yerr=All_A['e_logg'], 
             xerr=All_A['e_Teff'], ls='none', fmt='', capsize=0, color='red')


plt.xlabel(r'T$_{\rm eff}$')
plt.ylabel(r'Log $g$')

plt.xscale('log')
plt.xlim(5.0e3, 1.0e5)

plt.show()