In [None]:
from datetime import datetime
from matplotlib import pyplot as plt
import multiprocessing as mp
import numpy as np

import vip_hci as vip
from hciplot import plot_frames, plot_cubes

from vip_hci.fm import normalize_psf
from vip_hci.psfsub import median_sub, pca
from vip_hci.preproc import find_scal_vector, frame_rescaling
from vip_hci.fits import open_fits
from vip_hci.metrics import significance, snr, snrmap
from vip_hci.var import mask_circle

import analysis
import redux_utils

In [None]:
data_dir = "./data/005_center_multishift/"
data_name = "wl_channel_%05i.fits"
datapath = data_dir + data_name
channel_min = 55
channel_max = 64
channel_nums = np.arange(channel_min, channel_max + 1)
wavelengths = np.linspace(3.59, 3.99, 30, endpoint=True)[10:20]

In [None]:
cube_list = []
for channel_num in channel_nums:
    cube_list.append(open_fits(datapath%channel_num, verbose=False))

cube = np.array(cube_list)
print(cube.shape)
n_channels = cube.shape[0]
n_frames = cube.shape[1]

In [None]:
# model psf - take median along time axis - beware of companion smearing
psf = np.median(cube[:,::redux_utils.everynthframe,:,:], axis=1)

In [None]:
# get flux and fwhm of host star in each channel
psfn, flux_st, fwhm = normalize_psf(psf, fwhm="fit", full_output=True, debug=False)

In [None]:
#pixel diameter of star
mask_rad = 10
mask = mask_circle(np.ones_like(cube[0,0]), mask_rad)

# opt_scals = np.zeros((n_frames, n_channels))
# opt_fluxes = np.zeros((n_frames, n_channels))
#iterate over wavelength channels to find opt spatial- and flux-scaling factors for each

# [for i in range(n_frames)]
n_frames = 100
time_cubes = [cube[:, i] for i in range(n_frames)]
input_list = list(zip(time_cubes, np.repeat([wavelengths], n_frames, axis=0), 
             np.repeat([flux_st], n_frames, axis=0), np.repeat([mask], n_frames, axis=0),
             np.repeat([2], n_frames), np.repeat(["stddev"], n_frames)))

In [None]:
n=100
dt = np.zeros(shape=(n,))

for i in range(n):
    tstart = datetime.now()
    find_scal_vector(*input_list[i])
    tend = datetime.now()
    dt[i] = (tend-tstart).seconds + (tend-tstart).microseconds/1e6

dt /= 1e6

In [None]:
print(np.mean(dt))
print(2202 * np.mean(dt) / 60)
# print(np.std(dt))
# print(np.sqrt(2202) * np.std(dt) / 60)


In [None]:
n=100
tstart = datetime.now()
with mp.Pool(redux_utils.numworkers) as pool:
    output = np.array(pool.starmap(find_scal_vector, input_list[:n], chunksize=5))
tend = datetime.now()
dt_mp = (tend-tstart).seconds + (tend - tstart).microseconds / 1e6

In [None]:
print(dt_mp)
print(2202/100 * dt_mp / 60)
# print(np.std(dt) / np.sqrt(n))
# print(np.sqrt(2202) * np.std(dt) / 60)

In [None]:
be_slow = False

if be_slow:
    tstart = datetime.now()
    print(tstart)
    with mp.Pool(redux_utils.numworkers) as pool:
        output = np.array(pool.starmap(find_scal_vector, input_list, chunksize=redux_utils.chunksize))
    tend = datetime.now()
    print(tend)
    dt = tend - tstart
    print(dt)
else:
    opt_scal_mean, opt_flux_mean = find_scal_vector(np.mean(cube, axis=1), wavelengths, flux_st, mask=mask, nfp=2, fm="stddev")

In [None]:
opt_scals = output[:,0]
opt_fluxes = output[:,1]

opt_scal_med = np.median(opt_scals, axis=0)
opt_flux_med = np.median(opt_fluxes, axis=0)

In [None]:
res_scaling = np.zeros_like(cube[:,0,:,:])
for i in range(n_channels):
    res_scaling[i] = opt_fluxes[0, i] * frame_rescaling(cube[i, 0], scale=opt_scal_med[i]) - cube[-1, 0]

In [None]:
# Classical ASDI
imlib = 'vip-fft'
interpolation = 'lanczos4'
angles = redux_utils.angles
mask_rad = 10
med_asdi = median_sub(cube, angles, scale_list=opt_scal_med, flux_sc_list=opt_flux_med,
                      radius_int=mask_rad, interpolation=interpolation, nproc=redux_utils.numworkers)

In [None]:
pl_loc = (12, 41)
st_loc = (63//2, 63//2)
pl_rad = np.sqrt(np.sum(np.square(np.array(st_loc) - np.array(pl_loc))))
fwhm_mean = np.mean(fwhm)
pl_snr = snr(med_asdi, pl_loc, fwhm=fwhm_mean, exclude_negative_lobes=True)
pl_sgn = significance(pl_snr, pl_rad, fwhm_mean, student_to_gauss=True)
print(pl_snr, pl_sgn)

In [None]:
plt.imshow(med_asdi)
plt.title("HD 1160 B Detection\nMedian-ASDI")
plt.text(0.6, 0.9, "snr: %.1f = %.1f$\sigma$"%(pl_snr, pl_sgn),
         transform=plt.gca().transAxes, fontsize=12, bbox=dict(facecolor='#f5f5dc', alpha=0.5))
plt.tight_layout()
plt.savefig("med_asdi.png")
plt.close()

In [None]:
# Full-frame PCA-ASDI
# Single step
pca_asdi = pca(cube, angles, scale_list=opt_scal_med, ncomp=redux_utils.numcomps, 
               adimsdi="single", crop_ifs=False, mask_center_px=mask_rad,
               interpolation=interpolation, mask_val=0, scaling="temp-standard", nproc=redux_utils.numworkers)

In [None]:
plot_frames(pca_asdi, colorbar=True)
analysis.calc_stats(pca_asdi, fwhm_mean)

In [None]:
# Double step
pca_asdi_dbl = pca(cube, angles, scale_list=opt_scal_med, ncomp=(redux_utils.numcomps, redux_utils.numcomps),
                   adimsdi="double", crop_ifs=False, mask_center_px=mask_rad,
                   interpolation=interpolation, mask_val=0, scaling="temp-standard", nproc=redux_utils.numworkers)

In [None]:
plot_frames(pca_asdi_dbl, colorbar=True)
analysis.calc_stats(pca, fwhm_mean)

In [None]:
# Annular PCA-ASDI
# Double step
pca_asdi_dbl = pca(cube, angles, scale_list=opt_scal_med, ncomp=(redux_utils.numcomps, redux_utils.numcomps),
                   adimsdi="double", crop_ifs=False, mask_center_px=mask_rad, asize=fwhm_mean,
                   interpolation=interpolation, mask_val=0, scaling="temp-standard", nproc=redux_utils.numworkers)

In [None]:
plot_frames(pca_asdi_dbl, colorbar=True)
analysis.calc_stats(pca, fwhm_mean)