In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 25 18:42:18 2019

@author: roeeyairpartoush
"""

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.axes as pltax

from astropy.io import fits


# =============== PLOT_CUT ================
def plot_cut(imgs, mask_ims, noise_ims, inds, cntr, ang, lnlen, axs_img, avgWid):
    #    cntr = np.array([[1520.0],[1760.0]])
    # ang [deg], lnlen [pixels], cntr [[pixles-x],[pixels-y]]
    
    pltax.Axes(plt.figure('Light Curves'),[0,0,1,1])
    ax = plt.gca()
    
    ang = np.radians(ang);
    unitvec = np.array([[np.cos(ang)],[np.sin(ang)]])
    uninorm = np.array([[np.sin(ang)],[-np.cos(ang)]])
    
    
    shp = imgs[0].shape
    Nsmp = np.floor(lnlen).astype('int')+1
    smp_mat = np.zeros((inds.size,Nsmp))
    
    for i in np.arange(len(inds)):
        img_flat = imgs[inds[i]].flatten()
        
        if avgWid<=1:
            Nbtch=1
        else:
            Nbtch=avgWid
        
        smp_sum = np.zeros((Nsmp,))
        for n in np.arange(Nbtch):
            cntr_tmp = cntr +uninorm*n
            str_pnt = cntr_tmp -unitvec*lnlen/2
            end_pnt = cntr_tmp +unitvec*lnlen/2
            

            vec = np.squeeze(np.linspace(str_pnt, end_pnt, np.floor(lnlen)+1),2)
            [img_smp, vec] = samp_image(imgs[inds[i]], mask_ims[inds[i]], noise_ims[inds[i]], vec, 'roundloc')
            
            smp_sum = smp_sum + img_smp
            plt.sca(axs_img[inds[i]])
            plt.scatter(vec[:,0],vec[:,1],s=0.1)
        
        smp = smp_sum/Nbtch
        plt.sca(ax)
        plt.plot(np.arange(img_smp.size),smp)
        smp_mat[i,:] = smp
    
    return smp_mat, vec

# =============== SAMP_IMAGE ================
def samp_image(img, mask, noise, vec,method):
    # img = 2D image array, vec = 2D vector of sampling locations on image
#    im_shp = img.shape
#    Nr = img.shape[0]
#    Nc = img.shape[1]

    if method=='bilinear':
        vc_shp = (vec.shape[0],1)
        fx = np.reshape(np.floor(vec[:,0]),vc_shp)
        fy = np.reshape(np.floor(vec[:,1]),vc_shp)
        cx = np.reshape(np.ceil(vec[:,0]),vc_shp)
        cy = np.reshape(np.ceil(vec[:,1]),vc_shp)
        
#        fx[(fx<0) | (fx>im_shp[1])] = np.NaN
#        cx[(fx<0) | (fx>im_shp[1])] = np.NaN
#        fy[(fy<1) | (fy>im_shp[0])] = np.NaN
#        cy[(cy<1) | (cy>im_shp[0])] = np.NaN
        
        x_mod = np.mod(vec[:,0],1)
        y_mod = np.mod(vec[:,1],1)

        vec_fxfy = np.concatenate((fx,fy),1)
        vec_fxcy = np.concatenate((fx,cy),1)
        vec_cxfy = np.concatenate((cx,fy),1)
        vec_cxcy = np.concatenate((cx,cy),1) 
        
        smp_fxfy = smp_ind(img, mask, noise, vec_fxfy)
        smp_fxcy = smp_ind(img, mask, noise, vec_fxcy)
        smp_cxfy = smp_ind(img, mask, noise, vec_cxfy)
        smp_cxcy = smp_ind(img, mask, noise, vec_cxcy)
        
        DP = np.multiply #dot product, elementwise multiplication
        img_smp = DP(DP(smp_fxfy,1-x_mod)+DP(smp_cxfy,x_mod),1-y_mod) + DP(DP(smp_fxcy,1-x_mod)+DP(smp_cxcy,x_mod),y_mod)

    elif method=='roundloc':
#        vec = np.round(vec)
#        smp_inds = np.ravel_multi_index([vec[:,1].astype(int),vec[:,0].astype(int)], im_shp)
#        img_flat = img.flatten()
#        img_smp = img_flat[smp_inds]
        img_smp = smp_ind(img, mask, noise, vec)
            
    return img_smp, vec

# =============== SMP_IND ================
def smp_ind(img, mask, noise, vec):
    img = img.astype(float)
    img[mask!=0] = np.NaN

    smps = np.zeros((vec.shape[0],))
    shp = img.shape
    [vec, notnan] = filt_inds(shp,vec)
    
    inds = np.ravel_multi_index([vec[:,1].astype(int),vec[:,0].astype(int)], shp)
    img_flat = img.flatten()
    smps[notnan] = img_flat[inds]
    
#    ArgW = np.argwhere
#    SQ = np.squeeze
    NOT = np.logical_not
    smps[NOT(notnan)] = np.nan
    
    return smps

# =============== FILT_INDS ================
def filt_inds(im_shp, vec):
    vec_x = vec[:,0]
    vec_y = vec[:,1]
    
    vec_x[(vec_x<0) | (vec_x>im_shp[1])] = np.NaN
    vec_y[(vec_y<1) | (vec_y>im_shp[0])] = np.NaN
    
    vc_shp = (vec.shape[0],1)
    vec_x = np.reshape(vec_x,vc_shp)
    vec_y = np.reshape(vec_y,vc_shp)
    
    ArgW = np.argwhere
    SQ = np.squeeze
    NOT = np.logical_not
    
    notnan = SQ(NOT(np.isnan(vec_x) | np.isnan(vec_y)))
    inds_notnan = SQ(ArgW(notnan))
    vec_x = vec_x[inds_notnan]
    vec_y = vec_y[inds_notnan]
    
#    vec_x = np.reshape(np.squeeze(vec_x),inds_notnan.shape)
#    vec_y = np.reshape(np.squeeze(vec_y),inds_notnan.shape)
    
    vec = np.concatenate((vec_x,vec_y),1)
    
    return vec, notnan

# =============== LOAD_DIFIMG ================
def load_difimg(home_dir, prefix, midfix, sufix, tmplt_img, image_files):
    
    image_data = list()
    mask_img = list()
    noise_img = list()
    ax=list()
    ax_img=list()
    
#    plt.close('all')
    
    for ind in np.arange(len(image_files)):
        flnm = home_dir + prefix + image_files[ind] + midfix + tmplt_img + sufix
        image_data.append(fits.getdata( flnm + '.fits'))
        mask_img.append(  fits.getdata( flnm + '.mask.fits'))
        noise_img.append( fits.getdata( flnm + '.noise.fits'))
        
        fig = plt.figure(ind+1)
#        mng = plt.get_current_fig_manager()
#        mng.window.showMaximized()
        if ind==0:
#            ax.append(fig.add_subplot(212))
            ax_img.append(fig.add_subplot(111))
        else:
#            ax.append(fig.add_subplot(212,sharex=ax[ind-1],sharey=ax[ind-1]))
            ax_img.append(fig.add_subplot(111,sharex=ax_img[ind-1],sharey=ax_img[ind-1]))
            
        clim = 70
        
        mat = image_data[ind]
        plt.sca(ax_img[ind])
        plt.imshow(mat, cmap='gray', vmin=-clim, vmax=clim)
        
#        (x,y)=np.unique(mat,return_counts=True)
#        
#        plt.sca(ax[ind])    
#        plt.scatter(x,np.log10(y),2)
    
    return image_data, mask_img, noise_img, ax_img

# =============== PLOT_DIF ================
def plot_dif(base_img,diff_imgs):
    image_data = list()
    ax_img=list()
    
    for ind in np.arange(len(diff_imgs)):
        
        fig = plt.figure(ind+100)
#        mng = plt.get_current_fig_manager()
#        mng.window.showMaximized()
        if ind==0:
            ax_img.append(fig.add_subplot(111))
        else:
            ax_img.append(fig.add_subplot(111,sharex=ax_img[ind-1],sharey=ax_img[ind-1]))
            
        clim = 70
        
        image_data.append(diff_imgs[ind]-base_img)
        plt.sca(ax_img[ind])
        plt.imshow(image_data[ind], cmap='gray', vmin=-clim, vmax=clim)
        
#        (x,y)=np.unique(mat,return_counts=True)
#        
#        plt.sca(ax[ind])    
#        plt.scatter(x,np.log10(y),2)
    
    return image_data, ax_img


# =============== FLNM2TIME ================
def flnm2time(names):
    # names = list of str, each with first 8 chars representing date of later epoch in diff image
    # times = int representing number of days since names[0]
    times = np.zeros(len(names))
    [y0,m0,d0] = sdate(names[0])
    for i in np.arange(len(names)-1)+1:
        [y,m,d] = sdate(names[i])
        times[i] = 365*(y-y0) + 30*(m-m0) + (d-d0)
    return times

# =============== SDATE ================
def sdate(str_date):
    year =  int(str_date[0:4])
    month = int(str_date[4:6])
    day =   int(str_date[6:8])
    
    return year,month,day

#def avg_mats(mats):
#    # assuming mats is a list of all-same-sized Numpy.ndarrays
#    avg_m = np.zeros(mats[0].shape)
#    
#    for i in np.arange(len(mats)):
#        asf