In [None]:
import numpy as np
from astropy import units as u
from snewpy.models import presn, ccsn
from snewpy.neutrino import Flavor

import pylab as plt
from contextlib import contextmanager

In [None]:
@contextmanager
def raises(exception):
    "A small utility to catch and show the exceptions"
    try:
        yield
    except exception as e:
        print(f'{exception.__name__}: {e}')
    

# Usage of flux.Container interface

## 1. Initialize model

In [None]:
model = ccsn.Bollig_2016(progenitor_mass=27<<u.Msun)

times    = model.time #np.linspace(0,2,1500)<<u.second;
energies = np.linspace(0,50,501)<<u.MeV

## 2. Calculate flux

In [None]:
flux = model.get_flux(t = times, E = energies, distance=10<<u.kpc)

## Working with the Container class

The container class is defined in `snewpy.flux.Container`

In [None]:
from snewpy.flux import Container

Container?

### Inspection

We will use the flux container, obtained from the SupernovaModel on the previous step.

On print the Container gives an output, describing it's `array` dimensions, unit (here it's `[1/(cm2 MeV s)]`), and the range of each of three axes: `flavor`, `time` and `energy`

In [None]:
print(flux)

One can access the array, as a `astropy.Quantity` object:

In [None]:
flux.array

and individual axes

In [None]:
flux.energy

### Slicing

Flux can be sliced the same way as a usual np.array:

In [None]:
#get the flux for specific flavor
flux[Flavor.NU_E] 

In [None]:
#Or a trim the time or energy dimensions, 
#here we take first 1000 points in time
flux[:,:1000,:]

### Summation

Container can be summed by any of the axes: `flavor`,`time`,`energy`.

In [None]:
#Sum over all flavors
flux.sum('flavor')

In [None]:
#trying to summ over time or flavor will raise an exception
with raises(ValueError):
    flux.sum('time')

### Integration
Also it can be integrated (using linear interpolation between points)

In [None]:
#Integrate over the full range, if limits are not provided
flux.integrate('energy')

In [None]:
#Integrate over the first second of the flux
flux.integrate('time',limits=[0,1]<<u.s)

In [None]:
#Integrate over the several time bins
flux.integrate('time',limits=[0,1,2,3]<<u.s)

After integration, the time dimension of the array here is `len(limits)-1`
But the corresponding axis (in this case `time`) keeps all the limits (and here it has 4 points)

### "Summable" and "integrable" axes

Container will automatically watch which axes can be integrated, and which can only be summed over.

For example, the flux, obtained from the `SupernovaModel.get_flux` is differential over time and energy, which means it can be integrated over `energy` and `time`. 
You can see it from the class name `d2FdEdT` and the unit:

In [None]:
flux

Internally the axes which can be summed or integrated are kept in the private properties:

In [None]:
print(f'Can integrate over {flux._integrable_axes}')
print(f'Can sum over {flux._sumable_axes}')

After the integration over the axis, it becomes "summable":

In [None]:
fI = flux.integrate('time',limits=[0,1,2,3]<<u.s)
print(f'Can integrate over {fI._integrable_axes}')
print(f'Can sum over {fI._sumable_axes}')

In [None]:
#after we integrated over time we can sum over the time bins
fI.sum('time')

In [None]:
#but cannot integrate over time again
with raises(ValueError):
    fI.integrate('time')

**Note**: *integration changes the physical unit, summation or slicing do not*

### Integrate or sum (project?)

Sometimes you just want a projection of your flux to a specific axis (say, time).
In this case you want to integrate over energy if it's integrable, or just sum over the energy bins, if it is already integral flux:

In [None]:
assert fI.integrate_or_sum('energy') == fI.integrate('energy')
assert fI.integrate_or_sum('time') == fI.sum('time')

## 3. Calculate rate

I made a `RateCalculator` -  subclass of a `SimpleRate` (temporary solution for cleanness, we might merge them later).

The main difference - it uses `RateCalculator.run` function to calculate rates - multiplying flux by the cross-section, target number, smearing matrix and efficiency.

**Note** on the rate calculation:
In `SimpleRate` we used an input from `generate_fluence`, so the flux was already integrated within the energy bins, and the cross-section was sampled in the centers of energy bins.

In the `RateCalculator` we're multiplying the flux by the cross-section, and only after that we integrate over energy bins, so that should be more precise.

In [None]:
from snewpy.rate_calculator import RateCalculator
rc = RateCalculator()

In [None]:
#calculate time differential rate 
rates = rc.run(flux, 'icecube')
rates['ibd']

`RateCalculator.run` outputs a dictionary of Container objects with rates for each channel

Since rhe rate calculation only operates on energy, the time structure remains the same: if it was flux (`1/(MeV cm2 s)`) the rate will be `1/s`

But if we pass a flux integrated over time (fluence), we get just number of events:

In [None]:
#calculate time integral rate 
fluence = flux.integrate('time', np.arange(0,2.1,0.1)<<u.s)
ratesI = rc.run(fluence, 'icecube')
ratesI['ibd']

### Saving and loading

In [None]:
#Container can be saved to a file 
fluence.save('fluence.npz')
#and loaded using the class method
fluence1 = Container.load('fluence.npz')
assert fluence1 == fluence

fluence1

## 4. Plotting examples

### Utility functions

In [None]:
#Utility function to draw the flux
from snewpy.flux import Axes

def project(flux, axis, integrate=True):
    axis = Axes[axis] #convert to enum
    integrate_axis = Axes.time 
    if axis == integrate_axis:
        integrate_axis = Axes.energy 
    fI = (flux.integrate if integrate else flux.sum)(integrate_axis)
    return fI.axes[axis], fI
    
def plot_projection(flux, axis, step=False, integrate=True):
    x,fI = project(flux,axis, integrate)
    y = fI.array.squeeze().T
    if step:
        #we're dealing with bins, not points
        l = plt.step(x[:-1], y, where='pre', label=[Flavor(flv).to_tex() for flv in flux.flavor])
    else:
        l = plt.plot(x, y, label=[Flavor(flv).to_tex() for flv in flux.flavor])
    
    plt.ylabel(f'{fI.__class__.__name__},  {y.unit}')
    plt.xlabel(f'{Axes[axis].name},  {x.unit}')
    return l

In [None]:
#Utility function to draw the 
def plot_rates(rates):
    for ch,r in rates.items():
        rT = r.sum('energy')

        plt.plot(rT.time, rT.array.squeeze(), label=ch)
    #plt.xlim(right=0.5)
    plt.ylabel(f'{rT.__class__.__name__},  {r.array.unit}')
    plt.xlabel(f'time, {rT.time.unit}')
    


### Plot fluxes

In [None]:
#plot the neutrino flux 
fig,ax = plt.subplots(1,2, figsize=(12,6))
plt.sca(ax[0])
plot_projection(flux, 'energy', integrate=True)
plt.legend()

plt.sca(ax[1])
plot_projection(flux, 'time', integrate=True)
plt.legend()
plt.xscale('log')
plt.show()

In [None]:
for ch, rate in rates.items():
    l = plot_projection(rate, 'time', integrate=False)
    l[0].set_label(ch)
plt.yscale('log')
plt.legend(loc='right')
plt.ylim(0.1)
plt.show()

In [None]:
for ch, rate in ratesI.items():
    l = plot_projection(rate, 'time', integrate=False, step=True)
    l[0].set_label(ch)
plt.yscale('log')
plt.legend(loc='right')
plt.ylim(0.1)
plt.show()