In [1]:
import h5py as h5
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy


from multiresticodm.config import Config
from multiresticodm.inputs import Inputs
from multiresticodm.outputs import Outputs
from multiresticodm.utils.misc_utils import *
from multiresticodm.utils.math_utils import *
from multiresticodm.utils.probability_utils import *
from multiresticodm.contingency_table import instantiate_ct
from multiresticodm.markov_basis import instantiate_markov_basis

In [2]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [None]:
def validate_tables(out):
    if out.inputs is None:
        out.inputs = Inputs(
            config = out.config,
            synthetic_data = False,
            logger = out.logger
        )
    else:
        try:
            out.inputs.cast_from_xarray()
        except:
            pass
    ct = instantiate_ct(
        config = out.config,
        **out.inputs.data_vars(),
        level = 'EMPTY'
    )
    samples = out.get_sample('table')
    print(dict(samples.sizes))
    print('axes constraints',ct.constraints['constrained_axes'])
    print('cell constraints',len(ct.constraints['cells']))
    tables_admissible = all([ct.table_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')])
    print('Tables admissible',tables_admissible)
    if not tables_admissible:
        print('Tables margins admissible',any([ct.table_margins_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))
        print('Tables cells admissible',all([ct.table_cells_admissible(torch.tensor(tab.values.squeeze())) for _,tab in samples.groupby('id')]))

In [None]:
# Get important paths
experiment_id = 'NonJointTableSIM_NN_SweepedNoise_01_02_2024_14_38_56'
# 'NonJointTableSIM_NN_SweepedNoise_26_01_2024_13_26_18'
# 'NonJointTableSIM_NN_SweepedNoise_30_01_2024_23_25_12'
# 'NonJointTableSIM_NN_SweepedNoise_01_02_2024_00_45_29'
# 'JointTableSIM_NN_SweepedNoise_23_01_2024_21_33_25'
# 'JointTableSIM_NN_SweepedNoise_30_01_2024_21_29_31'
experiment_dir = f'../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/{experiment_id}/'
relative_experiment_dir = os.path.relpath(experiment_dir,os.getcwd())

In [None]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    # "coordinate_slice": [
        # "da.loss_name == str(['dest_attraction_ts_likelihood_loss'])",
        # "~da.title.isin([str('_unconstrained'), str('_total_constrained')])"
    # ],
    "metadata_keys":[],
    "burnin_thinning_trimming": [],
    # {'iter': {"burnin":10000, "thinning":90, "trimming":1000}}
    "n_workers": 1,
    "group_by":[],
    "filename_ending":"test",
    "sample":["table"],
    "force_reload":True
}

In [None]:
sweep_id = 'loss_name_[dest_attraction_ts_likelihood_loss]/seed_1/'
sigmas = ['low']#,'high','learned']
titles = ['_doubly_constrained']#,'_doubly_10%_cell_constrained','_doubly_20%_cell_constrained']
progress = tqdm(
    total = len(sigmas)*len(titles),
    desc = 'Loading sweep data'
)
for sig in sigmas:
    for titl in titles:
        # Create current sweep id
        current_sweep_id = os.path.join('samples',sweep_id,f"sigma_{sig}/title_{titl}/")
        
        print(f"sigma_{sig}/title_{titl}/")
        # Initialise outputs
        current_sweep_outputs = Outputs(
            config = os.path.join(relative_experiment_dir,current_sweep_id),
            settings = settings,
            inputs = None,
            slice = True,
            level = 'INFO'
        )
        # Silence outputs
        current_sweep_outputs.logger.setLevels(console_level='EMPTY')
        # Load all data
        current_sweep_outputs.load()
        # Get first collection id
        current_sweep_outputs0 = current_sweep_outputs.get(0)
        # Validate tables
        # validate_tables(current_sweep_outputs0)
        # print('SRMSE',srmse(current_sweep_outputs0.data.table.mean('id'),current_sweep_outputs.inputs.data.ground_truth_table).values.squeeze())
        # print(np.where(np.isnan(current_sweep_outputs0.data.table.values.squeeze())))
        # print('\n')
        # break
    # break

In [None]:
# Get important paths
experiment_id = 'NonJointTableSIM_NN_SweepedNoise_01_02_2024_14_38_56'
# 'NonJointTableSIM_NN_SweepedNoise_26_01_2024_13_26_18'
# 'NonJointTableSIM_NN_SweepedNoise_30_01_2024_23_25_12'
# 'NonJointTableSIM_NN_SweepedNoise_01_02_2024_00_45_29'
# 'JointTableSIM_NN_SweepedNoise_23_01_2024_21_33_25'
# 'JointTableSIM_NN_SweepedNoise_30_01_2024_21_29_31'
experiment_dir = f'../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp2/{experiment_id}/'
relative_experiment_dir = os.path.relpath(experiment_dir,os.getcwd())

In [None]:
# Initialise outputs
current_sweep_outputs = Outputs(
    config = relative_experiment_dir,
    settings = settings,
    inputs = None,
    slice = True,
    level = 'INFO'
)
# Silence outputs
current_sweep_outputs.logger.setLevels(console_level='EMPTY')
# Load all data
current_sweep_outputs.load()
# Get first collection id
# current_sweep_outputs0 = current_sweep_outputs.get(0)
# Validate tables
# validate_tables(current_sweep_outputs0)

# Temporary scripts

In [None]:
experiment_id = 'JointTableSIM_NN_SweepedNoise_01_02_2024_12_01_32'
experiment_dir = f'../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/{experiment_id}/'
relative_experiment_dir = os.path.relpath(experiment_dir,os.getcwd())

In [None]:
# for d in os.walk(relative_experiment_dir):
#     if 'doubly' in d[0]:
#         for f in ['data.h5','metadata.json','outputs.log']:
#             if os.path.exists(os.path.join(d[0],f)):
#                 os.remove(os.path.join(d[0],f))
#         if os.path.exists(d[0]):
#             os.rmdir(d[0])

In [None]:
# for sweep_config in tqdm(current_sweep_outputs.config.sweep_configurations,total=len(current_sweep_outputs.config.sweep_configurations)):
#     new_config,sweep = current_sweep_outputs.config.prepare_experiment_config(sweep_config)
#     sweep_id = current_sweep_outputs.config.get_sweep_id(sweep = sweep)
#     filepath = os.path.join(relative_experiment_dir,'samples',sweep_id,'metadata.json')
#     metadata = read_json(filepath = filepath)
#     metadata['neural_network']['loss'] = new_config['neural_network']['loss']
#     write_json(data = metadata, filepath = filepath)


In [None]:
for sweep_config in tqdm(
    current_sweep_outputs.config.sweep_configurations,
    total = len(current_sweep_outputs.config.sweep_configurations),
    desc = 'Looping sweep configurations'
):
    new_config,sweep = current_sweep_outputs.config.prepare_experiment_config(sweep_config)
    sweep_id = current_sweep_outputs.config.get_sweep_id(sweep = sweep)
    print_json(sweep,indent=1)

In [None]:
added_key = 'proposal'

for sweep_config in tqdm(
    current_sweep_outputs.config.sweep_configurations,
    total = len(current_sweep_outputs.config.sweep_configurations),
    desc = 'Modifying h5 data sweep values'
):
    new_config,sweep = current_sweep_outputs.config.prepare_experiment_config(sweep_config)
    sweep_id = current_sweep_outputs.config.get_sweep_id(sweep = sweep)
    filepath = os.path.join(relative_experiment_dir,'samples',sweep_id,'data.h5')

    with h5.File(filepath,'r') as h5data:
        # Get existing attributes
        existing_sweep_params = deepcopy(list(h5data[experiment_id].attrs['sweep_params']))
        existing_sweep_values = deepcopy(list(h5data[experiment_id].attrs['sweep_values']))
        # Construct new attributes
        new_sweep_params = existing_sweep_params + [added_key]
        new_sweep_values = existing_sweep_values + [sweep[added_key]]
        # print(new_sweep_params)
        # print(new_sweep_values)
        # Delete existing attributes
        # del h5data[experiment_id].attrs['sweep_params']
        # del h5data[experiment_id].attrs['sweep_values']
        # Create new attributes
        # h5data[experiment_id].attrs['sweep_params'] = new_sweep_params
        # h5data[experiment_id].attrs['sweep_values'] = new_sweep_values

In [8]:
temp = [
    56.45505393831983,
    62.41335433232076,
    241.03685154496443,
    36.79460277031524,
    53.872603970990475,
    51.40828422174658,
    56.07868831067209,
    36.79460277031524,
    4.047070601109788,
    4.047070601109788,
    1.0835769491470515,
    1.0835769491470515,
    70.87158413195424,
    1.544627233920827,
    259.44076952471994,
    1.61867085267597,
    1.730646405081271,
    1.5654299819682425,
    4.241071940774812,
    1.730646405081271,
    4.212798232903896,
    1.263788192393704,
    1.263788192393704,
    219.48948956238047,
    216.57272892599536,
    252.59125143498346,
    218.02623173622607,
    4.688677227687005,
    218.02623173622607,
    216.57272892599536,
    4.074232047156511,
    4.688677227687005,
    4.074232047156511,
    1.0981703589795813,
    1.0981703589795813,
    3.564069500191919,
    3.1597776848892845,
    250.9073160164024,
    3.587989340238615,
    1.6404708062054956,
    3.0970027303710035,
    3.1597776848892845,
    4.269535404364509,
    1.6404708062054956,
    4.269535404364509,
    1.1431404597854,
    1.1431404597854,
    45.8826457125557,
    47.126845620619456,
    250.9073160164024,
    252.59125143498346,
    46.81266781732768,
    104.46263330578574,
    1.5971605952337455,
    106.58004734301677,
    4.269535404364509,
    4.269535404364509,
    47.126845620619456,
    47.126845620619456,
    1.5971605952337455,
    4.074232047156511,
    4.074232047156511,
    1.0981703589795813,
    1.1431404597854,
    1.0981703589795813,
    1.1431404597854,
    4.688677227687005,
    4.688677227687005,
    1.6404708062054956,
    1.6404708062054956,
    166.84360159035322,
    164.6264437185948,
    157.09585921738943,
    179.582632736149,
    171.36790027749134,
    18.598425434504914,
    179.582632736149,
    94.49009207894893,
    90.16778755208009,
    142.09867904843043,
    142.09867904843043,
    159.2115965984378,
    166.84360159035322,
    180.7878807959811,
    178.3854196335239,
    136.5086468992917,
    16.822923912760512,
    166.84360159035322,
    178.3854196335239,
    90.16778755208009,
    91.38215029620801,
    138.3471195905516,
    138.3471195905516,
    92.61286784857025,
    87.78726003097803,
    161.3558283312015,
    157.09585921738943,
    144.01243723759012,
    12.617907774208758,
    157.09585921738943,
    91.99545102395501,
    90.16778755208009,
    136.5086468992917,
    136.5086468992917,
    164.6264437185948,
    186.93654557042277,
    184.45237578758378,
    163.52893817223134,
    25.639833985080045,
    22.279766388892316,
    186.93654557042277,
    163.52893817223134,
    98.35946346672688,
    90.16778755208009,
    58.766889662025896,
    58.766889662025896,
    138.3471195905516,
    169.09061972585334,
    162.43874930232087,
    172.51801611096633,
    18.848905517916737,
    16.161125462946675,
    169.09061972585334,
    172.51801611096633,
    90.16778755208009,
    87.20201380559797,
    45.27291858997497,
    45.27291858997497,
    123.47682809535691,
    134.69460537686658,
    148.91035515125563,
    157.09585921738943,
    16.161125462946675,
    117.82856972987608,
    13.04704762980551,
    148.91035515125563,
    86.62066920722167,
    93.23442839476174,
    39.33996025539529,
    39.33996025539529,
    160.28012680639486,
    179.582632736149,
    182.00121773870583,
    163.52893817223134,
    23.66221799500506,
    20.97808373029105,
    163.52893817223134,
    98.35946346672688,
    90.16778755208009,
    58.766889662025896,
    58.766889662025896,
    163.52893817223134,
    159.2115965984378,
    173.67585081377587,
    172.51801611096633,
    19.752451139456557,
    18.848905517916737,
    172.51801611096633,
    91.38215029620801,
    87.20201380559797,
    45.27291858997497,
    45.27291858997497,
    105.86951633445906,
    89.56667120220071,
    157.09585921738943,
    148.91035515125563,
    16.488704740146463,
    13.67247401059599,
    89.56667120220071,
    148.91035515125563,
    92.61286784857025,
    93.23442839476174,
    39.33996025539529,
    39.33996025539529
  ]

In [9]:
temp = np.array(temp)

temp2 = np.zeros(len(temp))
for i,j in enumerate(np.argsort(temp)):
    temp2[j] = len(temp) - i

In [10]:
temp2.tolist()

[97.0,
 92.0,
 6.0,
 115.0,
 99.0,
 100.0,
 98.0,
 114.0,
 145.0,
 146.0,
 173.0,
 174.0,
 91.0,
 162.0,
 1.0,
 158.0,
 153.0,
 161.0,
 139.0,
 152.0,
 140.0,
 163.0,
 164.0,
 7.0,
 10.0,
 3.0,
 8.0,
 133.0,
 9.0,
 11.0,
 144.0,
 134.0,
 143.0,
 172.0,
 171.0,
 148.0,
 150.0,
 4.0,
 147.0,
 157.0,
 151.0,
 149.0,
 136.0,
 154.0,
 135.0,
 166.0,
 165.0,
 105.0,
 103.0,
 5.0,
 2.0,
 104.0,
 68.0,
 159.0,
 66.0,
 138.0,
 137.0,
 101.0,
 102.0,
 160.0,
 141.0,
 142.0,
 169.0,
 167.0,
 170.0,
 168.0,
 132.0,
 131.0,
 155.0,
 156.0,
 30.0,
 34.0,
 47.0,
 18.0,
 27.0,
 123.0,
 19.0,
 71.0,
 80.0,
 55.0,
 56.0,
 43.0,
 32.0,
 16.0,
 20.0,
 62.0,
 124.0,
 31.0,
 21.0,
 81.0,
 77.0,
 58.0,
 59.0,
 74.0,
 87.0,
 41.0,
 45.0,
 54.0,
 130.0,
 46.0,
 76.0,
 83.0,
 60.0,
 61.0,
 33.0,
 12.0,
 14.0,
 36.0,
 116.0,
 118.0,
 13.0,
 39.0,
 69.0,
 79.0,
 93.0,
 96.0,
 57.0,
 28.0,
 40.0,
 26.0,
 122.0,
 126.0,
 29.0,
 24.0,
 82.0,
 89.0,
 109.0,
 108.0,
 64.0,
 63.0,
 50.0,
 48.0,
 127.0,
 65.0,
 129.0,
 