In [9]:
%matplotlib qt5
import os
import numpy as np
import pandas as pd
import sqlite3
import lsst.daf.persistence as dafPersist
import lsst.afw.display as afwDisplay
import lsst.geom
import matplotlib.ticker as plticker
from astropy.visualization import (ZScaleInterval, SqrtStretch, ImageNormalize)
import astropy.units as u

import matplotlib.pyplot as plt
import matplotlib
import astropy

In [10]:
(np.__version__, matplotlib.__version__, astropy.__version__)

('1.14.5', '2.2.2', '3.0.3')

In [11]:
cwpRepo = '/home/gkovacs/data/repo_DM-17825/ingested/rerun/proc_2019-02-21'
cwpTemplateRepo = '/home/gkovacs/data/repo_DM-17825/templates'
my_dbName = '/home/gkovacs/data/repo_DM-17825/ingested/rerun/proc_2019-02-21/association.db'
mrawls_dbName = '/home/gkovacs/data/repo_DM-17825/mrawls_cw_processed2/association.db'

butlerCwp = dafPersist.Butler(cwpRepo)
butlerCwpTemplate = dafPersist.Butler(cwpTemplateRepo)

In [3]:
def loadAllPpdbObjects(repo, dbName='association.db'):
    """Load select DIAObject columns from a PPDB into a pandas dataframe.

    Parameters
    ----------
    repo : `str`
        Path to an output repository from an ap_pipe run.
    dbName : `str`, optional
        Name of the PPDB, which must reside in (or relative to) repo.

    Returns
    -------
    objTable : `pandas.DataFrame`
        DIA Object Table containing only objects with validityEnd NULL.
        Columns selected are presently hard-wired here.
    """
    connection = sqlite3.connect(os.path.join(repo, dbName))

    # These are the tables available in the ppdb
    tables = {'obj': 'DiaObject', 'src': 'DiaSource', 'ccd': 'CcdVisit'}

    # Only get objects with validityEnd NULL because that means they are still valid
    objTable = pd.read_sql_query('select diaObjectId, ra, decl, nDiaSources, \
                                  gPSFluxMean, validityEnd, flags from {0} \
                                  where validityEnd is NULL;'.format(tables['obj']), connection)
    return objTable
# ---
def defMiniRegion(objTable):
    miniRegion = ((objTable['decl'] < -5.6) & (objTable['decl'] > -5.8) & 
               (objTable['ra'] > 155.2) & (objTable['ra'] < 155.3) &
               (objTable['nDiaSources'] > 2))
    return miniRegion
# ---
def plotMiniRegion(objTable, miniRegion, title=None):
    print('Plotting {0} DIAObjects'.format(len(objTable.loc[miniRegion, 'ra'])))
    fig = plt.figure(figsize=(7,5))
    ax1 = fig.add_subplot(111)
    cb1 = ax1.scatter((objTable.loc[miniRegion, 'ra'].values*u.deg).to_value(u.rad),
                      (objTable.loc[miniRegion, 'decl'].values*u.deg).to_value(u.rad),
                      marker='.', lw=0, s=objTable.loc[miniRegion, 'nDiaSources']*8,
                      c=objTable.loc[miniRegion, 'flags'],  #c=objTable.loc[miniRegion, 'nDiaSources'],
                      alpha=0.5, )
                      #cmap=plt.cm.get_cmap('viridis'))
    binMax = np.max(objTable['nDiaSources'].values)
    #cbplot = plt.colorbar(cb1, ax=ax1)
    #cbplot.set_label('Number of DIASources')
    #cbplot.set_clim(0, binMax)
    #cbplot.solids.set_edgecolor("face")
    plt.xlabel('RA (rad)')
    plt.ylabel('Dec (rad)')
    #plt.xlim([155.3, 155.2])
    #plt.ylim([-5.8, -5.6])
    plt.xlim([2.71040, 2.70875])
    plt.ylim([-0.1014, -0.0978])
    if title:
        plt.title(title)
# ---
def load_sources(repo, obj, sqliteFile='association.db'):
    connection = sqlite3.connect(os.path.join(repo, sqliteFile))
    tables = {'obj': 'DiaObject', 'src': 'DiaSource', 'ccd': 'CcdVisit'}
    srcTable = pd.read_sql_query('select diaSourceId, diaObjectId, ccdVisitId, midPointTai, \
                                 apFlux, psFlux, apFluxErr, psFluxErr, totFlux, totFluxErr, flags \
                                 from {1} where diaObjectId = {0};'.format(obj, tables['src']), connection)
    connection.close()
    return(srcTable)
# ---
def plot_lightcurve(repo, templateRepo, obj, patch, objTable, 
                    useTotFlux=False, plotAllCutouts=False, cutoutIdx=0, labelCutouts=False,
                    diffimType='deepDiff_differenceExp'):
    sources = load_sources(repo, obj)
    ra = objTable.loc[objTable['diaObjectId'] == obj, 'ra']
    dec = objTable.loc[objTable['diaObjectId'] == obj, 'decl']
    flags = sources['flags']    
    dataIds = sources['ccdVisitId'].values  # these are ints
    dataIdDicts = []
    for dataId in dataIds:
        visit = int(str(dataId)[0:6])
        ccdnum = int(str(dataId)[6:])
        dataIdDict = {'visit': visit, 'ccdnum': ccdnum}
        dataIdDicts.append(dataIdDict)
    centerSource = lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees)
    size = lsst.geom.Extent2I(30, 30)
    
    print('DIAObject ID:', obj)
    #print('Flags:', flags)
    print('RA (deg):', ra.values)
    print('Dec (deg):', dec.values)
    print('DIASource IDs:', sources['diaSourceId'].values)
    print('Data IDs:', dataIdDicts)

    plt.figure()

    # light curve with psFlux by default (uses totFlux if useTotFlux=True)
    plt.subplot(212)
    plt.xlabel('Time (MJD)', size=16)
    if not useTotFlux:
        plt.errorbar(sources['midPointTai'], sources['psFlux']*1e9, yerr=sources['psFluxErr']*1e9, 
                     ls=':', marker='o', color='#2979C1')
        plt.ylabel('Difference Flux (nJy)', size=16)
    else:
        plt.errorbar(sources['midPointTai'], ources['totFlux']*1e9, yerr=sources['totFluxErr']*1e9,
                     ls=':', marker='o', color='#2979C1')
        plt.ylabel('Flux (nJy)', size=16)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)    
    
    # processed image
    plt.subplot(231)
    plt.gca().get_xaxis().set_ticks([])
    plt.gca().get_yaxis().set_ticks([])
    plt.title('Processed', size=16)
    butler = dafPersist.Butler(repo)
    calexpFirst = butler.get('calexp', dataIdDicts[cutoutIdx])
    calexpArray = calexpFirst.getCutout(centerSource, size).getMaskedImage().getImage().getArray()
    calexpNorm = ImageNormalize(calexpArray, interval=ZScaleInterval(), stretch=SqrtStretch())
    plt.imshow(np.rot90(np.fliplr(calexpArray)), cmap='gray', norm=calexpNorm)
    
    # template image
    plt.subplot(232)
    plt.gca().get_xaxis().set_ticks([])
    plt.gca().get_yaxis().set_ticks([])
    plt.title('Template', size=16)
    templateDataId = {'filter': 'g', 'tract': 0, 'patch': patch}
    butlerTemplate = dafPersist.Butler(templateRepo)
    template = butlerTemplate.get('deepCoadd', dataId=templateDataId)
    templateArray = template.getCutout(centerSource, size).getMaskedImage().getImage().getArray()
    templateNorm = ImageNormalize(templateArray, interval=ZScaleInterval(), stretch=SqrtStretch())
    plt.imshow(np.flipud(templateArray), cmap='gray', norm=templateNorm)
    #plt.imshow(np.fliplr(np.rot90(templateArray)), cmap='gray', norm=templateNorm)
    
    # difference image
    plt.subplot(233)
    plt.gca().get_xaxis().set_ticks([])
    plt.gca().get_yaxis().set_ticks([])
    plt.title('Difference', size=16)
    diffimFirst = butler.get(diffimType, dataIdDicts[cutoutIdx])
    diffimArray = diffimFirst.getCutout(centerSource, size).getMaskedImage().getImage().getArray()
    diffimNorm = ImageNormalize(diffimArray, interval=ZScaleInterval(), stretch=SqrtStretch())
    plt.imshow(np.rot90(np.fliplr(diffimArray)), cmap='gray', norm=diffimNorm)
    
    if plotAllCutouts:
        fig = plt.figure(figsize=(8,8))  # optional figure with cutouts for all visits
        fig.subplots_adjust(hspace=0, wspace=0)
        for idx, dataId in enumerate(dataIdDicts):
            calexp = butler.get('calexp', dataId)
            calexpArray = calexp.getCutout(centerSource, size).getMaskedImage().getImage().getArray()
            calexpNorm = ImageNormalize(calexpArray, interval=ZScaleInterval(), stretch=SqrtStretch())
            diffim = butler.get(diffimType, dataId)
            diffimArray = diffim.getCutout(centerSource, size).getMaskedImage().getImage().getArray()
            diffimNorm = ImageNormalize(diffimArray, interval=ZScaleInterval(), stretch=SqrtStretch())
            plt.subplot(10, 10, idx+1)
            plt.gca().get_xaxis().set_ticks([])
            plt.gca().get_yaxis().set_ticks([])
            plt.imshow(np.rot90(np.fliplr(calexpArray)), cmap='gray', norm=calexpNorm)
            if labelCutouts:
                if idx == 0:
                    plt.text(1, 26, 'Proc', color='lime', size=8)
                plt.text(2, 5, str(sources['midPointTai'][idx])[1:8], color='lime', size=8)            
            plt.subplot(10, 10, idx+50+1)
            plt.gca().get_xaxis().set_ticks([])
            plt.gca().get_yaxis().set_ticks([])
            plt.imshow(np.rot90(np.fliplr(diffimArray)), cmap='gray', norm=diffimNorm)
            if labelCutouts:
                if idx == 0:
                    plt.text(1, 26, 'Diff', color='lime', size=8)
                plt.text(2, 5, str(sources['midPointTai'][idx])[1:8], color='lime', size=8)

In [None]:
cwpObjTable = loadAllPpdbObjects(cwpRepo)


In [12]:
# Get the image, make the cutout and show in a bigger cutout
dataId={'visit': 411371, 'ccdnum': 57}
calexp=butlerCwp.get('calexp',dataId=dataId)

In [13]:
ra=155.27087585
decl=-5.68988946

In [16]:
image_center = lsst.geom.SpherePoint(ra,decl,lsst.geom.degrees)
cutout_extent = lsst.geom.Extent2I(50, 50)

In [17]:
exposure = butlerCwp.get('calexp',dataId=dataId)
cutout = exposure.getCutout(image_center,cutout_extent)

In [18]:
bbox= cutout.getBBox()
bbox.getCorners()
extent = (bbox.getBeginX(), bbox.getEndX(), bbox.getBeginY(), bbox.getEndY())
extentT = (bbox.getBeginY(), bbox.getEndY(), bbox.getBeginX(), bbox.getEndX())
extentR = (bbox.getEndY(), bbox.getBeginY(), bbox.getEndX(), bbox.getBeginX())
print(extentR)

(1496, 1446, 1090, 1040)


In [19]:
A = cutout.getMaskedImage().getImage().getArray()
img_norm = ImageNormalize(A,interval=ZScaleInterval(), stretch=SqrtStretch())

# Converting SkyWcs to astropy.wcs.WCS

SkyWcs has a fits metadata representation, that can be converted to a dictionary.

In [20]:
from astropy.wcs import WCS
from astropy.visualization.wcsaxes import WCSAxes

M = cutout.getWcs().getFitsMetadata()
W = WCS(M.toDict()).swapaxes(0,1)

In [21]:
W

WCS Keywords

Number of WCS axes: 2
CTYPE : 'DEC--TAN-SIP'  'RA---TAN-SIP'  
CRVAL : -5.686392881856413  155.2908557516832  
CRPIX : 1743.524956  1016.824682  
CD1_1 CD1_2  : -1.2485291485569e-07  -7.2894527279287e-05  
CD2_1 CD2_2  : 7.30983903374267e-05  -1.2922970767026e-08  
NAXIS : 0  0

In [22]:
fig = plt.figure()
WA = WCSAxes(fig,rect=(0.1,0.1,0.8,0.8), wcs=W)
ax = fig.add_axes(WA)

ax.imshow(A.T[::-1,::-1],cmap='gray',norm=img_norm,origin='lower',extent=extentR)
#ax.coords[0].set_color('blue')
#ax.coords[0].set_major_formatter('d.ddd')
ax.coords['ra'].set_major_formatter('d.ddd')
# It is confusing whether coords[0], coords[1] refer to
ax.coords[0].set_axislabel('RA')
ax.coords[1].set_axislabel('DEC')
ax.coords.grid(True,ls='dotted',color='blue')
#ax.coords['ra'].grid(color='blue')
ax.coords[0].set_ticklabel_position('l') 
#ax.coords[1].set_color('red')
ax.coords[1].set_ticklabel_position('b') 
#ax.coords['dec'].grid(color='red')

  np.sqrt(values, out=values)
  xa[xa < 0] = -1
  np.sqrt(values, out=values)
  xa[xa < 0] = -1


In [None]:
display = afwDisplay.getDisplay()

display.setMaskTransparency(60)
display.scale("asinh", "zscale")


mask = cutout.getMask()
for maskName, maskBit in mask.getMaskPlaneDict().items():
    print('{}: {}'.format(maskName, display.getMaskPlaneColor(maskName)))

In [None]:
display.mtv(cutout)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.imshow(np.rot90(np.fliplr(A)),cmap='gray',norm=img_norm)

In [None]:
cwpObjTable = loadAllPpdbObjects(cwpRepo)
cwpMiniRegion = defMiniRegion(cwpObjTable)

In [None]:
my_conn = sqlite3.connect(my_dbName)
mr_conn = sqlite3.connect(mrawls_dbName)

In [None]:
cwpObjList = list(cwpObjTable.loc[cwpMiniRegion, 'diaObjectId'])
cwpMiniUnflagged = cwpMiniRegion & (cwpObjTable['flags'] == 0)
cwpObjMiniList = list(cwpObjTable.loc[cwpMiniUnflagged, 'diaObjectId'])

In [None]:

cwpObjMiniList = list(cwpObjTable.loc[cwpMiniUnflagged, 'diaObjectId'])


In [None]:
patchList = ['10,8', '11,8', '12,8', '13,8',
             '10,7', '11,7', '12,7', '13,7',
             '10,9', '11,9', '12,9', '13,9',
             '10,5', '11,5', '12,5', '13,5',
             '10,6', '11,6', '12,6', '13,6',
             '10,10', '11,10', '12,10', '13,10']

In [None]:


def patchFinder(obj, objTable, templateButler, patchList):
    for patch in patchList:
        ra = objTable.loc[objTable['diaObjectId'] == obj, 'ra']
        dec = objTable.loc[objTable['diaObjectId'] == obj, 'decl']
        centerSource = lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees)
        size = lsst.geom.Extent2I(30, 30)
        templateDataId = {'filter': 'g', 'tract': 0, 'patch': patch}
        templateImage = templateButler.get('deepCoadd', dataId=templateDataId)
        try:
            cutout = templateImage.getCutout(centerSource, size)
        except:
            continue
        else:
            templatePatch = patch
            #print('template patch:', templatePatch)
            #print('object id:', obj)
            return templatePatch
            break



In [None]:
patch = patchFinder(cwpObjList[0], cwpObjTable, butlerCwpTemplate, patchList)
cwpTemplate = butlerCwpTemplate.get('deepCoadd', dataId={'filter': 'g', 'tract': 0, 'patch': patch})

In [None]:
D = np.array([[1,2],[3,4]])

In [None]:
np.rot90(D,axes=(0,1))

In [None]:
?np.rot90

# =============

In [26]:
fig = plt.figure()
ax = fig.add_subplot(1,1,1,projection=W)
ra = ax.coords['ra']
print(ra.set_axislabel_visibility_rule)
ra.set_axislabel('RA')

ax.imshow(A,cmap='gray',norm=img_norm,origin='lower',extent=extent)
#ax.coords[0].set_color('blue')
ax.coords['ra'].set_ticklabel_position('bltr') 
ax.coords['ra'].grid(color='blue') 
#ax.coords[1].set_color('red')
ax.coords['dec'].set_ticklabel_position('bltr') 
ax.coords['dec'].grid(color='red')

<bound method CoordinateHelper.set_axislabel_visibility_rule of <astropy.visualization.wcsaxes.coordinate_helpers.CoordinateHelper object at 0x7fad0f25f2e8>>


  np.sqrt(values, out=values)
  xa[xa < 0] = -1
  np.sqrt(values, out=values)
  xa[xa < 0] = -1


In [19]:
%matplotlib qt5
from astropy.io import fits
from astropy.wcs import WCS
import matplotlib.pyplot as plt

hdu = fits.open('test.fits')[1]

wcs = WCS(hdu.header)

#plt.subplot(projection=wcs)
fig=plt.figure()
fig.add_subplot(1,1,1)
plt.imshow(hdu.data, origin='lower')
ax=plt.gca()
ax.coords['ra'].set_ticklabel_position('l')
ax.coords['ra'].set_axislabel('RA')


AttributeError: 'AxesSubplot' object has no attribute 'coords'

In [25]:
ax.get_coords_overlay('pixel')

ValueError: Unknown frame: pixel

In [21]:
wcs

WCS Keywords

Number of WCS axes: 2
CTYPE : 'RA---TAN-SIP'  'DEC--TAN-SIP'  
CRVAL : 155.290855751683  -5.68639288185641  
CRPIX : -33.175318  287.524956  
CD1_1 CD1_2  : -1.29229707670269e-08  7.30983903374267e-05  
CD2_1 CD2_2  : -7.28945272792879e-05  -1.24852914855695e-07  
NAXIS : 30  30

In [33]:
hdu = fits.open('test.fits')