In [1]:
import os
import glob
import optuna
import warnings
import h5py as h5
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from copy import deepcopy

from gensit.config import Config
from gensit.inputs import Inputs
from gensit.outputs import Outputs
from gensit.utils.misc_utils import *
from gensit.utils.math_utils import *
from gensit.utils.probability_utils import *
from gensit.contingency_table import instantiate_ct
from gensit.contingency_table.MarkovBasis import instantiate_markov_basis
from gensit.static.plot_variables import LATEX_RC_PARAMETERS, COLOR_NAMES

In [2]:
%matplotlib inline

# AUTO RELOAD EXTERNAL MODULES
%load_ext autoreload
%autoreload 2

In [3]:
# LaTeX font configuration
mpl.rcParams.update(LATEX_RC_PARAMETERS)

In [4]:
# Create new logging object
logger = setup_logger(
    __name__,
    console_level = 'INFO',
    file_level = 'EMPTY'
)

# GeNSIT (Joint)

In [5]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        "da.loss_name==str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])",
        "da.sigma==0.14142"
    ],
    "slice":True,
    "metadata_keys":[],
    "burnin_thinning_trimming": [],#{'iter': {"burnin":10000, "thinning":9, "trimming":10000}}
    "sample":["table"],
    "group_by":[],
    "filename_ending":"test",
    "force_reload":False,
    "n_workers": 1
}
# Initialise outputs
jointgensit_outputs = Outputs(
    config = f'../../data/outputs/DC/exp1/JointTableSIM_NN_SweepedNoise__totally_and_cell_constrained_21_05_2024_13_25_40/config.json',
    settings = settings,
    inputs = None,
    slice = True,
    level = 'NOTE'
)
# Silence outputs
jointgensit_outputs.logger.setLevels(console_level='NOTE')
# Load all data
jointgensit_outputs.load()

# Get data from first sweep of the SIM_NN experiment
jointgensit_outputs = jointgensit_outputs.get(0)

35:30.460 config INFO ----------------------------------------------------------------------------------
35:30.477 config INFO Parameter space size: 
 --- sigma: ['sigma', 'to_learn'] (3)
 --- loss_name: ['loss_name', 'loss_function', 'loss_kwargs'] (2)
35:30.493 config INFO Total = 6.
35:30.509 config INFO ----------------------------------------------------------------------------------
35:30.728 outputs INFO //////////////////////////////////////////////////////////////////////////////////
35:30.743 outputs INFO Slicing coordinates:
35:30.758 outputs INFO loss_name==str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss'])
35:30.773 outputs INFO sigma==0.14142
35:30.789 outputs INFO //////////////////////////////////////////////////////////////////////////////////
35:30.807 outputs INFO Reading samples table.
35:43.070 outputs PROGRESS Slicing table
35:43.080 outputs PROGRESS Before coordinate slicing table[0]: {'iter': 100000, 'origin': 179, 'destination': 179}
36:04.955

In [8]:
jointgensit_outputs.data.table.coords

Coordinates:
  * iter           (iter) int32 1 2 3 4 5 6 ... 99996 99997 99998 99999 100000
  * origin         (origin) int16 1 2 3 4 5 6 7 ... 173 174 175 176 177 178 179
  * destination    (destination) int16 1 2 3 4 5 6 7 ... 174 175 176 177 178 179
  * sweep          (sweep) object MultiIndex
  * sigma          (sweep) float32 0.1414
  * to_learn       (sweep) <U17 "['alpha', 'beta']"
  * loss_name      (sweep) <U63 "['dest_attraction_ts_likelihood_loss', 'tabl...
  * loss_function  (sweep) <U20 "['custom', 'custom']"
  * loss_kwargs    (sweep) <U26 "{'noise_percentage': None}"

In [9]:
inputs = Inputs(
    config = jointgensit_outputs.config,
    synthetic_data = False,
    logger = jointgensit_outputs.logger
)
inputs.cast_to_xarray()
ground_truth_table = inputs.data.ground_truth_table

37:20.742 inputs NOTE Loading Harris Wilson data ...


37:20.844 inputs NOTE Margins not provided
37:20.861 inputs NOTE Cells subset values file not provided


In [None]:
ground_truth_table.sum()

: 

In [None]:
jointgensit_table_mean = jointgensit_outputs.data.table.mean(['iter'],dtype='float64').squeeze('sweep')

In [None]:
jointgensit_relative_colsum_l1_error = abs(jointgensit_table_mean-ground_truth_table).where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')
jointgensit_relative_colsum_l1_error /= ground_truth_table.where(inputs.data.test_cells_mask,drop=True).sum('origin')

# jointgensit_relative_colsum_l1_error = (jointgensit_table_mean-ground_truth_table)/ground_truth_table.where(ground_truth_table!=0,1)
# jointgensit_relative_colsum_l1_error = jointgensit_relative_colsum_l1_error.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

In [None]:
abs(jointgensit_relative_colsum_l1_error).sum(skipna=True).values,jointgensit_relative_colsum_l1_error.sum(skipna=True).values

In [None]:
srmse(
    prediction = jointgensit_table_mean,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
).values

# SIM_NN

In [None]:
# Output processing settings
settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        "da.sigma==0.14142"
    ],
    "slice":True,
    "metadata_keys":[],
    "burnin_thinning_trimming": [],#{'iter': {"burnin":10000, "thinning":9, "trimming":10000}}
    "sample":["intensity"],
    "group_by":[],
    "filename_ending":"test",
    "force_reload":True,
    "n_workers": 1
}
# Initialise outputs
sim_nn_outputs = Outputs(
    config = f'../../data/outputs/DC/exp1/SIM_NN_SweepedNoise__totally_and_cell_constrained_20_05_2024_15_59_08/config.json',
    settings = settings,
    inputs = None,
    slice = True,
    level = 'NOTE'
)
# Silence outputs
sim_nn_outputs.logger.setLevels(console_level='NOTE')
# Load all data
sim_nn_outputs.load()

# Get data from first sweep of the SIM_NN experiment
sim_nn_outputs = sim_nn_outputs.get(0)

In [None]:
sim_nn_intensity = sim_nn_outputs.get_sample('intensity')
sim_nn_intensity_mean = sim_nn_intensity.mean(['iter'],dtype='float64')

In [None]:
sim_nn_srmses = sim_nn_intensity_mean.data.intensity.groupby('seed').map(
    srmse,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
).values

# GMEL

In [None]:
gmel_settings = {
    "logging_mode": "INFO",
    "coordinate_slice": [
        # "da.seed==1"
    ],
    "slice":True,
    "metadata_keys":[],
    "burnin_thinning_trimming": [],
    "sample":["intensity"],
    "group_by":[],
    "filename_ending":"test",
    "force_reload":True,
    "n_workers": 1
}

# Initialise outputs
gmel_outputs = Outputs(
    config = f'../../data/outputs/DC/comparisons/GraphAttentionNetwork_Comparison_UnsetNoise__doubly_and_cell_constrained_all_region_features_16_05_2024_21_06_14/config.json',
    settings = gmel_settings,
    inputs = None,
    slice = True,
    level = 'NOTE'
)
# Silence outputs
gmel_outputs.logger.setLevels(console_level='NOTE')
# Load all data
gmel_outputs.load()

# Get data from first sweep of the experiment
gmel_outputs = gmel_outputs.get(0)

In [None]:
gmel_intensity_mean = gmel_outputs.data.intensity.mean(['iter'],dtype='float64')

In [None]:
gmel_relative_colsum_l1_error = abs(gmel_intensity_mean.mean(['seed'],dtype='float64')-ground_truth_table).where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')
gmel_relative_colsum_l1_error /= ground_truth_table.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

# gmel_relative_colsum_l1_error = (gmel_intensity_mean-ground_truth_table)/ground_truth_table.where(ground_truth_table!=0,1)
# gmel_relative_colsum_l1_error = gmel_relative_colsum_l1_error.where(inputs.data.test_cells_mask,drop=True).sum('origin',dtype='float64')

In [None]:
abs(gmel_relative_colsum_l1_error).sum().values,gmel_relative_colsum_l1_error.sum().values

In [None]:
gmel_srmses = gmel_intensity_mean.groupby('seed').map(
    srmse,
    ground_truth = inputs.data.ground_truth_table,
    mask = inputs.data.test_cells_mask
).values

# GMEL vs GeNSIT (Joint)

In [None]:
sum_dim = 'destination'

ground_truth = inputs.data.ground_truth_table.where(inputs.data.test_cells_mask,drop=True)

jointgensit_prediction = jointgensit_table_mean.where(inputs.data.test_cells_mask,drop=True)
# jointgensit_quantity = ((jointgensit_prediction - ground_truth)).sum(dim=sum_dim,dtype='float64',skipna=True)
jointgensit_quantity = (jointgensit_prediction-ground_truth_table).sum(sum_dim,dtype='float64')
jointgensit_quantity /= ground_truth.sum(sum_dim,dtype='float64')

gmel_prediction = gmel_intensity_mean.where(inputs.data.test_cells_mask,drop=True)
# gmel_quantity = ((gmel_prediction - ground_truth)).sum(dim=sum_dim,dtype='float64',skipna=True)
gmel_quantity = (gmel_prediction-ground_truth_table).sum(sum_dim,dtype='float64')
gmel_quantity /= ground_truth.sum(sum_dim,dtype='float64')

In [None]:
print(f"Joint GeNSIT min error: {np.min(jointgensit_quantity.values)} max error: {np.max(jointgensit_quantity.values)} total abs error: {np.sum(abs(jointgensit_quantity.values))}")
print(f"GMEL min error: {np.min(gmel_quantity.values)} max error: {np.max(gmel_quantity.values)} total abs error: {np.sum(abs(gmel_quantity.values))}")

In [None]:
print('Joint GeNSIT lower error locations:',sum([1 if abs(gmel_quantity.values[i]) >= abs(jointgensit_quantity.values[i]) else 0 for i in range(len(gmel_quantity.values))]))
print('GMEL lower error locations:',sum([1 if abs(gmel_quantity.values[i]) < abs(jointgensit_quantity.values[i]) else 0 for i in range(len(gmel_quantity.values))]))

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# # Create the heatmap
# plt.figure(figsize=(20, 20))
# sns.heatmap(inputs.data.ground_truth_table.where(inputs.data.test_cells_mask), cmap="viridis", linewidths=0.5, linecolor="gray", cbar=True)

# # Show the plot
# plt.show()


# Colourbar

In [None]:
import matplotlib.colors as mcolors

# Create a figure and axis
fig, ax = plt.subplots(figsize=(15,20))

# Remove axis visibility
ax.set_visible(False)

# Add the colorbar
cbar = fig.colorbar(
    plt.cm.ScalarMappable(
        cmap = 'bwr_r',
        norm = mcolors.TwoSlopeNorm(vmin=-1.0, vcenter=0.0, vmax=1.0)
    ),
    ax=ax,
    orientation='horizontal'
)

cbar.ax.tick_params(labelsize=16)

write_figure(
    figure=fig,
    filepath="../../data/outputs/DC/exp1/paper_figures/colourbar/spatial_residual_colourbar",
    figure_format="pdf"
)