In [None]:
import numpy as np
import matplotlib.pyplot as plt

from astropy.io import fits
from astropy.visualization import ImageNormalize, ZScaleInterval, PercentileInterval
from astropy.wcs import WCS
from astropy.wcs.utils import proj_plane_pixel_scales
from astropy import units as u
from astropy.coordinates import SkyCoord

from photutils import SkyCircularAperture
from photutils import CircularAperture
from photutils import CircularAnnulus
from photutils import aperture_photometry

#distance_range=[1.40745, 1.74476] * u.kpc # HIP 90617
arr=['s-pbhmi','s-pbsmi','sp--h-i','s-p-hmi','sp--hmi','sp--s-i','s-p-smi','sp--smi',
     'spu-hmi','spu-smi','s---s-i','s---smi','s-ubhmi','s-ubsmi','s-u-hmi','s-u-smi','spubsmi','spubhmi']

obj_name = 'V921Sco'
dist = 1650 #pc

arr_userfile=['/home/buddy/Documents/MIPT/MSX Images/V921 Sco/msxmapA.fits', '/home/buddy/Documents/MIPT/MSX Images/V921 Sco/msxmapC.fits',
              '/home/buddy/Documents/MIPT/MSX Images/V921 Sco/msxmapD.fits', '/home/buddy/Documents/MIPT/MSX Images/V921 Sco/msxmapE.fits']
#models= 'spubsmi'


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx], idx

def rob_fits(user_fitsfile):
    hdul = fits.open(user_fitsfile)
    wavelength = hdul['SPECTRAL_INFO'].data
    wave = wavelength['WAVELENGTH']
    aperture = hdul['APERTURES'].data
    ap = aperture['APERTURE']  
    fnu = hdul['VALUES'].data
    mod_names = hdul['MODEL_NAMES'].data
    names = np.array(mod_names.astype('<U11'))
    return(ap, wave, fnu*(1000/dist)**2/1000, names)

def user_fits(rob_fitsfile,ap,flag):
    hdul = fits.open(rob_fitsfile)
    hdul.info()
    image_data = hdul[0].data[::-1,: ] #инвертирует ось y
    wcs = WCS(hdul[0].header)
    waveleng = hdul[0].header['WAVELENG']
    scale = np.mean(proj_plane_pixel_scales(wcs))*3600
    #scale = np.abs(CDELT1)*3600  #arcsec/pix
    FOV = image_data.shape[0]*scale
    hdul.close()
    print("Scale =",scale, " arcsec/pix")
    print("FOV=", FOV/60/60, "deg")
    print('Image data shape:', image_data.shape)
    # AU           | 1.49598e+11 m 
    # pc           | 3.08568e+16 m  
    L = dist * FOV    #мой способ!!
    aupixel = L/image_data.shape[0]   # AU in pixel
    radius = ap/aupixel/2

    radpix = ap/2/dist/scale
    print('MAX RADIUS =', radpix[19], 'pix')
    
    #norm = ImageNormalize(image_data, interval=ZScaleInterval())         # Scale = ZScale
    norm = ImageNormalize(image_data, interval=PercentileInterval(99.5))  # Scale = 99.5%

    ax = plt.subplot(projection=wcs)
    for i in range(0,len(radpix)):
        ax.add_patch(plt.Circle((image_data.shape[0]/2,image_data.shape[0]/2), radpix[i] ,lw = 0.1, color='r', fill=False))
        ax.add_patch(plt.Circle((image_data.shape[0]/2,image_data.shape[0]/2), radpix[19]+ radpix[16] ,lw = 0.5, color='green', fill=False))
    plt.imshow(image_data, cmap='gray',  norm=norm)
    plt.xlabel('Galactic Longitude')
    plt.ylabel('Galactic Latitude')
    plt.colorbar()
    plt.savefig(str(obj_name +'_' + models +'_' + flag +'.png') , dpi=300)
    print('Min:', np.min(image_data))
    print('Max:', np.max(image_data))
    print('Mean:', np.mean(image_data))
    print('Stdev:', np.std(image_data))
    return(image_data, radpix, wcs, scale, waveleng,  aupixel )


def chi2_func(phot_Jy, fnu, wave_index, models,names,flag):
    chi2 = np.zeros(fnu.shape[0])
    for i in range(0,fnu.shape[0]):
        #chi2[i] = ((phot_Jy_MSX_A - fnu[i,:,wave_index])**2 *radpix**2).sum()
        chi2[i] = ((phot_Jy - fnu[i,:,wave_index])**2).sum()  #надо разделить на квадрат ошибки
    
    #idx = [chi2 == min(chi2)]
    #index = np.where(idx)[1]
    index = np.where(chi2 == np.nanmin(chi2))[0]
    #print('chi2 min =',np.nanmin(chi2))
    #print(index)
    plt.figure()
    plt.plot(ap , phot_Jy, 'r.') # MSX_A W/m^2-sr to Jy *7.133e12*scale**2/((180/np.pi)**2*3600**2)
    plt.plot(ap , fnu[int(index) ,:,wave_index], 'b.')
    plt.xlabel('AU')
    plt.ylabel('Jy')
    plt.title('%s  %s \n $\chi^2$ = %f %s \n MSX_%s'%(obj_name, models, chi2[index], names[index],flag))
    #plt.savefig(str('fit_'+ obj_name +'_' + models +'.png') , dpi=300)
    plt.show()
    return(chi2, wave_index)

for models in arr:
    rob_fitsfile = '/home/buddy/Documents/MIPT/Hyperion/%s/flux.fits'%(models)
    ap, wave, fnu, names = rob_fits(rob_fitsfile)
   
    image_data, radpix, wcs, scale, waveleng, aupixel  = user_fits(arr_userfile[0], ap, flag='A')
    phot_Jy_MSX_A = photometry_MSX_A(image_data, radpix, wcs)
    rob_wave, wave_index = find_nearest(wave, waveleng*1e6)
    print('User waveleng =', waveleng*1e6, 'um')
    print('Models waveleng =', rob_wave, 'um', '/index =',wave_index)
    chi2_A, waveA_index = chi2_func(phot_Jy_MSX_A, fnu, wave_index, models, names, flag ='A')
    
    image_data, radpix, wcs, scale, waveleng, aupixel  = user_fits(arr_userfile[1], ap, flag='C')
    phot_Jy_MSX_C = photometry_MSX_C(image_data, radpix, wcs)
    rob_wave, wave_index = find_nearest(wave, waveleng*1e6)
    print('User waveleng =', waveleng*1e6, 'um')
    print('Models waveleng =', rob_wave, 'um', '/index =',wave_index)
    chi2_C, waveC_index  = chi2_func(phot_Jy_MSX_C, fnu, wave_index, models, names, flag ='C')
    
    image_data, radpix, wcs, scale, waveleng, aupixel  = user_fits(arr_userfile[2], ap, flag='D')
    phot_Jy_MSX_D = photometry_MSX_D(image_data, radpix, wcs)
    rob_wave, wave_index = find_nearest(wave, waveleng*1e6)
    print('User waveleng =', waveleng*1e6, 'um')
    print('Models waveleng =', rob_wave, 'um', '/index =',wave_index)
    chi2_D, waveD_index  = chi2_func(phot_Jy_MSX_D, fnu, wave_index, models, names, flag ='D')
    
    image_data, radpix, wcs, scale, waveleng, aupixel  = user_fits(arr_userfile[3], ap, flag='E')
    phot_Jy_MSX_E = photometry_MSX_E(image_data, radpix, wcs)
    rob_wave, wave_index = find_nearest(wave, waveleng*1e6)
    print('User waveleng =', waveleng*1e6, 'um')
    print('Models waveleng =', rob_wave, 'um', '/index =',wave_index)
    chi2_E, waveE_index  = chi2_func(phot_Jy_MSX_E, fnu, wave_index, models, names, flag ='E')
        
 

        
    final_chi2 = chi2_A + chi2_C + chi2_D + chi2_E
            
    index = np.where(final_chi2 == np.nanmin(final_chi2))[0]
    print('final_chi2 min =',np.nanmin(final_chi2))
    print(index)
    plt.figure()
    plt.plot(ap , phot_Jy_MSX_A, 'r.') # MSX_A W/m^2-sr to Jy *7.133e12*scale**2/((180/np.pi)**2*3600**2)
    plt.plot(ap , phot_Jy_MSX_C, 'g.') 
    plt.plot(ap , phot_Jy_MSX_D, 'b.') 
    plt.plot(ap , phot_Jy_MSX_E, 'y.') 
    plt.plot(ap , fnu[int(index) ,:,waveA_index], 'r+')
    plt.plot(ap , fnu[int(index) ,:,waveC_index], 'g+')
    plt.plot(ap , fnu[int(index) ,:,waveD_index], 'b+')
    plt.plot(ap , fnu[int(index) ,:,waveE_index], 'y+')
    plt.xlabel('AU')
    plt.ylabel('Jy')
    plt.title('%s  %s \n $\chi^2$ = %0.4f %s \n)'%(obj_name, models, final_chi2[index], names[index]))
    plt.savefig(str('fit_'+ obj_name +'_'+'ACDE'+'_' + models +'.png') , dpi=300)
    plt.show()
    
    plt.figure()
    plt.plot(ap[0:10] , phot_Jy_MSX_A[0:10], 'r.') # MSX_A W/m^2-sr to Jy *7.133e12*scale**2/((180/np.pi)**2*3600**2)
    plt.plot(ap[0:10] , phot_Jy_MSX_C[0:10], 'g.') 
    plt.plot(ap[0:10] , phot_Jy_MSX_D[0:10], 'b.') 
    plt.plot(ap[0:10] , phot_Jy_MSX_E[0:10], 'y.') 
    plt.plot(ap[0:10] , fnu[int(index) ,:,waveA_index][0:10], 'r+')
    plt.plot(ap[0:10] , fnu[int(index) ,:,waveC_index][0:10], 'g+')
    plt.plot(ap[0:10] , fnu[int(index) ,:,waveD_index][0:10], 'b+')
    plt.plot(ap[0:10] , fnu[int(index) ,:,waveE_index][0:10], 'y+')
    plt.xlabel('AU')
    plt.ylabel('Jy')
    plt.title('%s  %s \n $\chi^2$ = %0.4f %s \n)'%(obj_name, models, final_chi2[index], names[index]))
    plt.savefig(str('fit_'+ obj_name +'_'+'ACDE'+'_' + models +'.png') , dpi=300)
    plt.show()
    
    plt.figure()
    plt.loglog(ap , phot_Jy_MSX_A, 'r.') # MSX_A W/m^2-sr to Jy *7.133e12*scale**2/((180/np.pi)**2*3600**2)
    plt.loglog(ap , phot_Jy_MSX_C, 'g.') 
    plt.loglog(ap , phot_Jy_MSX_D, 'b.') 
    plt.loglog(ap , phot_Jy_MSX_E, 'y.') 
    plt.loglog(ap , fnu[int(index) ,:,waveA_index], 'r+')
    plt.loglog(ap , fnu[int(index) ,:,waveC_index], 'g+')
    plt.loglog(ap , fnu[int(index) ,:,waveD_index], 'b+')
    plt.loglog(ap , fnu[int(index) ,:,waveE_index], 'y+')
    plt.xlabel('AU')
    plt.ylabel('Jy')
    plt.title('%s  %s \n $\chi^2$ = %0.4f %s \n)'%(obj_name, models, final_chi2[index], names[index]))
    plt.savefig(str('fit_'+ obj_name +'_'+'ACDE'+'_' + models +'.png') , dpi=300)
    plt.show()

        

#plt.figure()
#plt.plot(ap, phot_Jy_MSX_A, 'r.') # MSX_A W/m^2-sr to Jy *7.133e12*scale**2/((180/np.pi)**2*3600**2)
#plt.plot(ap, fnu[1000,:,wave_index], 'r.')
#plt.xlabel('AU')
#plt.ylabel('Jy')
#plt.show()
#print(np.log10(radpix/radpix[0]))