In [None]:
import h5py as h5
import numpy as np
import pandas as pd
from copy import deepcopy
from tqdm import tqdm


from multiresticodm.config import Config
from multiresticodm.inputs import Inputs
from multiresticodm.outputs import Outputs
from multiresticodm.utils.misc_utils import *
from multiresticodm.utils.math_utils import *
from multiresticodm.contingency_table import instantiate_ct
from multiresticodm.utils.probability_utils import *

In [None]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [None]:
# Get important paths
experiment_dir = '/home/iz230/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04/'
config_path = os.path.join(experiment_dir,'config.json')

In [None]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        "da.loss_name.isin([str(['dest_attraction_ts_likelihood_loss']),str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])])",
        "da.title == '_doubly_10%_cell_constrained'"
    ],
    "metadata_keys":[],
    "burnin_thinning_trimming": [{'iter': {"burnin":0, "thinning":1, "trimming":10000}}],
    "n_workers": 1,
    "filename_ending":"test",
    "sample":["table","intensity"],
    "force_reload":False
}

In [None]:
# Read config
conf = Config(
    path = config_path
)
conf.get_sweep_data()

In [None]:
outputs = Outputs(
    config = conf,
    settings = settings,
    base_dir = experiment_dir,
    console_handling_level = 'INFO',
    inputs = None,
    logger = conf.logger,
    print_slice = True
)

In [None]:
# Collect outputs from folder
outputs.load(indx = 0)

In [None]:
data = []
for i in tqdm(
    range(len(outputs.data)),
    leave=True,
    disable=True,
    desc='Computing validation metrics'
):
    print(f"{i+1}/{len(outputs.data)}")
    sweep_outputs = outputs.get(i)
    sweep_outputs.inputs = Inputs(
        config = sweep_outputs.config,
        synthetic_data = False,
        logger = outputs.logger
    )

    mean_intensity = sweep_outputs.compute_statistic(
        data = sweep_outputs.get_sample('intensity'),
        sample_name = 'intensity',
        statistic = 'signedmean',
        dim = ['id']
    )
    intensity_srmse = srmse(
        prediction = mean_intensity,
        ground_truth = outputs.get(0).get_sample('ground_truth_table').astype('float32')
    )
    # Create a data row
    datum = dict(zip(
        outputs.config.sweep_param_names,
        mean_intensity['sweep'].values[0]
    ))
    print('sweep',{k:v for k,v in datum.items() if k not in ['covariance','to_learn']})
    datum['intensity_srmse'] = intensity_srmse.values[0]
    # try:
    mean_table = sweep_outputs.compute_statistic(
        data = sweep_outputs.get_sample('table'),
        sample_name = 'table',
        statistic = 'mean',
        dim = ['id']
    )
    table_srmse = srmse(
        prediction = mean_table,
        ground_truth = outputs.get(0).get_sample('ground_truth_table').astype('float32')
    )
    datum['table_srmse'] = table_srmse.values[0]

    sweep_outputs.inputs.cast_from_xarray()
    ct = instantiate_ct(
        config = sweep_outputs.config,
        **sweep_outputs.inputs.data_vars(),
        level = 'EMPTY'
    )
    samples = sweep_outputs.get_sample('table')
    print(ct.constraints)
    tables_admissible = all([ct.table_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')])
    print('Tables admissible',tables_admissible)
    if not tables_admissible:
        print('Tables margins admissible',any([ct.table_margins_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))
        print('Tables cells admissible',all([ct.table_cells_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))
    # except:
        # pass
    data.append(datum)
    print('\n')
    break

In [None]:
df = pd.DataFrame.from_records(data)
df.drop(columns=['covariance','to_learn','axes','cells'],inplace=True)

In [None]:
df

In [None]:
root_path = '/home/iz230/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04/samples/'
relative_path = os.path.relpath(root_path,os.getcwd())
experiment_id = 'JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04'

In [None]:
joint_mcmc_files = ['exp6_JointTableSIMLatentMCMC_LowNoise_unconstrained_18_05_2023_11_16_28','exp6_JointTableSIMLatentMCMC_LowNoise_grand_total_17_05_2023_21_33_50','exp6_JointTableSIMLatentMCMC_LowNoise_row_margin_production_constrained_13_06_2023_11_52_01','exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_26_05_2023_15_50_42','exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_permuted_cells_10%_05_06_2023_10_30_49','exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_permuted_cells_20%_05_06_2023_10_30_53','exp14_JointTableSIMLatentMCMC_HighNoise_unconstrained_23_05_2023_11_55_33','exp14_JointTableSIMLatentMCMC_HighNoise_grand_total_23_05_2023_11_15_23','exp14_JointTableSIMLatentMCMC_HighNoise_row_margin_production_constrained_13_06_2023_14_03_14','exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_19_05_2023_10_55_00','exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_permuted_cells_10%_07_06_2023_09_35_18','exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_permuted_cells_20%_05_06_2023_12_40_31','exp5_SIMLatentMCMC_LowNoise_grand_total_18_05_2023_11_09_58','exp5_SIMLatentMCMC_LowNoise_row_margin_27_01_2023_18_46_59','exp5_SIMLatentMCMC_HighNoise_grand_total_23_05_2023_11_07_44','exp5_SIMLatentMCMC_HighNoise_row_margin_06_02_2023_16_54_39']

simple_mcmc_files = ['exp5_SIMLatentMCMC_LowNoise_grand_total_18_05_2023_11_09_58','exp5_SIMLatentMCMC_LowNoise_row_margin_27_01_2023_18_46_59','exp5_SIMLatentMCMC_HighNoise_grand_total_23_05_2023_11_07_44','exp5_SIMLatentMCMC_HighNoise_row_margin_06_02_2023_16_54_39']

mcmc_files = joint_mcmc_files+simple_mcmc_files

In [None]:
# for fl in tqdm(mcmc_files):
#     cmd = f"scp -r $BACKUP/ticodm_mcmc_outputs/cambridge_work_commuter_lsoas_to_msoas/{fl} ~/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/mcmc_samples"
#     os.system(cmd)

In [None]:
def get_old_experiment_id(exp_id,sweep_param):
    if 'JointTableSIM_MCMC' in exp_id:
        if sweep_param['sigma'] == 'low' and sweep_param['title'] == '_unconstrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_unconstrained_18_05_2023_11_16_28'
        if sweep_param['sigma'] == 'low' and sweep_param['title'] == '_total_constrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_grand_total_17_05_2023_21_33_50'
        elif sweep_param['sigma'] == 'low' and sweep_param['title'] == '_row_constrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_row_margin_production_constrained_13_06_2023_11_52_01'
        elif sweep_param['sigma'] == 'low' and sweep_param['title'] == '_doubly_constrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_26_05_2023_15_50_42'
        elif sweep_param['sigma'] == 'low' and sweep_param['title'] == '_doubly_10%_cell_constrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_permuted_cells_10%_05_06_2023_10_30_49'
        elif sweep_param['sigma'] == 'low' and sweep_param['title'] == '_doubly_20%_cell_constrained':
            return 'exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_permuted_cells_20%_05_06_2023_10_30_53'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_unconstrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_unconstrained_23_05_2023_11_55_33'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_total_constrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_grand_total_23_05_2023_11_15_23'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_row_constrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_row_margin_production_constrained_13_06_2023_14_03_14'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_doubly_constrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_19_05_2023_10_55_00'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_doubly_10%_cell_constrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_permuted_cells_10%_07_06_2023_09_35_18'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_doubly_20%_cell_constrained':
            return 'exp14_JointTableSIMLatentMCMC_HighNoise_both_margins_permuted_cells_20%_05_06_2023_12_40_31'

    elif 'SIM_MCMC' in exp_id:
        if sweep_param['sigma'] == 'low' and sweep_param['title'] == '_total_constrained':
            return 'exp5_SIMLatentMCMC_LowNoise_grand_total_18_05_2023_11_09_58'
        elif sweep_param['sigma'] == 'low' and sweep_param['title'] == '_row_constrained':
            return 'exp5_SIMLatentMCMC_LowNoise_row_margin_27_01_2023_18_46_59'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_total_constrained':
            return 'exp5_SIMLatentMCMC_HighNoise_grand_total_23_05_2023_11_07_44'
        elif sweep_param['sigma'] == 'high' and sweep_param['title'] == '_row_constrained':
            return 'exp5_SIMLatentMCMC_HighNoise_row_margin_06_02_2023_16_54_39'

def sample_to_shape(sam_name):
    if sam_name == 'table':
        return (100000,69,13)
    elif sam_name == 'log_destination_attraction':
        return (100000,13,1)
    elif sam_name == 'sign':
        return (100000,)
    else:
        return (100000,2)

def sample_to_dtype(sam_name):
    if sam_name == 'table':
        return 'int32'
    elif sam_name == 'log_destination_attraction':
        return 'float32'
    elif sam_name == 'sign':
        return 'uint8'
    else:
        return 'float32'


In [None]:
olddata = {}
for fl in mcmc_files:
    olddata[fl] = {}
    root_path = f'/home/iz230/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/mcmc_samples/{fl}/samples/'
    relative_path = os.path.relpath(root_path,os.getcwd())
    print(relative_path)
    if 'JointTableSIMLatentMCMC' in fl:
        samples_pool = ['sign','log_destination_attraction','table','theta']
    else:
        samples_pool = ['sign','log_destination_attraction','theta']

    for sample_name in samples_pool:
        filenames = [f"{sample_name}_{('batch_'+str(j)+'_') if j != '' else ''}samples.npy" for j in ['']+list(range(100))]

        old_samples = np.array([np.load(os.path.join(relative_path,fname)) for fname in filenames if os.path.isfile(os.path.join(relative_path,fname))])
        old_samples = old_samples.reshape(sample_to_shape(sample_name))
        old_samples = old_samples.astype(sample_to_dtype(sample_name))
        
        if sample_name != 'theta':
            olddata[fl][sample_name] = old_samples
            print(sample_name,old_samples.shape)
            print([fname for fname in filenames if os.path.isfile(os.path.join(relative_path,fname))][0])
            print('\n')
        else:
            for i,param in enumerate(['alpha','beta']):
                olddata[fl][param] = old_samples[:,i]
                print(param,old_samples[:,i].shape)
                print('\n')
        

In [None]:
# tables_new = samples.values.squeeze().astype('int32')
tables_new[0]

In [None]:
tables_old = olddata['exp6_JointTableSIMLatentMCMC_LowNoise_both_margins_permuted_cells_10%_05_06_2023_10_30_49']['table'].squeeze()

In [None]:
for experid in ['JointTableSIM_MCMC_SweepedNoise_16_05_2023_20_09_04','SIM_MCMC_SweepedNoise_16_05_2023_20_09_04']:
    root_path = f'/home/iz230/MultiResTICODM/data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/{experid}/samples/'
    relative_path = os.path.relpath(root_path,os.getcwd())
    for sigma in ['high','low']:
        for title in ['_doubly_10%_cell_constrained','_doubly_20%_cell_constrained','_doubly_constrained','_row_constrained','_total_constrained','_unconstrained']:
            filepath = os.path.join(relative_path,f"sigma_{sigma}",f"title_{title}","data.h5")
            if os.path.exists(filepath):
                with h5.File(filepath,'r+') as h5data:
                    sweep_values = h5data[experid].attrs['sweep_values']
                    # if not cells.startswith('constraints/') and len(cells) > 0:
                    #     cells = 'constraints/'+cells
                    #     sweep_values[cells_index] = cells
                    #     h5data[experiment_id].attrs.modify('sweep_values',sweep_values)
                    # elif cells == 'constraints/':
                    #     sweep_values[cells_index] = ''
                    #     h5data[experiment_id].attrs.modify('sweep_values',sweep_values)
                    print(h5data[experid].attrs['sweep_values'])

                    old_experid = get_old_experiment_id(experid,{"sigma":sigma,"title":title})
                    
                    for sample_name,sample_data in olddata[old_experid].items():
                        print(sample_name)
                        current_dataset = h5data[experid][sample_name]
                        current_dataset[...] = sample_data
                    print('\n')