In [None]:
import sys
import os
# Add the root directory of your project to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(''), '..')))


# External packages
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import torch
from torch.utils.data import DataLoader
from functools import partial
import IPython.display
from importlib import reload


# Internal packages
from src.data.processing import CMAPSS_dataset
from src.models.forecast_model import TruncNormNetwork
from src.models.rl_agent import DQNAgent
from src.models.rl_network import Network as RLNet
from src.training.loss_functions import CRPS_truncnorm_int
from src.utils.environment import ForecastEnv
from src.utils.replaybuffer import ReplayBuffer
from src.utils.epsgreed_funcs import eps_decay
import src.tests.rl_tests
from src.training.loss_functions import _lognorm_cdf
from src.tests.rl_tests import rl_test_engine, rl_test_grid
from src.tests.forecast_tests import forecast_test_testset, forecast_test_valengine


# Import scripts
from scripts.process_data import process_data
from scripts.run_forecast_training import train_forecast, train_forecast_bm
from scripts.run_rl_training import train_rl

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Data set parameters

In [None]:
dataset = 4 # FD00i
dataset_name = f'FD00{dataset}'
window_size = [31, 21, 38, 19] # size of the sliding window
upper_RUL = 128 #128 # max value for piecewise RUL

In [None]:
# process the data sets according to the above parameters
process_data(window_size, upper_RUL)

Forecasting and Scheduling model parameters

In [None]:
run_id = '-MainRun-'+dataset_name # Run ID for save paths

forecast_data_fraction = 0.7
rl_data_fraction = 0.2
test_data_fraction = 0.1

Forecast_parameters = {
    'num_epochs': 50,
    'batch_size': 128,
    'learning_rate': 1e-3,
    'hidden_dim': 64,
    'dropout': 0.2,
    'num_lstms': 1,
    'std_max': 0.5,
    'val_fraction': 0.03,
    'lambd': 0.,
}

RL_parameters = {
    'n_actions': 3,
    'planning_window': 10,
    'n_obs': (10,),# or equal to planning_window ?
    'num_frames': 40000,#48000,#46_000, #24000,#24000,#42_000,
    'memory_size': 5000,#2000,
    'batch_size': 64,
    'target_update': 200,#150,
    'epsilon_func': eps_decay('exp', decay=1/7000),#eps_decay('exp', decay=1/6000),#eps_decay('lin', decay=1/6000, bumprate=12000), # eps_decay('exp', decay=1/2000, bumprate=12000), #,
    'vmin': -80,
    'vmax': 1,
    'atom_size': 20,
    #'cvar_alpha': 1,
    'window_size': window_size[dataset-1] # -1 if bm otherwise window_size
}


In [None]:
# Set random seeds for data split
RANDOMSEED = 19960417#1996417#1741996 #19960417
torch.manual_seed(RANDOMSEED)
np.random.seed(RANDOMSEED)

In [None]:
# Construct dataloaders for training
traindata = torch.load(f'../data/processed/trainset{dataset_name}_w{window_size[dataset-1]}_M{upper_RUL}.pt')
use_engines, rl_engines, test_engines = torch.utils.data.random_split(traindata.engineid.unique(), [forecast_data_fraction, rl_data_fraction, test_data_fraction])

torch.save(rl_engines.dataset[rl_engines.indices], f'../data/used/rl_engines_{dataset}{run_id}.pt')

train_engines, val_engines = torch.utils.data.random_split(use_engines, [1-Forecast_parameters['val_fraction'], Forecast_parameters['val_fraction']])

train_set = torch.utils.data.Subset(traindata, np.where(np.in1d(traindata.engineid, train_engines.dataset[train_engines.indices]))[0])
val_set = torch.utils.data.Subset(traindata, np.where(np.in1d(traindata.engineid, val_engines.dataset[val_engines.indices]))[0])

forecast_dataloaders = {
'train': DataLoader(train_set, batch_size=Forecast_parameters['batch_size'], shuffle=True, num_workers=0),
'val': DataLoader(val_set, batch_size=Forecast_parameters['batch_size'], shuffle=True, num_workers=0)
}

In [None]:
val_engines.indices

In [None]:
len(test_engines.indices)

In [None]:
# Reset random seeds for training model 1
torch.manual_seed(RANDOMSEED)
np.random.seed(RANDOMSEED)
fmodel, flosses, foptim, _ = train_forecast_bm(dataset_name, run_id,**Forecast_parameters, 
                                                train_fraction=forecast_data_fraction,
                                                upper_RUL=upper_RUL,
                                                window_size=window_size[dataset-1], 
                                                dataloaders=forecast_dataloaders)

In [None]:
# Reset random seeds for training model 1
torch.manual_seed(RANDOMSEED)
np.random.seed(RANDOMSEED)
wfmodel, wflosses, wfoptim, _ = train_forecast_bm(dataset_name, run_id,**Forecast_parameters, 
                                                    train_fraction=forecast_data_fraction,
                                                    upper_RUL=upper_RUL,
                                                    window_size=window_size[dataset-1], 
                                                    dataloaders=forecast_dataloaders,
                                                    crpsweight=50)
    

In [None]:
fig, (ax, wax) = plt.subplots(1,2, figsize=(10,4), layout='tight')
ax.plot(np.arange(Forecast_parameters['num_epochs'])+1, np.asarray(flosses['train'])[:,0], zorder=2, label='Training')
ax.plot(np.arange(Forecast_parameters['num_epochs'])+1, np.asarray(flosses['val'])[:,0], zorder=3, label='Validation')
ax.grid(zorder=1)
ax.set_ylabel('Loss')
ax.set_xlabel('Epoch')
ax.set_title('Risk-Neutral Forecast')
ax.legend()

wax.plot(np.arange(Forecast_parameters['num_epochs'])+1, np.asarray(wflosses['train'])[:,0], zorder=2, label='Training')
wax.plot(np.arange(Forecast_parameters['num_epochs'])+1, np.asarray(wflosses['val'])[:,0], zorder=3, label='Validation')
wax.grid(zorder=1)
wax.set_ylabel('Loss')
wax.set_xlabel('Epoch')
wax.set_title('Risk-Averse Forecast')
wax.legend()

plt.show()

Save the forecast models

In [None]:
try:
    os.mkdir(f'../results/models/models{run_id}')
except FileExistsError:
    pass

torch.save({'model': fmodel, 'dataloader': forecast_dataloaders}, f'../results/models/models{run_id}/forecastmodel.pt')
torch.save({'model': wfmodel, 'dataloader': forecast_dataloaders}, f'../results/models/models{run_id}/w_forecastmodel.pt')

In [None]:
# Load them for evaluation
fmodel = torch.load(f'../results/models/models{run_id}/forecastmodel.pt')['model']
wfmodel = torch.load(f'../results/models/models{run_id}/w_forecastmodel.pt')['model']
forecast_dataloaders = torch.load(f'../results/models/models{run_id}/forecastmodel.pt')['dataloader']

fmodel.eval()
wfmodel.eval()

In [None]:
forecast_val_engines_ids = np.unique([forecast_dataloaders['val'].dataset.__getitem__(i)[3] for i in range(len(forecast_dataloaders['val'].dataset))])
forecast_val_engines_ids

In [None]:
for val_engine in forecast_val_engines_ids:
    forecast_test_valengine(fmodel, val_engine, forecast_dataloaders['val'].dataset.dataset)

In [None]:
for val_engine in forecast_val_engines_ids:
    forecast_test_valengine(wfmodel, val_engine, forecast_dataloaders['val'].dataset.dataset)

Test models using dedicated test set

In [None]:
test_dataset = torch.load(f'../data/processed/testsetFD00{dataset}_w{window_size[dataset-1]}_M{upper_RUL}.pt')

In [None]:
try:
    os.mkdir(f'../results/figures/testimg{run_id}-normal')
except FileExistsError:
    pass
try:
    os.mkdir(f'../results/figures/testimg{run_id}-weighted')
except FileExistsError:
    pass
f'../results/figures/testimg{run_id}-normal'

In [None]:
forecast_testmetrics, forecast_testforecasts = forecast_test_testset(fmodel, test_dataset, to_plot=1, savepath=f'../results/figures/testimg{run_id}-normal')

In [None]:
wforecast_testmetrics, wforecast_testforecasts = forecast_test_testset(wfmodel, test_dataset, to_plot=1, savepath=f'../results/figures/testimg{run_id}-weighted')

In [None]:
fig, ax = plt.subplots(layout='tight', figsize=(20,6))
lows, ups = scipy.stats.lognorm.interval(0.95, s=forecast_testforecasts[2], scale=np.exp(forecast_testforecasts[1]))
ax.errorbar(np.arange(1,len(forecast_testforecasts[0])+1)-.2, np.exp(np.array(forecast_testforecasts[1]) + 0.5*np.array(forecast_testforecasts[2])**2), yerr=np.array([np.abs(np.exp(np.array(forecast_testforecasts[1]) + 0.5*np.array(forecast_testforecasts[2])**2)-lows),np.abs(ups-np.exp(np.array(forecast_testforecasts[1]) + 0.5*np.array(forecast_testforecasts[2])**2))]), capsize=2, fmt='o', color='tab:blue', zorder=2, label='Risk Neutral 95% CI')
lows, ups = scipy.stats.lognorm.interval(0.95, s=wforecast_testforecasts[2], scale=np.exp(wforecast_testforecasts[1]))
ax.errorbar(np.arange(1,len(wforecast_testforecasts[0])+1)+.2, np.exp(np.array(wforecast_testforecasts[1]) + 0.5*np.array(wforecast_testforecasts[2])**2), yerr=np.array([np.abs(np.exp(np.array(wforecast_testforecasts[1]) + 0.5*np.array(wforecast_testforecasts[2])**2)-lows),np.abs(ups-np.exp(np.array(wforecast_testforecasts[1]) + 0.5*np.array(wforecast_testforecasts[2])**2))]), capsize=2, fmt='o', color='tab:orange', zorder=2, label='Risk Averse 95% CI')
ax.scatter(np.arange(1,len(wforecast_testforecasts[0])+1), wforecast_testforecasts[0], color='tab:green', marker='X', zorder=3, label='True RUL', edgecolor='k',s=60)
ax.grid(zorder=-1)
ax.set_xlim(0,len(forecast_testforecasts[0])+1)
ax.set_ylabel('RUL')
ax.set_xlabel('Test engine ID')
ax.legend()
plt.savefig(f'../results/figures/testimg{run_id}-normal/alltestengines.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, (ax, wax) = plt.subplots(2, sharex=True, figsize=(10, 6))
cdfvals = scipy.stats.norm.cdf((np.log(np.array(forecast_testforecasts[0])[:,0])-np.array(forecast_testforecasts[1]))/np.array(forecast_testforecasts[2]))
ax.bar(np.arange(1,len(forecast_testforecasts[0])+1), cdfvals, label='$\mathrm{Pr}(F(X)\leq \mathrm{RUL})$\n(underestimation)')
ax.bar(np.arange(1,len(forecast_testforecasts[0])+1), 1-cdfvals, bottom=cdfvals, color='tab:orange', label='$\mathrm{Pr}(F(X) > \mathrm{RUL})$\n(overestimation)')
ax.set_ylabel('Probability')
ax.set_title('Risk-Neutral Forecast')

wcdfvals = scipy.stats.norm.cdf((np.log(np.array(wforecast_testforecasts[0])[:,0])-np.array(wforecast_testforecasts[1]))/np.array(wforecast_testforecasts[2]))
wax.bar(np.arange(1,len(wforecast_testforecasts[0])+1), wcdfvals, label='$\mathrm{Pr}(F(X)\leq \mathrm{RUL})$')
wax.bar(np.arange(1,len(wforecast_testforecasts[0])+1), 1-wcdfvals, bottom=wcdfvals, color='tab:orange', label='$\mathrm{Pr}(F(X) > \mathrm{RUL})$')
wax.set_ylabel('Probability')
wax.set_xlabel('Test engine ID')
wax.set_title('Risk-Averse Forecast')

ax.legend(ncol=1, loc='upper left', bbox_to_anchor=(1, .1))
plt.show()

print('Risk\t|\tunder\t|\tover')
print(f'Neutral\t|\t{cdfvals.sum()/len(cdfvals)*100:.2f}\t|\t{np.sum(1-cdfvals)/len(cdfvals)*100:.2f}')
print(f'Averse\t|\t{wcdfvals.sum()/len(wcdfvals)*100:.2f}\t|\t{np.sum(1-wcdfvals)/len(wcdfvals)*100:.2f}')

In [None]:
fig, ax = plt.subplots(len(forecast_testmetrics.keys()), sharex=True, layout='constrained', figsize=(6,10))
for i, key in enumerate(forecast_testmetrics.keys()):
    ax[i].scatter(range(1,len(forecast_testmetrics[key])+1), forecast_testmetrics[key], color='k', marker='.', zorder=2)
    ax[i].set_ylabel(key)
    ax[i].grid(zorder=-1)
ax[-1].set_xlabel('')
plt.show()

fig, ax = plt.subplots(len(wforecast_testmetrics.keys()), sharex=True, layout='constrained', figsize=(6,10))
for i, key in enumerate(wforecast_testmetrics.keys()):
    ax[i].scatter(range(1,len(wforecast_testmetrics[key])+1), wforecast_testmetrics[key], color='k', marker='.', zorder=2)
    ax[i].set_ylabel(key)
    ax[i].grid(zorder=-1)
ax[-1].set_xlabel('')
plt.show()

In [None]:
print('\t\tRisk-Neutral\tRisk-Averse')
for key in forecast_testmetrics.keys():
    if "median" in key:
        print(f'{key}:\t{np.mean(forecast_testmetrics[key]) if not ("SF" in key) else np.sum(forecast_testmetrics[key]):.2f}'+
              f'\t\t{np.mean(wforecast_testmetrics[key]) if not ("SF" in key) else np.sum(wforecast_testmetrics[key]):.2f}')
    else:
        print(f'{key}:\t\t{np.mean(forecast_testmetrics[key]) if not ("SF" in key) else np.sum(forecast_testmetrics[key]):.2f}'+
              f'\t\t{np.mean(wforecast_testmetrics[key]) if not ("SF" in key) else np.sum(wforecast_testmetrics[key]):.2f}')

In [None]:
import src.tests.forecast_tests
reload(src.tests.forecast_tests)

In [None]:
from src.tests.forecast_tests import calibration_tests

In [None]:
quantiles, empcdf, xs, Fbar, Gbar, PIT_vals, PIT_ruls = calibration_tests(fmodel, test_dataset, device)

In [None]:
quantiles, wempcdf, xs, wFbar, wGbar, wPIT_vals, wPIT_ruls = calibration_tests(wfmodel, test_dataset, device)

In [None]:
fig, ax = plt.subplots()
ax.plot(quantiles, empcdf, marker='o', linestyle='--')
ax.plot(quantiles, wempcdf, marker='o', linestyle='--')
ax.plot(quantiles, quantiles, color='k', linestyle=':')
ax.grid()
ax.set_aspect('equal', 'box')
ax.set_xlabel('Estimated CDF')
ax.set_ylabel('Empirical CDF')
plt.savefig(f'../results/figures/testimg{run_id}-normal/margcal.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(xs, Fbar-Gbar)
ax.plot(xs, wFbar-wGbar)
#ax.plot(xs, Gbar)
ax.grid()
ax.set_xlabel('RUL')
ax.set_ylabel(r'$\overline{F}(x)-\overline{G}(x)$')
ax.set_box_aspect(1)
plt.savefig(f'../results/figures/testimg{run_id}-normal/probcal.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, (ax, axw) = plt.subplots(2, sharex=True)
ax.hist(PIT_vals, 10, density=True, zorder=2, edgecolor='black', alpha=0.8)
axw.hist(wPIT_vals, 10, density=True, color='tab:orange', zorder=2, edgecolor='black', alpha=0.8)
ax.set_title(f'Var(PIT)$={np.var(PIT_vals):.5f}\\approx\\frac{{1}}{{{1/np.var(PIT_vals):.2f}}}$')
axw.set_title(f'Var(PIT)$={np.var(wPIT_vals):.5f}\\approx\\frac{{1}}{{{1/np.var(wPIT_vals):.2f}}}$')
ax.plot([0,0,1,1],[0,1,1,0], color='black', zorder=3, linewidth=3)
axw.plot([0,0,1,1],[0,1,1,0], color='black', zorder=3, linewidth=3)

ax.grid(zorder=1)
axw.grid(zorder=1)
axw.set_xlabel('PIT')
ax.set_ylabel('Density')
axw.set_ylabel('Density')
plt.savefig(f'../results/figures/testimg{run_id}-normal/PITdiagram.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
%matplotlib inline

Distributional Reinforcement Learning for Maintenance Scheduling

In [None]:
rl_agent = train_rl(fmodel, traindata, rl_engines, val_engines,
                    seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                    **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35)

In [None]:
w_rl_agent = train_rl(wfmodel, traindata, rl_engines, val_engines,
                    seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                    **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35)

In [None]:
rl_agent_cvar = train_rl(fmodel, traindata, rl_engines, val_engines,
                    seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                    **RL_parameters, cvar_alpha=0.8, plot_interval=200, cf=8, c0=4, c1=.35)

In [None]:
w_rl_agent_cvar = train_rl(wfmodel, traindata, rl_engines, val_engines,
                    seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                    **RL_parameters, cvar_alpha=0.8, plot_interval=200, cf=8, c0=4, c1=.35)

In [None]:
rl_agent.save(f'../results/models/models{run_id}/rlmodel-neutral-mean.pt')
rl_agent_cvar.save(f'../results/models/models{run_id}/rlmodel-neutral-cvar.pt')
w_rl_agent.save(f'../results/models/models{run_id}/rlmodel-averse-mean.pt')
w_rl_agent_cvar.save(f'../results/models/models{run_id}/rlmodel-averse-cvar.pt')

In [None]:
def load_rl(loadpath, forecastmodel, dataset, engines, test_engines, seed, n_actions, planning_window, n_obs, num_frames, memory_size, batch_size, target_update, epsilon_func, vmin, vmax, atom_size, window_size, cvar_alpha=1, plot_interval=200, cf=4, c1=0.1, c0=2, cn=0.1):
    env = ForecastEnv(n_actions, planning_window,n_obs, forecastmodel, dataset, engines, test_engines, window_size, seed, cf, c1, c0, cn)

    agent = DQNAgent(env, memory_size, batch_size, target_update, epsilon_func, seed, v_min=vmin, v_max=vmax, atom_size=atom_size, cvar_alpha=cvar_alpha, load=loadpath)

    return agent


In [None]:
rl_agent = load_rl(f'../results/models/models{run_id}/rlmodel-neutral-mean.pt', 
                   fmodel, traindata, rl_engines, val_engines,
                   seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                   **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35)
rl_agent_cvar = load_rl(f'../results/models/models{run_id}/rlmodel-neutral-cvar.pt', 
                   fmodel, traindata, rl_engines, val_engines,
                   seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                   **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35, cvar_alpha=0.8)
w_rl_agent = load_rl(f'../results/models/models{run_id}/rlmodel-averse-mean.pt', 
                   wfmodel, traindata, rl_engines, val_engines,
                   seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                   **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35)
w_rl_agent_cvar = load_rl(f'../results/models/models{run_id}/rlmodel-averse-cvar.pt', 
                   wfmodel, traindata, rl_engines, val_engines,
                   seed=RANDOMSEED, # TODO: set up seeding @ top of notebook
                   **RL_parameters, plot_interval=200, cf=8, c0=4, c1=.35, cvar_alpha=0.8)

In [None]:
rl_agent.is_test = True
w_rl_agent.is_test = True
rl_agent_cvar.is_test = True
w_rl_agent_cvar.is_test = True

In [None]:
try:
    os.mkdir(f'../results/figures/RL{run_id}')
except FileExistsError:
    pass

In [None]:
import matplotlib as mpl
def plot_crossover(engineid, ruls, cvars, Pfail, actions, save=True, dset='test', model_type='neutral-mean'):
    cmap = mpl.colors.ListedColormap(['k']+list(plt.get_cmap('tab10').colors))
    fig, ax = plt.subplots(layout='tight')
    #ax.set_xscale('log')
    ax2 = ax.twinx()
    ax2.plot(Pfail, ruls, color='grey', alpha=0.5, linewidth=10)
    ax2.set_ylabel('True RUL')

    for i in range(cvars.shape[1]):
        ax.plot(Pfail, cvars[:,i], linestyle=':', marker='o', markersize=4, color=cmap(i), label='do nothing' if i==0 else f'replace {i:02d}', zorder=2 if i!=0 else 3)
    ax.grid(zorder=1)

    ax.set_ylabel('CVaR of action distribution')
    ax.set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=mpl.colors.ListedColormap(['k']+list(plt.get_cmap('tab10').colors))), 
                        ax=ax2, ticks=np.arange(0,1,1/11)+1/22, pad=0.12)
    cbar.set_ticklabels(['do nothing']+[f'replace {i:02d}' for i in range(1,11)])

    ax.set_title(f'Test engine: {engineid}')

    replaces = np.where(np.asarray(actions)!=0)[0]
    if len(replaces) != 0:
        blinex = np.min(Pfail[replaces])
        blineytext = ax.get_ylim()[1]*.8
        ax.axvline(blinex, color='k', zorder=1)
        ax.annotate('do nothing', [blinex, blineytext], ha='right', textcoords='offset points', xytext=(-10,0), zorder=10, weight='bold')
        ax.annotate('replace', [blinex, blineytext], ha='left', textcoords='offset points', xytext=(10,0), zorder=10, weight='bold')
        arrowlength = (ax.get_xlim()[1]-ax.get_xlim()[0])*.1
        ax.arrow(blinex, blineytext*.9, -arrowlength, 0, head_width=.4, head_length=3*1.5*0.001, color='k', zorder=10)
        ax.arrow(blinex, blineytext*.9, arrowlength, 0, head_width=.4, head_length=3*1.5*0.001, color='k', zorder=10)

        #ax.set_xlim(0, Pfail[replaces[np.argmin(Pfail[replaces])]+2])
    #ax.set_xlim(-0.05,0.25)
    if save:
        try:
            os.mkdir(f'../results/figures/RL{dset}img{run_id}-{model_type}')
        except FileExistsError:
            pass
        plt.savefig(f'../results/figures/RL{dset}img{run_id}-{model_type}/RLcrossover{engineid}.png', facecolor='white', bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()
    
    if len(replaces) != 0:
        return blinex
    else:
        return None

In [None]:
rl_agent.env.term_phase = False
w_rl_agent.env.term_phase = False
rl_agent_cvar.env.term_phase = False
w_rl_agent_cvar.env.term_phase = False
for engineid in range(len(rl_agent.env.test_engines)):
    states, statesCDF, ruls, rewards, actions, dists, cvars = rl_agent.test(engineid)
    w_states, statesCDF, w_ruls, w_rewards, w_actions, w_dists, w_cvars = w_rl_agent.test(engineid)
    states_cv, statesCDF_cv, ruls_cv, rewards_cv, actions_cv, dists_cv, cvars_cv = rl_agent_cvar.test(engineid)
    w_states_cv, statesCDF_cv, w_ruls_cv, w_rewards_cv, w_actions_cv, w_dists_cv, w_cvars_cv = w_rl_agent_cvar.test(engineid)

    Pfail = _lognorm_cdf(torch.tensor(RL_parameters['n_actions']-1), torch.from_numpy(states[:,-1,0]), torch.from_numpy(states[:,-1,1])).detach().cpu().numpy()
    w_Pfail = _lognorm_cdf(torch.tensor(RL_parameters['n_actions']-1), torch.from_numpy(w_states[:,-1,0]), torch.from_numpy(w_states[:,-1,1])).detach().cpu().numpy()
    Pfail_cv = _lognorm_cdf(torch.tensor(RL_parameters['n_actions']-1), torch.from_numpy(states_cv[:,-1,0]), torch.from_numpy(states_cv[:,-1,1])).detach().cpu().numpy()
    w_Pfail_cv = _lognorm_cdf(torch.tensor(RL_parameters['n_actions']-1), torch.from_numpy(w_states_cv[:,-1,0]), torch.from_numpy(w_states_cv[:,-1,1])).detach().cpu().numpy()

    plot_crossover(engineid, ruls, np.asarray(cvars), Pfail, actions, save=False)
    plot_crossover(engineid, w_ruls, np.asarray(w_cvars), w_Pfail, w_actions, save=False)
    plot_crossover(engineid, ruls_cv, np.asarray(cvars_cv), Pfail_cv, actions_cv, save=False)
    plot_crossover(engineid, w_ruls_cv, np.asarray(w_cvars_cv), w_Pfail_cv, w_actions_cv, save=False)
    

In [None]:
test_ids = test_engines.dataset[test_engines.indices]
test_ids

In [None]:
def run_test_engine(model: DQNAgent, idx:int, term: bool = True):
    model.is_test = True
    model.env.term_phase = False

    cur_data, model.env.cur_rul, _, _ = model.env.dataset.get_unit_by_id(idx)
    cur_data = cur_data.float().to(model.env.device)
    with torch.no_grad():
        pre_states = torch.cat(model.env.model(cur_data[:,model.env.dataoffset:]),dim=-1)
        model.env.states = model.env._transform_states(pre_states)
    model.env.t = 0
    model.env.terminal = 0

    state = model.env.states[0].to(model.device)
    actions, dists, cvars, rewards = [], [], [], []
    done = False
    final_rul = 0
    with torch.no_grad():
        while not done:
            eval = model.dqn(state)
            action = eval.argmax(1).cpu().numpy()[0]
            actions.append(action)
            cvar = eval.detach().cpu().numpy()[0]
            cvars.append(cvar)
            dist = model.dqn.dist(state).cpu().numpy()[0]
            dists.append(dist)

            state, reward, done = model.step(action)
            rewards.append(reward)
            if state is not None:
                state = torch.FloatTensor(state).to(model.device)
            if action != 0 and final_rul == 0:
                final_rul = model.env.cur_rul[model.env.t].numpy()
                #break
    
    actions = np.asarray(actions)
    dists = np.asarray(dists)
    cvars = np.asarray(cvars)
    rewards = np.asarray(rewards)

    try:
        #final_rul = model.env.cur_rul[model.env.t].numpy()
        ruls = model.env.cur_rul[:model.env.t].numpy()
    except IndexError:
        #final_rul = 0
        ruls = model.env.cur_rul.numpy()

    return final_rul, actions, rewards, dists, cvars, ruls, model.env.states, pre_states

In [None]:
model_dict = {'neutral-mean': rl_agent,
              'neutral-cvar': rl_agent_cvar,
              'averse-mean': w_rl_agent,
              'averse-cvar': w_rl_agent_cvar}
model_labels = model_dict.keys() # ['neutral-mean', 'neutral-cvar', 'averse-mean', 'averse-cvar']
metrics = {label: {'final ruls': [],
                   'actions': [],
                   'rewards': [],
                   'dists': [],
                   'cvars': [],
                   'crossp': [],
                   'pfails': [],
                   'termidx': [],
                   'states': []} 
                   for label in model_labels}

opt_rewards = []
for label in model_labels:
    for idx in test_engines.dataset[test_engines.indices]:
        final_rul, actions, rewards, dists, cvars, ruls, statesCDF, states = run_test_engine(model_dict[label], idx)
        if label == list(model_labels)[0]:
            opt_rewards.append((ruls[0]-RL_parameters['planning_window'])*0.1 - (4-RL_parameters['planning_window']*0.35))
        try:
            term_idx = np.where(actions != 0)[0][0]+1
        except IndexError:
            term_idx = None
        Pfail = _lognorm_cdf(torch.tensor(RL_parameters['planning_window']), states[:,-1,0], states[:,-1,1]).detach().cpu().numpy()
        crossp = plot_crossover(idx, ruls[:term_idx], cvars[:term_idx], Pfail[:term_idx], actions[:term_idx], dset='test', model_type=label, save=True)

        metrics[label]['final ruls'].append(final_rul)
        metrics[label]['actions'].append(actions)
        metrics[label]['rewards'].append(rewards)
        metrics[label]['dists'].append(dists)
        metrics[label]['cvars'].append(cvars)
        metrics[label]['crossp'].append(crossp)
        metrics[label]['pfails'].append(Pfail)
        metrics[label]['termidx'].append(term_idx)
        metrics[label]['states'].append(states[:,-1].cpu().numpy())

    

In [None]:
for label in model_labels:
    for key in metrics[label].keys():
        metrics[label][key] = np.asarray(metrics[label][key])

In [None]:
import matplotlib.lines
import matplotlib.patheffects
import matplotlib.colors
handles = [matplotlib.lines.Line2D([0],[0], color='k', label='Forecast 95% CI',path_effects=[matplotlib.patheffects.Stroke(linewidth=8, foreground=matplotlib.colors.to_rgba('tab:green',.5)),matplotlib.patheffects.Normal()]),
           matplotlib.lines.Line2D([0],[0], color='tab:orange', label='True RUL', linewidth=3)]
cs = np.array(['k', 'tab:red', 'tab:blue'])

for eng in range(len(test_engines)):
    fig, ax = plt.subplots(2, 1, layout='tight', sharex=True, figsize=(8,8))

    ax[0].xaxis.set_tick_params(which='both', labelbottom=True)
    ax[0].invert_xaxis()
    ax[0].set_ylabel('RUL Prediction')
    ax[1].set_ylabel('RUL Prediction')
    ax[1].set_xlabel('True RUL')
    alphas = np.arange(0.95, 0.05-.01, -.01)[::-1]
    for i, label in enumerate(model_labels):
        if i % 2 == 0:
            axi = i//2
            ruls = np.arange(metrics[label]['states'][eng].shape[0])[::-1]
            ax[axi].set_yticks(np.arange(0,70,10))
            ax[axi].axhline(0, color='k', linewidth=1)
            ax[axi].plot(ruls, ruls, color='tab:orange', linestyle='--', zorder=4, linewidth=3)
            ax[axi].plot(ruls, np.minimum(128, ruls), color='tab:orange', linestyle='-', zorder=5, linewidth=3)
        
            pc = 'tab:green' if i == 0 else 'tab:green'
            lows, ups = scipy.stats.lognorm.interval(alphas, s=metrics[label]['states'][eng][:,1,np.newaxis], scale=np.exp(metrics[label]['states'][eng][:,0,np.newaxis]))
            means = np.exp(metrics[label]['states'][eng][:,0] + metrics[label]['states'][eng][:,1]**2 /2)
            ax[axi].plot(ruls, means, color='k', lw=1, zorder=3)
            ax[axi].fill_between(ruls, means, lows[:,0], color=pc ,alpha=1-(alphas[0]*(.85-.1)+.1), linewidth=0, zorder=2)
            ax[axi].fill_between(ruls, means, ups[:,0], color=pc ,alpha=1-(alphas[0]*(.85-.1)+.1), linewidth=0, zorder=2)
            for j in range(alphas.shape[0]-1):
                ax[axi].fill_between(ruls, lows[:,j], lows[:,j+1], color=pc ,alpha=1-(alphas[j+1]*(.85-.1)+.1), linewidth=0, zorder=2)
                ax[axi].fill_between(ruls, ups[:,j], ups[:,j+1], color=pc ,alpha=1-(alphas[j+1]*(.85-.1)+.1), linewidth=0, zorder=2)
        ax[axi].bar(ruls[-40:], 5, width=1, bottom=-6-(i%2)*6, color=cs[metrics[label]['actions'][eng]][-40:], zorder=3, align='edge')
        ax[axi].annotate(label, (ruls[-41]+.25,2.5-6-(i%2)*6), ha='right', va='center', zorder=2)
        ax[axi].grid(zorder=-2)
        ax[axi].legend(handles=handles, loc='upper right')
    ax[0].set_title('Neutral Forecast', y=1.0, pad=-20, loc='center')
    ax[1].set_title('Averse Forecast', y=1.0, pad=-20, loc='center')
    ax[0].set_xlim(50,0)
    ax[0].set_ylim(-12.2,60)
    ax[1].set_ylim(-12.2,60)
    plt.savefig(f'../results/figures/RL{run_id}/engineactions{eng}.png', facecolor='white', bbox_inches='tight')
    if eng == 0:
        plt.show()
    else:
        plt.close()

In [None]:
fig, ax = plt.subplots()
ax.hist([metrics[label]['crossp'][np.where(metrics[label]['crossp'] != None)] for label in model_labels], bins=np.linspace(0,1,20), zorder=2, edgecolor='k', label=list(model_labels))
ax.legend()
ax.grid(zorder=1)
ax.set_xlabel('$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
plt.show()

In [None]:
fig, ax = plt.subplots(2,2,sharex=True,sharey=True)
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
ax = ax.flatten()
for i, lb in enumerate(model_labels):
    ax[i].hist(metrics[lb]['final ruls'], bins=np.arange(0,1+np.max([metrics[label]['final ruls'] for label in model_labels]))-.5, zorder=2, label=lb, histtype='bar', color=colors[i], edgecolor='k')
    ax[i].grid(zorder=1)
    ax[i].set_title(lb)
ax[-2].set_xlabel('Final RUL')
ax[-1].set_xlabel('Final RUL')
plt.show()

In [None]:
lifeleft = {label: [] for label in model_labels}
for label in model_labels:
    for i in range(len(metrics[label]['final ruls'])):
        if metrics[label]['termidx'][i] == None and metrics[label]['actions'][i][-1] > 0:
            if metrics[label]['actions'][i][-1] == 1:
                lifeleft[label].append(metrics[label]['final ruls'][i])
            else:
                lifeleft[label].append(metrics[label]['final ruls'][i]-RL_parameters['planning_window'])
        else:
            if metrics[label]['actions'][i][metrics[label]['termidx'][i]-1] == 2:
                lifeleft[label].append(metrics[label]['final ruls'][i]-RL_parameters['planning_window'])
            else:
                lifeleft[label].append(metrics[label]['final ruls'][i])




In [None]:
fig, ax = plt.subplots(2,2, sharex=False, sharey=False, layout='tight')
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
ymax = 0
ax = ax.flatten()
for i, lb in enumerate(model_labels):
    ax[i].hist(lifeleft[lb], bins=np.arange(-1, 1+np.max([lifeleft[label] for label in model_labels]))-.5, zorder=2, label=lb, edgecolor='k', color=colors[i])
    ax[i].grid(zorder=1)
    ax[i].set_title(lb)
    ax[i].axvline(0.5, color='k', linestyle='-', linewidth=2)
    xlim = ax[i].get_xlim()
    ax[i].axvspan(xlim[0], 0.5, *ax[i].get_ylim(), color='grey', alpha=0.4, hatch='x')
    ax[i].set_xlim(*xlim)
    ax[i].set_xlabel('RUL at scheduled replacements')
    ymax = max(ymax, ax[i].get_ylim()[1])
for i in range(len(ax)):
    ax[i].set_ylim(0,ymax)
plt.savefig(f'../results/figures/RL{run_id}/lifeleft.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
figs = []
axes = []
ymax=0
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
for i, lb in enumerate(model_labels):    
    fig, ax = plt.subplots(layout='tight', figsize=(4,3))
    
    ax.hist(lifeleft[lb], bins=np.arange(-1, 1+np.max([lifeleft[label] for label in model_labels]))-.5, zorder=2, label=lb, edgecolor='k', color=colors[i])
    ax.grid(zorder=1)
    #ax.set_title(lb)
    ax.axvline(0.5, color='k', linestyle='-', linewidth=2)
    xlim = ax.get_xlim()
    ax.axvspan(xlim[0], 0.5, *ax.get_ylim(), color='grey', alpha=0.4, hatch='x')
    ax.set_xlim(*xlim)
    ax.set_xlabel('RUL at scheduled replacements')
    ymax = max(ymax, ax.get_ylim()[1])
    figs.append(fig)
    axes.append(ax)
for i in range(len(axes)):
    axes[i].set_ylim(0,ymax)

for i, lb in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/lifeleft-{lb}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
print('\t\tmean\tstd\tmax\tmin')
for lb in model_labels:
    print(f'{lb}\t{np.mean(lifeleft[lb]):.2f}\t{np.std(lifeleft[lb]):.2f}\t{np.max(lifeleft[lb]):.2f}\t{np.min(lifeleft[lb]):.2f}')

In [None]:
sumrewards = {label: [] for label in model_labels}
for label in model_labels:
    for i in range(len(metrics[label]['rewards'])):
        sumrewards[label].append(metrics[label]['rewards'][i][:metrics[label]['termidx'][i]].sum())



In [None]:
_, bins = np.histogram(np.concatenate([opt_rewards, *[sumrewards[lb] for lb in model_labels]]), bins=int(np.sqrt(4*len(opt_rewards))))

In [None]:
fig, ax = plt.subplots(2,2, sharex=False, sharey=False, layout='tight')
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
ax = ax.flatten()
for i, lb in enumerate(model_labels):
    ax[i].hist([sumrewards[lb], opt_rewards], bins=bins, density=True, edgecolor='k', zorder=3, color=[colors[i], 'k'])
    if i == 0:
        ylim = ax[i].get_ylim()[1]
    ax[i].vlines([np.mean(sumrewards[lb]), np.mean(opt_rewards)], ymin=0, ymax=ylim*1.4, colors=[colors[i], 'k'], zorder=2, linewidth=2, alpha=0.8, linestyle='--')
    ax[i].grid(zorder=1)
    ax[i].set_title(lb)
    ax[i].set_xlabel('Total reward')
#ax[-1].set_xlabel('Reward')
plt.savefig(f'../results/figures/RL{run_id}/rewards.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
figs = []
axes = []
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
for i, lb in enumerate(model_labels):    
    fig, ax = plt.subplots(layout='tight', figsize=(4,3))
    ax.hist([sumrewards[lb], opt_rewards], bins=bins, density=True, edgecolor='k', zorder=3, color=[colors[i], 'k'])
    if i == 0:
        ylim = ax.get_ylim()[1]
    ax.vlines([np.mean(sumrewards[lb]), np.mean(opt_rewards)], ymin=0, ymax=ylim*1.4, colors=[colors[i], 'k'], zorder=2, linewidth=2, alpha=0.8, linestyle='--')
    ax.grid(zorder=1)
    #ax.set_title(lb)
    ax.set_xlabel('Total reward')
    figs.append(fig)
    axes.append(ax)

for i, lb in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/rewards-{lb}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
print(f'Optimal reward:\t{np.mean(opt_rewards):.2f}')
for lb in model_labels:
    print(f'{lb}:\t{np.mean(sumrewards[lb]):.2f}')

In [None]:
fig, ax = plt.subplots(2,2, sharex=True, sharey=True)
ax = ax.flatten()
for i, lb in enumerate(model_labels):
    ax[i].hist(np.asarray(sumrewards[lb])-np.asarray(opt_rewards), density=True, edgecolor='k', zorder=2, color=colors[i])
    mean = np.mean(np.asarray(sumrewards[lb])-np.asarray(opt_rewards))
    ax[i].axvline(mean, color=colors[i], linestyle='--', zorder=1, alpha=0.7)
    ax[i].grid(zorder=1)
    ax[i].set_title(f'{lb}: {mean:.2f}')
ax[-2].set_xlabel('Reward - Optimal Reward')
ax[-1].set_xlabel('Reward - Optimal Reward')
plt.show()

In [None]:
import matplotlib.colors
cmap = matplotlib.colors.ListedColormap(['k','tab:red','tab:blue'])

In [None]:
bins = np.arange(0,1+.1,.1)
fig, ax = plt.subplots(2,2, layout='constrained', figsize=(8,6))
ax = ax.flatten()
for k, label in enumerate(model_labels):
    ax[k].set_title(label)
    ax[k].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax[k].set_yticks([])
    ax[k].set_ylim(0,1.1)
    pf = np.concatenate(metrics[label]['pfails'])
    act = np.concatenate(metrics[label]['actions'])
    bar_heights = []
    for i, (low, up) in enumerate(zip(bins[:-1], bins[1:])):
        bar_heights.append([len(pf[np.logical_and(act==j, np.logical_and(pf>=low,pf<up if i < len(bins)-2 else pf<=up))]) for j in range(3)])
        s = sum(bar_heights[i])
        bot = 0
        for j in range(len(bar_heights[i])):
            bar_heights[i][j] /= s
            ax[k].bar((up+low)/2, bar_heights[i][j], width=0.09, bottom=bot, color=cmap(j))
            bot += bar_heights[i][j]
        ax[k].text((up+low)/2, bot, f'{s}', ha='center', va='bottom', fontsize=8)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=(np.arange(1/3/2,1,1/3)))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('actions', loc='left')
plt.show()


In [None]:
bins = np.arange(0,1+.1,.1)
fig, ax = plt.subplots(2,2, layout='constrained', figsize=(8,6))
ax = ax.flatten()
for k, label in enumerate(model_labels):
    ax[k].set_title(label)
    ax[k].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax[k].set_yticks([])
    ax[k].set_ylim(0,1.1)
    pf = np.concatenate([metrics[label]['pfails'][i][:metrics[label]['termidx'][i]] for i in range(len(metrics[label]['pfails']))])
    act = np.concatenate([metrics[label]['actions'][i][:metrics[label]['termidx'][i]] for i in range(len(metrics[label]['actions']))])
    bar_heights = []
    for i, (low, up) in enumerate(zip(bins[:-1], bins[1:])):
        bar_heights.append([len(pf[np.logical_and(act==j, np.logical_and(pf>=low,pf<up if i < len(bins)-2 else pf<=up))]) for j in range(3)])
        s = sum(bar_heights[i])
        bot = 0
        if s != 0:
            for j in range(len(bar_heights[i])):
                bar_heights[i][j] /= s
                ax[k].bar((up+low)/2, bar_heights[i][j], width=0.09, bottom=bot, color=cmap(j))
                bot += bar_heights[i][j]
        else:
            ax[k].bar((up+low)/2, 1, width=0.09, alpha=0)
        ax[k].text((up+low)/2, 1, f'{s}', ha='center', va='bottom', fontsize=8)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=(np.arange(1/3/2,1,1/3)))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('actions', loc='left')
plt.show()


In [None]:
testid = test_ids[0]
for lb in model_labels:
    IPython.display.display(IPython.display.Image(f'../results/figures/RLtestimg{run_id}-{lb}/RLcrossover{testid}.png'))

In [None]:
import src.tests.rl_tests
reload(src.tests.rl_tests)

In [None]:
from src.tests.rl_tests import rl_engine_anim_allmodels

In [None]:
rl_agent.env.test_engines.dataset[rl_agent.env.test_engines.indices]

In [None]:
try:
    os.mkdir(f'../results/figures/RLVideo{run_id}')
except FileExistsError:
    pass
print(test_engines.dataset[test_engines.indices])
for idx in test_engines.dataset[test_engines.indices]:
    if not os.path.exists(f'../results/figures/RLVideo{run_id}/{idx}-ani.mp4'):
        print(idx)
        rl_engine_anim_allmodels(idx, rl_agent, rl_agent_cvar, w_rl_agent, w_rl_agent_cvar, savepath=f'../results/figures/RLVideo{run_id}')

In [None]:
fig, ax = plt.subplots(2,2,sharex=False,sharey=False, layout='constrained', figsize=(8,6))
cmap = matplotlib.colors.ListedColormap(['k','tab:red','tab:blue'])
ax = ax.flatten()
for i, label in enumerate(model_labels):
    for j in range(10):
        msk = metrics[label]['actions'][j] != 0
        ax[i].scatter(np.exp(metrics[label]['states'][j][:,0][msk]+0.5*metrics[label]['states'][j][:,1][msk]**2), 
                      np.sqrt((np.exp(metrics[label]['states'][j][:,1][msk]**2)-1)*np.exp(2*metrics[label]['states'][j][:,0][msk]+metrics[label]['states'][j][:,1][msk]**2)), 
                      color=cmap(metrics[label]['actions'][j][msk]), s=50, alpha=1, zorder=3, edgecolor='k', linewidth=.1)
        ax[i].scatter(np.exp(metrics[label]['states'][j][:,0][~msk]+0.5*metrics[label]['states'][j][:,1][~msk]**2), 
                      np.sqrt((np.exp(metrics[label]['states'][j][:,1][~msk]**2)-1)*np.exp(2*metrics[label]['states'][j][:,0][~msk]+metrics[label]['states'][j][:,1][~msk]**2)), 
                      color=cmap(metrics[label]['actions'][j][~msk]), s=3, alpha=1, zorder=2)
    ax[i].grid(zorder=1)
    ax[i].set_title(label)
    ax[i].set_ylabel('std')
    #ax[2].set_ylabel('std')
    ax[i].set_xlabel('mean')
    #ax[3].set_xlabel('mean')
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=np.arange(0.25,1,.5))
cbar.set_ticklabels(['do nothing']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('Action', loc='left')
plt.savefig(f'../results/figures/RL{run_id}/testengine-traj.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.lines

figs, axes = [], []


for i, label in enumerate(model_labels):
    fig, ax = plt.subplots(layout='tight', figsize=(5,4))
    for j in range(len(metrics[label]['states'])):
        msk = metrics[label]['actions'][j] != 0
        ax.scatter(np.exp(metrics[label]['states'][j][:,0][msk]+0.5*metrics[label]['states'][j][:,1][msk]**2), 
                      np.sqrt((np.exp(metrics[label]['states'][j][:,1][msk]**2)-1)*np.exp(2*metrics[label]['states'][j][:,0][msk]+metrics[label]['states'][j][:,1][msk]**2)), 
                      color=cmap(metrics[label]['actions'][j][msk]), s=50, alpha=1, zorder=3, edgecolor='k', linewidth=.1)
        ax.scatter(np.exp(metrics[label]['states'][j][:,0][~msk]+0.5*metrics[label]['states'][j][:,1][~msk]**2), 
                      np.sqrt((np.exp(metrics[label]['states'][j][:,1][~msk]**2)-1)*np.exp(2*metrics[label]['states'][j][:,0][~msk]+metrics[label]['states'][j][:,1][~msk]**2)), 
                      color=cmap(metrics[label]['actions'][j][~msk]), s=3, alpha=1, zorder=2)
    ax.grid(zorder=1)
    ax.set_ylabel('Standard deviation')
    ax.set_xlabel('Mean')
    dn = matplotlib.lines.Line2D([0],[0], label='Do nothing', color='k', marker='o', linestyle='')
    rn = matplotlib.lines.Line2D([0],[0], label='Replace now', color='tab:red', marker='o', linestyle='')
    r10 = matplotlib.lines.Line2D([0],[0], label='Replace in 10', color='tab:blue', marker='o', linestyle='')
    ax.legend(handles=[dn, rn, r10])
    figs.append(fig)
    axes.append(axes)
for i, label in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/testengine-traj-{label}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
train_states = {'neutral': [], 'averse': []}
train_means = {'neutral': None, 'averse': None}
train_sds = {'neutral': None, 'averse': None}
mins, maxs = {'neutral': [], 'averse': []}, {'neutral': [], 'averse': []}
for label in list(model_labels)[::2]:
    labelidx = label.split('-')[0]
    for train_engineid in model_dict[label].env.engines:
        cur_data, _, _, _ = model_dict[label].env.dataset.get_unit_by_id(train_engineid)
        cur_data = cur_data.float().to(model_dict[label].env.device)
        states = torch.cat(model_dict[label].env.model(cur_data[:,model_dict[label].env.dataoffset:]),dim=-1)[:,-1].detach().cpu().numpy()
        train_states[labelidx].append(states)
    train_mean = np.concatenate([np.exp(train_states[labelidx][i][:,0]+0.5*train_states[labelidx][i][:,1]**2) for i in range(len(model_dict[label].env.engines))])
    train_sd = np.concatenate([np.sqrt((np.exp(train_states[labelidx][i][:,1]**2)-1)*np.exp(2*train_states[labelidx][i][:,0]+train_states[labelidx][i][:,1]**2)) for i in range(len(model_dict[label].env.engines))])
    train_means[labelidx] = train_mean
    train_sds[labelidx] = train_sd
    for i in np.arange(1,129,1):
        try:
            mins[labelidx].append(train_sd[np.logical_and(train_mean >= i, train_mean < i+1)].min())
        except ValueError:
            mins[labelidx].append(-np.inf)
        try:
            maxs[labelidx].append(train_sd[np.logical_and(train_mean >= i, train_mean < i+1)].max())
        except ValueError:
            maxs[labelidx].append(np.inf)
    mins[labelidx] = np.asarray(mins[labelidx])
    maxs[labelidx] = np.asarray(maxs[labelidx])

In [None]:
def grid_actions(agent, states):
    with torch.no_grad():
        eval = agent.dqn(states)
        action = eval.argmax(1).cpu().numpy()
    states = states.detach().cpu()
    return action
    

In [None]:
import gc
states = None
del train_states
with torch.no_grad():
    torch.cuda.empty_cache()
gc.collect()

In [None]:
Mus = torch.log(torch.linspace(0.001, 128, 500))
Sigs = torch.log(torch.linspace(np.exp(0.001), np.exp(1.5), 500))
grid_mu, grid_sig = torch.meshgrid(Mus, Sigs, indexing='ij')
pre_states = torch.vstack([grid_mu.flatten(), grid_sig.flatten()]).T.unsqueeze(1)
states = rl_agent.env._transform_states(pre_states).to(rl_agent.device)
grid_actions_dict = {label: grid_actions(model_dict[label], states) for label in model_labels}
states = None
torch.cuda.empty_cache()

In [None]:
means = np.exp(grid_mu+0.5*(grid_sig**2)).flatten()
stds = np.sqrt((np.exp(grid_sig**2)-1)*np.exp(2*grid_mu+grid_sig**2)).flatten()
msk = np.logical_and(means.numpy()<=128,stds.numpy()<=30)

In [None]:
import scipy.interpolate
xs, ys = np.mgrid[0:128:1000j,0:30:1000j]
resampled_action_dict = {label: scipy.interpolate.griddata((means[msk],stds[msk]), grid_actions_dict[label][msk], (xs,ys), method='nearest') for label in model_labels}

In [None]:
import scipy.spatial
hull = scipy.spatial.ConvexHull(np.vstack([means[msk].numpy(),stds[msk].numpy()]).T)

In [None]:
import matplotlib.path
poly_verts = [(means[msk][idx].item(),stds[msk][idx].item()) for idx in hull.vertices]
poly_verts.append(poly_verts[0])
poly_path = matplotlib.path.Path(poly_verts)
poly_msk = poly_path.contains_points(np.vstack((xs.flatten(),ys.flatten())).T)
poly_msk = poly_msk.reshape(xs.shape)

In [None]:
test_points = np.concatenate([np.concatenate(metrics[label]['states']) for label in model_labels])
test_means = np.exp(test_points[:,0]+0.5*test_points[:,1]**2)
test_stds = np.sqrt(np.exp(2*test_points[:,0]+test_points[:,1]**2)*(np.exp(test_points[:,1]**2)-1))

In [None]:
test_hull = scipy.spatial.ConvexHull(np.vstack((test_means,test_stds)).T)
test_poly_verts = [(test_means[idx],test_stds[idx]) for idx in test_hull.vertices]
test_poly_verts.append(test_poly_verts[0])
test_poly_path = matplotlib.path.Path(test_poly_verts)
test_poly_msk = test_poly_path.contains_points(np.vstack((means[msk].flatten(),stds[msk].flatten())).T)
test_poly_verts = np.asarray(test_poly_verts)

In [None]:
train_hull = {key: scipy.spatial.ConvexHull(np.vstack((train_means[key], train_sds[key])).T) for key in ['neutral', 'averse']}
train_poly_verts = {key:[(train_means[key][idx],train_sds[key][idx]) for idx in train_hull[key].vertices] for key in ['neutral', 'averse']}
train_poly_verts['neutral'].append(train_poly_verts['neutral'][0])
train_poly_verts['averse'].append(train_poly_verts['averse'][0])
train_poly_path = {key: matplotlib.path.Path(train_poly_verts[key]) for key in ['neutral','averse']}
train_poly_msk = {key: train_poly_path[key].contains_points(np.vstack((means[msk].flatten(),stds[msk].flatten())).T) for key in ['neutral', 'averse']}
train_poly_verts['neutral'] = np.asarray(train_poly_verts['neutral'])
train_poly_verts['averse'] = np.asarray(train_poly_verts['averse'])

train_poly_msk_int = {key: train_poly_path[key].contains_points(np.vstack((xs.flatten(),ys.flatten())).T).reshape(xs.shape) for key in ['neutral','averse']}
min_train_msk = {key: train_poly_path[key].contains_points(np.vstack((np.arange(1,129,1), mins[key])).T) for key in ['neutral','averse']}
max_train_msk = {key: train_poly_path[key].contains_points(np.vstack((np.arange(1,129,1), maxs[key])).T) for key in ['neutral','averse']}

In [None]:
min_train_msk['neutral'], max_train_msk['neutral']

In [None]:
scatter_test_poly_msk = test_poly_path.contains_points(np.vstack((means,stds)).T)
test_poly_msk_int = test_poly_path.contains_points(np.vstack((xs.flatten(),ys.flatten())).T)
test_poly_msk_int = test_poly_msk_int.reshape(xs.shape)

In [None]:
fig, ax = plt.subplots(2,2, sharex=False, sharey=False, layout='constrained', figsize=(8,6))
ax = ax.flatten()
for i, label in enumerate(model_labels):
    key = label.split('-')[0]
    masked_actions = resampled_action_dict[label].astype(float)
    masked_actions[~poly_msk] = np.nan
    masked_actions_out = masked_actions.copy()
    masked_actions_out[train_poly_msk_int[key]] = np.nan
    masked_actions[~train_poly_msk_int[key]] = np.nan
    ax[i].pcolormesh(xs,ys, masked_actions, cmap=cmap, vmin=0,vmax=2)
    ax[i].pcolormesh(xs,ys, masked_actions_out, cmap=cmap, vmin=0,vmax=2, alpha=0.1)
    ax[i].set_title(label)
    ax[i].plot(train_poly_verts[key][:,0], train_poly_verts[key][:,1], color='white', alpha=1, linestyle='-', linewidth=2)
    ax[i].set_xlim(0,40)
    ax[i].set_ylim(0,10)
    ax[i].set_xlabel('mean')
    ax[i].set_ylabel('std')

    ax[i].plot(np.arange(1,129,1), mins[label.split('-')[0]], linestyle='--', color='white')
    ax[i].plot(np.arange(1,129,1), maxs[label.split('-')[0]], linestyle='--', color='white')
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=np.arange(1/3/2,1,1/3))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('Action', loc='left')
#ax[0].set_ylabel('std')
#ax[2].set_ylabel('std')
#ax[2].set_xlabel('mean')
#ax[3].set_xlabel('mean')
plt.savefig(f'../results/figures/RL{run_id}/decisionfield.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.patches
figs, axes = [], []

for i, label in enumerate(model_labels):
    key = label.split('-')[0]
    fig, ax = plt.subplots(layout='tight', figsize=(5,4))
    masked_actions = resampled_action_dict[label].astype(float)
    masked_actions[~poly_msk] = np.nan
    masked_actions_out = masked_actions.copy()
    masked_actions_out[train_poly_msk_int[key]] = np.nan
    masked_actions[~train_poly_msk_int[key]] = np.nan
    ax.pcolormesh(xs,ys, masked_actions, cmap=cmap, vmin=0,vmax=2)
    ax.pcolormesh(xs,ys, masked_actions_out, cmap=cmap, vmin=0,vmax=2, alpha=0.2)
    #ax.set_title(label)
    ax.plot(train_poly_verts[key][:,0], train_poly_verts[key][:,1], color='white', alpha=1, linestyle='-', linewidth=2)
    ax.set_xlim(0,40)
    ax.set_ylim(0,10)
    ax.set_xlabel('Mean')
    ax.set_ylabel('Standard Deviation')
    dn = matplotlib.patches.Patch(label='Do nothing', color='k')
    rn = matplotlib.patches.Patch(label='Replace now', color='tab:red')
    r10 = matplotlib.patches.Patch(label='Replace in 10', color='tab:blue')
    ax.legend(handles=[dn,rn,r10], loc='upper left')
    ax.plot(np.arange(1,129,1), mins[label.split('-')[0]], linestyle='--', color='white')
    ax.plot(np.arange(1,129,1), maxs[label.split('-')[0]], linestyle='--', color='white')
    figs.append(fig)
    axes.append(ax)

for i, label in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/decisionfield-{label}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(2,2,sharex=False,sharey=False,layout='constrained', figsize=(8,6))
ymax = 0
ax = ax.flatten()
for i, label in enumerate(model_labels):
    n, _, _ = ax[i].hist([np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=np.arange(0,250,1), color=cmap([0,1,2]), histtype='barstacked', linewidth=1, zorder=3)
    #ax[i].hist([np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=np.arange(0,250,1), edgecolor=cmap([0,1,2]), histtype='barstacked', linewidth=1, zorder=3)
    #ax[i].hist([np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=np.arange(0,250,1), color=cmap([0,1,2]), histtype='stepfilled', alpha=.3, zorder=2)
    #ax[i].axvline(RL_parameters['planning_window'], color='g', linestyle='--', zorder=2, alpha=.5, linewidth=2)
    ax[i].grid(zorder=4, color='w')
    ax[i].set_title(label)
    ax[i].set_xlim(0,50)
    ymax = max(ymax, n.max())
    ax[i].set_xlabel('True RUL')
    ax[i].set_ylabel('Action frequency')
for axi in ax:
    axi.set_ylim(0,ymax)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=np.arange(1/3/2,1,1/3))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('Action', loc='left')
#ax[2].set_xlabel('True RUL')
#ax[3].set_xlabel('True RUL')
#ax[0].set_ylabel('Frequency of action')
#ax[2].set_ylabel('Frequency of action')
plt.savefig(f'../results/figures/RL{run_id}/action-ruldist.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
for label in model_labels:
    print(f'{label}:')
    act_hists = [np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)]
    print(f'\tPr(action=0|RUL<=10) = {np.count_nonzero(act_hists[0] <= 10)/act_hists[0].shape[0]:.5f}')
    print(f'\tPr(action=1|RUL>=10) = {np.count_nonzero(act_hists[1] >= 10)/act_hists[1].shape[0]:.5f}')
    print(f'\tPr(action=2|RUL<10) = {np.count_nonzero(act_hists[2] < 10)/act_hists[2].shape[0]:.5f}')

In [None]:
figs, axes = [], []
ymax = 0
for i, label in enumerate(model_labels):
    fig, ax = plt.subplots(layout='tight', figsize=(5,4))
    n,_,_, ax.hist([np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=np.arange(0,250,1), color=cmap([0,1,2]), histtype='barstacked', linewidth=2, zorder=3)
    #ax.hist([np.concatenate([np.arange(0,act_arr.shape[0])[::-1] for act_arr in metrics[label]['actions']])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=np.arange(0,250,1), color=cmap([0,1,2]), histtype='stepfilled', alpha=.3, zorder=2)
    #ax[i].axvline(RL_parameters['planning_window'], color='g', linestyle='--', zorder=2, alpha=.5, linewidth=2)
    ax.grid(zorder=1)
    #ax.set_title(label)
    ax.set_xlim(0,40)
    ymax = max(ymax, n.max())
    ax.set_xlabel('True RUL')
    ax.set_ylabel('Action frequency')
    dn = matplotlib.patches.Patch(label='Do nothing', edgecolor='k', facecolor=matplotlib.colors.to_rgba('k',1), linewidth=2)
    rn = matplotlib.patches.Patch(label='Replace now', edgecolor='tab:red', facecolor=matplotlib.colors.to_rgba('tab:red',1), linewidth=2)
    r10 = matplotlib.patches.Patch(label='Replace in 10', edgecolor='tab:blue', facecolor=matplotlib.colors.to_rgba('tab:blue',1), linewidth=2)
    ax.legend(handles=[dn,rn,r10], loc='lower right')
    figs.append(fig)
    axes.append(ax)
for axi in axes:
    axi.set_ylim(0,ymax)
for i, label in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/action-ruldist-{label}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(2,2,sharex=False,sharey=False,layout='constrained', figsize=(8,6))
ymax = 0
ax = ax.flatten()
for i, label in enumerate(model_labels):
    ax[i].hist([np.concatenate(metrics[label]['pfails'])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=20, color=cmap([0,1,2]), histtype='step', linewidth=2, zorder=3)
    ax[i].hist([np.concatenate(metrics[label]['pfails'])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=20, color=cmap([0,1,2]), histtype='stepfilled', alpha=.3, zorder=2)
    ax[i].grid(zorder=1)
    ax[i].set_title(label)
    ax[i].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax[i].set_ylabel('Frequency of action')
    ymax = max(ymax, ax[i].get_ylim()[1])
for axi in ax:
    axi.set_ylim(0,ymax)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=np.arange(1/3/2,1,1/3))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('Action', loc='left')
#ax[2].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
#ax[3].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
#ax[0].set_ylabel('Frequency of action')
#ax[2].set_ylabel('Frequency of action')

plt.savefig(f'../results/figures/RL{run_id}/action-pr10dist-test.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
figs, axes = [], []
ymax = 0
for i, label in enumerate(model_labels):
    fig, ax = plt.subplots(layout='tight', figsize=(5,4))
    ax.hist([np.concatenate(metrics[label]['pfails'])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=20, color=cmap([0,1,2]), histtype='step', linewidth=2, zorder=3)
    ax.hist([np.concatenate(metrics[label]['pfails'])[np.concatenate(metrics[label]['actions'])==i] for i in range(3)], density=True, bins=20, color=cmap([0,1,2]), histtype='stepfilled', alpha=.3, zorder=2)
    ax.grid(zorder=1)
    #ax.set_title(label)
    ax.set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax.set_ylabel('Frequency of action')
    ymax = max(ymax, ax.get_ylim()[1])
    dn = matplotlib.patches.Patch(label='Do nothing', edgecolor='k', facecolor=matplotlib.colors.to_rgba('k',.3), linewidth=2)
    rn = matplotlib.patches.Patch(label='Replace now', edgecolor='tab:red', facecolor=matplotlib.colors.to_rgba('tab:red',.3), linewidth=2)
    r10 = matplotlib.patches.Patch(label='Replace in 10', edgecolor='tab:blue', facecolor=matplotlib.colors.to_rgba('tab:blue',.3), linewidth=2)
    ax.legend(handles=[dn,rn,r10], loc='upper center')
    figs.append(fig)
    axes.append(ax)
for axi in axes:
    axi.set_ylim(0,ymax)
for i, label in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/action-pr10dist-{label}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
pf = 0.5*(1+scipy.special.erf((np.log(10)-grid_mu.flatten()[msk][test_poly_msk])/(grid_sig.flatten()[msk][test_poly_msk]*np.sqrt(2)))).numpy()
ymax = 0
fig, ax = plt.subplots(2,2,sharex=True,sharey=True,layout='constrained', figsize=(8,6))
ax = ax.flatten()
for i, label in enumerate(model_labels):
    ax[i].hist([pf[grid_actions_dict[label][msk][test_poly_msk]==i] for i in range(3)], density=True, bins=20, color=cmap([i for i in range(3)]), histtype='step', linewidth=2, zorder=3)
    ax[i].hist([pf[grid_actions_dict[label][msk][test_poly_msk]==i] for i in range(3)], density=True, bins=20, color=cmap([i for i in range(3)]), histtype='stepfilled', zorder=2, alpha=.3)
    ax[i].grid(zorder=1)
    ax[i].set_title(label)
    ax[i].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax[i].set_ylabel('Frequency of action')
    ymax = max(ymax, ax[i].get_ylim()[1])
for axi in ax:
    axi.set_ylim(0,ymax)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax.ravel(), ticks=np.arange(1/3/2,1,1/3))
cbar.set_ticklabels(['do nothing']+['replace now']+[f'replace {RL_parameters["planning_window"]:02d}'])
cbar.ax.set_title('Action', loc='left')
#ax[2].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
#ax[3].set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
#ax[0].set_ylabel('Frequency of action')
#ax[2].set_ylabel('Frequency of action')
plt.savefig(f'../results/figures/RL{run_id}/action-pr10dist-grid.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
figs, axes = [], []
ymax = 0
pf = 0.5*(1+scipy.special.erf((np.log(10)-grid_mu.flatten()[msk][test_poly_msk])/(grid_sig.flatten()[msk][test_poly_msk]*np.sqrt(2)))).numpy()
for i, label in enumerate(model_labels):
    fig, ax = plt.subplots(layout='tight', figsize=(5,4))
    ax.hist([pf[grid_actions_dict[label][msk][test_poly_msk]==i] for i in range(3)], density=True, bins=20, color=cmap([i for i in range(3)]), histtype='step', linewidth=2, zorder=3)
    ax.hist([pf[grid_actions_dict[label][msk][test_poly_msk]==i] for i in range(3)], density=True, bins=20, color=cmap([i for i in range(3)]), histtype='stepfilled', zorder=2, alpha=.3)
    ax.grid(zorder=1)
    #ax.set_title(label)
    ax.set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
    ax.set_ylabel('Frequency of action')
    ymax = max(ymax, ax.get_ylim()[1])
    dn = matplotlib.patches.Patch(label='Do nothing', edgecolor='k', facecolor=matplotlib.colors.to_rgba('k',.3), linewidth=2)
    rn = matplotlib.patches.Patch(label='Replace now', edgecolor='tab:red', facecolor=matplotlib.colors.to_rgba('tab:red',.3), linewidth=2)
    r10 = matplotlib.patches.Patch(label='Replace in 10', edgecolor='tab:blue', facecolor=matplotlib.colors.to_rgba('tab:blue',.3), linewidth=2)
    ax.legend(handles=[dn,rn,r10], loc='upper center')
    figs.append(fig)
    axes.append(ax)
for axi in axes:
    axi.set_ylim(0,ymax)
for i, label in enumerate(model_labels):
    figs[i].savefig(f'../results/figures/RL{run_id}/action-pr10dist-grid-{label}.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,1))
zj = rl_agent.support.cpu().numpy()
delta_z = float(rl_agent.v_max - rl_agent.v_min) / (rl_agent.atom_size - 1)
ax.bar(zj, metrics['neutral-mean']['dists'][0][-1,-1], width=delta_z, align='center', color='tab:cyan', edgecolor='w')
ax.axis('off')
plt.savefig('testd1.png', transparent=True)
plt.show()

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap, ListedColormap

In [None]:
lb = list(model_labels)[0]

In [None]:
#fig, ax = plt.subplots(3,4, subplot_kw={'projection': '3d'}, figsize=(12,8), layout='constrained')
fig, ax = plt.subplots(1,3, subplot_kw={'projection': '3d'}, figsize=(8,4), layout='constrained')
ax = ax.flatten()
x = np.arange(30)[::-1]
y = rl_agent.support.cpu().numpy()
XX, YY = np.meshgrid(x, y, indexing='ij')
for i in range(len(ax)):
    if i < 11:
        z = metrics[lb]['dists'][0][-30:,i]
        if i == 0:
            #cmap = plt.get_cmap('Greys')
            cmap = LinearSegmentedColormap.from_list('testcmap', list(zip([0,1],[plt.get_cmap('tab20c')((17)/20), plt.get_cmap('tab20c')((18)/20)])))
        else:
            cmap = LinearSegmentedColormap.from_list('testcmap', list(zip([0,1],[plt.get_cmap('tab20')((2*i)/20), plt.get_cmap('tab20')((2*i+1)/20)])))
        ax[i].bar3d(XX.flatten(), YY.flatten(), np.zeros(len(z.flatten())),1,np.diff(y)[0],z.flatten(), color=cmap((XX.flatten())/len(x)), alpha=0.8, edgecolor='k', linewidth=0.1)
        ax[i].invert_xaxis()
        #ax[i].view_init(elev=30, azim=45)
        ax[i].view_init(elev=20, azim=20)
        ax[i].get_zaxis().set_ticks([])
        ax[i].get_zaxis().line.set_linewidth(0)
        ax[i].w_xaxis.set_pane_color((1.,1.,1.,1.))
        ax[i].w_yaxis.set_pane_color((1.,1.,1.,1.))
        ax[i].grid(False)

    else:
        ax[i].axis('off')
plt.show()

In [None]:
cmap = ['grey', 'tab:red', 'tab:blue']
xrange = 50
yoffset = 95
fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, layout='tight', figsize=(12,6))
x = np.arange(xrange)[::-1]
y = rl_agent.support.cpu().numpy()
XX, YY = np.meshgrid(x, y, indexing='ij')
for i in range(3):
    z = metrics[lb]['dists'][0][-xrange:,i]
    ax.bar3d(XX.flatten(), YY.flatten()+yoffset*i, np.zeros(len(z.flatten())),1,np.diff(y)[0],z.flatten(), color=cmap[i], alpha=1, edgecolor='k', linewidth=.03, shade=True)
ax.invert_xaxis()
ax.set_yticks(np.concatenate([np.linspace(-60,0,4)+yoffset*i for i in range(3)]), labels=np.concatenate([np.linspace(-60,0,4, dtype=int) for i in range(3)]), rotation=0, va='center', ha='center')
ax.view_init(elev=15, azim=20)
ax.set_box_aspect(aspect=(6,8,2))
ax.set_xlabel('True RUL')
ax.set_ylabel('Reward Distribution')
ax.text(xrange, -81/2, .9, 'do nothing', (0,1,0),ha='center')
ax.text(xrange, -81/2+yoffset, .9, 'replace now', (0,1,0),ha='center', color='tab:red')
ax.text(xrange, -81/2+yoffset*2, .9, 'replace 10', (0,1,0),ha='center',color='tab:blue')
plt.savefig(f'../results/figures/RL{run_id}/3dreward-dist.png', facecolor='white', bbox_inches='tight')
plt.show()

In [None]:
i = 1
x = np.arange(30)[::-1]
y = rl_agent.support.cpu().numpy()
XX, YY = np.meshgrid(x, y, indexing='ij')
z = metrics[lb]['dists'][0][-30:,i+1]
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
cmap = LinearSegmentedColormap.from_list('testcmap', list(zip([0,1],[plt.get_cmap('tab20')((2*i+1)/20), plt.get_cmap('tab20')((2*i)/20)])))
ax.bar3d(XX.flatten(), YY.flatten(), np.zeros(len(z.flatten())),1,np.diff(y)[0],z.flatten(), color=cmap((30-XX.flatten())/len(x)), alpha=0.8)
ax.invert_xaxis()
ax.view_init(elev=30, azim=45)
plt.show()

In [None]:
Mus = torch.linspace(-6,6,300)
Sigs = torch.linspace(0.001, 1.5, 300)

grid_mu, grid_sig = torch.meshgrid(Mus, Sigs, indexing='ij')

In [None]:
pre_states = torch.vstack([grid_mu.flatten(), grid_sig.flatten()]).T.to(rl_agent.device).unsqueeze(1)

In [None]:
pre_states.shape

In [None]:
states = rl_agent.env._transform_states(pre_states)

In [None]:
states.shape

In [None]:
with torch.no_grad():
    eval = rl_agent.dqn(states)
    action = eval.argmax(1).cpu().numpy()
    cvar = eval.detach().cpu().numpy()
    dist = rl_agent.dqn.dist(states).cpu().numpy()

In [None]:
import matplotlib.colors

In [None]:
cmap = matplotlib.colors.ListedColormap(['k']+list(plt.get_cmap('tab10').colors))

In [None]:
fig, ax = plt.subplots()
im = ax.pcolormesh(np.exp(grid_mu), grid_sig, action.reshape((Mus.shape[0],Sigs.shape[0])), cmap=cmap, vmin=0, vmax=10)
ax.set_xlim(0,128)
ax.set_xlabel('$e^\mu$')
ax.set_ylabel('$\sigma$')
cbar = fig.colorbar(im, ax=ax, ticks=np.arange(0,11,1)+.5)
cbar.set_ticklabels(['do nothing']+[f'replace {i:02d}' for i in range(1,11)])
plt.show()

In [None]:
fig, ax = plt.subplots(3,4, layout='constrained', figsize=(12,8))
ax = ax.flatten()
for i in range(len(ax)):
    if i < 11:
        ax[i].set_title('do nothing' if i == 0 else f'k={i:02d}')
        ax[i].set_xlabel(r'$\mathrm{exp}(\mu)$')
        ax[i].set_ylabel(r'$\sigma$')
        im = ax[i].pcolormesh(np.exp(grid_mu), grid_sig, cvar[:,i].reshape((Mus.shape[0],Sigs.shape[0])), vmin=rl_agent.v_min, vmax=rl_agent.v_max, cmap='turbo')
        ax[i].set_xlim(0,128)
    else:
        ax[i].axis('off')
fig.colorbar(im, ax=ax.ravel().tolist(), label='CVaR')
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.scatter(np.exp(grid_mu + 0.5*grid_sig**2).flatten(), np.max(cvar, axis=1), s=1, alpha=0.3, color=cmap(np.argmax(cvar, axis=1)))
ax.set_xlim(0,150)
ax.set_xlabel('mean')
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.scatter(np.exp(grid_mu + 0.5*grid_sig**2).flatten(), (np.exp(2*grid_mu+grid_sig**2)*(np.exp(grid_sig**2)-1)).flatten(), s=5, alpha=1, color=cmap(np.argmax(cvar, axis=1)))
ax.set_xlim(0,50)
ax.set_ylim(0,50)
ax.set_xlabel('mean')
ax.set_ylabel('variance')
plt.show()

In [None]:
np.vstack([np.exp(grid_mu + 0.5*grid_sig**2).flatten(), (np.exp(2*grid_mu+grid_sig**2)*(np.exp(grid_sig**2)-1)).flatten()]).shape

In [None]:
action.shape

In [None]:
import scipy.interpolate


fig, ax = plt.subplots()
xs, ys = np.mgrid[0:50:999j, 0:50:666j]
resampled_action = scipy.interpolate.griddata((np.exp(grid_mu + 0.5*grid_sig**2).flatten(), (np.exp(2*grid_mu+grid_sig**2)*(np.exp(grid_sig**2)-1)).flatten()), action, (xs, ys) ,method='nearest')
im = ax.pcolormesh(xs,ys,resampled_action, cmap=cmap,vmin=0, vmax=10)
cbar = fig.colorbar(im, ax=ax, ticks=np.arange(0,11,1)+.5)
cbar.set_ticklabels(['do nothing']+[f'replace {i:02d}' for i in range(1,11)])
ax.set_xlabel('mean')
ax.set_ylabel('variance')
plt.show()

In [None]:
fig, ax = plt.subplots()
mask = np.asarray(np.logical_and(np.exp(grid_mu + 0.5*grid_sig**2).flatten() <= 150, (np.exp(2*grid_mu+grid_sig**2)*(np.exp(grid_sig**2)-1)).flatten()<=100))
pf = 0.5*(1+scipy.special.erf((np.log(10)-grid_mu)/(grid_sig*np.sqrt(2))))
pf = pf.flatten()
argcvar = np.argmax(cvar, axis=1)
ax.hist([pf[np.logical_and(argcvar == i, mask)] for i in range(11)], density=True, bins=20, histtype='step', color=cmap([i for i in range(11)]), alpha=1, linewidth=2, zorder=3, label=['do nothing' if i==0 else f'replace {i:02d}' for i in range(11)])
ax.hist([pf[np.logical_and(argcvar == i, mask)] for i in range(11)], density=True, bins=20, histtype='stepfilled', color=cmap([i for i in range(11)]), alpha=.3, zorder=2)
ax.grid(zorder=1)
cbar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ax=ax, ticks=(np.arange(0,11,1)+.5)/11)
cbar.set_ticklabels(['do nothing']+[f'replace {i:02d}' for i in range(1,11)])
cbar.ax.set_title('Action', loc='left')
ax.set_xlabel(r'$\mathrm{Pr}(\mathrm{RUL}\leq 10)$')
ax.set_ylabel('Frequency of action')
plt.show()