In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from pprint import pprint

from tiled.client import from_uri
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
# mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

# Basic Tutorial

The [AIMM post-processing pipeline](https://github.com/AI-multimodal/aimm-post-processing) is built around the `Operator` object. The `Operator`'s job is to take a `client`-like object and execute a post-processing operation on it. The specific type of operation is defined by the operator. All metadata/provenance is tracked.

In [None]:
from aimmdb.postprocessing import operations

Connect to the `tiled` client. This one is the [aimmdb](https://github.com/AI-multimodal/aimmdb) hosted at [aimm.lbl.gov](https://aimm.lbl.gov/api). Note that my API key is stored in an environment variable, `TILED_API_KEY`. 

In [None]:
CLIENT = from_uri("https://aimm.lbl.gov/api")

In [None]:
list(CLIENT["dataset"])

## Unary operators

A [unary operator](https://en.wikipedia.org/wiki/Unary_operation) takes a single input. This input specifically refers to the fact that these operators only act on a single data point (meaning a `DataFrameClient`) at a time. We'll provide some examples here.

First, lets get a single `DataFrameClient` object:

In [None]:
df_client = CLIENT["uid"]["Bt5hUbgkfzR"]
type(df_client)

### The identity

The simplest operation we can perform is nothing! Let's see what it does. First, feel free to print the output of the `df_client` so you can see what's contained. Using the `read()` method will allow you to access the actual data, and the `metadata` property will allow you to access the metadata:

In [None]:
_ = df_client.read()    # is a pandas.DataFrame
_ = df_client.metadata  # is a python dictionary

The identity operator is instantiated and then run on the `df_client`.

In [None]:
op = operations.Identity()
result = op(df_client)

Every result of any operator will be a dictionary with two keys: `"data"` and `"metadata"`, which correspond to the results of `read()` and `metadata` above. The data is the correspondingly modified `pandas.DataFrame` object (which in the case of the identity, is of course the same as what we started with). The metadata is custom created for a derived, post-processed object.

First, let's check that the original and "post-processed" data are the same.

In [None]:
assert (df_client.read() == result["data"]).all().all()

Next, the metadata:

In [None]:
result["metadata"]

First, a new unique id is assigned. Second, given this is a derived quantity, the previous original metadata is now gone in place of a `post_processing` key. This key contains every bit of information needed for provenance, including the parents (which is just one in the case of a unary operator), the operator details (including code version), any keyword arguments used during instantiation, and the datetime at which the opration was run. We use the [MSONable](https://pythonhosted.org/monty/_modules/monty/json.html) library to take care of most of this for us.

We can compare against the original metadata to see the differences.

In [None]:
df_client.metadata

### Standardizing the grids

Often times (and especially for e.g. machine learning applications) we need to interpolate our spectral data onto a common grid. We can do this easily with the `StandardizeGrid` unary operator.

In [None]:
op = operations.StandardizeGrid(x0=7550.0, xf=8900.0, nx=100, x_column="energy", y_columns=["itrans"])
result = op(df_client)

Here's a visualization of what it's done:

In [None]:
d0 = df_client.read()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.plot(d0["energy"], d0["itrans"], 'k-')
ax.plot(result["data"]["energy"], result["data"]["itrans"], 'r-')
plt.show()

### Spectral postprocessing

In order to make XAS usable, we need to do a few things. Particularly:
1. Subtract off the pre-edge trend
2. Normalize the tail region to 1

We provide tools to do this systematically. For what follows, we'll do this on a Co spectrum.

In [None]:
node = CLIENT["uid"]["Bt5hUbgkfzR"]
df = node.read()

import numpy as np
df["mutrans"] = -np.log(df["itrans"]/df['i0'])

energy = df["energy"]
mutrans = df["mutrans"]

Here's what the spectrum looks like before postprocessing:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.plot(energy, mutrans)
plt.show()

In [None]:
remove_background = operations.RemoveBackground(x0=7510, xf=7690, y_columns=["mutrans"], victoreen_order=0)
standardize_intensity = operations.StandardizeIntensity(x0=8000, xf=None, y_columns=["mutrans"])

In [None]:
tmp_data = remove_background({'data': df, 'metadata': node.metadata})
new_data = standardize_intensity(tmp_data)

In [None]:
new_df = new_data["data"]
energy = new_df["energy"]
mutrans = new_df["mutrans"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.plot(energy, mutrans)
plt.show()

Alternative normalization scheme using x-ray larch (see: https://xraypy.github.io/xraylarch/)

In [None]:
normalize_xas = operations.NormalizeLarch(y_columns=["mutrans"])
larch_norm_data = normalize_xas(tmp_data)

larch_norm_df = larch_norm_data["data"]
energy = larch_norm_df["energy"]
mutrans = larch_norm_df["mutrans"]

fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.plot(energy, mutrans)
plt.show()

# Multiple input operators

MultiOperators are defined act on an arbitrary number of inputs and return a single output. Any number of `DataFrameClient` objects can be passed to a MultiOperator and acted on.

For example here we use a MultiOperator to average data from several spectra:

In [None]:
average_data = operations.AverageData(y_column="mutrans")
avg_result = average_data(larch_norm_data, {'data': df, 'metadata': node.metadata})
avg_result["metadata"]

In [None]:
avg_data = avg_result["data"]
energy = avg_data["energy"]
mutrans = avg_data["mutrans"]

fig, ax = plt.subplots(1, 1, figsize=(3, 2))
ax.plot(energy, mutrans)
plt.show()