In [8]:
import os
import glob
import optuna
import warnings
import h5py as h5
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from copy import deepcopy

from gensit.config import Config
from gensit.inputs import Inputs
from gensit.outputs import Outputs
from gensit.utils.misc_utils import *
from gensit.utils.math_utils import *
from gensit.utils.probability_utils import *
from gensit.contingency_table import instantiate_ct
from gensit.contingency_table.MarkovBasis import instantiate_markov_basis
from gensit.static.plot_variables import LATEX_RC_PARAMETERS, COLOR_NAMES

In [9]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

## Import samples

In [3]:
# Specify experiment id
experiment_id = "JointTableSIM_NN_SweepedNoise__totally_and_cell_constrained_21_05_2024_13_25_40"
# Specify experiment group id
experiment_group_id = 'exp1/'
dataset = 'DC'
experiment_dir = f'../../data/outputs/{dataset}/{experiment_group_id}/{experiment_id}/'
relative_experiment_dir = os.path.relpath(experiment_dir,os.getcwd())

In [4]:
# Create new logging object
logger = setup_logger(
    __name__,
    console_level = 'INFO',
    file_level = 'EMPTY'
)

In [5]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        "da.loss_name==str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])",
        "da.sigma==0.14142"
    ],
    "slice":True,
    "metadata_keys":[],
    "burnin_thinning_trimming": [{'iter': {"burnin":10000, "thinning":9, "trimming":10000}}],
    "sample":["table"],
    "group_by":[],
    "filename_ending":"test",
    "force_reload":False,
    "n_workers": 1
}

In [6]:
# Initialise outputs
jointgensit_outputs = Outputs(
    config = os.path.join(relative_experiment_dir,"config.json"),
    settings = settings,
    inputs = None,
    slice = True,
    level = 'NOTE'
)
# Silence outputs
jointgensit_outputs.logger.setLevels(console_level='NOTE')
# Load all data
jointgensit_outputs.load()

# Get data from first sweep of the SIM_NN experiment
jointgensit_outputs = jointgensit_outputs.get(0)

42:29.176 config INFO ----------------------------------------------------------------------------------
42:29.186 config INFO Parameter space size: 
 --- sigma: ['sigma', 'to_learn'] (3)
 --- loss_name: ['loss_name', 'loss_function', 'loss_kwargs'] (2)
42:29.196 config INFO Total = 6.
42:29.206 config INFO ----------------------------------------------------------------------------------
42:29.372 outputs INFO //////////////////////////////////////////////////////////////////////////////////
42:29.382 outputs INFO Slicing coordinates:
42:29.392 outputs INFO loss_name==str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])
42:29.402 outputs INFO sigma==0.14142
42:29.413 outputs INFO iter: burnin = 10000, thinning = 9, trimming = 10000
42:29.423 outputs INFO //////////////////////////////////////////////////////////////////////////////////
42:29.433 outputs INFO Reading samples table.
42:30.207 outputs PROGRESS Slicing table
42:30.217 outputs PROGRESS Before coordinate sli

In [7]:
jointgensit_outputs.data.table.coords

Coordinates:
  * iter           (iter) int32 10001 10010 10019 10028 ... 99974 99983 99992
  * origin         (origin) int16 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179
  * destination    (destination) int16 1 2 3 4 5 6 7 ... 174 175 176 177 178 179
  * sweep          (sweep) object MultiIndex
  * sigma          (sweep) float32 0.1414
  * to_learn       (sweep) <U17 "['alpha', 'beta']"
  * loss_name      (sweep) <U63 "['dest_attraction_ts_likelihood_loss', 'tabl...
  * loss_function  (sweep) <U20 "['custom', 'custom']"
  * loss_kwargs    (sweep) <U26 "{'noise_percentage': None}"

In [8]:
inputs = Inputs(
    config = jointgensit_outputs.config,
    synthetic_data = False,
    logger = jointgensit_outputs.logger
)
inputs.cast_to_xarray()
ground_truth_table = inputs.data.ground_truth_table

42:35.988 inputs NOTE Loading Harris Wilson data ...
42:36.188 inputs NOTE Margins not provided
42:36.205 inputs NOTE Cells subset values file not provided


In [9]:
jointgensit_table_mean = jointgensit_outputs.data.table.mean(['iter','sweep'],dtype='float64')

In [10]:
jointgensit_relative_colsum_l1_error = (jointgensit_table_mean-ground_truth_table).where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')
jointgensit_relative_colsum_l1_error /= ground_truth_table.where(inputs.data.test_cells_mask,drop=True).sum('origin')

# jointgensit_relative_colsum_l1_error = (jointgensit_table_mean-ground_truth_table)/ground_truth_table.where(ground_truth_table!=0,1)
# jointgensit_relative_colsum_l1_error = jointgensit_relative_colsum_l1_error.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

In [11]:
abs(jointgensit_relative_colsum_l1_error).sum(skipna=True).values,jointgensit_relative_colsum_l1_error.sum(skipna=True).values

(array(223.05628674), array(215.8073646))

In [34]:
jointgensit_relative_colsum_l1_error

In [35]:
srmse(
    prediction = jointgensit_table_mean,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
)

In [36]:
# Output processing settings
gmel_settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        "da.seed==1"
    ],
    "slice":True,
    "metadata_keys":[],
    "burnin_thinning_trimming": [],
    "sample":["intensity"],
    "group_by":[],
    "filename_ending":"test",
    "force_reload":False,
    "n_workers": 1
}

# Initialise outputs
gmel_outputs = Outputs(
    config = f'../../data/outputs/{dataset}/comparisons/GraphAttentionNetwork_Comparison_UnsetNoise__doubly_and_cell_constrained_all_region_features_16_05_2024_21_06_14/config.json',
    settings = gmel_settings,
    inputs = None,
    slice = True,
    level = 'NOTE'
)
# Silence outputs
gmel_outputs.logger.setLevels(console_level='NOTE')
# Load all data
gmel_outputs.load()

# Get data from first sweep of the experiment
gmel_outputs = gmel_outputs.get(0)

55:12.015 config INFO ----------------------------------------------------------------------------------
55:12.032 config INFO Parameter space size: 
 --- seed (10)
55:12.048 config INFO Total = 10.
55:12.064 config INFO ----------------------------------------------------------------------------------
55:12.105 outputs INFO //////////////////////////////////////////////////////////////////////////////////
55:12.121 outputs INFO Slicing coordinates:
55:12.137 outputs INFO seed==1
55:12.154 outputs INFO //////////////////////////////////////////////////////////////////////////////////
55:12.170 outputs INFO Reading samples intensity.
55:19.615 outputs PROGRESS Slicing intensity
55:19.626 outputs PROGRESS Before coordinate slicing intensity[0]: {'seed': 10, 'iter': 10000, 'origin': 179, 'destination': 179}
55:24.274 outputs PROGRESS After coordinate slicing intensity[0]: {'iter': 10000, 'origin': 179, 'destination': 179}
55:24.284 outputs PROGRESS After index slicing intensity: {'iter': 

In [37]:
gmel_outputs.data.intensity.coords

Coordinates:
  * seed         (seed) int32 1
  * iter         (iter) int32 1 2 3 4 5 6 7 ... 9995 9996 9997 9998 9999 10000
  * origin       (origin) int16 1 2 3 4 5 6 7 8 ... 173 174 175 176 177 178 179
  * destination  (destination) int16 1 2 3 4 5 6 7 ... 174 175 176 177 178 179

In [38]:
gmel_intensity_mean = gmel_outputs.data.intensity.mean(['iter','seed'],dtype='float64')

In [None]:
gmel_relative_colsum_l1_error = (gmel_intensity_mean-ground_truth_table).where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')
gmel_relative_colsum_l1_error /= ground_truth_table.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

# gmel_relative_colsum_l1_error = (gmel_intensity_mean-ground_truth_table)/ground_truth_table.where(ground_truth_table!=0,1)
# gmel_relative_colsum_l1_error = gmel_relative_colsum_l1_error.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

In [40]:
abs(gmel_relative_colsum_l1_error).sum().values,gmel_relative_colsum_l1_error.sum().values

(array(21.14783516), array(-13.09536224))

In [41]:
gmel_relative_colsum_l1_error

In [42]:
gmel_outputs.data.intensity.groupby('seed').mean('iter',dtype='float64').groupby('seed').map(
    srmse,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
).mean('seed')

In [43]:
srmse(
    prediction = gmel_intensity_mean,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
)

# GMEL vs GeNSIT (Joint)

In [131]:
sum_dim = 'destination'

ground_truth = inputs.data.ground_truth_table.where(inputs.data.test_cells_mask,drop=True)

jointgensit_prediction = jointgensit_table_mean.where(inputs.data.test_cells_mask,drop=True)
# jointgensit_quantity = ((jointgensit_prediction - ground_truth)).sum(dim=sum_dim,dtype='float64',skipna=True)
jointgensit_quantity = (jointgensit_prediction-ground_truth_table).sum(sum_dim,dtype='float64')
jointgensit_quantity /= ground_truth.sum(sum_dim,dtype='float64')

gmel_prediction = gmel_intensity_mean.where(inputs.data.test_cells_mask,drop=True)
# gmel_quantity = ((gmel_prediction - ground_truth)).sum(dim=sum_dim,dtype='float64',skipna=True)
gmel_quantity = (gmel_prediction-ground_truth_table).sum(sum_dim,dtype='float64')
gmel_quantity /= ground_truth.sum(sum_dim,dtype='float64')

In [132]:
print(f"Joint GeNSIT min error: {np.min(jointgensit_quantity.values)} max error: {np.max(jointgensit_quantity.values)} total abs error: {np.sum(abs(jointgensit_quantity.values))}")
print(f"GMEL min error: {np.min(gmel_quantity.values)} max error: {np.max(gmel_quantity.values)} total abs error: {np.sum(abs(gmel_quantity.values))}")

Joint GeNSIT min error: -0.8706719271623672 max error: 8.526016666666665 total abs error: 91.94442282389474
GMEL min error: -0.9747560113550354 max error: 20.20517249229849 total abs error: 167.89391226906542


In [133]:
print('Joint GeNSIT lower error locations:',sum([1 if abs(gmel_quantity.values[i]) >= abs(jointgensit_quantity.values[i]) else 0 for i in range(len(gmel_quantity.values))]))
print('GMEL lower error locations:',sum([1 if abs(gmel_quantity.values[i]) < abs(jointgensit_quantity.values[i]) else 0 for i in range(len(gmel_quantity.values))]))

Joint GeNSIT lower error locations: 160
GMEL lower error locations: 19


In [116]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Create the heatmap
# plt.figure(figsize=(20, 20))
# sns.heatmap(inputs.data.ground_truth_table.where(inputs.data.test_cells_mask), cmap="viridis", linewidths=0.5, linecolor="gray", cbar=True)

# # Show the plot
# plt.show()


# Colourbar

In [11]:
import matplotlib.colors as mcolors

# Create a figure and axis
fig, ax = plt.subplots(figsize=(15,20))

# Remove axis visibility
ax.set_visible(False)

# Add the colorbar
cbar = fig.colorbar(
    plt.cm.ScalarMappable(
        cmap = 'bwr_r',
        norm = mcolors.TwoSlopeNorm(vmin=-1.0, vcenter=0.0, vmax=1.0)
    ),
    ax=ax,
    orientation='horizontal'
)

cbar.ax.tick_params(labelsize=16)

write_figure(
    figure=fig,
    filepath="../../data/outputs/DC/exp1/paper_figures/colourbar/spatial_residual_colourbar",
    figure_format="pdf"
)