# Welcome to 511 keV imaging with COSIpy-classic
In this notebook, we'll use a Richardson-Lucy deconvolution algorithm to image 511 keV emission from the center of the Milky Way Galaxy. This analysis requires significant computer memory (>50 GB), so you may want to use a more resource-intensive computer for this work. Please refer to the README for additional information on each step of the analysis.

## Import packages
We will need to import the cosipy-classic functions from COSIpy_dc1.py, response_dc1, and COSIpy_tools_dc1, as well as some other standard Python packages [Note: we need recent versions of Numba and Jaxopt for this version of the notebook - JDB]

In [None]:
from COSIpy_dc1 import *
import response_dc1
from COSIpy_tools_dc1 import *
from numba import jit, njit, prange, set_num_threads
from tqdm.autonotebook import tqdm

import pickle

# set parallelism for whole notebook
set_num_threads(8)


### For the modified RL algorithm implemented here, we need to define a jaxopt objective function that fits background plus two images (the current image plus a delta image given by the RL formalism)


In [None]:
import jax.config
import jax.numpy as jnp
import jax.scipy.stats as jstats
import jaxopt

# to better match Stan's behavior
jax.config.update("jax_enable_x64", True)

# objective function for MLE
def objective(params, data):
        (Abg, flux) = params
        (conv_sky, bg_model, bg_idx_arr, y, mu_flux, sigma_flux, mu_Abg, sigma_Abg) = data
 
        M = Abg[bg_idx_arr[:,None]] * bg_model + jnp.sum(flux[:,None,None] * conv_sky, axis=0)

        # ensure that we don't accidentally use negative Possion means, which blows up likelihood
        M = jnp.maximum(M, 0)
        
        lp = jnp.sum(jstats.poisson.logpmf(y, M), axis=None) + \
             jnp.sum(jstats.norm.logpdf(flux, mu_flux, sigma_flux)) + \
             jnp.sum(jstats.norm.logpdf(Abg, mu_Abg, sigma_Abg))
        
        return -lp  # minimize to maximize LL

#opt = { 'disp': True }
optimizer = jaxopt.ScipyBoundedMinimize(fun=objective, method="l-bfgs-b", tol=1e-10)#, options=opt)

## Define file names
This file contains the 10X flux 511 keV simulation and Ling BG. 

You can optionally image only 511 keV (without background) by changing this file to the 511 keV-only simulation. You will have to adjust the RL algorithm parameters later in the notebook.

In [None]:
data_dir = '../data_products' # directory containing data & response files
filename = 'GC511_10xFlux_and_Ling.inc1.id1.extracted.tra.gz'# 511 keV with Ling BG
response_filename = data_dir + '/511keV_imaging_response.npz' # detector response
background_filename = data_dir + '/Scaled_Ling_BG_1x.npz' # background response
background_mode = 'from file'

pklfname = "511keV.pkl"

## Read simulation and define analysis object
Read in the data set and create the main cosipy-classic “analysis1" object, which provides various functionalities to study the specified file. This cell usually takes a few minutes to run.

In [None]:
try:
    analysis1 = pickle.load(open(pklfname,'rb'))
    
except:
    print("loading analysis dataset")
    analysis1 = COSIpy(data_dir, filename)
    analysis1.read_COSI_DataSet()
    with open(pklfname, 'wb') as f:
        pickle.dump(analysis1, f)
        

# Bin the data
Calling "get_binned_data()" may take several minutes, depending on the size of the dataset and the number of bins. Keep an eye on memory here: if your time bins are very small, for example, this could be an expensive operation.

As currently written, "get_binned_data()" uses about **4 GB memory**. [But this was reduced by at least half by storing the bin counts as ints, and by replacing lists with arrays - JDB.]

In [None]:
# Define the bin sizes
Delta_T = 1800 # time bin size in seconds
energy_bin_edges = np.array([501, 521]) # as defined in the response
pixel_size = 6. # as defined in the response

analysis1.dataset.time_binning_tags_fast(time_bin_size=Delta_T)
analysis1.dataset.init_binning(energy_bin_edges=energy_bin_edges, pixel_size=pixel_size) # initiate the binning
analysis1.dataset.get_binned_data_fast() # bin data

## Examine the shape of the binned data.
The binned data are contained in "analysis1.dataset.binned_data." This is a 4-dimensional object representing the 5 dimensions of the Compton data space: (time, energy, $\phi$, FISBEL).

The number of bins in each dimension are shown by calling "shape."

In [None]:
print("time, energy, phi, fisbel")
print(analysis1.dataset.binned_data.shape)

In [None]:
# Can print the width of each time bin and the total time
print(analysis1.dataset.times.times_wid)
print(analysis1.dataset.times.total_time)

## Plot raw spectrum & light curve
For a single energy bin, the spectrum is necessarily a top hat in the sole non-zero bin.

In [None]:
analysis1.dataset.plot_lightcurve()

analysis1.dataset.plot_raw_spectrum()
plt.xscale('log')

# Define the pointing object with the cosipy pointing class.
This may also take several minutes to run.  [Down to a few seconds after acceleration, even without parallelism - JDB]

In [None]:
# definition of pointings (balloon stability + Earth rotation)
pointing1 = Pointing(dataset=analysis1.dataset,)

# Define the BG model

In [None]:
# Ling BG simulation to model atmospheric background
background1 = BG(dataset=analysis1.dataset,mode=background_mode,filename=background_filename)

# Read in the Response Matrix

In [None]:
# 511 keV response
rsp = response_dc1.SkyResponse(filename=response_filename,pixel_size=pixel_size) # read in detector response

## Exploring the shape of the data space
The shape of the response spans (Galactic latitude $b$, Galactic longitude $\ell$, Compton scattering angle $\phi$,  FISBEL, energy). There is 1 energy bin for the 511 keV response ("analysis1.dataset.energies.n_energy_bins"). This is why there is no fifth dimension for the energy printed below. The shape of the data and background objects span (time, Compton scattering angle, FISBEL).

In [None]:
rsp.rsp.response_grid_normed_efinal.shape

In [None]:
np.shape(analysis1.dataset.binned_data)

In [None]:
np.shape(background1.bg_model)

# Imaging Setup

## Define a grid on the sky to make images

In [None]:
# Convenient variable for deg --> radian conversion
deg2rad = np.pi/180.

# We define our sky-grid on a regular (pixel_size x pixel_size) grid for testing (later finer grid)
binsize = pixel_size

# Number of pixels in l and b
n_l = int(360/binsize)
n_b = int(180/binsize)

# Galactic coordiantes: l and b pixel edges
l_arrg = np.linspace(-180, 180, n_l+1)
b_arrg = np.linspace(-90, 90, n_b+1)

# Making a grid
L_ARRg, B_ARRg = np.meshgrid(l_arrg, b_arrg)

# Choosing the centre points as representative
l_arr = l_arrg[0:-1] + binsize/2
b_arr = b_arrg[0:-1] + binsize/2
L_ARR, B_ARR = np.meshgrid(l_arr, b_arr)

# Define solid angle for each pixel for normalisations later
domega = (binsize * deg2rad * np.diff(np.sin(np.deg2rad(b_arrg)))).repeat(n_l).reshape(n_b, n_l)

## Convert sky grid to zenith/azimuth pairs for all pointings

In [None]:
# calculate the zeniths and azimuths on that grid for all times
zensgrid,azisgrid = zenaziGrid_fast(pointing1.ypoins[:,0], pointing1.ypoins[:,1],
                                    pointing1.xpoins[:,0], pointing1.xpoins[:,1],
                                    pointing1.zpoins[:,0], pointing1.zpoins[:,1],
                                    L_ARR.ravel(), B_ARR.ravel())

In [None]:
# Reshape for next routines ... 
zensgrid = zensgrid.reshape(n_b, n_l, len(pointing1.xpoins))
azisgrid = azisgrid.reshape(n_b, n_l, len(pointing1.xpoins))

## Get observation indices for non-zero bins

In [None]:
# Choose an energy bin to analyze
ebin = 0 # We only have one energy bin (501-521 keV), so the index is necessarily 0.
nonzero_idx = background1.calc_this[ebin]


## Reduce the response dimensions

In [None]:
sky_response_CDS = rsp.rsp.response_grid_normed_efinal.reshape(
    n_b,
    n_l,
    analysis1.dataset.phis.n_phi_bins*\
    analysis1.dataset.fisbels.n_fisbel_bins, analysis1.dataset.energies.n_energy_bins)[:, :, nonzero_idx, ebin]

In [None]:
# reduced response dimensions:
# lat x lon x CDS
sky_response_CDS.shape

## Function to get the response of an image for arbitrary time binning

In [None]:
# Input assumptions:
# - cdtpoins is sorted in nondecreasing order, as it results from a cumsum() call in the dataset construction code.
# - tmin and tmax are lower and upper endpoints of time bins for Pointing data. Hence, tmax[i] = tmin[i+1].
# - n_ph_dx gives the indices of the time bins with > 0 data items to be processed in the loop.  There are n_hours such nonempty bins.

@njit(fastmath=True,parallel=True,nogil=True)
def gir_loop(cdtpoins, tmin, n_ph_dx, n_hours, n_lat, n_lon, Response, weights, zens, azis):

    image_response = np.empty((n_hours,n_lat,n_lon,Response.shape[2]), dtype=np.float32)
 
    # Elts associated with jth time bin have values > tmin[j] and <= tmax[j].
    # But there are no elts with value > tmax[j] and < tmin[j+1].
    # Hence, indices for jth bin are bmins[j] .. bmins[j+1] - 1, inclusive.
    bmins = np.searchsorted(cdtpoins, tmin[n_ph_dx], 'right') # least i with value > tmin

    # Add end of cdtpoins array as sentinel to close last bin range.
    bmins = np.append(bmins, cdtpoins.size)
    
    for c in prange(n_hours):
        
        acc = np.empty(Response.shape[2], dtype=np.float64)

        for LAT in range(n_lat):
            for LON in range(n_lon):
                acc[:] = 0.
                for v in range(bmins[c], bmins[c+1]):
                    acc += Response[zens[LAT,LON,v], azis[LAT,LON,v],:] * weights[LAT,LON,v] # accumulate in 64 bits
                image_response[c,LAT,LON,:] = acc

    return image_response


@jit
def get_image_response_from_pixelhit_general(Response,zenith,azimuth,domega,dt,n_hours,binsize=6,cut=90,altitude_correction=False,al=None):
    """
    Get Compton response from hit pixel for each zenith/azimuth vector(!) input.
    Binsize determines regular(!!!) sky coordinate grid in degrees.

    :param: zenith        Zenith positions of all points of predefined sky grid with
                          respect to the instrument (in deg)
    :param: azimuth       Azimuth positions of all points of predefined sky grid with
                          respect to the instrument (in deg)
    :param: domega        Latitude weighting of pixels on sky grid
    :option: binsize      Default 6 deg (matching the sky dimension of the response). If set
                          differently, make sure it matches the sky dimension as otherwise,
                          false results may be returned
    :option: cut          Threshold to cut the response calculation after a certain zenith angle.
                          Default 90
    :param: n_hours       Number of hours in cdxervation
    :option: altitude_correction Default False: use interpolated transmission probability, normalised to 33 km and 500 keV,
                          to modify number of expected photons as a function of altitude and zenith angle of cdxervation
    :option: al           Altitude values according to dt from construct_pointings(); used of altitude_correction is set to True
    """

    # assuming useful input:
    # azimuthal angle is periodic in the range [0,360[
    # zenith ranges from [0,180[

    # and which pixel centre
    #hit_pixel_z = (hit_pixel_zi + 0.5) * binsize
    
    zens = np.floor(zenith/binsize).astype(np.int32)
    azis = np.floor(azimuth/binsize).astype(np.int32)
        
    nz = zenith.shape[2]

    n_l = int(360/binsize)
    n_b = int(180/binsize)
    
    # take care of regular grid by applying weighting with latitude 
    weights = domega.repeat(nz).reshape(n_b, n_l, nz) * dt
    
    # remove zeniths for which the pixel center is above the threshold
    weights[zens > cut/binsize - 0.5] = 0.

    # check for negative indices and remove
    # NB: za_idx contains zeniths and azimuths, which are computed by zenazigrid.  The
    # zeniths are outputs of arccos() and so lie in the range 0..2pi (before conversion
    # # to degrees and pixel IDs, which cannot change the sign).  The azimuths are
    # explicitly converted to be non-negative.  So this code is a no-op. - JDB
    #weights[zens < 0] = 0.
    #weights[azis < 0] = 0.
    #zens[zens < 0] = 0
    #azis[azis < 0] = 0
            
    #if altitude_correction == True:
    #    altitude_response = return_altitude_response()
    #else:
    #    altitude_response = one_func

    # get responses at pixels    
    return gir_loop(pointing1.cdtpoins, \
                    analysis1.dataset.times.times_min, analysis1.dataset.times.n_ph_dx, \
                          n_hours, n_b, n_l, Response, weights, zens, azis)
    
    ## Original code replaced by above loop call
    
    #image_response = np.zeros((n_hours,n_lat,n_lon,Response.shape[2]))

    #for c in tqdm(range(n_hours)):
    #    cdx = np.where((pointing1.cdtpoins > analysis1.dataset.times.times_min[analysis1.dataset.times.n_ph_dx[c]]) &
    #                   (pointing1.cdtpoins <= analysis1.dataset.times.times_max[analysis1.dataset.times.n_ph_dx[c]]))[0]
    # 
    #    # this calculation is basically a look-up of the response entries. In general, weighting (integration) with the true shape can be introduced, however with a lot more computation time (Simpson's rule in 2D ...)
    #    image_response[c,:,:,:] += np.sum(Response[za_idx[0,:,:,cdx],za_idx[1,:,:,cdx],:]*np.einsum('klij->iklj', weights[:,:,cdx,None])*dt[cdx,None,None,None],axis=0)#*altitude_weights[:,:,None]
        
    #return image_response


## Calculate the general response for the current data set
This has to be done only once (for the data set).

Takes ~20 minutes to run and ~60 GB memory! [Now down to 1-4 minutes (not sure why it varies so much!) and 35 GB - JDB]

In [None]:
cut = 90 
sky_response_scaled = [] # clear out any old (large!) matrix if we are running this more than once
sky_response_scaled = get_image_response_from_pixelhit_general(
    Response=sky_response_CDS,
    zenith=zensgrid,
    azimuth=azisgrid,
    domega=domega,
    dt=pointing1.dtpoins,
    n_hours=analysis1.dataset.times.n_ph,
    binsize=pixel_size,
    cut=cut)
    #altitude_correction=False)

In [None]:
# data-set-specific response dimensions
# times x lat x lon x CDS
sky_response_scaled.shape

## Exposure map
i.e. the response weighted by time

In [None]:
# mostly loop-nested version of computation.
@njit(fastmath=True,parallel=True,nogil=True)
def emap_fast(response, n_b, n_l):
    expo_map = np.empty((n_b, n_l))
    n_i = response.shape[0]
    n_j = response.shape[3]

    for x in prange(n_b):
        for y in range(n_l):
            expo_map[x,y] = 0
            for i in range(n_i):
                for j in range(n_j):
                    expo_map[x,y] += response[i,x,y,j]
                
    return expo_map

# original computation
def emap(response, n_b, n_l):
    expo_map = np.zeros((n_b, n_l))

    for i in range(response.shape[0]):
        expo_map += np.sum(response[i,:,:,:], axis=2)
    return expo_map
 

expo_map = emap_fast(sky_response_scaled, n_b, n_l)

## Plotting the exposure map weighted with the pixel size

In [None]:
plt.subplot(projection='aitoff')
p = plt.pcolormesh(L_ARRg*deg2rad,B_ARRg*deg2rad,np.roll(expo_map/domega,axis=1,shift=0))
plt.contour(L_ARR*deg2rad,B_ARR*deg2rad,np.roll(expo_map/domega,axis=1,shift=0),colors='black')
plt.colorbar(p, orientation='horizontal')

# Set up the RL algorithm

### Define function for a starting map for the RL deconvolution. We choose an isotropic map, i.e. all pixels on the sky are initialized with the same value

In [None]:
def IsoMap(ll,bb,A0,binsize=pixel_size):
    shape = np.ones(ll.shape)
    norm = np.sum(shape*(binsize*np.pi/180)*(np.sin(np.deg2rad(bb+binsize/2)) - np.sin(np.deg2rad(bb-binsize/2))))
    val = A0*shape/norm
    return val

### Number of time bins

In [None]:
d2h = analysis1.dataset.binned_data.shape[0]
d2h

### Select only one energy bin (as above) for data set

In [None]:
print('ebin: ',ebin)
dataset = analysis1.dataset.binned_data[:,ebin,:,:].reshape(d2h,
                                                            analysis1.dataset.phis.n_phi_bins*analysis1.dataset.fisbels.n_fisbel_bins)[:,nonzero_idx]


### Same for background

In [None]:
background_model = background1.bg_model_reduced[ebin]

### Check for consistency of data and background
They must have the same dimensions. If not, the algorithm won't work.

In [None]:
dataset.shape, background_model.shape

## Set an initial guess for the background amplitude
Feel free to play with this value, but here are suggestions informed by testing thus far:

### If source+BG:
We suggest setting "fitted_bg" to 0.9 or 0.99 when the loaded data/simulation (analysis1 object) contains both source and background. This is a rough estimate of the background contribution (90, 99%) to the entire data set.

### If analyzing source only:
When the analysis1 object does not contain background, we suggest setting this parameter to 1E-6, i.e. very close to zero background contribution.

In [None]:
fitted_bg = np.array([0.99])


# Richardson-Lucy algorithm

## Individual steps are explained in the code.
The steps follow the algorithm as outlined in [Knoedlseder et al. 1999](https://ui.adsabs.harvard.edu/abs/1999A%26A...345..813K/abstract). Refer to that paper for a mathematical description of the algorithm.

The total memory used during these iterations is about 74 GB!! You might not be able to do much else with your machine while this is running. 

[Now down to < 15 secs per iteration. We run until the MAP likelihood converges, which takes around 20 iterations - JDB]

In [None]:
# Might not use this depending on if you choose to smooth the delta map

from scipy.ndimage import gaussian_filter

In [None]:
from accelerate import convolve_fast, convdelta_fast
import time

# Experiment with these variables!
#############################
# initial map (isotropic flat, small value)
map_init = IsoMap(L_ARR, B_ARR, 0.01)

# number of RL iterations, Usually test with ~50 iterations, and we can get fully converged images with ~150 iterations. 
maxiters = 150

# if MAP likelihood changes by less than this fraction in 1 iteration, terminate
# NB: this is somewhat arbitrary and may be much too small in practice; I can't see
# a difference in the images below 1e-7
ltol = 1e-9

# acceleration parameter
afl_scl = 1000.
#############################

# Define regions of the sky that we actually cannot see
# here we select everything, i.e. we have no bad exposure

bad_expo = np.where(expo_map/domega <= 0)

#############################

## Define background model cuts, indices, and resulting number of cuts
bg_cuts, idx_arr, Ncuts = background1.bg_cuts, background1.idx_arr, background1.Ncuts
 
# temporary background model
tmp_model_bg = np.zeros((d2h, background_model.shape[1]))

for g in range(d2h):
    tmp_model_bg[g,:] = background_model[g,:]*fitted_bg[idx_arr-1][g]

## Save intermediate iterations: initialise arrays to save images and other parameters
# maps per iteration
map_iterations = np.zeros((maxiters, n_b, n_l))

# likelihood of maps (vs. initial i.e. basically only background)
map_likelihoods = np.zeros(maxiters)

# store per-iter fit likelihoods, ie fit quality
intermediate_lp = np.zeros(maxiters)

# store per-iter acceleration parameters (lambda)
acc_par = np.zeros(maxiters)

# store per-iter fitted background parameters 
bg_pars = np.zeros((maxiters,Ncuts))


## Zeroth iteration: copy initial map to become the 'old map' (see below)
map_old = map_init

# cf. Knoedlseder+1997 what the values denominator etc are
# this is the response R summed over the CDS and the time bins
denominator = expo_map
   
# convolve this map with the response
print('Convolving with response (init expectation)')
tstart = time.time()
expectation_init = convolve_fast(sky_response_scaled, map_init,
                                 sky_response_scaled.shape[0], sky_response_scaled.shape[3],
                                 n_b, n_l)
tend = time.time()
print(f'Time in convolution: {tend - tstart:.2f}s')

# set old expectation (in data space bins) to new expectation (convolved image)
expectation_old = expectation_init

# setting the map to zero where we selected a bad exposure (we didn't, but to keep it general)
map_old[bad_expo] = 0

# check for each pixel to be finite
map_old[np.isnan(map_old)] = 0

# save map from prior iteration
map_iterations[0,:,:] = map_old 

# expectation (in data space) is the image (expectation_old) plus the background (tmp_model_bg)
expectation_tot_old = expectation_old + tmp_model_bg 

# calculate likelihood of current total expectation
map_likelihoods[0] = cashstat(dataset,expectation_tot_old)

###########################################################
###########################################################
## here run over the number of iterations #################
###########################################################
## the time for the convolutions is very large ############
## this can be 10 minutes (!) per iteration ###############
## this should be tested for a few iterations #############
## and then run overnight or similar ######################
###########################################################
###########################################################
dml = np.inf
for its in tqdm(range(1,maxiters)):
    
    if dml < ltol:
        print(f"MAP likelihood change was less than {ltol} -- terminating")
        break

    # calculate numerator of RL algorithm
   
    print(f'Calculating Delta image, iteration {its}, numerator')
    tstart = time.time()
    W = dataset / expectation_tot_old - 1.
    numerator = convdelta_fast(sky_response_scaled, W, n_b, n_l, W.shape[0], W.shape[1])
    tend = time.time()
    print(f'Time in Delta image calc: {tend - tstart:.2f}s')
    
    # calculate delta map (denominator scaled by fourth root to avoid exposure edge effects)
    # You can try changing 0.25 to 0, 0.5, for example
    delta_map_tot_old = (numerator/denominator)*map_old*(denominator)**0.25
    
    # Alternatively, you can also try to smooth it 
    #delta_map_tot_old = gaussian_filter(delta_map_tot_old, 0.5)
    
    #################################
     
    # check again for finite values and zero our bad exposure regions
    delta_map_tot_old[bad_expo] = 0
    delta_map_tot_old[np.isnan(delta_map_tot_old)] = 0
    
    # plot each iteration's map and its delta map 
    # (not required, but nice to see how the algorithm is doing)
    plt.figure(figsize=(16,6))
    plt.subplot(121)
    plt.pcolormesh(L_ARRg,B_ARRg,np.roll(map_old, axis=1, shift=0)) 
    plt.colorbar()

    plt.subplot(122)
    plt.pcolormesh(L_ARRg,B_ARRg,np.roll(delta_map_tot_old, axis=1, shift=0)) 
    plt.colorbar()
    plt.show()
        
    # convolve delta image
    print(f'Convolving Delta image, iteration {its}')
    tstart = time.time()
    conv_delta_map_tot = convolve_fast(sky_response_scaled, delta_map_tot_old,
                                  sky_response_scaled.shape[0], sky_response_scaled.shape[3],
                                  n_b, n_l)
    tend = time.time()
    print(f'Time in convolution: {tend - tstart:.2f}s')

    # find maximum acceleration parameter to multiply delta image with
    # so that the total image is still positive everywhere
    assert np.min(map_old) >= 0, "map_old contains negative entries!"

    neg = delta_map_tot_old < 0

    # If there are no negative entries in delta_map_tot_old, there is no upper bound on the
    # acceleration.  Original code used a value of ~10000 in this case.  If we use much larger
    # value, RL seems to oscillate rather than converging smoothly and gives a worse final
    # likelihood (observed on the Point_Sources notebook, which is the only one with this issue).
    if not neg.any():
        afl = 10000
    else:
        afl = int(np.floor(np.min(-afl_scl * map_old[neg] / delta_map_tot_old[neg]))) 
        afl = min(afl, 10000)
    
    print('Maximum acceleration parameter found: ', afl/afl_scl)

    # fit:
    y = dataset.astype(int)

    conv_sky = np.concatenate([[expectation_old],[conv_delta_map_tot/afl_scl]])
    
    mu_Abg = fitted_bg    # can play with this
    sigma_Abg = fitted_bg # can play with this
    mu_flux = np.array([1,afl/2])
    sigma_flux = np.array([1e-2,afl])

    init_params =  (jnp.ones(Ncuts) * fitted_bg, jnp.array([1, afl/2.]))
    
    acceleration_factor_limit = afl * 0.95
    lower_bounds = (jnp.ones(Ncuts) * 1e-8,    jnp.ones(2) * 1e-8)
    upper_bounds = (jnp.ones(Ncuts) * jnp.inf, jnp.ones(2) * acceleration_factor_limit)
    
    tstart = time.time()
    res = optimizer.run(init_params, bounds=(lower_bounds, upper_bounds),
                        data=(conv_sky, tmp_model_bg, idx_arr, y,
                              mu_flux, sigma_flux, mu_Abg, sigma_Abg))
    tend = time.time()
    print(f'Time in optimizer: {tend - tstart:.2f}s')

    if not res.state.success:
        print("*** Optimizer failed! rerun with options = { 'disp': True } to see error messages")
       
        # proceed with a safe acceleration <= 1 (safe = new map does not go negative at any pixel)
        print("proceeding with a safe acceleration parameter")
        accScale = np.min(1., acceleration_factor_limit)
    else:
        # save values
        print(f'Saving new map, and fitted parameters, iteration {its}')
        intermediate_lp[its-1] = -res.state.fun_val
        
        newAbg, newflux = res.params
        newAcc = newflux[1]
        bg_pars[its-1,:] = newAbg
        acc_par[its-1]   = newAcc

        accScale = newAcc/afl_scl
    
    # make new map as old map plus scaled delta map (must copy out of Jax to make writable on next line)
    map_new = np.array(map_old + accScale * delta_map_tot_old)
    
    # setting the map to zero where we selected a bad exposure (we didn't, but to keep it general)
    map_new[bad_expo] = 0
    
    # check for each pixel to be finite
    map_new[np.isnan(map_old)] = 0

    # save map from prior iteration
    map_iterations[its,:,:] = map_new 

    # make new expectation as old expectation plus scaled conv_delta map
    expectation_new = np.array(expectation_old + accScale * conv_delta_map_tot)

    # expectation (in data space) is the image (expectation_old) plus the background (tmp_model_bg)
    expectation_tot_new = expectation_new + tmp_model_bg 

    # calculate likelihood of current total expectation
    map_likelihoods[its] = cashstat(dataset, expectation_tot_new)

    # how much did the MAP likelihood improve since the last iteration?
    dml = np.abs((map_likelihoods[its] - map_likelihoods[its-1])/map_likelihoods[its-1])
    
    # swap maps
    map_old = map_new
    
    # and expectations
    expectation_old = expectation_new
    expectation_tot_old = expectation_tot_new
    
    print(f"After iteration {its}: MAP likelihood = {map_likelihoods[its]:.2f}, rel. change = {dml:.2e}")
    # and repeat


## Plot the fitted background parameter and the map flux

In [None]:
plt.figure(figsize=(14,6))
plt.subplot(121)
plt.plot(range(its-1), [i[0] for i in bg_pars[:its-1]], '.-')
plt.xlabel('Iteration')
plt.ylabel('BG params]')


plt.subplot(122)
map_fluxes = np.zeros(its)
for i in range(its):
    map_fluxes[i] = np.sum(map_iterations[i,:,:]*domega)
    
plt.plot(map_fluxes[:its],'o-')
plt.xlabel('Iteration')
plt.ylabel('Flux')# [ph/keV]')

## Did the algorithm converge? Look at the likelihoods.
intermediate_lp: Fit likelihoods, i.e. fit quality

map_likelihoods: likelihood of maps (vs. initial i.e. basically only background)

In [None]:
plt.figure(figsize=(14,6))
plt.subplot(121)
plt.plot(np.arange(1, its), intermediate_lp[:its-1], '.-')
plt.xlabel('Iteration')
plt.ylabel('likelihood (intermediate_lp)')

plt.subplot(122)
plt.plot(range(its+1), map_likelihoods[:its+1], '.-')
plt.xlabel('Iteration')
plt.ylabel('likelihood (map_likelihoods)')

print(f'final MAP likelihood = {map_likelihoods[its]}')

## Make the image!
You can loop over all iterations to make a GIF or just show one iteration (usually the final iteration).

In [None]:
from IPython.display import Image
from IPython.display import Video

from matplotlib import animation

from matplotlib import colors

from scipy.ndimage import gaussian_filter as smooth

In [None]:
# Choose an image to plot
idx = its


In [None]:
# Choose a color map like viridis (matplotlib default), nipy_spectral, twilight_shifted, etc. Not jet.
cmap = plt.get_cmap('viridis') 

# Bad exposures will be gray
cmap.set_bad('lightgray')


##################
# Select here which pixels should be gray
map_iterations_nan = np.copy(map_iterations)

# Select also non-zero exposures here to be gray (avoiding the edge effects)
# You can play with this. Most success in testing with 1e4, 1e3
bad_expo = np.where(expo_map/domega <= 1e4) 

for i in range(maxiters):
    map_iterations_nan[i, bad_expo[0], bad_expo[1]] = np.nan
#################    


# Set up the plot
fig, ax = plt.subplots(figsize=(10.24,7.68), subplot_kw={'projection':'aitoff'}, nrows=1, ncols=1)

ax.set_xticks(np.array([-120,-60,0,60,120])*deg2rad)
ax.set_xticklabels([r'$-120^{\circ}$'+'\n',
                            r'$-60^{\circ}$'+'\n',
                            r'$0^{\circ}$'+'\n',
                            r'$60^{\circ}$'+'\n',
                            r'$120^{\circ}$'+'\n'])
ax.tick_params(axis='x', colors='orange')

ax.set_yticks(np.array([-60,-30,0,30,60])*deg2rad)
ax.tick_params(axis='y', colors='orange')

plt.xlabel('Gal. Lon. [deg]')
plt.ylabel('Gal. Lat. [deg]')



# "ims" is a list of lists, each row is a list of artists to draw in the
# current frame; here we are just animating one artist, the image, in
# each frame
ims = []


# If you want to make a GIF of all iterations:
#for i in range(its):

# If you only want to plot one image:
for i in [idx]:

    ttl = plt.text(0.5, 1.01, f'RL iteration {i}', horizontalalignment='center', 
                   verticalalignment='bottom', transform=ax.transAxes)
    
    # Either gray-out bad exposure (map_iterations_nan) or don't mask (map_iterations)
    # Masking out bad exposure 
    #image = map_iterations_nan[i,:,:]
    image = map_iterations[i,:,:]

    
    img = ax.pcolormesh(L_ARRg*deg2rad,B_ARRg*deg2rad,
                        
                        # Can shift the image along longitude. Here, no shift.
                        np.roll(image, axis=1, shift=0),
            
                        # Optionally smooth with gaussian filter
                        #smooth(np.roll(image, axis=1, shift=0), 0.75/pixel_size),
                        
                        cmap=plt.cm.viridis,
                        
                        # Optionally set the color scale. Default: linear
                        #norm=colors.PowerNorm(0.33)
                       )
    ax.grid()
    
    ims.append([img, ttl])

cbar = fig.colorbar(img, orientation='horizontal')
cbar.ax.set_xlabel(r'[Arbitrary Units]')
    

# Can save a sole image as a PDF 
#plt.savefig(data_dir + f'images/511keV_RL_image_iteration{idx}.pdf', bbox_inches='tight')
    
    
# # Can save all iterations as a GIF
# ani = animation.ArtistAnimation(fig, ims, interval=200, blit=True, repeat_delay=0)
# ani.save(f'/home/jacqueline/511keV_RL_image_{idx}iterations.gif')


# What do we see?

We clearly see the "bulge" emission of positron-electron annihilation at the center of the Milky Way. This was also seen in the published image of real COSI-balloon flight data [(Siegert et al. 2020)](https://iopscience.iop.org/article/10.3847/1538-4357/ab9607/meta):

<img width="600" alt="Siegert_2020_COSI_511keV" src="https://user-images.githubusercontent.com/33991471/196853486-68a90111-245b-442d-841c-756f47c9c14f.png">


The extended disk emission seen in the SPI image above is not visible here. This is expected; SPI saw about 1 photon per week from the disk and has over a decade of observation time. There is not enough data in the 46-day balloon flight to image the disk.

However, we can still probe the emission morphology of the bulge by fitting a 2-D Gaussian, for example, to our simulated image. Constraining the parameters of this fit is important for modeling the physics (positron propogation, point sources of positrons, etc.) behind this enduring mystery.  

## Fit a 2D Gaussian to the emission

In [None]:
def gauss_2d(xtuple, A, x0, y0, sigma_x, sigma_y, theta):
    # theta: rotate the blob by positive, counterclockwise angle theta
    # https://en.wikipedia.org/wiki/Gaussian_function#Two-dimensional_Gaussian_function
    # NB: code actually rotates *clockwise* by theta because sign of b formula is inverted
    (x, y) = xtuple
    x0 = float(x0)
    y0 = float(y0)
    a = np.cos(theta)**2/(2*sigma_x**2) + np.sin(theta)**2/(2*sigma_y**2)
    b = np.sin(2*theta)/(4*sigma_x**2)  - np.sin(2*theta)/(4*sigma_y**2)
    c = np.sin(theta)**2/(2*sigma_x**2) + np.cos(theta)**2/(2*sigma_y**2)
    tot = A*np.exp( -( a*(x-x0)**2 + 2*b*(x-x0)*(y-y0) + c*(y-y0)**2 ) )
    return tot.ravel()

In [None]:
import scipy.optimize as opt
initial_guess = (2, 0, 0, 10, 10, 2)
x = (L_ARRg*deg2rad)[:-1, :-1]
y = (B_ARRg*deg2rad)[:-1, :-1]
z = map_iterations_nan[idx,:,:]
z[np.isnan(z)] = 0.

 # added bounds and boosted feval max to improve convergence and give sensible angle theta - JDB
popt, pcov = opt.curve_fit(gauss_2d, (x, y), z.ravel(), p0=initial_guess, 
                           bounds=([0,-np.inf,-np.inf,0,0,-np.pi],[np.inf,np.inf,np.inf,np.inf,np.inf,np.pi]),maxfev=10000)

im_fitted = gauss_2d((x, y), *popt)

In [None]:
fig, ax = plt.subplots(figsize=(10.24,7.68),subplot_kw={'projection':'aitoff'},nrows=1,ncols=1)

ax.set_xticks(np.array([-120,-60,0,60,120])*deg2rad)
ax.tick_params(axis='x', colors='orange')
ax.set_xticklabels([r'$-120^{\circ}$'+'\n',
                            r'$-60^{\circ}$'+'\n',
                            r'$0^{\circ}$'+'\n',
                            r'$60^{\circ}$'+'\n',
                            r'$120^{\circ}$'+'\n'])
ax.set_yticks(np.array([-60,-30,0,30,60])*deg2rad)
ax.tick_params(axis='y', colors='orange')

ax.set_xlabel('Gal. Lon. [deg]')
ax.set_ylabel('Gal. Lat. [deg]')

# Plot original image
ax.pcolormesh(L_ARRg*deg2rad, B_ARRg*deg2rad, z.reshape(len(x), len(x[0])), cmap=plt.cm.viridis)

# Plot contours
num_contours = 2
levels = [np.max(im_fitted)*0.05, np.max(im_fitted)*0.1,
          np.max(im_fitted)*0.5, np.max(im_fitted)*0.8]

#plt.contour(x, y, im_fitted.reshape(len(x), len(x[0])), levels=num_contours, colors='w')

plt.contour(L_ARR*deg2rad, B_ARR*deg2rad, im_fitted.reshape(len(x), len(x[0])), levels = levels, colors='white')

cbar = fig.colorbar(img, orientation='horizontal')
#cbar.ax.set_xlabel(r'Flux [10$^{-2}$ ph cm$^{-2}$ s$^{-1}$]')
    
ax.grid()

In [None]:
print('A:', popt[0])
print('x0 [deg]:', popt[1]*180/np.pi)
print('y0 [deg]:', popt[2]*180/np.pi)
print('sigma_x [deg]:', popt[3]*180/np.pi, '--> FWHM_x [deg]:', 2*np.sqrt(2*np.log(2))*popt[3]*180/np.pi)
print('sigma_y [deg]:', popt[4]*180/np.pi, '--> FWHM_y [deg]:', 2*np.sqrt(2*np.log(2))*popt[4]*180/np.pi)
print('theta [deg]:', popt[5]*180/np.pi)