In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from functools import partial
import os
import pickle as pkl
from collections.abc import MutableMapping

import matplotlib.pyplot as plt
import matplotlib.tri as tri
import matplotlib.ticker as ticker
import numpy as np
import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["DDE_BACKEND"] = "jax"

# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

from jax import config
config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import optax

try:
    print(f'Jax: CPUs={jax.local_device_count("cpu")} - GPUs={jax.local_device_count("gpu")}')
except:
    pass
    
import deepxde_al_patch.deepxde as dde

from deepxde_al_patch.model_loader import construct_model, construct_net
from deepxde_al_patch.modified_train_loop import ModifiedTrainLoop
from deepxde_al_patch.plotters import plot_residue_loss, plot_error, plot_prediction
from deepxde_al_patch.train_set_loader import load_data

from deepxde_al_patch.ntk import NTKHelper
from deepxde_al_patch.utils import get_pde_residue, print_dict_structure

In [None]:
plt.rcParams['figure.dpi'] = 150
plt.rcParams['font.size'] = 18
plt.rcParams["figure.titlesize"] = 24
plt.rcParams['text.usetex'] = False

main_graph = 'al_pinn_graphs_final/main'
data_folder = '../../'

In [None]:
def contour_on_ax(ax, xs, zs, levels, res=200, rm_axis=False):
    xi, yi = [np.linspace(np.min(xs[:,i]), np.max(xs[:,i]), res) for i in range(2)]
    grid = np.meshgrid(xi, yi)
    triang = tri.Triangulation(xs[:,0], xs[:,1])
    interpolator = tri.LinearTriInterpolator(triang, zs)
    Xi, Yi = np.meshgrid(xi, yi)
    zi = interpolator(Xi, Yi)
    cb = ax.contourf(xi, yi, zi, levels=levels, cmap="RdBu_r")
    if rm_axis:
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    else:
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,0]) - np.min(xs[:,0]) > 1.0) else 0.5))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1 if (np.max(xs[:,1]) - np.min(xs[:,1]) > 1.0) else 0.5))
    return cb


def plot_contours(xs, ys_list, titles, res=200, sym_colour=False, ptile=False, cbar=True, axislabels=None):
    
    nrows = ys_list[0].shape[1]
    fig, axs = plt.subplots(
        nrows=nrows, 
        ncols=len(ys_list), 
        sharex=True, 
        sharey=True, 
        figsize=(3 * (len(ys_list) + (1 if cbar else -1)), 3 * nrows + 2),
        constrained_layout=True
    )
    
    p_d = 1
    if nrows == 1:
        if ptile:
            min_ = np.percentile(ys_list, p_d)
            max_ = np.percentile(ys_list, 100-p_d)
        else:
            min_ = np.min(ys_list)
            max_ = np.max(ys_list)
        if sym_colour and (min_ < 0 < max_):
            m = max(-min_, max_)
            min_ = -m
            max_ = m
        if ptile:
            ys_list = [np.clip(y, min_, max_) for y in ys_list]
        levels = np.linspace(min_, max_, num=res)
        if not hasattr(axs, '__iter__'):
            axs = np.array([axs])
        for ax, zs, title in zip(axs, ys_list, titles):
            cb = contour_on_ax(ax, xs, zs[:,0], levels, res, rm_axis=not cbar)
            ax.set_title(title)
            ax.set_xlabel(axislabels[0])
        axs[0].set_ylabel(axislabels[1])
        axs = axs.ravel().tolist()
        if cbar:
            fig.colorbar(cb, ax=axs)
    
    else:
        for i in range(nrows):
            ys_list_reduced = [y[:,i] for y in ys_list]
            if ptile:
                min_ = np.percentile(ys_list_reduced, p_d)
                max_ = np.percentile(ys_list_reduced, 100-p_d)
            else:
                min_ = np.min(ys_list_reduced)
                max_ = np.max(ys_list_reduced)
            if sym_colour and (min_ < 0 < max_):
                m = max(-min_, max_)
                min_ = -m
                max_ = m
            if ptile:
                ys_list_reduced = [np.clip(y, min_+1e-9, max_-1e-9) for y in ys_list_reduced]
            levels = np.linspace(min_, max_, num=res)
            for ax, zs in zip(axs[i], ys_list_reduced):
                cb = contour_on_ax(ax, xs, zs, levels, res, rm_axis=not cbar)
                ax.set_xlabel(axislabels[0])
            axs[i,0].set_ylabel(axislabels[1])
            if cbar:
                fig.colorbar(cb, ax=axs[i])
        for ax, title in zip(axs[0], titles):
            ax.set_title(title)
    
    return fig, axs


def plot_training_data(ax, samples):
    ms = 4.
    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False)
    if 'anc' in samples.keys():
        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False)
    for i, bc_pts in enumerate(samples['bcs']):
        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False)

In [None]:
for run_idx in [0, 2, 3, 4, 5, 6]:

    if run_idx == 0:
        graph_root = 'al_pinn_graphs_final/main'
        max_runs = 10
        algs = {
            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls=':', marker='p')),
            'random_Hammersley_prop-0.8': ('Hammersley', dict(c='grey', ls=':', marker='h')),
            'residue_prop-0.8': ('RAD', dict(c='red', ls='--', marker='v')),
            'residue_prop-0.8_alltype': ('RAD-All', dict(c='orange', ls='--', marker='^')),
            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (ours)', dict(c='green', ls='-', marker='s')),
            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),
        }
        
#     elif run_idx == 1:
#         graph_root = 'al_pinn_graphs_final/main_w_pinn-r'
#         max_runs = 10
#         algs = {
#             'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls=':', marker='p')),
#             'random_Hammersley_prop-0.8': ('Hammersley', dict(c='grey', ls=':', marker='h')),
#             'residue_prop-0.8': ('RAD', dict(c='red', ls='--', marker='v')),
#             'residue_prop-0.8_alltype': ('RAD-All', dict(c='orange', ls='--', marker='^')),
#             'sampling_residue_scale-none_mem_autoal': ('PINNACLE-R', dict(c='lightblue', ls='-', marker='>')),
#             'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S (ours)', dict(c='green', ls='-', marker='s')),
#             'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),
#         }

    elif run_idx == 2:
        graph_root = 'al_pinn_graphs_final/mult_nonad'
        max_runs = 5
        algs = {
            'random_pseudo_prop-0.5': ('Unif-0.5', dict(c='black', ls='-', marker='s')),
            'random_pseudo_prop-0.8': ('Unif-0.8', dict(c='black', ls='--', marker='p')),
            'random_pseudo_prop-0.95': ('Unif-0.95', dict(c='black', ls='-.', marker='h')),
            'random_Hammersley_prop-0.5': ('Hamm-0.5', dict(c='m', ls='-', marker='s')),
            'random_Hammersley_prop-0.8': ('Hamm-0.8', dict(c='m', ls='--', marker='p')),
            'random_Hammersley_prop-0.95': ('Hamm-0.95', dict(c='m', ls='-.', marker='h')),
            'random_Sobol_prop-0.5': ('Sobol-0.5', dict(c='red', ls='-', marker='s')),
            'random_Sobol_prop-0.8': ('Sobol-0.8', dict(c='red', ls='--', marker='p')),
            'random_Sobol_prop-0.95': ('Sobol-0.95', dict(c='red', ls='-.', marker='h')),
        }

    elif run_idx == 3:
        graph_root = 'al_pinn_graphs_final/mult_adapt'
        max_runs = 5
        algs = {
            'residue_prop-0.5': ('RAD-0.5', dict(c='red', ls='-', marker='^')),
            'residue_prop-0.8': ('RAD-0.8', dict(c='red', ls='--', marker='v')),
            'residue_prop-0.95': ('RAD-0.95', dict(c='red', ls='-.', marker='>')),
            'residue_prop-0.5_alltype': ('RAD-All-0.5', dict(c='lightblue', ls='-', marker='^')),
            'residue_prop-0.8_alltype': ('RAD-All-0.8', dict(c='lightblue', ls='--', marker='v')),
            'residue_prop-0.95_alltype': ('RAD-All-0.95', dict(c='lightblue', ls='-.', marker='>')),
#             'sampling_residue_scale-none_mem_autoal': ('PINNACLE-R', dict(c='gray', ls=':', marker='p')),
            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S', dict(c='green', ls='-', marker='o')),
            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K', dict(c='blue', ls='-', marker='s')),
        }

    elif run_idx == 4:
        graph_root = 'al_pinn_graphs_final/pinnacle-k'
        max_runs = 5
        algs = {
            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls='--', marker='p')),
            'kmeans_alignment_scale-none_autoal': ('No Memory', dict(c='orange', ls='-', marker='^')),
            'kmeans_alignment_scale-none_mem': ('No Auto Trigger', dict(c='m', ls='-', marker='v')),
            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K', dict(c='blue', ls='-', marker='o')),
        }

    elif run_idx == 5:
        graph_root = 'al_pinn_graphs_final/pinnacle-s'
        max_runs = 5
        algs = {
            'random_pseudo_prop-0.8': ('Uniform Rand', dict(c='black', ls='--', marker='p')),
            'sampling_alignment_scale-none_autoal': ('No Memory', dict(c='orange', ls='-', marker='^')),
            'sampling_alignment_scale-none_mem': ('No Auto Trigger', dict(c='m', ls='-', marker='v')),
#             'sampling_residue_scale-none_mem_autoal': ('Diff. Criterion', dict(c='grey', ls='-', marker='h')),
            'sampling_alignment_scale-none_mem_autoal': ('PINNACLE-S', dict(c='green', ls='-', marker='o')),
        }
        
    elif run_idx == 6:
        graph_root = 'al_pinn_graphs_final/autoloss'
        max_runs = 10
        algs = {
            'random_Hammersley_prop-0.8': ('Hammersley', dict(c='black', ls=':', marker='p')),
            'random_Hammersley_prop-0.8_auto': ('Hammersley + Loss weighing', dict(c='red', ls='--', marker='s')),
            'kmeans_alignment_scale-none_mem_autoal': ('PINNACLE-K (ours)', dict(c='blue', ls='-', marker='o')),
        }

    ######## RUN NORMAL PLOTS

    case_list = [
        
#         (
#             'al_pinn_results/kdv-1d{1.0-0.0}_inv_anc/nn-None-4-32_adam_bcsloss-1.0_budget-1000-200-10',
#             [0, 10000, 20000, 100000],
#             '1D Korteweg–De Vries (Inv)',
#         ),

#         (
#             'al_pinn_results/fd-2d{1.0-0.01}_inv_anc[0,1]/nn-laaf-6-64_adam_bcsloss-1.0_budget-1000-200-30',
#             [0, 10000, 20000, 100000],
#             '2D Fluid Dynamics (Inv)',
#         ),
        
#         (
#             'al_pinn_results/eik1-3d{}_anc0/nn-laaf-8-32_adam_bcsloss-1.0_budget-500-100-5',
#             [10000, 20000, 50000, 100000],
#             '3D Eikonal (Inv)',
#         ),
        
        (
            'al_pinn_results/diffhc-1d{}_ic/nn-None-2-32_adam_bcsloss-1.0_budget-100-50-0',
            [10000, 20000, 30000],
            '1D Diffusion w/ Hard-Constrained PINN'
        ),
        
        (
            'al_pinn_results/kdv-1d{1.0-0.0}_ic/nn-None-4-32_adam_bcsloss-1.0_budget-300-100-0',
            [0, 10000, 20000, 100000],
            '1D Korteweg–De Vries',
        ),
        
        (
            'al_pinn_results/sw-2d{}_pb-0_ic/nn-None-4-32_adam_bcsloss-1.0_budget-1000-200-0',
            [10000, 20000, 50000, 100000],
            'Shallow Water',
        ),

        (
            'al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0',
            [0, 10000, 50000, 200000],
            '1D Advection',
        ),

        (
            'al_pinn_results/conv-1d{1.0}_pb-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-300-100-0',
            [0, 100000, 200000, 300000],
            '1D Advection',
        ),

        (
            'al_pinn_results/conv-1d{1.0}_pb-80_ic/nn-None-8-128_adam_bcsloss-1.0_budget-1000-200-0',
            [0, 10000, 50000, 200000],
            '1D Advection',
        ),

        (
            'al_pinn_results/burgers-1d{0.02}_pb-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-300-100-0', 
            [0, 10000, 50000, 200000],
            '1D Burger\'s',
        ),

    ]

    if 'main' not in graph_root:
        case_list = [c for c in case_list if ('conv' in c[0]) or ('burger' in c[0])]


    for case_folder, steps_plot, suptit in case_list:
        
        if ('conv' in case_folder) or ('burger' in case_folder) or ('fd' in case_folder):
            max_run_2 = 10
        else:
            max_run_2 = 5

        if steps_plot[-1] > 150000:
            tick_spacing = 100000
        elif steps_plot[-1] > 50000:
            tick_spacing = 50000
        else:
            tick_spacing = 10000

        throwout = []

        print('---------------------------------------------------\n\n')
        print('PROCESSING:', graph_root, case_folder)

        max_steps = steps_plot[-1]

        root_folder = os.path.join(data_folder, case_folder)

        _, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')

    #     net, _ = construct_net(
    #         input_dim=2, 
    #         output_dim=1, 
    #         hidden_layers=int(depth), 
    #         hidden_dim=int(width), 
    #         arch=(None if arch == 'None' else arch)
    #     )

        cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}
        print('Exist:', list(cases.keys()))

        if len(list(cases.keys())) < 1:
            continue

        data = dict()
        steps_min = dict()
        plotted_cases = dict()

        for c in cases.keys():

            s_min = float('inf')

            runs = []
            runs_cases = []

            for r in cases[c]:

                try:

                    d = dict()

                    if len(runs) < min(max_run_2, max_runs):

                        for file in os.listdir(f'{root_folder}/{c}/{r}'):

                            if file.startswith('snapshot_data'):

                                fname = f'{root_folder}/{c}/{r}/{file}'

                                with open(fname, 'rb') as f:
                                    d_update = pkl.load(f)

                                d.update(d_update)

                    steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])
#                     if ('eik' in case_folder) or ('sw-2d' in case_folder):
#                         steps_range.pop(0)

                    if (len(steps_range) > 0) and (max_steps == steps_range[-1]) and (None in d.keys()):

                        print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])

                        s_min = min(s_min, steps_range[-1])

                        x_test = d[None]['x_test']

                        d_modified = {
                            'x_test': x_test,
                            'y_test': d[None]['y_test'],
                            'steps': steps_range,
                            'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
                            'err_mean': [np.nanmean(d[k]['error_test']**2)**0.5 for k in steps_range],
#                             'err_mean': [d[k]['error_test_mean'] for k in steps_range],
                            'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],
                            'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],
                            'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],
                            'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],
                            'res': [d[k]['residue_test'] for k in steps_range],
                            'err': [d[k]['error_test'] for k in steps_range],
                            'pred': [d[k]['pred_test'] for k in steps_range],
                            'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
                            'inv': [d[k]['params'][1] for k in steps_range],
                        }
                        
                        if 'sw-2d' in case_folder:
                            d_modified['y_test'] = d_modified['y_test'][:,0:1]
                            d_modified['pred'] = [y[:,0:1] for y in d_modified['pred']]
                            d_modified['err'] = [y[:,0:1] for y in d_modified['err']]
                        
                        if ('kdv' in case_folder) and ('inv' in case_folder):
                            d_modified['inv'] = [
                                (v[0], 0.0025 * np.exp(v[1]))
                                for v in d_modified['inv']
                            ]

                        if 'darcy' in case_folder:
                            a_pred = [y[:,0] for y in d_modified['pred']]
                            a_true = d_modified['y_test'][:,0]
                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2)**0.5 for y in a_pred]
#                             d_modified['a_err_mean'] = [np.mean(np.abs(a_true - y)) for y in a_pred]
                            d_modified['a_err_q90'] = [np.percentile((a_true - y), 90) for y in a_pred]
                            f_true = np.array(a_true > 0.5, dtype=float)
                            f_pred = [np.array(y > 0.5, dtype=float) for y in a_pred]
                            d_modified['bool_err_mean'] = [np.mean(np.abs(f_true - y)) for y in f_pred]
                            multidim = True
                        elif 'reacdiff' in case_folder:
                            multidim = True
                        elif 'fd-2d' in case_folder:
                            a_pred = [y[:,2] for y in d_modified['pred']]
                            a_true = d_modified['y_test'][:,2]
                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2)**0.5 for y in a_pred]
#                             d_modified['a_err_mean'] = [np.mean(np.abs(a_true - y)) for y in a_pred]
                            d_modified['a_err_q90'] = [np.percentile((a_true - y), 90) for y in a_pred]
                            multidim = True
                        elif 'eik' in case_folder:
                            a_pred = [y[:,1] for y in d_modified['pred']]
                            a_true = d_modified['y_test'][:,1]
                            d_modified['a_err_mean'] = [np.mean((a_true - y)**2)**0.5 for y in a_pred]
#                             d_modified['a_err_mean'] = [np.mean(np.abs(a_true - y)) for y in a_pred]
                            d_modified['a_err_q90'] = [np.percentile((a_true - y), 90) for y in a_pred]
                            multidim = True
                            if ('eik3' in case_folder) or ('eik5' in case_folder):
                                f_true = np.array(a_true > 0.5, dtype=float)
                                f_pred = [np.array(y > 0.5, dtype=float) for y in a_pred]
                                d_modified['bool_err_mean'] = [np.mean(np.abs(f_true - y)) for y in f_pred]
                        else:
                            multidim = False

                        if 'inv' in case_folder:
                            d_modified['inv_param_true'] = [float(x) for x in case_folder.split('{')[1].split('}')[0].split('-')]
                            d_modified['inv_param_pred'] = tuple([float(d[k]['params'][1][i]) for k in steps_range] for i in range(len(d_modified['inv_param_true'])))
                            inv_params = d_modified['inv_param_true']
                            plot_inv_param = True
                        else:
                            plot_inv_param = False

                        runs.append(d_modified)
                        runs_cases.append(r)

                    else:
                        throwout.append(f'{root_folder}/{c}/{r}')

                except Exception as e:
    #                 raise e
                    pass

            if len(runs) > 0:
                data[c] = runs
                steps_min[c] = s_min
                plotted_cases[c] = runs_cases

        print('To plot algorithms =', {k: len(data[k]) for k in data.keys()})
        print()

        print('Throwout:')
        [print(th) for th in throwout]

        if len([k for k in data.keys()]) < 1:
            continue

        graph_folder = os.path.join(data_folder, graph_root, case_folder)
        os.makedirs(graph_folder, exist_ok=True)

        with open(os.path.join(graph_folder, 'cases_plotted'), 'w+') as f:
            f.write(str(plotted_cases))

        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['err_mean'] for y in data[c]]
            mean = np.mean(ys, axis=0)
            err = np.std(ys, axis=0)
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean error')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #     fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'err_mean.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        figl = plt.figure(figsize=(6, 1.5))
        axl = figl.add_subplot(111)
        axl.legend(*ax.get_legend_handles_labels() , loc="center", ncol=3)
        axl.axis('off')
        figl.savefig(os.path.join(graph_folder, f'labels.pdf'), bbox_inches='tight', pad_inches=0.05)
        plt.close('all')

        figl = plt.figure(figsize=(6, 1))
        axl = figl.add_subplot(111)
        axl.legend(*ax.get_legend_handles_labels() , loc="center", ncol=len(ax.get_legend_handles_labels()[0]))
        axl.axis('off')
        figl.savefig(os.path.join(graph_folder, f'labels_flat.pdf'), bbox_inches='tight', pad_inches=0.05)
        plt.close('all')

        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['err_mean'] for y in data[c]]
            mean = np.percentile(ys, 50, axis=0)
            err1 = mean - np.percentile(ys, 20, axis=0)
            err2 = np.percentile(ys, 80, axis=0) - mean
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean error')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #     fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'err_med.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        if ('darcy' in case_folder) or ('eik' in case_folder) or ('fd' in case_folder):

            if 'darcy' in case_folder:
                field_name = 'a'
            elif 'eik' in case_folder:
                field_name = 'v'
            elif 'fd' in case_folder:
                field_name = 'p'

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y['a_err_mean'] for y in data[c]]
                mean = np.mean(ys, axis=0)
                err = np.std(ys, axis=0)
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'Mean error of {field_name}(x)')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            # fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_avg.pdf'), bbox_inches='tight', pad_inches=0.1)
            plt.close('all')

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y['a_err_mean'] for y in data[c]]
                mean = np.percentile(ys, 50, axis=0)
                err1 = mean - np.percentile(ys, 20, axis=0)
                err2 = np.percentile(ys, 80, axis=0) - mean
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'Mean error of {field_name}(x)')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            # fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_med.pdf'), bbox_inches='tight', pad_inches=0.1)
            plt.close('all')

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y['a_err_q90'] for y in data[c]]
                mean = np.mean(ys, axis=0)
                err = np.std(ys, axis=0)
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'90th Quantile Error of {field_name}(x)')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            # fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'err-{field_name}_q90.pdf'), bbox_inches='tight', pad_inches=0.1)
            plt.close('all')

            if ('eik3' in case_folder) or ('eik5' in case_folder):

                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['bool_err_mean'] for y in data[c]]
                    mean = np.mean(ys, axis=0)
                    err = np.std(ys, axis=0)
                    label, marker = algs[c]
                    ax.set_yscale("log")
                    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.set_xlabel('Steps')
                ax.set_ylabel(f'Mean boolean error of {field_name}(x)')
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                # fig.suptitle(suptit)
                fig.savefig(os.path.join(graph_folder, f'err-{field_name}-bool_avg.pdf'), bbox_inches='tight', pad_inches=0.1)
                plt.close('all')

                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['bool_err_mean'] for y in data[c]]
                    mean = np.median(ys, axis=0)
                    err1 = mean - np.percentile(ys, 20, axis=0)
                    err2 = np.percentile(ys, 80, axis=0) - mean
                    label, marker = algs[c]
                    ax.set_yscale("log")
                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.set_xlabel('Steps')
                ax.set_ylabel(f'Median boolean error of {field_name}(x)')
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                # fig.suptitle(suptit)
                fig.savefig(os.path.join(graph_folder, f'err-{field_name}-bool_med.pdf'), bbox_inches='tight', pad_inches=0.1)

            plt.close('all')


        for q in [50, 90, 95, 100]:

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y[f'err_q{q}'] for y in data[c]]
                mean = np.mean(ys, axis=0)
                err = np.std(ys, axis=0)
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'{q}th Quantile error')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            # fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'err_q{q}.pdf'), bbox_inches='tight', pad_inches=0.1)
            plt.close('all')


        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['res_mean'] for y in data[c]]
            mean = np.mean(ys, axis=0)
            err = np.std(ys, axis=0)
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean residue')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
        # fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'res_mean.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        if plot_inv_param:
            for i in range(len(inv_params)):

                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['inv_param_pred'][i] for y in data[c]]
                    mean = np.mean(ys, axis=0)
                    err = np.std(ys, axis=0)
                    label, marker = algs[c]
                    ax.errorbar(data[c][0]['steps'], mean, err, capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.axhline(y=inv_params[i], color='black')
                ax.set_ylim(max(0., inv_params[i]-0.05), None)
                ax.set_xlabel('Steps')
                ax.set_ylabel(f'Inv. param {i+1}')
                # fig.suptitle(suptit)
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}.pdf'), bbox_inches='tight', pad_inches=0.1)

                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['inv_param_pred'][i] for y in data[c]]
                    mean = np.percentile(ys, 50, axis=0)
                    err1 = mean - np.percentile(ys, 20, axis=0)
                    err2 = np.percentile(ys, 80, axis=0) - mean
                    label, marker = algs[c]
                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.axhline(y=inv_params[i], color='black')
                ax.set_ylim(max(0., inv_params[i]-0.05), None)
                ax.set_xlabel('Steps')
                ax.set_ylabel(f'Inv. param {i+1}')
                # fig.suptitle(suptit)
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_med.pdf'), bbox_inches='tight', pad_inches=0.1)

                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['inv_param_pred'][i] for y in data[c]]
                    inv_err = np.abs(inv_params[i] - np.array(ys))
                    mean = np.mean(inv_err, axis=0)
                    err = np.std(inv_err, axis=0)
                    label, marker = algs[c]
                    ax.set_yscale("log")
                    ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.set_xlabel('Steps')
                ax.set_ylabel(f'Inv. param {i+1} error')
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                # fig.suptitle(suptit)
                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_err.pdf'), bbox_inches='tight', pad_inches=0.1)


                fig, ax = plt.subplots(figsize=(5, 4))
                for c in data.keys():
                    ys = [y['inv_param_pred'][i] for y in data[c]]
                    inv_err = np.abs(inv_params[i] - np.array(ys))
                    mean = np.percentile(inv_err, 50, axis=0)
                    err1 = mean - np.percentile(inv_err, 20, axis=0)
                    err2 = np.percentile(inv_err, 80, axis=0) - mean
                    label, marker = algs[c]
                    ax.set_yscale("log")
                    ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
                ax.set_xlabel('Steps')
                ax.set_ylabel('Mean error')
                ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
                # fig.suptitle(suptit)
                fig.savefig(os.path.join(graph_folder, f'inv_param_{i}_err_med.pdf'), bbox_inches='tight', pad_inches=0.1)

                plt.close('all')


        k0 = list(data.keys())[0]
        best_from = 'err_mean'
        
        if ('conv' in case_folder) or ('burger' in case_folder) or ('kdv' in case_folder) or ('eik' in case_folder) or ('sw-2d' in case_folder) or ('diff' in case_folder):
            
            if ('conv' in case_folder) or ('burger' in case_folder) or ('kdv' in case_folder) or ('diff' in case_folder):
                axislabel = ('x', 't')
            else:
                axislabel = ('x', 'y')

            for stidx in [-1, -2]:

                for dim in range(data[k0][0]['y_test'].shape[1]):
                    
                    if data[k0][0]['x_test'].shape[1] > 2:
                        z_max = 1.
                        xidx = data[k0][0]['x_test'][:,2] == z_max
                        print(f'Only plotting slice where last dim value = {z_max}')
                    else:
                        xidx = jnp.ones(shape=data[k0][0]['x_test'].shape[0], dtype=bool)

                    if '_change/' in case_folder:
                        start_y = [data[k0][0]['pred'][0][:,dim:dim+1], data[k0][0]['y_test'][xidx,dim:dim+1]]
                        start_title = ['Initial model', 'True solution']
                    else:
                        start_y = [data[k0][0]['y_test'][xidx,dim:dim+1]]
                        start_title = ['True solution']

                    fig, axs = plot_contours(
                        xs=data[k0][0]['x_test'][xidx,:2], 
                        ys_list=start_y + [data[c][np.argmin([x[best_from][-1] for x in data[c]])]['pred'][stidx][xidx,dim:dim+1] for c in data.keys()], 
                        titles=start_title + [algs[c][0] for c in data.keys()], 
                        axislabels=axislabel,
                    )
            #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')
            #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                    fig.savefig(os.path.join(graph_folder, f'pred_s{data[c][0]["steps"][stidx]}-d{dim}.png'), bbox_inches='tight', pad_inches=0.1)
                    plt.close('all')

                    fig, axs = plot_contours(
                        xs=data[k0][0]['x_test'][xidx,:2], 
                        ys_list=[data[c][np.argmin([x[best_from][-1] for x in data[c]])]['err'][stidx][xidx,dim:dim+1] for c in data.keys()],
                        titles=[algs[c][0] for c in data.keys()], 
                        axislabels=axislabel,
                    )
            #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')
            #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                    fig.savefig(os.path.join(graph_folder, f'err_s{data[c][0]["steps"][stidx]}-d{dim}.png'), bbox_inches='tight', pad_inches=0.1)
                    plt.close('all')

                fig, axs = plot_contours(
                    xs=data[k0][0]['x_test'][xidx,:2], 
                    ys_list=[data[c][np.argmin([x[best_from][-1] for x in data[c]])]['res'][stidx][xidx] for c in data.keys()],
                    titles=[algs[c][0] for c in data.keys()], 
                    axislabels=axislabel,
                )
        #         fig.suptitle(f'{suptit}, step {steps_plot[stidx]}')
        #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                fig.savefig(os.path.join(graph_folder, f'res_s{data[c][0]["steps"][stidx]}.png'), bbox_inches='tight', pad_inches=0.1)
                plt.close('all')
                
                if '2d' in case_folder:
                    fig, ax = plt.subplots(figsize=(5, 4))
                    ms = 6.
                    samples = data[c][0]['chosen_pts'][0]
                    ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='PDE CL Pts')
                    for i, (bc_pts, name) in enumerate(zip(samples['bcs'], ['IC CL Pts', 'BC CL Pts'])):
                        ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False, label=name)
                    if 'anc' in samples.keys():
                        ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='Exp Pts')
                    figl = plt.figure(figsize=(6, 1.5))
                    axl = figl.add_subplot(111)
                    axl.legend(*ax.get_legend_handles_labels() , loc="center", ncol=4)
                    axl.axis('off')
                    figl.savefig(os.path.join(graph_folder, f'labels_trainpts.pdf'), bbox_inches='tight', pad_inches=0.05)
                    plt.close('all')
                    
            if (('main' in graph_root) or ('pinnacle' in graph_root)) and ('2d' in case_folder):

                for c in data.keys():

                    min_idx = np.argmin([x[best_from][-1] for x in data[c]])
                    steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]

                    fig, axs = plot_contours(
                        xs=data[k0][0]['x_test'], 
                        ys_list=[data[c][min_idx]['pred'][s] for s in steps],
                        titles=[f'Step {s}' for s in steps_plot], 
                        axislabels=axislabel,
                    )
                    if multidim:
                        for ax, s in zip(axs[0], steps):
                            plot_training_data(ax, data[c][0]['chosen_pts'][s])
                    else:
                        for ax, s in zip(axs, steps):
                            plot_training_data(ax, data[c][0]['chosen_pts'][s])
            #         fig.suptitle(algs[c][0])
            #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                    fig.savefig(os.path.join(graph_folder, f'data_pred_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                    plt.close('all')

                    fig, axs = plot_contours(
                        xs=data[k0][0]['x_test'], 
                        ys_list=[data[c][min_idx]['err'][s] for s in steps], 
                        titles=[f'Step {s}' for s in steps_plot], 
                        axislabels=axislabel,
                    )
                    if multidim:
                        for ax, s in zip(axs[0], steps):
                            plot_training_data(ax, data[c][0]['chosen_pts'][s])
                    else:
                        for ax, s in zip(axs, steps):
                            plot_training_data(ax, data[c][0]['chosen_pts'][s])
            #         fig.suptitle(algs[c][0])
            #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                    fig.savefig(os.path.join(graph_folder, f'data_err_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                    plt.close('all')

                    fig, axs = plot_contours(
                        xs=data[c][min_idx]['x_test'], 
                        ys_list=[data[c][min_idx]['res'][s] for s in steps], 
                        titles=[f'Step {data[c][0]["steps"][s]}' for s in steps], 
                        axislabels=axislabel,
                    )
                    for ax, s in zip(axs, steps):
                        plot_training_data(ax, data[c][0]['chosen_pts'][s])
            #         fig.suptitle(algs[c][0])
            #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                    fig.savefig(os.path.join(graph_folder, f'data_res_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                    plt.close('all')


    ######################################## FINE TUNING EXPERIMENTS

    case_list = [

        (
            'al_pinn_results_ic_change/conv-1d{1.0}_ftic-898-40_ic/nn-None-8-128_adam_bcsloss-1.0_budget-200-50-0',
            [0, 10000, 50000, 200000],
    #         [0, 20000, 40000, 60000, 80000, 100000, 150000],
            '1D Advection (FT)'
        ),

        (
            'al_pinn_results_ic_change/burgers-1d{0.02}_ftic-131-20_ic/nn-None-4-128_adam_bcsloss-1.0_budget-200-50-0',
            [0, 10000, 50000, 200000],
    #         [0, 20000, 40000, 60000, 80000, 100000, 150000],
            '1D Burgers (FT)'
        ),

    ]

    for case_folder, steps_plot, suptit in case_list:

        if steps_plot[-1] > 150000:
            tick_spacing = 100000
        elif steps_plot[-1] > 50000:
            tick_spacing = 50000
        else:
            tick_spacing = 10000

        throwout = []

        print('---------------------------------------------------\n\n')
        print('PROCESSING:', graph_root, case_folder)

        max_steps = steps_plot[-1]

        root_folder = os.path.join(data_folder, case_folder)

        _, arch, depth, width = root_folder.split('/')[-1].split('_')[0].split('-')

        net, _ = construct_net(
            input_dim=2, 
            output_dim=1, 
            hidden_layers=int(depth), 
            hidden_dim=int(width), 
            arch=(None if arch == 'None' else arch)
        )

        cases = {x: os.listdir(f'{root_folder}/{x}') for x in algs.keys() if os.path.exists(f'{root_folder}/{x}')}
        print('Exist:', list(cases.keys()))

        data = dict()
        steps_min = dict()
        plotted_cases = dict()

        for c in cases.keys():

            s_min = float('inf')

            runs = []
            runs_cases = []

            included = set()

            for r in sorted(cases[c]):

                try:

                    ver = r.split('_')[0]
                    if ver in included:
                        continue

                    d = dict()

                    for file in os.listdir(f'{root_folder}/{c}/{r}'):

                        if file.startswith('snapshot_data'):

                            fname = f'{root_folder}/{c}/{r}/{file}'

                            with open(fname, 'rb') as f:
                                d_update = pkl.load(f)

                            d.update(d_update)

                    steps_range = sorted([x for x in d.keys() if (x is not None) and (max_steps >= x)])
                    if (len(steps_range) > 0) and (max_steps == steps_range[-1]) and (None in d.keys()):

                        print(c, r, sorted([x for x in d.keys() if (x is not None)])[-1])
                        included.add(ver)

                        s_min = min(s_min, steps_range[-1])

                        x_test = d[None]['x_test']

                        d_modified = {
                            'x_test': x_test,
                            'y_test': d[None]['y_test'],
                            'steps': steps_range,
                            'res_mean': [d[k]['residue_test_mean'] for k in steps_range],
#                             'err_mean': [d[k]['error_test_mean'] for k in steps_range],
                            'err_mean': [np.nanmean(d[k]['error_test']**2)**0.5 for k in steps_range],
                            'err_q50': [np.percentile(d[k]['error_test'], 50) for k in steps_range],
                            'err_q90': [np.percentile(d[k]['error_test'], 90) for k in steps_range],
                            'err_q95': [np.percentile(d[k]['error_test'], 95) for k in steps_range],
                            'err_q100': [np.percentile(d[k]['error_test'], 100) for k in steps_range],
                            'res': [d[k]['residue_test'] for k in steps_range],
                            'err': [d[k]['error_test'] for k in steps_range],
                            'pred': [d[k]['pred_test'] if 'pred_test' in d[k].keys() 
                                     else net.apply(d[k]['params'][0], x_test)
                                     for k in steps_range],
                            'chosen_pts': [d[k]['al_intermediate']['chosen_pts'] for k in steps_range],
                            'inv': [d[k]['params'][1] for k in steps_range],
                        }

                        arr_shape = [d_modified['y_test'].shape[1]] + [np.unique(x).shape[0] for x in d_modified['x_test'].T]
                        d_modified['y_test_fft'] = np.fft.fftn(
                            d_modified['y_test'].reshape(*arr_shape), 
                            axes=[1, 2]
                        )
                        d_modified['pred_fft'] = [np.fft.fftn(
                            y.reshape(*arr_shape), axes=[1, 2]) 
                            for y in d_modified['pred']]

                        d_modified['fft_err'] = [np.abs(yf - d_modified['y_test_fft'])
                            for yf in d_modified['pred_fft']]

            #             idxs = np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))[0].T
            #             idxs = np.array([idxs, idxs])

                        idxs = np.array(np.meshgrid(np.arange(arr_shape[1]), np.arange(arr_shape[2]))).swapaxes(1, 2)

                        klow = (idxs <= 4).all(axis=0).astype(float)
                        kmid = (idxs <= 12).all(axis=0).astype(float) - klow
                        khigh = (idxs <= np.inf).all(axis=0).astype(float) - kmid - klow

                        for s, k in [('low', klow), ('mid', kmid), ('high', khigh)]:
                            d_modified[f'fft_mean_{s}'] = [np.sum(yf * k[None, :]) / (np.sum(k) * yf.shape[0])
                                for yf in d_modified['fft_err']]

                        runs.append(d_modified)
                        runs_cases.append(r)

                    else:
                        throwout.append(f'{root_folder}/{c}/{r}')

                except:
                    pass

            if len(runs) > 0:
                data[c] = runs
                steps_min[c] = s_min
                plotted_cases[c] = runs_cases

        print('To plot algorithms =', {k: len(data[k]) for k in data.keys()})

        if len([k for k in data.keys()]) == 0:
            continue

        graph_folder = os.path.join(data_folder, graph_root, case_folder)
        os.makedirs(graph_folder, exist_ok=True)

        with open(os.path.join(graph_folder, 'cases_plotted'), 'w+') as f:
            f.write(str(plotted_cases))



        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['err_mean'] for y in data[c]]
            mean = np.mean(ys, axis=0)
            err = np.std(ys, axis=0)
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean error')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #     fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'err_mean.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['err_mean'] for y in data[c]]
            mean = np.percentile(ys, 50, axis=0)
            err1 = mean - np.percentile(ys, 20, axis=0)
            err2 = np.percentile(ys, 80, axis=0) - mean
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [err1, err2], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean error')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
        # fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'err_med.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')


        for q in [50, 90, 95, 100]:

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y[f'err_q{q}'] for y in data[c]]
                mean = np.mean(ys, axis=0)
                err = np.std(ys, axis=0)
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'{q}th Quantile error')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #         fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'err_q{q}.pdf'), bbox_inches='tight', pad_inches=0)
            plt.close('all')


        fig, ax = plt.subplots(figsize=(5, 4))
        for c in data.keys():
            ys = [y['res_mean'] for y in data[c]]
            mean = np.mean(ys, axis=0)
            err = np.std(ys, axis=0)
            label, marker = algs[c]
            ax.set_yscale("log")
            ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
        ax.set_xlabel('Steps')
        ax.set_ylabel('Mean residue')
        ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #     fig.suptitle(suptit)
        fig.savefig(os.path.join(graph_folder, f'res_mean.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')



        for s in ['low', 'mid', 'high']:

            fig, ax = plt.subplots(figsize=(5, 4))
            for c in data.keys():
                ys = [y[f'fft_mean_{s}'] for y in data[c]]
                mean = np.mean(ys, axis=0)
                err = np.std(ys, axis=0)
                label, marker = algs[c]
                ax.set_yscale("log")
                ax.errorbar(data[c][0]['steps'], mean, [np.zeros_like(err), err], capsize=2, label=label, alpha=0.7, 
                        markerfacecolor='none', markeredgecolor=marker['c'], color=marker['c'], ls=marker['ls'], marker=marker['marker'])
            ax.set_xlabel('Steps')
            ax.set_ylabel(f'Mean FFT ({s}) diff.')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    #         fig.suptitle(suptit)
            fig.savefig(os.path.join(graph_folder, f'fft-{s}_mean.pdf'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')


        axislabel = ('x', 't')

        k0 = list(data.keys())[0]
        idx = 0

        if '_change/' in case_folder:
            start_y = [data[k0][idx]['pred'][0], data[k0][idx]['y_test']]
            start_title = ['Initial model', 'True solution']
        else:
            start_y = [data[k0][idx]['y_test']]
            start_title = ['True solution']

        if graph_root == main_graph:
            keys_used = [
                'random_Hammersley_prop-0.8', 'residue_prop-0.8_alltype', 
                'sampling_alignment_scale-none_mem_autoal', 'kmeans_alignment_scale-none_mem_autoal'
            ]
        else:
            keys_used = data.keys()

        fig, axs = plot_contours(
            xs=data[k0][idx]['x_test'], 
            ys_list=start_y + [data[c][idx]['pred'][-1] for c in keys_used], 
            titles=start_title + [algs[c][0] for c in keys_used], 
            axislabels=axislabel,
        )
    #     fig.suptitle(f'{suptit}, step {data[c][0]["steps"][-1]}')
    #     fig.tight_layout(rect=[0, 0, 1, 0.95])
        fig.savefig(os.path.join(graph_folder, f'pred_s{data[c][0]["steps"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        fig, axs = plot_contours(
            xs=data[k0][idx]['x_test'], 
            ys_list=[data[c][idx]['err'][-1] for c in keys_used],
            titles=[algs[c][0] for c in keys_used], 
            axislabels=axislabel,
        )
    #     fig.suptitle(f'{suptit}, step {data[c][0]["steps"][-1]}')
    #     fig.tight_layout(rect=[0, 0, 1, 0.95])
        fig.savefig(os.path.join(graph_folder, f'err_s{data[c][0]["steps"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')

        fig, axs = plot_contours(
            xs=data[k0][idx]['x_test'], 
            ys_list=[data[c][idx]['res'][-1] for c in keys_used],
            titles=[algs[c][idx] for c in keys_used], 
            axislabels=axislabel,
        )
    #     fig.suptitle(f'{suptit}, step {data[c][0]["steps"][-1]}')
    #     fig.tight_layout(rect=[0, 0, 1, 0.95])
        fig.savefig(os.path.join(graph_folder, f'res_s{data[c][idx]["steps"][-1]}.png'), bbox_inches='tight', pad_inches=0.1)
        plt.close('all')
        
        fig, ax = plt.subplots(figsize=(5, 4))
        ms = 6.
        samples = data[c][0]['chosen_pts'][0]
        ax.plot(samples['res'][:, 0], samples['res'][:, 1], 'o', color='black', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='PDE CL Pts')
        for i, (bc_pts, name) in enumerate(zip(samples['bcs'], ['IC CL Pts', 'BC CL Pts'])):
            ax.plot(bc_pts[:, 0], bc_pts[:, 1], 's', color=f'C{i+1}', ms=ms, alpha=0.95, zorder=10, clip_on=False, label=name)
        if 'anc' in samples.keys():
            ax.plot(samples['anc'][:, 0], samples['anc'][:, 1], '^', color='blue', ms=ms, alpha=0.95, zorder=10, clip_on=False, label='Exp Pts')
        figl = plt.figure(figsize=(6, 1.5))
        axl = figl.add_subplot(111)
        axl.legend(*ax.get_legend_handles_labels() , loc="center", ncol=4)
        axl.axis('off')
        figl.savefig(os.path.join(graph_folder, f'labels_trainpts.pdf'), bbox_inches='tight', pad_inches=0.05)
        plt.close('all')

        if graph_root == main_graph:

            for c in data.keys():

                min_idx = 0
        #         min_idx = np.argmin([x['err_mean'][-1] for x in data[c]])
                steps = [data[c][min_idx]['steps'].index(s) for s in steps_plot]

                fig, axs = plot_contours(
                    xs=data[k0][0]['x_test'], 
                    ys_list=(
#                         [data[c][min_idx]['y_test']] +
                        [data[c][min_idx]['pred'][s] for s in steps]
                    ), 
                    titles=(
#                         ['True solution'] +
                        [f'Step {s}' for s in steps_plot]
                    ), 
                    axislabels=axislabel,
                )
                for ax, s in zip(axs, steps):
                    plot_training_data(ax, data[c][0]['chosen_pts'][s])
        #         fig.suptitle(algs[c][0])
        #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                fig.savefig(os.path.join(graph_folder, f'data_pred_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                plt.close('all')

                fig, axs = plot_contours(
                    xs=data[k0][0]['x_test'], 
                    ys_list=[data[c][min_idx]['err'][s] for s in steps], 
                    titles=[f'Step {s}' for s in steps_plot], 
                    axislabels=axislabel,
                )
                for ax, s in zip(axs, steps):
                    plot_training_data(ax, data[c][0]['chosen_pts'][s])
        #         fig.suptitle(algs[c][0])
        #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                fig.savefig(os.path.join(graph_folder, f'data_err_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                plt.close('all')

                fig, axs = plot_contours(
                    xs=data[c][min_idx]['x_test'], 
                    ys_list=[data[c][min_idx]['res'][s] for s in steps], 
                    titles=[f'Step {data[c][0]["steps"][s]}' for s in steps], 
                    axislabels=axislabel,
                )
                for ax, s in zip(axs, steps):
                    plot_training_data(ax, data[c][0]['chosen_pts'][s])
        #         fig.suptitle(algs[c][0])
        #         fig.tight_layout(rect=[0, 0, 1, 0.95])
                fig.savefig(os.path.join(graph_folder, f'data_res_{algs[c][0]}.png'), bbox_inches='tight', pad_inches=0.1)
                plt.close('all')

    print('=================================DONE=============================\n\n')