In [2]:
import numpy as np
import os
import torch
from misc.example_helper import *
import importlib
import processing
import models
%load_ext autoreload
%autoreload 2
importlib.reload(processing)
importlib.reload(models)
from processing.parametric_data_manager import ParametricSHREDDataManager
from models.shred_models import SHRED

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import os

# Initialize ParametricSHREDDataManager
manager = ParametricSHREDDataManager(
    lags = 20,
    train_size = 0.7,
    val_size = 0.15,
    test_size = 0.15,
    scaling = "minmax",
    compression = 20,
    time=np.arange(0, 201),
    )

# Add data to manager (with sensors)
dataset = np.load('kuramoto_sivashinsky\KuramotoSivashinsky_data.npz')
data = dataset['u'] # shape (500, 201, 100)
mu = dataset['mu'] # shape (500, 201, 2)
mobile_sensors = [
    forward_backward_walk(start=0, end = data.shape[2], timesteps=data.shape[1], forward_first=True),
    forward_backward_walk(start=0, end = data.shape[2], timesteps=data.shape[1], forward_first=False),
]

manager.add(
    data=data,
    random_sensors=2,
    stationary_sensors=[(0,), (1,)],
    # mobile_sensors=mobile_sensors,
    params=mu,
)


sensor summary                       sensor type location/trajectory
0  stationary (randomly selected)               (17,)
1  stationary (randomly selected)               (55,)
2      stationary (user selected)                (0,)
3      stationary (user selected)                (1,)
6
compressed full_state_data: (70350, 20)
done generating dataset


In [4]:
# Get train/valid/test datasets
train_set, valid_set, test_set = manager.preprocess()

# Print dataset shapes
print('Data Shapes:')
print ('Reconstructor Data')
print('train X:', train_set.reconstructor.X.shape)
print('train Y:', train_set.reconstructor.Y.shape)
print('valid X:', valid_set.reconstructor.X.shape)
print('valid Y:', valid_set.reconstructor.Y.shape)
print('test X:', test_set.reconstructor.X.shape)
print('test Y:', test_set.reconstructor.Y.shape)

Data Shapes:
Reconstructor Data
train X: torch.Size([70350, 21, 6])
train Y: torch.Size([70350, 20])
valid X: torch.Size([15075, 21, 6])
valid Y: torch.Size([15075, 20])
test X: torch.Size([15075, 21, 6])
test Y: torch.Size([15075, 20])


In [18]:
train_set.reconstructor.X[201,:,0]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.7233])

In [None]:
# initialize SHRED
shred = SHRED(sequence='LSTM', decoder='SDN')
# fit SHRED
shred.fit(train_set, valid_set, num_epochs=10)

Newest Version
looking good
looking good

Fitting Reconstructor...


Epoch 1/10: 100%|██████████| 1100/1100 [00:11<00:00, 95.36batch/s, loss=0.00928, L2=0.203, val_loss=0.00417, val_L2=0.142] 
Epoch 2/10: 100%|██████████| 1100/1100 [00:10<00:00, 102.60batch/s, loss=0.00455, L2=0.147, val_loss=0.00345, val_L2=0.129]
Epoch 3/10: 100%|██████████| 1100/1100 [00:10<00:00, 101.64batch/s, loss=0.00357, L2=0.131, val_loss=0.00268, val_L2=0.114]
Epoch 4/10: 100%|██████████| 1100/1100 [00:10<00:00, 100.92batch/s, loss=0.00305, L2=0.121, val_loss=0.0024, val_L2=0.108]
Epoch 5/10: 100%|██████████| 1100/1100 [00:11<00:00, 97.78batch/s, loss=0.0027, L2=0.114, val_loss=0.00196, val_L2=0.0972]
Epoch 6/10: 100%|██████████| 1100/1100 [00:10<00:00, 104.50batch/s, loss=0.0024, L2=0.107, val_loss=0.00178, val_L2=0.0926]
Epoch 7/10: 100%|██████████| 1100/1100 [00:10<00:00, 105.76batch/s, loss=0.00216, L2=0.101, val_loss=0.00162, val_L2=0.0883]
Epoch 8/10: 100%|██████████| 1100/1100 [00:10<00:00, 106.13batch/s, loss=0.002, L2=0.0975, val_loss=0.00165, val_L2=0.0893]
Epoch 9/1

{'Reconstructor Validation Errors': array([0.14175616, 0.12896176, 0.113763  , 0.10759839, 0.09716538,
        0.09256653, 0.08825097, 0.0892787 , 0.08320087, 0.08591171],
       dtype=float32),
 'Forecaster Validation Errors': None}

In [19]:
prediction = shred._reconstructor(test_set.reconstructor.X).detach().cpu().numpy()
prediction.shape

(15075, 20)