In [None]:
import os
import numpy as np
import MilneEddington as ME
import crisp
import time
import warnings
# Suppress the specific warning
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for <class 'numpy.float32'> type is zero")
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for <class 'numpy.float64'> type is zero")
import inv_utils as iu
import me_utils as meu
from helita.io import lp

from hmi_plot import plot_hmi_ic_mag, plot_sst_pointings

In [None]:
import importlib
importlib.reload(iu)
importlib.reload(meu)
print('reloaded') 

In [None]:
# Load the configuration from the JSON file
input_config = iu.load_yaml_config('input_config.yaml')

In [None]:
# Check the input configuration
config = iu.check_input_config(input_config, pprint=True, confirm=False)

In [None]:

# Extract the input parameters
data_dir = config['data_dir']
save_dir = config['save_dir']
crisp_im = config['crisp_im']
xorg = config['xorg']
xsize = config['xsize']
yorg = config['yorg']
ysize = config['ysize']
xrange = config['xrange']
yrange = config['yrange']
tt = config['time_index']
scale = config['scale']
is_north_up = config['is_north_up']
crop = config['crop']
shape = config['shape']
best_frame = config['best_frame']
contrasts = config['contrasts']
hmi_con_series = config['hmi_con_series']
hmi_mag_series = config['hmi_mag_series']
email = config['email']
fov_angle = config['fov_angle']
plot_sst_pointings_flag = config['plot_sst_pointings_flag']
plot_hmi_ic_mag_flag = config['plot_hmi_ic_mag_flag']
plot_crisp_image_flag = config['plot_crisp_image_flag']

In [None]:
# Extract the fits information from the header
fits_info = config['fits_info']
nx = fits_info['nx']
ny = fits_info['ny']
mu = fits_info['mu']
x1 = fits_info['hplnt'][tt][0]
x2 = fits_info['hplnt'][tt][1]
y1 = fits_info['hpltt'][tt][0]
y2 = fits_info['hpltt'][tt][1]
tobs = fits_info['all_start_times'][tt]
tstart = fits_info['start_time_obs']
tend = fits_info['end_time_obs']
hplnt = fits_info['hplnt']
hpltt = fits_info['hpltt']

In [None]:
# Reset the x and y ranges if cropping is enabled
if crop:
    x_list = np.linspace(x1, x2, num=nx)
    y_list = np.linspace(y1, y2, num=ny)
    x_list = x_list[xrange[0]:xrange[1]]
    y_list = y_list[yrange[0]:yrange[1]]
    x1 = x_list[0]
    x2 = x_list[-1]
    y1 = y_list[0]
    y2 = y_list[-1]
    nx = xsize
    ny = ysize

In [None]:
if plot_sst_pointings_flag:
    plot_sst_pointings(tstart, hmi_con_series, hplnt, hpltt,figsize=(6, 6), email=email, save_dir=save_dir)

In [None]:
if plot_hmi_ic_mag_flag:
    plot_hmi_ic_mag(tobs, hmi_con_series, hmi_mag_series, email, x1, x2, y1, y2, save_dir=save_dir, figsize=(10, 5),  is_north_up=is_north_up, fov_angle=fov_angle, shape=shape)

In [None]:
if plot_crisp_image_flag:
    print('SST CRISP image with North up:', not(is_north_up))
    iu.plot_crisp_image(crisp_im, tt=tt, ss=0, ww=0, figsize=(6,6), fontsize=10, rot_fov=fov_angle, north_up=not(is_north_up), crop=crop, xrange=xrange, yrange=yrange, xtick_range=[x1,x2], ytick_range=[y1,y2])

In [None]:
inversion_config = iu.load_yaml_config('inversion_config.yaml')
# Load the variables from the inversion configuration
dtype = inversion_config['dtype']
nthreads = inversion_config['nthreads']
sigma_strength= inversion_config['sigma_strength']
sigma_list = inversion_config['sigma_list']
erh = inversion_config['erh']
init_model_params = inversion_config['init_model_params']
nRandom1 = inversion_config['nRandom1']
nIter1 = inversion_config['nIter1']
chi2_thres1 = inversion_config['chi2_thres1']
median_filter_chi2_mean_thres = inversion_config['median_filter_chi2_mean_thres']
median_filter_size = inversion_config['median_filter_size']
nRandom2 = inversion_config['nRandom2']
nIter2 = inversion_config['nIter2']
chi2_thres2 = inversion_config['chi2_thres2']
nIter3 = inversion_config['nIter3']
chi2_thres3 = inversion_config['chi2_thres3']
alpha_strength = inversion_config['alpha_strength']
alpha_list = inversion_config['alpha_list']
nan_mask_replacements = inversion_config['nan_mask_replacements']
verbose = inversion_config['verbose']

In [None]:
ll = meu.load_crisp_frame(crisp_im, tt, crop=crop, xrange=xrange, yrange=yrange)

In [None]:
obs, sig, l0, me = meu.init_me_model(ll, sigma_strength, sigma_list, erh=erh, dtype=dtype, nthreads=nthreads)

In [None]:
Imodel = meu.init_model(me, ny, nx, init_model_params=init_model_params, dtype=dtype)

In [None]:
Imodel, syn, chi2 = meu.run_randomised_me_inversion(Imodel, me, obs, sig, nRandom=nRandom1, nIter=nIter1, chi2_thres=chi2_thres1, mu=mu, verbose=verbose)
masked_chi2_mean = iu.masked_mean(chi2, ll.mask)
if verbose:
    print(f'Masked chi2 mean: {masked_chi2_mean:.2f}')
    iu.plot_inversion_output(Imodel, ll.mask, scale=scale, save_fig=False)
    iu.plot_mag(Imodel, ll.mask, scale=scale, save_fig=False)

In [None]:
importlib.reload(meu)

In [None]:
Imodel = meu.apply_median_filter_based_on_chi2(Imodel, masked_chi2_mean, median_filter_chi2_mean_thres, median_filter_size)
if verbose:    
    iu.plot_inversion_output(Imodel,ll.mask,scale=scale, save_fig=False)
    iu.plot_mag(Imodel,ll.mask,scale=scale, save_fig=False)

In [None]:
Imodel, syn, chi2 = meu.run_randomised_me_inversion(Imodel, me, obs, sig, nRandom=nRandom2, nIter=nIter2, chi2_thres=chi2_thres2, mu=mu, verbose=verbose)
masked_chi2_mean = iu.masked_mean(chi2, ll.mask)
if verbose:
    print(f'Masked chi2 mean: {masked_chi2_mean:.2f}')
    iu.plot_inversion_output(Imodel, ll.mask, scale=scale, save_fig=False)
    iu.plot_mag(Imodel, ll.mask, scale=scale, save_fig=False)

In [None]:
mo, syn, chi2 = meu.run_spatially_regularized_inversion(me, Imodel, obs, sig, nIter3, chi2_thres3, mu, alpha_strength, alpha_list, method=1, delay_bracket=3, dtype=dtype,verbose=True)


In [None]:
errors = me.estimate_uncertainties(np.squeeze(mo), obs, sig, mu=mu)

In [None]:
corrected_mo = meu.correct_velocities_for_cavity_error(mo, ll.cmap, l0, global_offset=0.0)

In [None]:
if verbose:
    print(f'Masked chi2 mean: {masked_chi2_mean:.2f}')
    iu.plot_inversion_output(corrected_mo,ll.mask,scale=scale, save_fig=False)
    iu.plot_mag(corrected_mo,ll.mask,scale=scale, save_fig=False)

---

In [None]:
iu.plot_sst_blos_bhor(blos_cube, bhor_cube, tt=tt,xrange=xrange, yrange=yrange, figsize=(20,10), fontsize=12, crop=crop)

In [None]:
importlib.reload(iu)

In [None]:
from einops import rearrange
mos_im = rearrange(mos, 'ny nx nparams -> nparams ny nx')
errors_im = rearrange(errors, 'ny nx nparams -> nparams ny nx')

In [None]:
for i in range(9):
    iu.masked_stats(mos[:,:,i], mask)

In [None]:
inversion_mask_replacements = [0, 0, 0, 0, 0, 0, 0, 0, 0] # Blos, inc, azi, v_los, v_dop, line op, damping, s0, s1

In [None]:
masked_mos = np.zeros_like(mos)
for i in range(9):
    masked_mos[:,:,i] = iu.masked_data(mos[:,:,i], mask, replace_val=inversion_mask_replacements[i])


In [None]:
iu.plot_inversion_output(masked_mos,scale=scale, save_fig=False)

In [None]:
masked_errors = np.zeros_like(errors)
for i in range(9):
    masked_errors[:,:,i] = iu.masked_data(errors[:,:,i], mask, replace_val=inversion_mask_replacements[i], fix_inf=True)
iu.plot_inversion_output(masked_errors,scale=scale, save_fig=False)

In [None]:
for i in range(9):
    iu.masked_stats(errors[:,:,i], mask)

In [None]:
b_err = iu.masked_data(errors[:,:,0], mask)
print(np.nanmean(b_err))
print(np.nanmin(b_err))
print(np.nanmax(b_err))

In [None]:
importlib.reload(iu)
iu.plot_image(masked_errors[:,:,1], scale=scale, title='B_tot (G)', save_fig=False, clip=True, vmax=1, vmin=0)

In [None]:
# apply masked_data to all components or errors and save as masked_errors
masked_errors = np.zeros((ny, nx, 9), dtype=dtype)
for i in range(9):
    masked_errors[:,:,i] = iu.masked_data(errors[:,:,i], mask)

In [None]:
importlib.reload(iu)
minc = iu.masked_data(errors[:,:,1], mask, replace_val=0)
print(np.min(minc))
print(np.max(minc))
print(np.median(minc))

In [None]:
iu.plot_inversion_output(masked_errors, mask, scale=scale, save_fig=False)

In [None]:
iu.plot_output(mos,mask,scale=scale)
iu.plot_mag(Imodel,mask,scale=scale, save_fig=False)

In [None]:
## save the results as fits files with the same header as the input data
iu.save_fits(mos, fits_header, 'temp/inv_mos.fits', overwrite=True)

In [None]:
ff = iu.load_fits_data('temp/inv_mos.fits')

In [None]:
hh = iu.load_fits_header('temp/inv_mos.fits')

#### Things to complete
- [x] Move all the inputs to a dictionary and later save them in the header of the output file. Also add the best seeing frame number.
- [x] Move the preprocessing steps like plotting and FOV details as an optional but default true step
- [x] Plot a rectangle to show cropping region is true
- [ ] Save fits with [blos, theta, phi, vlos + errors + mask] for each frame (temporarily) and later combine for final fits
- [ ] Check for option to convert to fcube and icube formats using ispy or helita tools
- [ ] Add option to do only one frame separately if user wants.
- [ ] Add fov angle and other inputs needed for ambiguity resolution and remap in header

#### To do for final cube
- [x] Pick the best seeing frame from the dataset
- [ ] Run the full inversion for the best seeing frame
- [ ] Use this output as an initial guess for the other frames
