In [None]:
from matplotlib.pyplot import *
import numpy as np
import pandas as pd
import os
import importlib

from ATARI.sammy_interface import sammy_classes, sammy_functions

from ATARI.ModelData.particle_pair import Particle_Pair
# from ATARI.ModelData.experimental_model import Experimental_Model
from ATARI.syndat.control import syndatOPT, Syndat_Control

from copy import copy

import ATARI.utils.plotting as myplot


In [None]:
sammypath = ''
assert(sammypath != '')

In [None]:
%matplotlib widget

# Overview

Syndat models for each measurement were developed, investigated, and saved in the associated example notebooks.
Here, we are going to load each Syndat model, add them to the Syndat Control module, and draw samples from them.
A few things to note, 
1) options should likely be re-defined as we want to ensure resonance and parameter sampling is turned on
2) the energy ranges can be overwritten s.t. you can draw samples for a smaller window.
3) The background function is the same for all 3 the transmission measurements, this can be implemented with the model correlations input


In [None]:
from ATARI.utils import atario

syndat_trans1mm = atario.load_syndat(os.path.join(os.getcwd(), "results", "SyndatModel_1mmT.pkl"))
syndat_trans3mm = atario.load_syndat(os.path.join(os.getcwd(), "results", "SyndatModel_3mmT.pkl"))
syndat_trans6mm = atario.load_syndat(os.path.join(os.getcwd(), "results", "SyndatModel_6mmT.pkl"))

syndat_cap1mm = atario.load_syndat(os.path.join(os.getcwd(), "results", "SyndatModel_1mmY.pkl"))
syndat_cap2mm = atario.load_syndat(os.path.join(os.getcwd(), "results", "SyndatModel_2mmY.pkl"))

In [None]:
energy_range_all = [197.5, 235]

Ta_pair = Particle_Pair(isotope="Ta181", formalism="XCT", energy_range=energy_range_all, 
                        ac=0.81271, M=180.948030, m=1, I=3.5, i=0.5, l_max=1)      

Ta_pair.add_spin_group(Jpi='3.0', J_ID=1,
                       D=9.0030,
                       gn2_avg=452.56615, gn2_dof=1,
                       gg2_avg=32.0, gg2_dof=100)

Ta_pair.add_spin_group(Jpi='4.0', J_ID=2,
                       D=8.3031,
                       gn2_avg=332.24347, gn2_dof=1,
                       gg2_avg=32.0, gg2_dof=100)

In [None]:
## get new syndat models in a smaller window

syndat_models = [syndat_trans1mm, syndat_trans3mm, syndat_trans6mm, syndat_cap1mm, syndat_cap2mm]
syndat_models_new = []
for each in syndat_models:
    syndat_models_new.append(each.truncate_energy_range(energy_range_all, return_copy=True))


In [None]:

rto = sammy_classes.SammyRunTimeOptions(sammypath,
                             {"Print"   :   True,
                              "bayes"   :   False,
                              "keep_runDIR"     : False,
                              "sammy_runDIR": "sammy_runDIR_gen"
                              })

options = syndatOPT(sampleRES=True,
                    sampleTURP=True,
                    save_raw_data=False) # raw data not needed here

## define model correlation for a_b parameters
trans_a_b = syndat_trans1mm.generative_measurement_model.model_parameters.a_b
model_correlations = [
                    {'models': [1,1,1,0,0],
                    'a_b'   : trans_a_b }
                        ]

syndat = Syndat_Control(Ta_pair,
                        syndat_models = syndat_models_new,
                        model_correlations=model_correlations,
                        options=options)

In [None]:
syndat.sample(rto, num_samples=10)

In [None]:
sample1 = syndat.get_sample(4)
datasets = [val.pw_reduced for key, val in sample1.items()]
experiments = [val.generative_experimental_model for val in syndat_models]

importlib.reload(myplot)
fig = myplot.plot_reduced_data_TY(datasets=datasets,
                            experiments=experiments,
                            xlim=energy_range_all,
                            plot_datasets_true=True
                            )

# Saving the Syndat Control module and all samples
Similar to the syndat_models, you can use the atario function to save the syndat_control object as a pickle. 
This is commented out because an existing pkl file is loaded

In [None]:

atario.save_syndat_control(syndat, os.path.join(os.getcwd(), "results", "SyndatModel_All_200_235.pkl"))