In [None]:
import os
import json

import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt

from scipy.interpolate import splev,splrep,interp1d

from sunkit_dem import Model
from sunkit_dem.models import hk12

from sunkit_dem.util import quantity_1d_to_sequence

from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.visualization import ImageNormalize, LogStretch
from sunpy.map import Map
from sunpy.net import Fido,attrs

import ndcube

%matplotlib inline

In [None]:
# let's just check ndcube is version 2.0.0
ndcube.__version__

## Get AIA Data
We will get `[94, 131, 171, 193, 211, 304, 335]` and drop the `304` channel

In [None]:
## Get AI
q = Fido.search(
    attrs.Time('2012/02/11T15:00:00', end='2012/02/11T15:00:10'),
    attrs.Instrument('AIA'),
    attrs.Wavelength(94*u.angstrom, 335*u.angstrom),
)

In [None]:
files = Fido.fetch(q, overwrite=True)

In [None]:
maps = [Map(f) for f in files]
maps = sorted(maps, key=lambda x: x.wavelength)

In [None]:
maps.pop(-2)

## Need to do AIA prep using `aiapy`

In [None]:
# For now, just divide by exposure time

maps = [Map(m.data/m.exposure_time.value, m.meta) for m in maps]

In [None]:
# currently resample the data which may be the result of not doing aia_prep?

maps = [
    m.submap(
        bottom_left=SkyCoord(75 * u.arcsec, -75 * u.arcsec, frame=m.coordinate_frame),
        top_right=SkyCoord(-175 * u.arcsec, -325 * u.arcsec, frame=m.coordinate_frame)
    ).resample(
        [50, 50] * u.pixel
    )
        for m in maps]

In [None]:
[m.data.shape for m in maps]

In [None]:
fig = plt.figure()
ax = fig.gca(projection=maps[2])
maps[2].plot(axes=ax)

In [None]:
cubes = []
for m in maps:
    # Add wavelength to WCS
    wcs = m.wcs.to_header()
    wcs['CTYPE3'] = 'WAVE'
    wcs['CUNIT3'] = u.angstrom.to_string()
    wcs['CDELT3'] = 1
    wcs['CRPIX3'] = 1
    wcs['CRVAL3'] = m.wavelength.to(u.angstrom).value
    wcs['NAXIS3'] = 1
    # Update naxis args
    wcs['NAXIS1'] = m.data.shape[1]
    wcs['NAXIS2'] = m.data.shape[0]

    # Add dimension to data
    data = u.Quantity(m.data[np.newaxis,:,:], 'ct / pixel / s')
    # Create cube
    cb = ndcube.NDCube(
            data, WCS(wcs), meta=m.meta, uncertainty=0.2*data.value
    )    
    cubes.append(cb)
    

seq = ndcube.NDCubeSequence(cubes, common_axis=0)

## Generate the response functions
There definitely needs to be a better Python way for this.

In [None]:
# response functions

temperature_bin_edges = 10**np.arange(5.5,7,0.1)*u.K
temperature_bin_centers = 10**((np.log10(temperature_bin_edges.value)[1:] 
                                + np.log10(temperature_bin_edges.value)[:-1])/2) * u.K

channels = [int(m.wavelength.to(u.angstrom).value) for m in maps]

with open('/Users/pwright/Documents/personal/sunkit-dem/sdo_aia.json') as json_file:
    response_data = json.load(json_file)
    
response = {}
for c in channels:
    nots = splrep(response_data[f'{c}']['temperature_response_x'],
                  response_data[f'{c}']['temperature_response_y'])
    response[c] = u.Quantity(splev(temperature_bin_centers.value, nots),
                             response_data[f'{c}']['temperature_response_y_units'],)

In [None]:
temperature_bin_edges

## Looking at an NDCube

In [None]:
seq[0].data.shape

In [None]:
seq[0].uncertainty

In [None]:
seq[0].wcs, seq[0].extra_coords, seq[0].unit

## NDSequence

In [None]:
seq.dimensions

In [None]:
seq.array_axis_physical_types

In [None]:
seq.cube_like_dimensions

## Define and fit the model

In [None]:
hk12 = Model(seq, response, temperature_bin_edges, model='hk12')

In [None]:
dem_2d = hk12.fit()

## NDCollection

In [None]:
dem_2d.keys()

In [None]:
import sunpy

In [None]:
my_map = sunpy.map.Map(dem_2d['dem'].data[:,:,0], maps[2].meta)

In [None]:
fig = plt.figure(figsize=(20,20))
for i,c in enumerate(dem_2d['dem']):
    
    mp = sunpy.map.Map(dem_2d['dem'].data[i,:,:], maps[2].meta)

    ax = fig.add_subplot(4,4,i+1, projection=mp)
        
    mp.plot(
        norm=ImageNormalize(vmin=1e20,vmax=1e22,stretch=LogStretch()),
        cmap='inferno',
    )

    ax.set_title(f"$\log{{T}} = {dem_2d['dem'].axis_world_coords(0)[0][i]:.2f}$")


In [None]:
pix_lat,pix_lon = maps[2].world_to_pixel(
    SkyCoord(Tx=40*u.arcsec,
             Ty=-180*u.arcsec,
             frame=maps[2].coordinate_frame))
pix_lat = int(np.ceil(pix_lat.value))
pix_lon = int(np.ceil(pix_lon.value))

In [None]:
plt.plot(dem_2d['dem'].data[:,pix_lat,pix_lon])

plt.yscale('log')
plt.ylim(1e19,1e21)

In [None]:
dem_2d.keys()

In [None]:
dem_2d['dem'][:,pix_lat,pix_lon].plot()
plt.yscale('log')
plt.ylim(1e19,1e21)

In [None]:
dem_2d['dem'][:,pix_lat,pix_lon].uncertainty