# Setup/Imports

In [1]:
import pickle
import keras
import tensorflow as tf
from keras import backend as K
import numpy as np
import sys
import os
sys.path.append(os.path.abspath('../'))
from helpers.data_generator import process_data, DataGenerator
from helpers.custom_losses import denorm_loss, hinge_mse_loss
from helpers.custom_losses import percent_correct_sign, baseline_MAE
from models.LSTMConv2D import get_model_lstm_conv2d, get_model_simple_lstm
from models.LSTMConv2D import get_model_linear_systems, get_model_conv2d
from utils.callbacks import CyclicLR, TensorBoardWrapper
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from time import strftime, localtime
import matplotlib
from matplotlib import pyplot as plt
import copy
from tqdm import tqdm_notebook

from helpers.normalization import normalize, denormalize, renormalize
#import tkinter as tk
#from tkinter import filedialog
#root = tk.Tk()
#root.withdraw()

Using TensorFlow backend.


In [2]:
num_cores = 4
config = tf.ConfigProto(intra_op_parallelism_threads=num_cores,
                        inter_op_parallelism_threads=num_cores, 
                        allow_soft_placement=True,
                        device_count = {'CPU' : 1,
                                        'GPU' : 0})
session = tf.Session(config=config)
K.set_session(session)

In [3]:
%matplotlib inline
font={'family': 'Times New Roman',
      'size': 10}
plt.rc('font', **font)
matplotlib.rcParams['figure.facecolor'] = (1,1,1,1)

# Functions

In [79]:
def hinge_loss(delta_true, delta_pred, threshold):
    """Finds the percentage of the time the prediction has the right sign
    """
    if isinstance(delta_true, dict):
        out = {}
        for k in delta_true.keys():
            if np.mean(np.abs(delta_true[k]))>threshold:
                out[k] = np.sum(np.maximum(np.sign(delta_true[k]*delta_pred[k]),0))/delta_true[k].size
            else:
                out[k] = np.nan
        return out
    if np.mean(np.abs(delta_true))>threshold:
        return np.sum(np.maximum(np.sign(delta_true*delta_pred),0))/delta_true.size
    else:
        return np.nan

def batch_hinge(model,generator,param_dict,sig):
    """Finds the percentage of the time the prediction has the right sign
    """    
    nbatches = len(generator)
    err = 0
    for ind in range(nbatches):
        inputs, targets = generator[ind]
        predictions = model.predict_on_batch(inputs)
        full_pred, full_true, delta_pred, delta_true, baseline = batch_denorm(inputs,targets,predictions,param_dict,sig)
        err += hinge_loss(delta_true,delta_pred)
    return err/nbatches    

def get_deltas(inputs, targets, profiles, predictions=None):
    delta_true = {}
    for profile in profiles:
        delta_true['target_' + profile] = targets['target_' + profile] - inputs['input_' + profile]
    if predictions is not None:
        delta_pred = {}
        for profile in profiles:
                delta_pred['target_' + profile] = predictions['target_' + profile] - inputs['input_' + profile]
        return delta_true, delta_pred
    else:
        return delta_true
                
                
                
                
def batch_denorm(inputs,targets,predictions,param_dict,sig):
    """Denormalizes and calculates deltas, prep for plotting and analysis
    """
    targets = targets['target_' + sig]
    baseline = inputs['input_' + sig][:,-1]
    predictions = predictions[sig]
    if predict_deltas:
        full_pred = predictions + baseline
        full_true = targets + baseline
    else:
        full_pred = predictions
        full_true = targets
    denorm_baseline = denormalize(baseline, param_dict[sig])
    denorm_full_pred = denormalize(full_pred, param_dict[sig])
    denorm_full_true = denormalize(full_true, param_dict[sig])
    denorm_delta_pred = denorm_full_pred - denorm_baseline
    denorm_delta_true = denorm_full_true - denorm_baseline
    return denorm_full_pred, denorm_full_true, denorm_delta_pred, denorm_delta_true, denorm_baseline

def plot_batch(y_true, y_pred, baseline, psi, labels, axlabels,shots,times):
    batch_size = y_true.shape[0]
    ncols = 4
    nrows = int(batch_size/ncols)
    figsize = (20,10*batch_size/10)
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    ax = ax.flatten()
    # Plot predictions and true deltas
    for i in range(batch_size):
        ax[i].plot(psi,y_true[i], label=labels[0])
        ax[i].plot(psi,y_pred[i], label=labels[1])
        ax[i].plot(psi,baseline[i], label=labels[2])
        ax[i].title.set_text('Shot ' + str(int(shots[i])) + ', time ' + str(int(times[i])))
        ax[i].set_xlabel(axlabels[0])
        ax[i].set_ylabel(axlabels[1])
        ax[i].set_ylim(-2*normalization_params[sig]['std'],2*normalization_params[sig]['std'])
        ax[i].legend(loc=0)
    plt.tight_layout()
    return fig,ax

def set_future_actuators(inputs, actuator_names, change):
    for sig in actuator_names:
        inputs['input_future_' + sig][:] = inputs['input_past_' + sig][:,-1] + change
    return inputs

# Load Model & Get Data

In [None]:
#file_path='/global/cscratch1/sd/abbatej/run_results/model-conv1d_profiles-thomson_temp_EFITRT1-thomson_dens_EFITRT1_act-pinj_15L-pinj_15R-pinj_21L-pinj_21R-pinj_30L-pinj_30R-pinj_33L-pinj_33R-tinj-curr-gasA_targ-temp-dens_profLB-1_actLB-8_norm-RobustScaler_activ-relu_nshots-10000_15Aug19-12-07.h5'
#file_path='/global/cscratch1/sd/abbatej/run_results/model-conv1d_profiles-thomson_temp_EFITRT1-thomson_dens_EFITRT1_act-pinj-tinj-curr-gasA_targ-temp-dens_profLB-1_actLB-8_norm-RobustScaler_activ-relu_nshots-1000_15Aug19-07-05.h5'
file_path='/global/homes/a/abbatej/plasma-profile-predictor/model-conv1d_profiles-thomson_dens_EFITRT1-thomson_temp_EFITRT1_act-pinj-curr-tinj-gasA_targ-temp-dens_profLB-1_actLB-10_norm-StandardScaler_activ-relu_nshots-50_28Aug19-15-08.h5'
#file_path = '/global/cscratch1/sd/abbatej/run_results/model-conv2d_profiles-temp-dens_act-pinj-curr-tinj-gasA_targ-temp-dens_profLB-1_actLB-8_norm-StandardScaler_activ-relu_nshots-10000_10Aug19-12-29.h5'
#file_path = '/global/cscratch1/sd/abbatej/run_results/model-conv2d_profiles-temp-dens_act-pinj-curr-tinj-gasA_targ-temp-dens_profLB-1_actLB-8_norm-StandardScaler_activ-relu_nshots-10000_11Aug19-12-26.h5'
#Rory small: file_path = '/global/cscratch1/sd/abbatej/run_results/model-conv2d_profiles-temp-dens_act-pinj-curr-tinj-gasA_targ-temp-dens_profLB-1_actLB-8_norm-StandardScaler_activ-relu_nshots-1000_11Aug19-14-01.h5'
model = keras.models.load_model(file_path, compile=False)
print('loaded model: ' + file_path.split('/')[-1])
file_path = file_path[:-3] + 'params.pkl'
with open(file_path, 'rb') as f:
     analysis_params = pickle.load(f, encoding='latin1')
print('loaded dict: ' + file_path.split('/')[-1])
print('with parameters: ' + str(analysis_params.keys()))


In [None]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
SVG(model_to_dot(model,show_shapes=True,show_layer_names=True,rankdir='TB').create(prog='dot', format='svg'))

# Sensitivity Analysis

In [None]:
shots = [173649]
times = [6000]
inputs, targets = train_generator.get_data_by_shot_time(shots,times)
input_const = set_future_actuators(inputs, actuator_names, 0)
pred_const = model.predict_on_batch(input_const)
input_increasing = set_future_actuators(inputs, actuator_names, 0.5)
pred_increasing = model.predict_on_batch(input_increasing)
input_decreasing = set_future_actuators(inputs, actuator_names, -0.5)
pred_decreasing = model.predict_on_batch(input_decreasing)

In [None]:
inputs['input_past_pinj']

# Full Run Processing

In [65]:
base_path = os.path.expanduser('~/run_results/')
files = os.listdir(base_path)
model_files = [file for file in files if 'h5' in file]
param_files = [file[:-3] + '_params.pkl' for file in model_files]
num_models = len(model_files)
print('number of models: ' + str(num_models))

number of models: 102


In [66]:
models = []
params = []
for modelname, paramsname in tqdm_notebook(zip(model_files,param_files), total=num_models):
    models.append(keras.models.load_model(base_path+modelname, compile=False))
    with open(base_path+paramsname, 'rb') as f:
         params.append(pickle.load(f, encoding='latin1'))

HBox(children=(IntProgress(value=0, max=102), HTML(value='')))




In [91]:
sign_acc = {'temp': [], 'dens':[]}
threshold = .1
ftop_filepath_base = '/scratch/gpfs/jabbate/data_60_ms_randomized_flattop/'
for model, scenario in tqdm_notebook(zip(models,params), total=num_models, position=0):
    with open(os.path.join(ftop_filepath_base, 'val.pkl'), 'rb') as f:
        valdata = pickle.load(f)
    val_generator = DataGenerator(valdata,
                                  1,
                                  scenario['input_profile_names'],
                                  scenario['actuator_names'],
                                  scenario['target_profile_names'],
                                  scenario['scalar_input_names'],
                                  scenario['lookbacks'],
                                  scenario['lookahead'],
                                  scenario['predict_deltas'],
                                  scenario['profile_downsample'],
                                  scenario['shuffle_generators'])
    temp_sign_acc = []
    dens_sign_acc = []
    for i in tqdm_notebook(range(int(len(val_generator)/100)), position=1, leave=False):
        inputs, targets = val_generator[i]
        pred = model.predict_on_batch(inputs)
        predictions = {}
        for i, sig in enumerate(scenario['target_profile_names']):
            predictions['target_'+sig] = pred[i]
        if not scenario['predict_deltas']:
            delta_true, delta_pred = get_deltas(inputs, targets, scenario['target_profile_names'], predictions)
        else:
            delta_true = targets
            delta_pred = predictions
        foo = hinge_loss(delta_true,delta_pred, threshold)
        temp_sign_acc.append(foo['target_temp'])
        dens_sign_acc.append(foo['target_dens'])
    sign_acc['temp'].append(np.nanmean(temp_sign_acc))
    sign_acc['dens'].append(np.nanmean(dens_sign_acc))

        
        
    
    
    

HBox(children=(IntProgress(value=0, max=102), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

In [94]:
np.max(sign_acc['dens'])

0.7975265475265475

In [95]:
np.argmax(sign_acc['dens'])

12

In [96]:
params[12]

{'model_type': 'conv2d',
 'epochs': 100,
 'model_kwargs': {'max_channels': 64},
 'actuator_names': ['pinj', 'curr', 'tinj', 'gasA'],
 'scalar_input_names': [],
 'flattop_only': True,
 'input_profile_names': ['temp', 'dens'],
 'target_profile_names': ['temp', 'dens'],
 'batch_size': 128,
 'predict_deltas': False,
 'processed_filename_base': '/scratch/gpfs/jabbate/data_60_ms_randomized_flattop/',
 'normalization_dict': {'curr': {'nanmean': 843946.2390728785,
   'method': 'RobustScaler',
   'median': 993414.8958333334,
   'iqr': 404358.2291666666},
  'dens': {'nanmean': array([5.2486587 , 5.2453074 , 5.2368107 , 5.223216  , 5.2056646 ,
          5.1847086 , 5.1608973 , 5.1348605 , 5.106694  , 5.0769434 ,
          5.0455556 , 5.012879  , 4.9789305 , 4.943857  , 4.9077783 ,
          4.8707128 , 4.8328533 , 4.794133  , 4.7547617 , 4.7147164 ,
          4.674113  , 4.632993  , 4.5914083 , 4.549442  , 4.507101  ,
          4.4644856 , 4.4216003 , 4.3785224 , 4.335278  , 4.2919135 ,
         

In [None]:

sensitivity = []
ftop_filepath_base = '/scratch/gpfs/jabbate/data_60_ms_randomized_flattop/'
for model, scenario in tqdm_notebook(zip(models,params), total=num_models, position=0):
    with open(os.path.join(ftop_filepath_base, 'val.pkl'), 'rb') as f:
        valdata = pickle.load(f)
    val_generator = DataGenerator(valdata,
                                  1,
                                  scenario['input_profile_names'],
                                  scenario['actuator_names'],
                                  scenario['target_profile_names'],
                                  scenario['scalar_input_names'],
                                  scenario['lookbacks'],
                                  scenario['lookahead'],
                                  scenario['predict_deltas'],
                                  scenario['profile_downsample'],
                                  scenario['shuffle_generators'])
    delta_actuators = [-1.5, -1, -.5, 0, .5, 1, 1.5]
    predictions_pert = {sig: {act:{d:[] for d in delta_actuators} for act in scenario['actuator_names']} for sig in scenario['target_profile_names']}
    for i in tqdm_notebook(range(int(len(val_generator)/100)), position=1, leave=False):
        inputs, targets = val_generator[i]
        input_const = set_future_actuators(inputs, scenario['actuator_names'], 0)
        pred_const = model.predict_on_batch(input_const)
        predictions_const = {}
        for i, sig in enumerate(scenario['target_profile_names']):
            predictions_const[sig] = pred_const[i]
        for act in scenario['actuator_names']:
            for d in delta_actuators:
                input_pert = set_future_actuators(input_const, [act], d)
                pred_pert = model.predict_on_batch(input_pert)
                for i, sig in enumerate(scenario['target_profile_names']):
                    predictions_pert[sig][act][d].append(np.mean(pred_pert[i] - predictions_const[sig], axis=-1))
    for sig in scenario['target_profile_names']:
        for act in scenario['actuator_names']:
            for d in delta_actuators:
                predictions_pert[sig][act][d] = {'mean': np.mean(predictions_pert[sig][act][d]),'std': np.std(predictions_pert[sig][act][d])}
    sensitivity.append(predictions_pert) 
    


HBox(children=(IntProgress(value=0, max=102), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))

In [67]:
scenarios = [int(foo.split('_')[-1].split('-')[-1].split('.')[0]) for foo in model_files]

# Plot Predictions vs True

In [None]:
sig='itemp'
all_means=[]
all_stds=[]
for shot in data.keys():
    if sig in data[shot]:
        data[shot][sig][np.isinf(data[shot][sig])] = np.nan
        this_mean=np.nanmean(data[shot][sig])
        this_std=np.nanstd(data[shot][sig])
        if False: #abs(this_mean)>4000:
            print(shot)
            print(np.mean(all_means))
            print(this_mean)
            print('\n')
            print(np.std(all_stds))
            print(this_std)
            print('\n')
        else:
            all_means.append(this_mean)
            all_stds.append(this_std)
plt.hist(all_means)

In [None]:
shots=list(data.keys())
shot=np.random.choice(shots)
plt.plot(data[shot]['curr_target'])
inds=np.where(np.isclose(data[shot]['curr_target'],max(data[shot]['curr_target'])))[0]
plt.axvline(inds[0])
plt.axvline(inds[-1])
plt.show()

In [None]:
analysis_params['batch_size']=1
train_generator = DataGenerator(traindata,**analysis_params)
steps_per_epoch = len(train_generator)

In [None]:
batch_ind+=1
inputs, targets = train_generator[batch_ind]
shotnums = train_generator.cur_shotnum[:,max(profile_lookback, actuator_lookback)]
times = train_generator.cur_times[:,max(profile_lookback, actuator_lookback)]
pred = model.predict_on_batch(inputs)
# # add deltas to future actuators
# #predict on new actuators
# predictions = {}
# for i, sig in enumerate(target_profile_names):
#     predictions[sig] = pred[i]
# psi = np.linspace(0,1,profile_length)
# full_pred = {}
# full_true = {}
# delta_pred = {}
# delta_true = {}
# baseline = {}
# for sig in target_profile_names:
#     full_pred[sig], full_true[sig], delta_pred[sig], delta_true[sig], baseline[sig] = batch_denorm(inputs,targets,predictions,
#                                                                                                    normalization_params,sig)

# Design your own shot

In [None]:
plt.rcParams['figure.figsize'] = [20, 15]

In [None]:
all_shots=np.unique(traindata['shotnum'])
shotnum=np.random.choice(all_shots) #all_shots[5]
t=4700

psi = np.linspace(0,1,profile_length)
sigs=['curr','pinj','tinj','gasA','temp','dens']
profiles=['temp','dens']
input_profiles=['thomson_temp_EFITRT1','thomson_dens_EFITRT1']
changed_input_keys=['pinj']
changed_profile_keys=[] #'thomson_dens_EFITRT1']

shot_indices=np.where(traindata['shotnum'][:,0]==shotnum)[0]

time_offset=np.searchsorted(traindata['time'][shot_indices,-lookahead],t)


#prev_targets=train_generator[shot_indices[0]+time_offset-1][1]

inputs=train_generator[shot_indices[0]+time_offset][0]

changed_input=copy.deepcopy(inputs)
for key in changed_input_keys:
    #changed_input['input_past_{}'.format(key)]=np.array([np.linspace(-2,0,actuator_lookback)])
    #changed_input['input_future_{}'.format(key)]-=[0,.5,1,1.5]
    
    prev=traindata[key][shot_indices[0]+time_offset][-lookahead]
    changed_input['input_future_{}'.format(key)]=np.array([[prev]*lookahead])
    changed_input['input_future_{}'.format(key)]=np.array([np.linspace(prev,prev-2,lookahead)]) 
    #prev=traindata[key][shot_indices[0]+time_offset][-1]#0]
    #changed_input['input_past_{}'.format(key)]=np.array([[prev]*actuator_lookback])

for key in changed_profile_keys:
    #changed_input['input_{}'.format(key)]=np.array([[[0]*profile_length]])
    changed_input['input_{}'.format(key)]-=np.array([[np.linspace(0,1,profile_length)]])
    
targets=train_generator[shot_indices[0]+time_offset][1]

num_cols=4

fig=plt.figure()
for i,sig in enumerate(sigs):
    ax=fig.add_subplot(len(sigs),num_cols,i*num_cols+1)
    if len(traindata[sig][shot_indices].shape)>2:
        ax.plot(traindata['time'][shot_indices,-lookahead],
                np.mean(traindata[sig][shot_indices,-lookahead,:],axis=1),
                c='b')
#         ax.contourf(traindata['time'][shot_indices,-lookahead],
#                 psi,
#                 traindata[sig][shot_indices,-lookahead,::analysis_params['profile_downsample']].T)
        ax.set_ylim(-2,2)
        
    else:
        ax.plot(traindata['time'][shot_indices[0],:-lookahead],
                        traindata[sig][shot_indices[0],:-lookahead],
               c='b')
        ax.plot(traindata['time'][shot_indices,-lookahead],
                        traindata[sig][shot_indices,-lookahead],
               c='b')
        if sig in changed_input_keys:
            ax.plot(traindata['time'][shot_indices[0]+time_offset,-lookahead:],
                   changed_input['input_future_{}'.format(sig)].squeeze(),
                   color='r')
            ax.plot(traindata['time'][shot_indices[0]+time_offset,:-lookahead],
                   changed_input['input_past_{}'.format(sig)].squeeze(),
                   color='r')
        if sig=='curr':
            ax.set_ylim(-6,6)
            N=1
            smoothed=np.convolve(traindata[sig][shot_indices,-lookahead], np.ones((N,))/N, mode='valid') 
            #smoothed=traindata[sig][shot_indices,-lookahead][::N]
            #flattop_start_ind=np.where(np.isclose(np.diff(smoothed),0,atol=2e-3))[0][0]
            flattop_inds=np.where(np.isclose(smoothed,np.max(smoothed),atol=1e-3))[0]
            ax.axvline(traindata['time'][flattop_inds[0],-lookahead],linewidth=4)
            ax.axvline(traindata['time'][flattop_inds[-1],-lookahead],linewidth=4)
        else:
            ax.set_ylim(-2,2)
        
    
    ax.axvline(traindata['time'][shot_indices[0]+time_offset,-lookahead],
              color='r')
    ax.axvline(traindata['time'][shot_indices[0]+time_offset,-1],
              color='g')
    ax.axvline(traindata['time'][shot_indices[0]+time_offset,0],
              color='k')
    ax.set_xlim(0,np.amax(traindata['time'][shot_indices]))
    ax.set_title(sig)

# truths=train_generator[shot_indices[0]+time_offset][1]

for i,profile in enumerate(input_profiles):
    ax=fig.add_subplot(len(input_profiles),num_cols,i*num_cols+1+1)
    
    ax.plot(psi,inputs['input_{}'.format(profile)].squeeze(),c='g')
    if profile in changed_profile_keys:
        ax.plot(psi,changed_input['input_{}'.format(profile)].squeeze(),c='r')
    ax.set_title('{}'.format(profile))
    ax.axhline(0,color='k',alpha=.5)
    ax.set_ylim(-2,2)

predictions=model.predict_on_batch(inputs)
changed_predictions=model.predict_on_batch(changed_input)
for i,profile in enumerate(profiles):
    ax=fig.add_subplot(len(profiles),num_cols,i*num_cols+1+1+1)
    ax.plot(psi,targets['target_{}'.format(profile)].squeeze(),label='True',c='g')
    ax.plot(psi,predictions[i].squeeze(),label='Prediction',
            c='b')
    ax.plot(psi,changed_predictions[i].squeeze(),color='r',label='Perturbed Prediction')
    ax.set_title('{} change'.format(profile))
    ax.axhline(0,color='k',alpha=.5)
    ax.set_ylim(-2,2)
    ax.legend()
    
    ax=fig.add_subplot(len(profiles),num_cols,i*num_cols+1+1+1+1)
    ax.plot(psi,
            traindata[profile][shot_indices[0]+time_offset][0][::analysis_params['profile_downsample']],
            label='Previous',
            c='k')
    ax.plot(psi,
            traindata[profile][shot_indices[0]+time_offset][-1][::analysis_params['profile_downsample']],
            label='Next',
            c='g')
    ax.plot(psi,
            traindata[profile][shot_indices[0]+time_offset][0][::analysis_params['profile_downsample']]+predictions[i].squeeze(),
            label='Prediction',
            c='b')
    ax.plot(psi,
            traindata[profile][shot_indices[0]+time_offset][0][::analysis_params['profile_downsample']]+changed_predictions[i].squeeze(),
            label='Perturbed Prediction',
            c='r')
    ax.legend()
    ax.set_title(profile)
    ax.set_ylim(-2,2)
    ax.axhline(0,color='k',alpha=.5)
    
time=train_generator.cur_times[:,max(profile_lookback, actuator_lookback)].squeeze()
fig.suptitle('Shot {:.0f}, {:.0f} ms'.format(shotnum,time),fontsize=30)
plt.subplots_adjust(hspace=.2)
plt.show()

In [None]:
sig = 'temp'
fig, ax = plot_batch(full_true[sig],full_pred[sig],baseline[sig],
                     psi,['true','pred','baseline'],['psi',sig], shotnums,times)  

In [None]:
sig = 'dens'
fig, ax = plot_batch(full_true[sig],full_pred[sig],baseline[sig],
                     psi,['true','pred','baseline'],['psi',sig], shotnums,times)  

In [None]:
sig = 'press'
fig, ax = plot_batch(full_true[sig],full_pred[sig],baseline[sig],
                     psi,['true','pred','baseline'],['psi',sig], shotnums,times)  