**Notebook to get DEM from AIA data using Hannah & Kontar implementation**

@Author: David Long
@Editor: Mohamed Nedal

In [44]:
import warnings
warnings.simplefilter('ignore')

from sys import path as sys_path
import os.path
import platform
import datetime as dt
from aiapy.calibrate.prep import correct_degradation
import numpy as np
import glob
import matplotlib
# matplotlib.use('Agg')
from matplotlib import pyplot as plt
import matplotlib.colors as colors
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits as fits
from sunpy.map import Map
from sunpy.net import Fido, attrs as a
from sunpy.coordinates import propagate_with_solar_surface
import scipy.io as io
sys_path.append('/home/mnedal/repos/demreg/python')
from dn2dem_pos import dn2dem_pos
script_path = os.path.abspath('./scripts')
if script_path not in sys_path:
    sys_path.append(script_path)
from general_routines import closest
from aiapy.calibrate import register, update_pointing, estimate_error
import aiapy.psf
import asdf
from bisect import bisect_left, bisect_right
from sunpy.time import parse_time

In [3]:
## Define constants and make the data directories
if platform.system() == 'Linux':
    data_disk = '/home/mnedal/data/AIA/'

os.makedirs(data_disk, exist_ok='True')

## Function with event information
start_time = '2024/05/14 17:00:00'
end_time   = '2024/05/14 19:00:00'

cadence = 10*u.second #seconds
img_time_range = [dt.datetime.strptime(start_time, "%Y/%m/%d %H:%M:%S"), dt.datetime.strptime(end_time, "%Y/%m/%d %H:%M:%S")]

ref_time = '2012/06/14 01:00:00'
bottom_left = [1381, 881]*u.pixel  
top_right   = [2215, 1881]*u.pixel  

ref_file_date = dt.datetime.strftime(dt.datetime.strptime(ref_time,'%Y/%m/%d %H:%M:%S'), '%Y/%m/%d')
img_file_date = dt.datetime.strftime(dt.datetime.strptime(ref_time,'%Y/%m/%d %H:%M:%S'), '%Y/%m/%d')

## Define and make the output directories
if platform.system() == 'Linux':
    output_dir = f'{data_disk}/DEM/{img_file_date}/'

os.makedirs(output_dir, exist_ok='True')
passband = [94, 131, 171, 193, 211, 335]

In [4]:
## Function to get the list of AIA filenames
def get_filelist(data_disk, passband, img_file_date):
    files = glob.glob(data_disk+str(passband).rjust(4, "0")+'/'+img_file_date+'/*.fits', recursive=True)
    files.sort()
    files_dt = []
    for file_i in files:
        hdr = fits.getheader(file_i, 1)
        try:
            files_dt.append(dt.datetime.strptime(hdr.get('DATE-OBS'), '%Y-%m-%dT%H:%M:%S.%fZ'))
        except:
            files_dt.append(dt.datetime.strptime(hdr.get('DATE-OBS'), '%Y-%m-%dT%H:%M:%S.%f'))
    return files, files_dt

In [5]:
## Function to download data
def get_data(start_time, end_time, img_file_date, cadence, pband, data_disk):
    ## Identify and download the data
    attrs_time = a.Time(start_time, end_time)
    wvlnth = a.Wavelength(int(pband)*u.Angstrom, int(pband)*u.Angstrom)
    result = Fido.search(attrs_time, a.Instrument('AIA'), wvlnth, a.Sample(cadence))
    files = Fido.fetch(result, path=data_disk+str(pband).rjust(4, '0')+'/'+img_file_date, overwrite=False, progress=True)

In [40]:
## Function to reduce filerange to within time range of interest
def reduce_filerange(files_in, file_time_in, img_time_range):
    left = bisect_left(file_time_in, img_time_range[0])
    right = bisect_right(file_time_in, img_time_range[1])
    files_out = files_in[left:right]
    file_time_out = file_time_in[left:right]
    return files_out, file_time_out

Get the data for the field of view reference

In [6]:
strt_time = dt.datetime.strptime(ref_time, "%Y/%m/%d %H:%M:%S")

files, files_dt = get_filelist(data_disk, 193, ref_file_date)
if not files:
    get_data(strt_time-dt.timedelta(seconds=10), strt_time+dt.timedelta(seconds=10), ref_file_date, 10*u.second, 193, data_disk)

files, files_dt = get_filelist(data_disk, 193, ref_file_date)

ind = np.abs([t - strt_time for t in files_dt])
map = Map(files[ind.argmin()])
    
pix_width = [(top_right[0]-bottom_left[0])/2, (top_right[1]-bottom_left[1])/2]
pix_centre = [pix_width[0]+bottom_left[0], pix_width[1]+bottom_left[1]]
crd_bl = SkyCoord(map.pixel_to_world(bottom_left[0], bottom_left[1]), frame=map.coordinate_frame)
crd_tr = SkyCoord(map.pixel_to_world(top_right[0], top_right[1]), frame=map.coordinate_frame)
    
crd_cent = SkyCoord(map.pixel_to_world(pix_centre[0], pix_centre[1]), frame=map.coordinate_frame)
crd_width = [(crd_tr.Tx.arcsecond-crd_bl.Tx.arcsecond)/2, (crd_tr.Ty.arcsecond-crd_bl.Ty.arcsecond)/2]

In [7]:
## Function to get AIA submap
def get_submap(time_array,index,img,f_0171,crd_cent,crd_width):
    ind_0171 = closest(np.array(time_array[2][:]), time_array[index][img])
    map = Map(f_0171[ind_0171])
    with propagate_with_solar_surface():
        diffrot_cent = crd_cent.transform_to(map.coordinate_frame)
    bl = SkyCoord((diffrot_cent.Tx.arcsecond-crd_width[0])*u.arcsec, (diffrot_cent.Ty.arcsecond-crd_width[1])*u.arcsec, frame=map.coordinate_frame)
    tr = SkyCoord((diffrot_cent.Tx.arcsecond+crd_width[0])*u.arcsec, (diffrot_cent.Ty.arcsecond+crd_width[1])*u.arcsec, frame=map.coordinate_frame)
    submap = map.submap(bl, top_right=tr)
    return submap

In [8]:
## Function to prep AIA images, deconvolve with PSF and produce submap
def prep_images(time_array, index, img, f_0094, f_0131, f_0171, f_0193, f_0211, f_0335, crd_cent, crd_width):
    ind_0094 = closest(np.array(time_array[0][:]), time_array[index][img])
    ind_0131 = closest(np.array(time_array[1][:]), time_array[index][img])
    ind_0171 = closest(np.array(time_array[2][:]), time_array[index][img])
    ind_0193 = closest(np.array(time_array[3][:]), time_array[index][img])
    ind_0211 = closest(np.array(time_array[4][:]), time_array[index][img])
    ind_0335 = closest(np.array(time_array[5][:]), time_array[index][img])

    farray = [f_0094[ind_0094], f_0131[ind_0131], f_0171[ind_0171], f_0193[ind_0193], f_0211[ind_0211], f_0335[ind_0335]]
    maps = Map(farray)
    with propagate_with_solar_surface():
        diffrot_cent = crd_cent.transform_to(maps[0].coordinate_frame)
    bl = SkyCoord((diffrot_cent.Tx.arcsecond-crd_width[0])*u.arcsec, (diffrot_cent.Ty.arcsecond-crd_width[1])*u.arcsec, frame=maps[0].coordinate_frame)
    bl_x, bl_y = maps[0].world_to_pixel(bl)
    tr = SkyCoord((diffrot_cent.Tx.arcsecond+crd_width[0])*u.arcsec, (diffrot_cent.Ty.arcsecond+crd_width[1])*u.arcsec, frame=maps[0].coordinate_frame)
    tr_x, tr_y = maps[0].world_to_pixel(tr)
    submap_0 = maps[0].submap([int(bl_x.value), int(bl_y.value)]*u.pixel, top_right=[int(tr_x.value), int(tr_y.value)]*u.pixel)
    nx, ny = submap_0.data.shape
    nf=len(maps)

    print('Prepping images & deconvolving with PSF')
    map_arr = []
    error_array = np.zeros([nx, ny, nf])

    for m in range(0, len(maps)):
        psf = aiapy.psf.psf(maps[m].wavelength)
        aia_map_deconvolved = aiapy.psf.deconvolve(maps[m], psf=psf)
        aia_map_updated_pointing = update_pointing(aia_map_deconvolved)
        aia_map_registered = register(aia_map_updated_pointing)
        aia_map_corrected = correct_degradation(aia_map_registered)
        aia_map_norm = aia_map_corrected/aia_map_corrected.exposure_time
        submap = aia_map_norm.submap([int(bl_x.value), int(bl_y.value)]*u.pixel, top_right=[int(tr_x.value), int(tr_y.value)]*u.pixel)
        map_arr.append(submap)
        num_pix = submap.data.size
        error_array[:,:,m] = estimate_error(submap.data*(u.ct/u.pix),submap.wavelength,num_pix)

    map_array = Map(map_arr[0], map_arr[1], map_arr[2], map_arr[3],
                    map_arr[4], map_arr[5], sequence=True, sortby=None) 
    print('Images prepped & region of interest selected')
    return map_array, error_array

In [9]:
## Function to calculate DEM
def calculate_dem(map_array, err_array):
    nx,ny = map_array[0].data.shape
    nf = len(map_array)
    image_array = np.zeros((nx, ny, nf))
    for img in range(0, nf):
        image_array[:,:,img] = map_array[img].data

    if platform.system() == 'Linux':
        trin = io.readsav('/disk/solar2/dml/idl/aia_tresp_en.dat')
        
    tresp_logt = np.array(trin['logt'])
    nt = len(tresp_logt)
    nf = len(trin['tr'][:])
    trmatrix = np.zeros((nt,nf))
    for i in range(0,nf):
        trmatrix[:,i] = trin['tr'][i]    
    
    t_space = 0.1
    t_min = 5.6
    t_max = 7.4
    logtemps = np.linspace(t_min, t_max, num=int((t_max-t_min)/t_space)+1)
    temps = 10**logtemps
    mlogt = ([np.mean([(np.log10(temps[i])), np.log10((temps[i+1]))]) for i in np.arange(0,len(temps)-1)])
    dem, edem, elogt, chisq, dn_reg = dn2dem_pos(image_array, err_array, trmatrix, tresp_logt, temps, max_iter=15)
    dem = dem.clip(min=0)
    return dem, edem, elogt, chisq, dn_reg, mlogt, logtemps

In [10]:
## Function to plot the DEM images
def plot_dem_images(submap, dem, logtemps, img_arr_tit):
    nt = len(dem[0,0,:])
    nt_new = int(nt/2)
    nc, nr = 3, 3
    plt.rcParams.update({'font.size':12, 'font.family':"sans-serif",\
                         'font.sans-serif':"Arial", 'mathtext.default':"regular"})
    fig, axes = plt.subplots(nrows=nr, ncols=nc, figsize=(10,12), sharex=True, sharey=True, subplot_kw=dict(projection=submap), layout='constrained')
    plt.suptitle('Image time = '+dt.datetime.strftime(submap.date.datetime, "%Y-%m-%dT%H:%M:%S"))
    fig.supxlabel('Solar X (arcsec)')
    fig.supylabel('Solar Y (arcsec)')
    cmap = plt.cm.get_cmap('cubehelix_r')

    for i, axi in enumerate(axes.flat):
        new_dem = (dem[:,:,i*2]+dem[:,:,i*2+1])/2.
        plotmap = Map(new_dem, submap.meta)
        plotmap.plot(axes=axi, norm=colors.LogNorm(vmin=1e19, vmax=1e22), cmap=cmap)
    
        y = axi.coords[1]
        y.set_axislabel(' ')
        if i == 1 or i == 2 or i == 4 or i == 5 or i == 7 or i == 8:
            y.set_ticklabel_visible(False)
        x = axi.coords[0]
        x.set_axislabel(' ')
        if i < 6:
            x.set_ticklabel_visible(False)

        axi.set_title(f'Log(T) = {logtemps[i*2]:.2f} - {logtemps[i*2+1+1]:.2f}')

    plt.tight_layout(pad=0.1, rect=[0, 0, 1, 0.98])
    plt.colorbar(ax=axes.ravel().tolist(), label='$\mathrm{DEM\;[cm^{-5}\;K^{-1}]}$', fraction=0.03, pad=0.02)
    plt.savefig(img_arr_tit, bbox_inches='tight')
    plt.close(fig)
    return

---

## START FROM HERE ...

---

In [None]:
## Download the data
for pband in passband:
    files, file_time = get_filelist(data_disk, pband, img_file_date)
    left = bisect_left(file_time, img_time_range[0])
    right = bisect_right(file_time, img_time_range[1])
    n_img = ((img_time_range[1]-img_time_range[0]).total_seconds()/(cadence/u.second))

    if n_img > len(files[left:right]):
        print('Fewer than expected FITS files for '+str(pband).rjust(4, "0")+' passband')
        print('Downloading data for '+str(pband).rjust(4, "0")+' passband')
        get_data(start_time, end_time, img_file_date, cadence, pband, data_disk)
    else:
        print('Data already downloaded for '+str(pband).rjust(4, "0")+' passband')

In [None]:
## Get list of files from each passband to identify the smallest number of files
print('Getting list of files')
f_0094, time_0094 = get_filelist(data_disk, 94, img_file_date)
f_0094, time_0094 = reduce_filerange(f_0094, time_0094, img_time_range)

f_0131, time_0131 = get_filelist(data_disk, 131, img_file_date)
f_0131, time_0131 = reduce_filerange(f_0131, time_0131, img_time_range)

f_0171, time_0171 = get_filelist(data_disk, 171, img_file_date)
f_0171, time_0171 = reduce_filerange(f_0171, time_0171, img_time_range)

f_0193, time_0193 = get_filelist(data_disk, 193, img_file_date)
f_0193, time_0193 = reduce_filerange(f_0193, time_0193, img_time_range)

f_0211, time_0211 = get_filelist(data_disk, 211, img_file_date)
f_0211, time_0211 = reduce_filerange(f_0211, time_0211, img_time_range)

f_0335, time_0335 = get_filelist(data_disk, 335, img_file_date)
f_0335, time_0335 = reduce_filerange(f_0335, time_0335, img_time_range)

flength = [len(f_0094), len(f_0131), len(f_0171), len(f_0193), len(f_0211), len(f_0335)]
flist = [f_0094, f_0131, f_0171, f_0193, f_0211, f_0335]
time_array = [time_0094, time_0131, time_0171, time_0193, time_0211, time_0335]
index = np.argmin(flength)

In [None]:
# Begin image processing
start_img = closest(np.array(time_array[index][:]), dt.datetime.strptime(start_time, "%Y/%m/%d %H:%M:%S"))

for img in range(start_img, len(flist[index])):
    print('Processing image, time = '+dt.datetime.strftime(time_array[index][img], "%Y-%m-%dT%H:%M:%S"))

    # Get and process images.
    err_arr_tit = output_dir+'error_data_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'.asdf'
    map_arr_tit = output_dir+'prepped_data_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'_{index:03}.fits'
    files = os.path.exists(err_arr_tit)
    
    if files == False:
        map_array, err_array = prep_images(time_array, index, img, f_0094, f_0131, f_0171, f_0193, f_0211, f_0335, crd_cent, crd_width)
        map_array.save(map_arr_tit, overwrite='True')
        tree = {'err_array': err_array}
        with asdf.AsdfFile(tree) as asdf_file:
            asdf_file.write_to(err_arr_tit, all_array_compression='zlib')
    else:
        print('Loading previously prepped images')
        arrs = asdf.open(err_arr_tit)
        err_array = arrs['err_array']
        ffin=sorted(glob.glob(output_dir+'prepped_data_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'*.fits'))
        map_array = Map(ffin)
    
    # Calculate DEMs
    dem_arr_tit = output_dir+'dem_data_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'.asdf'
    files = os.path.exists(dem_arr_tit)
    
    if files == False:
        print('Calculating DEM')
        dem, edem, elogt, chisq, dn_reg, mlogt, logtemps = calculate_dem(map_array, err_array)
        tree = {'dem':dem, 'edem':edem, 'mlogt':mlogt, 'elogt':elogt, 'chisq':chisq, 'logtemps':logtemps}
        with asdf.AsdfFile(tree) as asdf_file:  
            asdf_file.write_to(dem_arr_tit, all_array_compression='zlib')
    else:
        print('Loading previously calculated DEM')
        arrs = asdf.open(dem_arr_tit)  
        dem = arrs['dem']
        edem = arrs['edem']
        mlogt = arrs['mlogt']
        elogt = arrs['elogt']
        chisq = arrs['chisq']
        logtemps = arrs['logtemps']

    # Plot results
#    img_tit = output_dir+'Centre_pixel_DEM_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'.png'
#    plot_dem(dem,edem,mlogt,elogt,img_tit)

    # Get a submap to have the scales and image properties.
    submap = get_submap(time_array, index, img, f_0171, crd_cent, crd_width)
    img_arr_tit = output_dir+'DEM_images_'+dt.datetime.strftime(time_array[index][img], "%Y%m%d_%H%M%S")+'.png'
    plot = plot_dem_images(submap, dem, logtemps, img_arr_tit)
    print('DEM plotted')