In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import pickle
import warnings
import numpy as np
import pylab as plt

from astropy.io import fits
from natsort import natsorted
from astropy.time import Time

import sys
sys.path.append('/home/scratch/psalas/LASSI/lassi-analysis_v2')

from zernikies import getZernikeCoeffsOLS
from utils.utils import midPoint, stride, rolling_std
from lassiAnalysis import extractZernikesLeicaScanPair
from analyzeActiveSurface import processActiveSurfaceFITSPair

In [7]:
def zernikeOLS(x, y, z, nZern):

    # Use WLS to determine the Zernike coefficients.
    dz_s = stride(z.filled(np.nan), r_c=(4,4))
    dz_std = rolling_std(dz_s, no_null=False)
    dz_std_pad = np.pad(dz_std, ((2,1),(2,1)), mode='constant', constant_values=np.nan)
    dz_std_pad = np.ma.masked_where(dz_std_pad <= 20e-6, dz_std_pad)
    w = np.power(dz_std_pad, -2.)
    w = np.ma.masked_invalid(w)
    x_ = np.ma.masked_invalid(x-midPoint(x))
    y_ = np.ma.masked_invalid(y-midPoint(y))
    fl_wls = getZernikeCoeffsOLS(x_, y_, z, nZern, weights=w)
    
    return fl_wls

def process(maskRadius, sigma, outputDict, signal_registration, signal_type, radialMask,
            fixed_reference=False, refScanFile=None):
    
    # Use the .zernike.fits files to select signal scans.
    zern_files = natsorted(glob.glob("{0}/LASSI/*.zernike.fits".format(fitsDir)))[3:]
    
    # Use all the fits files to select reference scans.
    fits_files = natsorted(glob.glob("{0}/LASSI/*.fits".format(fitsDir)))[19:-12]
    fits_files = [ff for ff in fits_files if "zernike" not in ff and "smoothed" not in ff]
    
    # Load the ScanLog
    hdu = fits.open("{0}/ScanLog.fits".format(fitsDir))
    scanArr = hdu[1].data
    
    zFitDict = {}
    zFitDict['header'] = {'scan0': scan0, 'scanf': scanf, 'nZern': nZern,
                          'signal_registration': signal_registration,
                          'sigma_clip_diff': sigma_clip_diff,
                          'kernel_size': kernel_size,
                          'do_replace_nans': do_replace_nans,
                          'guess': guess, 'maskRadius': maskRadius,
                          'fixed_reference': fixed_reference,
                          'reference_scan_file': refScanFile,
                          'signal_type': signal_type,
                          'sigma': sigma,
                          'radial_mask': radialMask}
    
    # Get the Zernike coefficients from the active surface.
    z_idx = np.zeros((len(zern_files)), dtype=np.int)
    scans = np.zeros((len(zern_files)), dtype=np.int)

    z_as = np.zeros((len(zern_files), nZern-1), dtype=np.float)
    z_lassi = np.zeros((len(zern_files), nZern), dtype=np.float)
    z_as_obs = np.zeros((len(zern_files), nZern), dtype=np.float)

    z_in = np.zeros(len(zern_files), dtype=np.float)
    z_in_as = np.zeros(len(zern_files), dtype=np.float)
    z_obs = np.zeros(len(zern_files), dtype=np.float)
    
    
    if signal_type == b"SIGNAL":
        for i, zf in enumerate(zern_files):

            hdu = fits.open(zf)
            head = hdu[0].header
            z_lassi[i] = hdu[2].data['value']
            scans[i] = head['MC_SCAN']

            idx = np.where(scanArr['SCAN'] == scans[i])[0][0]
            idx_as = idx - 3
            sig_scan = scanArr['SCAN'][idx_as]
            ref_scan = scanArr['SCAN'][idx_as-6]

            # Load the AS fits file for the signal scan.
            as_file = file = '/'.join(scanArr[idx_as]['FILEPATH'].split('/')[-2:])
            hdu = fits.open("{0}/{1}".format(fitsDir, as_file))
            z_as[i] = hdu[1].data['value']

            # The AS Zernike contents start at Z1, not Z0.
            # The LASSI Zernike start at Z0.
            z_idx[i] = np.where(z_as[i] != 0)[0]

            as_file_sig = '/'.join(scanArr[idx_as]['FILEPATH'].split('/')[-2:])
            as_file_ref = '/'.join(scanArr[idx_as - 6]['FILEPATH'].split('/')[-2:])
            #print(as_file_sig, as_file_ref)
            xas, yas, aas, aar, fitlist = processActiveSurfaceFITSPair("{0}/{1}".format(fitsDir, as_file_ref), 
                                                                       "{0}/{1}".format(fitsDir, as_file_sig), 
                                                                       column='ABSOLUTE', filterDisabled=True, 
                                                                       verbose=False, plot=False)
            z_as_obs[i] = fitlist

            z_in[i] = z_as[i][z_idx[i]]
            z_in_as[i] = z_as_obs[i][z_idx[i]+1]*1e6 # microns
            z_obs[i] = z_lassi[i][z_idx[i]+1]

            zFitDict[scans[i]] = {'input zernike': z_idx[i] + 1,
                                  'input zernike value': z_as[i][z_idx[i]],
                                  'active surface zernike value': abs(z_in_as[i]),
                                  'active surface reference scan': ref_scan,
                                  'active surface signal scan': sig_scan,
                                 }
    
    if signal_type == b"REF":
        all_files = fits_files
    else:
        all_files = zern_files
    
    # Make an array with the scan times.
    dates = ['{0}T{1}'.format('-'.join(fn.split('/')[-1].split('.')[0].split("_")[:3]), 
                          fn.split('/')[-1].split('.')[0].split("_")[-1]) for fn in all_files]
    dates = Time(dates)
    
    # Measure Zernike coefficients from the scans.
    for i, ff in enumerate(all_files):

        if signal_type == b"SIGNAL":
            sf = ff.replace(".zernike.fits", ".fits")
            sigScanFile = "{0}/{1}".format(scanDir, sf.split('/')[-1].replace('.fits', '.ptx.csv'))
            # Find the corresponding reference scan.
            hdu = fits.open(sf)
            head = hdu[0].header
            scan = head['SCAN']
            refScan = head['RSCANNUM']
        else:
            # Check if the fits file is a reference scan.
            hdu = fits.open(ff)
            head = hdu[0].header
            scan = head['SCAN']
            if head['REFSCAN'] == 0:
                continue
            sigScanFile = "{0}/{1}".format(scanDir, ff.split('/')[-1].replace('.fits', '.ptx.csv'))
                
        if not fixed_reference:
            rf = scanArr[np.where(scanArr['SCAN'] == refScan)[0][0]][2].split('/')[-1]
            refScanFile =  "{0}/{1}".format(scanDir, rf.replace('.fits', '.ptx.csv'))            

        x,y,dz,fl = extractZernikesLeicaScanPair(refScanFile, sigScanFile, n=n, nZern=nZern, 
                                                 pFitGuess=guess, rMaskRadius=maskRadius, 
                                                 radialMask=radialMask,
                                                 **{'sigma':sigma}, verbose=False)

        # Use WLS to determine the Zernike coefficients.
        fl_wls = zernikeOLS(x, y, dz, nZern)
        
        if signal_type == b"SIGNAL":
            
            zFitDict[scan]['time mjd'] = dates[i].mjd
            
            zFitDict[scan]['recovered zernike'] = fl
            zFitDict[scan]['recovered zernike wls'] = fl_wls
            
            iz = zFitDict[scan]['input zernike']
            izv = zFitDict[scan]['input zernike value']
            zFitDict[scan]['recovered zernike difference'] = (fl[iz]*1e6 - izv)/izv
            zFitDict[scan]['recovered zernike difference wls'] = (fl_wls[iz]*1e6 - izv)/izv

            izv = zFitDict[scan]['active surface zernike value']
            zFitDict[scan]['recovered zernike difference AS'] = (fl[iz]*1e6 - izv)/izv
            zFitDict[scan]['recovered zernike difference AS wls'] = (fl_wls[iz]*1e6 - izv)/izv

            zFitDict[scan]['deformation map rms'] = np.nanstd(dz.filled(np.nan))
            
        else:
            zFitDict[scan] = {'recovered zernike': fl,
                              'deformation map rms': np.nanstd(dz.filled(np.nan)) 
                             }
        
    pickle.dump( zFitDict, open( outputDict, "wb" ) )

In [8]:
n = 512     # Use a nxn grid.
nZern = 37  # Only fit for 36 Zernike terms.
scan0 = 16
scanf = 125
signal_registration = False
sigma_clip_diff = False
kernel_size = 1
do_replace_nans = False
guess = [60., 0., 0., -50., 0., 0.]
fixed_reference = True
radialMask = False
signal_type = b'SIGNAL'
scanDir = '/home/scratch/psalas/LASSI/gpus/output/'
fitsDir = '/home/gbtdata/TLASSI_200315'
refScanFile = '/home/scratch/psalas/LASSI/gpus/output/2020_03_16_05:22:24.ptx.csv'

In [9]:
%%time
outputDicts = ["zFitDict_{}c.pickle".format(i) for i in range(2,3)]
maskRadius = [50]*len(outputDicts)
sigmas = [3]*len(outputDicts)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for out, mr, sigma in zip(outputDicts, maskRadius, sigmas):
        print(mr, sigma, out)
        process(mr, sigma, out, signal_registration, signal_type, radialMask,
                fixed_reference=fixed_reference, refScanFile=refScanFile)

50 3 zFitDict_2c.pickle
CPU times: user 44min 30s, sys: 36min 37s, total: 1h 21min 8s
Wall time: 8min 4s
