In [None]:
import h5py as h5
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy


from multiresticodm.config import Config
from multiresticodm.inputs import Inputs
from multiresticodm.outputs import Outputs
from multiresticodm.utils.misc_utils import *
from multiresticodm.utils.math_utils import *
from multiresticodm.contingency_table import instantiate_ct
from multiresticodm.utils.probability_utils import *

In [None]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [None]:
# Get important paths
experiment_dir = '../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04/'
config_path = os.path.join(experiment_dir,'config.json')

In [None]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [],
        # "da.loss_name.isin([str(['dest_attraction_ts_likelihood_loss']),str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])])",
    "metadata_keys":[],
    "burnin_thinning_trimming": [{'iter': {"burnin":10000, "thinning":100, "trimming":1000}}],
    "n_workers": 1,
    "filename_ending":"test",
    "sample":["table","intensity"],
    "force_reload":True
}

In [None]:
outputs = Outputs(
    config = config_path,
    settings = settings,
    base_dir = experiment_dir,
    inputs = None,
    slice = True
)
# Silence outputs
outputs.logger.setLevels(console_level='EMPTY')
# Collect outputs from folder
outputs.load()

In [None]:
def validate_tables(out):
    out.inputs.cast_from_xarray()
    ct = instantiate_ct(
        config = out.config,
        **out.inputs.data_vars(),
        level = 'EMPTY'
    )
    samples = out.get_sample('table')
    print('axes constraints',ct.constraints['constrained_axes'])
    print('cell constraints',len(ct.constraints['cells']))
    tables_admissible = all([ct.table_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')])
    print('Tables admissible',tables_admissible)
    if not tables_admissible:
        print('Tables margins admissible',any([ct.table_margins_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))
        print('Tables cells admissible',all([ct.table_cells_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))

In [None]:
data = []
for i in tqdm(
    range(len(outputs.data)),
    leave=True,
    disable=True,
    desc='Computing validation metrics'
):
    print(f"{i+1}/{len(outputs.data)}") 
    sweep_outputs = outputs.get(i)
    sweep_outputs.inputs = Inputs(
        config = sweep_outputs.config,
        synthetic_data = False,
        logger = outputs.logger
    )

    mean_intensity = sweep_outputs.compute_statistic(
        data = sweep_outputs.get_sample('intensity'),
        sample_name = 'intensity',
        statistic = 'signedmean',
        dim = ['id']
    )
    intensity_srmse = srmse(
        prediction = mean_intensity,
        ground_truth = outputs.get(0).get_sample('ground_truth_table').astype('float32')
    )
    # Create a data row
    datum = dict(zip(
        outputs.config.sweep_param_names,
        mean_intensity['sweep'].values[0]
    ))
    print('sweep',{k:v for k,v in datum.items() if k not in ['covariance','to_learn']})
    datum['intensity_srmse'] = intensity_srmse.values[0]
    
    try:
        mean_table = sweep_outputs.compute_statistic(
            data = sweep_outputs.get_sample('table'),
            sample_name = 'table',
            statistic = 'mean',
            dim = ['id']
        )
        table_srmse = srmse(
            prediction = mean_table,
            ground_truth = outputs.get(0).get_sample('ground_truth_table').astype('float32')
        )
        datum['table_srmse'] = table_srmse.values[0]

        
    except:
        pass
    data.append(datum)
    print()

In [None]:
df = pd.DataFrame.from_records(data)
df.drop(columns=['covariance','to_learn','axes','cells'],inplace=True)

In [None]:
df

In [None]:
root_path = '/home/iz230/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04/samples/'
relative_path = os.path.relpath(root_path,os.getcwd())

In [None]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [],
        # "da.loss_name.isin([str(['dest_attraction_ts_likelihood_loss']),str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])])",
    "metadata_keys":[],
    "burnin_thinning_trimming": [{'iter': {"burnin":10000, "thinning":100, "trimming":1000}}],
    "n_workers": 1,
    "filename_ending":"test",
    "sample":["table","intensity"],
    "force_reload":True
}

In [None]:
current_sweep_outputs = Outputs(
    config = '../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04/samples/sigma_high/title__total_intensity_row_table_constrained/',
    settings = settings,
    inputs = None,
    slice = True
)
current_sweep_outputs.load()