In [112]:
from datetime import datetime
import os

import heliopy.data.omni as omni
from matplotlib import pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow import keras

In [131]:
START_TIME = datetime(1995, 1, 1)
END_TIME = datetime(2018, 2, 28)

INPUT_LENGTH = 24

### Load in data

In [132]:
def get_omni_rtn_data(start_time, end_time):
    identifier = 'OMNI_COHO1HR_MERGED_MAG_PLASMA'  # COHO 1HR data
    omni_data = omni._omni(start_time, end_time, identifier=identifier, intervals='yearly', warn_missing_units=False)
    return omni_data

In [133]:
data = get_omni_rtn_data(START_TIME, END_TIME).to_dataframe()

Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 1995-01-01 00:00:00 - 1996-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 1996-01-01 00:00:00 - 1997-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 1997-01-01 00:00:00 - 1998-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 1998-01-01 00:00:00 - 1999-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 1999-01-01 00:00:00 - 2000-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2000-01-01 00:00:00 - 2001-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2001-01-01 00:00:00 - 2002-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2002-01-01 00:00:00 - 2003-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2003-01-01 00:00:00 - 2004-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2004-01-01 00:00:00 - 2005-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2005-01-01 00:00:00 - 2006-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2006-01-01 00:00:00 - 2007-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2007-01-01 00:00:00 - 2008-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2008-01-01 00:00:00 - 2009-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2009-01-01 00:00:00 - 2010-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2010-01-01 00:00:00 - 2011-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2011-01-01 00:00:00 - 2012-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2012-01-01 00:00:00 - 2013-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2013-01-01 00:00:00 - 2014-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2014-01-01 00:00:00 - 2015-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2015-01-01 00:00:00 - 2016-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Downloading OMNI_COHO1HR_MERGED_MAG_PLASMA for interval 2016-01-01 00:00:00 - 2017-01-01 00:00:00


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [139]:
mag_field_strength = np.array(data["BR"])

### Split into 24-hour sections 

In [154]:
lstm_inputs = np.array([mag_field_strength[i:i + 24] for i in range(len(mag_field_strength) - 24)])[:, :, np.newaxis]
lstm_outputs = np.array(mag_field_strength[24:])

nan_check = np.array([mag_field_strength[i:i + 25] for i in range(len(mag_field_strength) - 25)])

lstm_inputs = lstm_inputs[np.where([~np.any(np.isnan(i)) for i in nan_check])]
lstm_outputs = lstm_outputs[np.where([~np.any(np.isnan(i)) for i in nan_check])]

print("Input shape:", lstm_inputs.shape)
print("Output shape:", lstm_outputs.shape)

print("Any Nans?:", np.any(np.isnan(lstm_outputs)) or np.any(np.isnan(lstm_inputs)))

Input shape: (201388, 24, 1)
Output shape: (201388,)
Any Nans?: False


### Split into train/val/test

In [155]:
lstm_inputs_train, lstm_inputs_val_test, lstm_outputs_train, lstm_outputs_val_test = train_test_split(lstm_inputs, 
                                                                                                      lstm_outputs,
                                                                                                      random_state=42,
                                                                                                      test_size=0.33)

lstm_inputs_val, lstm_inputs_test, lstm_outputs_val, lstm_outputs_test = train_test_split(lstm_inputs_val_test, 
                                                                                          lstm_outputs_val_test,
                                                                                          random_state=42,
                                                                                          test_size=0.33)

print("Train size:", len(lstm_inputs_train))
print("Val size:", len(lstm_inputs_val))
print("Test size:", len(lstm_inputs_test))

Train size: 134929
Val size: 44527
Test size: 21932


In [164]:
# Switch to actually doing it as splitting train/val/test into time-based 
lstm_inputs_train, lstm_outputs_train = lstm_inputs[:134929], lstm_outputs[:134929]
lstm_inputs_val, lstm_outputs_val = lstm_inputs[134929:134929 + 44527], lstm_outputs[134929:134929 + 44527]
lstm_inputs_test, lstm_outputs_test = lstm_inputs[134929 + 44527:], lstm_outputs[134929 + 44527:]

print("Train size:", len(lstm_inputs_train))
print("Val size:", len(lstm_inputs_val))
print("Test size:", len(lstm_inputs_test))

Train size: 134929
Val size: 44527
Test size: 21932


### LSTM!

In [165]:
model = keras.models.Sequential(
    [
        keras.layers.LSTM(20, activation="linear", name="lstm_initial", input_shape=(None, 1)),
        keras.layers.Dense(1, name="dense_final", activation="linear"),
    ]
)
model.summary()

Model: "sequential_22"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_initial (LSTM)          (None, 20)                1760      
_________________________________________________________________
dense_final (Dense)          (None, 1)                 21        
Total params: 1,781
Trainable params: 1,781
Non-trainable params: 0
_________________________________________________________________


In [166]:
model.compile(optimizer="rmsprop", loss="mse", metrics=["mae"])
model.fit(lstm_inputs_train, lstm_outputs_train, validation_data=(lstm_inputs_val, lstm_outputs_val),
          batch_size=2048, epochs=30, 
          callbacks=keras.callbacks.EarlyStopping(restore_best_weights=True, patience=30))

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<tensorflow.python.keras.callbacks.History at 0x11ac259b0>