# SPItorch Tutorial Notebook

This notebook tries to give a _guided tour_ of the most the important bits to help folks get up and running quickly.

The SPItorch package is abbreviated to `spt` to save your fingers a bit of typing. Let's import it:

In [None]:
import spt

Just because we're in a notebook, we'll change directory to the root of the SPItorch project so that our example file paths will work on any machine. We'll also take care of some other setup stuff:

In [None]:
try: # One-time setup
    assert(_SETUP)
except NameError:
    import os
    import torch as t
    os.chdir(os.path.split(spt.__path__[0])[0])
    dtype = t.float32
    device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
    if device == t.device("cuda"):
        print(f'Using GPU for training')
        !nvidia-smi -L
    else:
        print("CUDA is unavailable; training on CPU.")
    _SETUP = True

# Forward Modelling

Here we show some examples of how Prospector / py-FSPS forward models are wrapped in SPItorch.

We can load a specific observation (e.g. galaxy) from our catalogue (if we know the index) as follows:

In [None]:
index = 42
o = spt.load_observation(index)

Or if we'd like to pick out a random observation, we can just leave out the index:

In [None]:
o = spt.load_observation()
print(o)

As you can see, this is just a pandas `Series`, and the index has also been added into the series under the `idx` key (albeit as a float).

We can initialise a 'prospector object' (i.e. a thin convenience wrapper around Prospector methods) with this observation as follows:

In [None]:
p = spt.Prospector(o)

This might have taken a few seconds because the SPS libraries had to be loaded - these are a pretty slow component of the project.

We can visualise our observation as follows:

In [None]:
p.visualise_obs(show=True, save=False)

You can represent information about the model, it's parameters and filters by printing out the prospector wrapper:

In [None]:
print(p)

The prospector model has a number of internal (`theta`) parameters, which are initialised by the `init` values of the parameter description (or drawn from the prior if this is missing). We can inspect them as follows:

In [None]:
# Theta are the 'free parameters' which we optimise / fit
print(p.model.theta)
print(p.model.theta_labels())  # same as above
print(p.model.free_params)

# We do not attempt to optimise / fit the fixed parameters:
print(p.model.fixed_params)

We can also visualise the photometry and spectroscopy with these initial parameters, and compare these to the observations (note, they'll probably be quite different at this stage!)

In [None]:
p.visualise_model(show=True, save=False)

Prospector comes with a number of optimisation methods. These are a little slow (particularly if you have a lot of observations that you need to get through), and the point of the machine learning inference section is to try to speed these up.

### Numerical Optimisation

to try to get our predicted photometry a little closer to the observations, we can get the model's `theta` values in the right ballpark with some numerical optimisers:

In [None]:
results = p.numerical_fit()

We can see that the model's current `theta` value is automatically updated to the best theta from the optimisation (we run multiple starts to avoid local minima; this is configured in the settings.)

In [None]:
import numpy as np
ind_best = np.argmin([r.cost for r in results])
print(f'Best index: {ind_best}')
theta_best = results[ind_best].x.copy()
print(f'Best theta:\n{theta_best}')
print(f'Current model theta:\n{p.model.theta}')

Having done that, we can visualise the model's photometric perdicitons again, and see how well they line up with the observations:

In [None]:
p.visualise_model(theta=theta_best, show=True, save=False)

## EMCEE optimisation

In [None]:
p.emcee_fit()

In [None]:
p.visualise_model(show=True, save=False)

You can load up an arbitrary model; either specified by a known file path, or by an (index, fitting method) pair as follows:

In [None]:
from spt.types import MCMCMethod
try:
    p.load_fit_results(file='/path/to/my/file')
except ValueError:
    # of course, the above is going to fail...
    pass
p.load_fit_results(index=p.index, method=MCMCMethod.EMCEE)

This populates the `Prospector.fit_results` property. From this, you can plot traceplots and the like.

In [None]:
p.fit_results.keys()