# Initial example of the XCA plotting for a classifier.

Using everyone's favorite prototype, BaTiO$_3$.
Unfortunately, even a convolutional VAE struggles to separate the phases of BaTiO$_3$.
So we use something a little more complex that includes predictive capacity,
but just pull out the encoder and decoder parts.

This leaves some opening for a predictive VAE agent that observes, makes predictions,
and compares against the XCA feed forward predictions. The federation grows...



In [None]:
from federation.plumbing.filesystem import ObservationalDirectoryAgent
from federation.xca.vae import VAECompanion
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
encoder_path = "../saved_models/BTO_VAE/encoder"
decoder_path = "../saved_models/BTO_VAE/decoder"
data_path = "../example_data/BTO/"
training_data_paths = list(Path("../saved_models/BTO_VAE/eg_training_data").glob("*.nc"))
eg_model_data = xr.open_dataset(training_data_paths[0])
eg_exp_data = np.loadtxt("../example_data/BTO/BTO_150K.IoQ")

# Extract linspace from examples
model_tth = eg_model_data.coords["2theta"].values
exp_tth = eg_exp_data[:,0]

In [None]:
# This could also trim the data to the ROI
# Data is automatically normalized onto (-1, 1) within the XCA companion
def data_transform(data):
    return data

def independent_from_path(path):
    return float(path.stem.split("_")[-1][:-1])

In [None]:
companion = VAECompanion(encoder_path=encoder_path,
                         decoder_path=decoder_path,
                         model_tth=model_tth,
                         exp_tth=exp_tth,
                         coordinate_transform=None,
                         latent_dims=(0, 1))

agent = ObservationalDirectoryAgent(companion,
                                    data_path,
                                    path_spec="*.IoQ",
                                    data_transform=data_transform,
                                    independent_from_path=independent_from_path)

In [None]:
agent.load_dir()

### We take a subset of training data and "prime" the plot.
This shows how the training classes were distributed in latent space.
It also casts a normalization factor for the size of the points produced by `observe()`.

In [None]:
arrays = list()
for path in training_data_paths:
    arrays.append(xr.load_dataarray(path))
X = xr.concat(arrays, dim="idx", combine_attrs="drop_conflicts").data
labels = [da.attrs['input_cif'] for da in arrays]



In [None]:
agent.companion.prime_plot(X, labels)

### Observe will plot the measured data over the primed plot.
- If the plot was not primed with training data, it will just populate latent space.
- Simple `observe()` will plot the most recent. Alternatively, an independent variable can be sought after.
- By default only one point will be displayed at a time.  This can be changed by adjusting the `hold` parameter.
- The size of the points will depend on relative reconstruction error to some maximum.
    - This maxmimum is by default the maximum of the dependent reconstruction errors.
    - If the plot is primed, the maximum will default to the maximum reconstruction error from the training data used in priming.
    - Alternatively, it can be hard set using the `max_error` parameter.

In [None]:
agent.companion.observe()

In [None]:
agent.companion.observe(independent=[225])
agent.companion.hold = True

In [None]:
agent.companion.observe(independent=[150], max_error=0.2)


