In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import sys
sys.path.append('../')
from jax.config import config
import tensorflow as tf
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
tf.config.set_visible_devices([], device_type='GPU')
import utils
from plot_analysis_utils import label_map

In [None]:
figure_config = {0: {'title': 'ODE 1', 'xlabel': 'x', 'ylabel': 'y', 'ylim': [0, 0.07]},
                1: {'title': 'ODE 2', 'xlabel': 'x', 'ylabel': 'y', 'ylim': [0, 0.07]},  
                2: {'title': 'ODE 3', 'xlabel': 'x', 'ylabel': 'y', 'ylim': [0, 0.025]},
                }

In [None]:
analysis_folder = "../analysis/analysis0511a-v4-ind"
stamp = "20230515-094404_1000000"
dataset_name = "data0511a"
patterns_list = [['ode', 'series'], ['pde'], ['mfc']]

In [None]:
def pattern_match(patterns, name):
    for pattern in patterns:
        if pattern in name:
            return True
    return False

In [None]:
# for arxiv version, the captions are below the figures
for fi, patterns in enumerate(patterns_list):
  fig, ax = plt.subplots(figsize=(4,5))
  error_record = {}
  for demo_num in [1,2,3,4,5]:
    with open("{}/err_{}_{}_{}_{}.pickle".format(analysis_folder, stamp, dataset_name, demo_num, demo_num + 1), 'rb') as file:
      results = pickle.load(file)
    
    for key, value in results.items():
      if pattern_match(patterns, key): # key match the patterns
        if key not in error_record:
            error_record[key] = []
        error_record[key].append(value["relative_error_mean"])

  for key,relative_error_mean in error_record.items():
    new_err_mean = np.array(relative_error_mean)
    demo_num_list = (1,2,3,4,5)  
    ax.plot( demo_num_list, new_err_mean, label=label_map[key]['legend'], 
                                          linestyle= label_map[key]['linestyle'],
                                          marker= label_map[key]['marker'], markersize=7)

  ax.set_xticks(range(1,len(demo_num_list)+1))
  ax.set_xlabel('number of demos')
  ax.set_ylabel('relative error')
  ax.set_ylim(figure_config[fi]['ylim'])
  # plt.grid()
  # ax.legend(ncols = 2,loc='upper center', bbox_to_anchor=(0.5, -0.2), fontsize = 10)
  # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize = 10)
  ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), fontsize = 10)
  plt.subplots_adjust(bottom=0.46, left = 0.2, right = 0.95, top = 0.95)
  # plt.tight_layout()
  plt.savefig('{}/ind_err_{}_err.pdf'.format(analysis_folder, fi))
  plt.close('all')


In [None]:
# for formal version, the captions are to the right of the figures
for fi, patterns in enumerate(patterns_list):
  fig, ax = plt.subplots(figsize=(7,2.5))
  error_record = {}
  for demo_num in [1,2,3,4,5]:
    with open("{}/err_{}_{}_{}_{}.pickle".format(analysis_folder, stamp, dataset_name, demo_num, demo_num + 1), 'rb') as file:
      results = pickle.load(file)
    
    for key, value in results.items():
      if pattern_match(patterns, key): # key match the patterns
        if key not in error_record:
            error_record[key] = []
        error_record[key].append(value["relative_error_mean"])

  for key,relative_error_mean in error_record.items():
    new_err_mean = np.array(relative_error_mean)
    demo_num_list = (1,2,3,4,5)  
    ax.plot( demo_num_list, new_err_mean, label=label_map[key]['legend'], 
                                          linestyle= label_map[key]['linestyle'],
                                          marker= label_map[key]['marker'], markersize=7)

  ax.set_xticks(range(1,len(demo_num_list)+1))
  ax.set_xlabel('number of demos')
  ax.set_ylabel('relative error')
  ax.set_ylim(figure_config[fi]['ylim'])
  # plt.grid()
  ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize = 10)
  # ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), fontsize = 10)
  plt.subplots_adjust(bottom=0.17, left = 0.12, right = 0.54, top = 0.95)
  # plt.tight_layout()
  plt.savefig('{}/ind_err_{}_err_right.pdf'.format(analysis_folder, fi))
  plt.close('all')
