In [2]:
# to access functions from root directory
import sys
sys.path.append('/data/ad181/RemoteDir/multilevel_ppo')

In [3]:
%matplotlib notebook
import os
import pickle
import numpy as np 
import matplotlib.pyplot as plt  
from tqdm import trange
from time import time
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

from utils.env_evaluate_functions import eval_actions

In [4]:
seed=1
img_dir = './env_images'
with open('env_data_v1/env_train_dict.pkl', 'rb') as input:
    env_train_dict = pickle.load(input)
os.makedirs(img_dir, exist_ok=True)

In [5]:
samples = 100
levels = [3,4]
titles = ['level 1', 'level 2']
actions_ = np.ones((env_train_dict[1].ressim_params.terminal_step,env_train_dict[1].ressim_params.action_space.shape[0]))

ts, rs = [], []
for seed in trange(samples):
    ts_seed, rs_seed = [], []
    states_seed = []
    for level in levels:
        env_train_dict[level].seed(seed)
        start_time = time()
        states, actions, rewards = eval_actions(env_train_dict[level], actions_)
        t = time() - start_time
        ts_seed.append(t)
        rs_seed.append(sum(rewards))
        states_seed.append(states)
    ts.append(ts_seed)
    rs.append(rs_seed)

  for j in range(len(p_1)-1):
100%|██████████| 100/100 [02:34<00:00,  1.55s/it]


In [49]:
titles = ['level 1\n37x111', 'level 2\n73x219']
extents = [(0,37,0,111),(0,73,0,219)]
zoom_windows = [(2,12,9,19),(4,24,18,38)]
fig, axs = plt.subplots(1,len(levels),figsize=(1.8*len(levels),3.8) )
for i,(ax,extent, z_w) in enumerate(zip(axs, extents, zoom_windows)):
    ax.axis('off')
    im = ax.imshow(states_seed[i][-1] , origin='lower', cmap='RdBu', vmin=0, vmax=1)
    axins = zoomed_inset_axes(ax, 2.8, loc='right') # zoom = 6
    axins.imshow(states_seed[i][-1], extent=extent, interpolation="nearest", origin="lower", cmap='RdBu')
    # sub region of the original image
    x1, x2, y1, y2 = 15, 20, 15, 20
    axins.set_xlim(z_w[0], z_w[1])
    axins.set_ylim(z_w[2], z_w[3])
    plt.xticks(visible=False)
    plt.yticks(visible=False)
    mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="w")
    axins.spines['bottom'].set_color('w')
    axins.spines['top'].set_color('w') 
    axins.spines['right'].set_color('w')
    axins.spines['left'].set_color('w')
    ax.set_title(titles[i])
cbar_ax = fig.add_axes([0.88, 0.12, 0.02, 0.76])
fig.colorbar(im, cax=cbar_ax, orientation="vertical") 
fig.savefig(img_dir+'/s_levels.pdf')

<IPython.core.display.Javascript object>

In [6]:
ts = np.array(ts[1:])
fig,axs = plt.subplots(1,1,figsize=(5,4))
avg_t = ts.mean(axis=0)
avg_t = avg_t/avg_t[-1]
axs.bar(titles, avg_t, color='gray', width=0.8 )
axs.grid('on')
for i,t in enumerate(avg_t[:-1]):
    axs.text(i,t+0.04,str(round(t,2)),
            horizontalalignment='center',
            verticalalignment='center')
axs.set_ylabel('normalised cost')
fig.savefig(img_dir+'/t_levels.pdf')

<IPython.core.display.Javascript object>

In [7]:
from copy import deepcopy 

rs=[]

for level in levels:
    rs_l = []
    env_ = deepcopy(env_train_dict[level])
    k_list = env_.ressim_params.k_list
    for k in k_list:
        env_.set_k(np.array([k]))
        states, actions, rewards = eval_actions(env_, actions_)
        rs_l.append(np.sum(rewards))
    rs.append(rs_l)

In [8]:
rs=np.array(rs).reshape(16,-1)
fig,axs = plt.subplots(1,1,figsize=(6,4))
order = np.argsort(rs[:,1])
# order = np.arange(16)

axs.plot( rs[order,0], 'o--')
# axs.plot( rs[order,1], '.--')
axs.plot( rs[order,1], '.--', color='gray')
axs.legend(titles)
axs.set_xlabel(r'training permeability sample index $i$ for $k_i$')
axs.set_ylabel('recovery factor')
axs.set_xticks(np.arange(16))
axs.set_xticklabels(order)
axs.grid('on', alpha=0.3)
# fig.savefig(data_dir+'/'+case_label+'_r_comparison.pdf')
fig.show()

print(rs[order,0])
print(rs[order,1])
# print(rs[order,2])

<IPython.core.display.Javascript object>

[0.67775434 0.68014562 0.69914158 0.69791898 0.68939185 0.69495011
 0.68367958 0.68042646 0.69709626 0.70398557 0.69511492 0.71186934
 0.69505695 0.69317925 0.71499144 0.7257098 ]
[0.64242999 0.65225819 0.67645373 0.67827979 0.67999412 0.68171993
 0.69351402 0.69451886 0.7001145  0.70159269 0.7040443  0.70501348
 0.70845984 0.71026568 0.71388443 0.72344193]
