In [1]:
from pathlib import Path
import time
import torch
import numpy as np
import pandas as pd
from dataset_numba import WHDataset, CSTRDataset_numba

from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
from transformer_sim import Config,TSTransformer
import metrics


import nonlinear_benchmarks
from nonlinear_benchmarks.error_metrics import RMSE

In [2]:
# Fix all random sources to make script fully reproducible
torch.manual_seed(452)
np.random.seed(55)
system_seed = 66 # Controls the system generation
data_seed = 0 # Controls the input generation

In [3]:
# Overall settings
out_dir_name = "models"

# System settings
nu = 1
ny = 1
#seq_len = 600
batch_size = 320

fixed_system = False # Are we testing on a fixed system?

# Compute settings
no_cuda = False
threads = 5
compile = False

# Configure compute
torch.set_num_threads(threads) 
use_cuda = not no_cuda and torch.cuda.is_available()
device_name  =  "cpu"
device = torch.device(device_name)
device_type = 'cpu' # for later use in torch.autocast
torch.set_float32_matmul_precision("high")

In [4]:
# Create out dir
out_dir = Path(out_dir_name)
exp_data = torch.load(out_dir / "resume_16k_prbs_40_rep25_last.pt", map_location=device) # fine-tune on WH systems
exp_data2 = torch.load(out_dir / "resume_16k_CSTR_40_rep25_last.pt", map_location=device) # fine-tune on CSTR systems
exp_data3 = torch.load(out_dir / "scratch_16k_CSTR_200_rep25_4100.pt", map_location=device) # scratch on CSTR systems
exp_data4 = torch.load(out_dir / "ckpt_16000_400skip_RNNpatch.pt", map_location=device) # zero-shot

cfg = exp_data["cfg"]
# For compatibility with initial experiment without seed
try:
    cfg.seed
except AttributeError:
    cfg.seed = None

In [5]:
exp_data["iter_num"]

991000

In [6]:
cfg.seq_len_ctx

16000

In [7]:
model_args = exp_data["model_args"]
conf = Config(**model_args)
model = TSTransformer(conf).to(device)
# model = TSTransformer_paper(conf).to(device)/
model.load_state_dict(exp_data["model"]);
# cfg.seed +=1

40


In [9]:
model_args2 = exp_data2["model_args"]
conf2 = Config(**model_args2)
model2 = TSTransformer(conf2).to(device)
# model = TSTransformer_paper(conf).to(device)/
model2.load_state_dict(exp_data2["model"]);
# cfg.seed +=1

40


In [10]:
model_args3 = exp_data3["model_args"]
conf3 = Config(**model_args3)
model3 = TSTransformer(conf3).to(device)
# model = TSTransformer_paper(conf).to(device)/
model3.load_state_dict(exp_data3["model"]);
# cfg.seed +=1

40


In [10]:
model_args4 = exp_data4["model_args"]
conf4 = Config(**model_args4)
model4 = TSTransformer(conf4).to(device)
# model = TSTransformer_paper(conf).to(device)/
model4.load_state_dict(exp_data4["model"]);
# cfg.seed +=1

40


In [8]:
# Create data loader
lin_opts = dict(mag_range=cfg.mag_range, phase_range=cfg.phase_range, strictly_proper=True)
# if out_dir_name[-5:] == 'query':
# test_ds = WHDataset(nx=cfg.nx, nu=cfg.nu, ny=cfg.ny, 
#                                         seq_len=cfg.seq_len_ctx+cfg.seq_len_skip+cfg.seq_len_n_in+cfg.seq_len_new,
#                         system_seed=cfg.seed, input_seed=cfg.seed+1, noise_seed=cfg.seed+2,
#                         **lin_opts)
# else:
# test_ds = WHDataset(nx=cfg.nx, nu=cfg.nu, ny=cfg.ny, 
#                                         seq_len=cfg.seq_len_ctx+cfg.seq_len_skip+cfg.seq_len_n_in+cfg.seq_len_new,
#                         system_seed=cfg.seed, input_seed=cfg.seed+1, noise_seed=cfg.seed+2,
#                         **lin_opts)
test_ds = CSTRDataset_numba(seq_len=cfg.seq_len_ctx+cfg.seq_len_skip+cfg.seq_len_n_in+cfg.seq_len_new,
                    shift_seed=cfg.seed, input_seed=cfg.seed+1, noise_seed=cfg.seed+2)
# test_ds = LinearDynamicalDataset(nx=cfg.nx, nu=cfg.nu, ny=cfg.ny, seq_len=cfg.seq_len_ctx+cfg.seq_len_new)
test_dl = DataLoader(test_ds, batch_size=batch_size, num_workers=1)

In [9]:
batch_y, batch_u= next(iter(test_dl))
# print(batch_u.mean(axis = 1),batch_u.std(axis = 1))
batch_y = batch_y[:,:,[0]]
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)
# batch_whitenoise = batch_whitenoise.to(device)
noise_std = 0.0
with torch.no_grad():
    batch_y_ctx = batch_y[:, :cfg.seq_len_ctx, :]
    batch_u_ctx = batch_u[:, :cfg.seq_len_ctx, :]
    batch_y_new = batch_y[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    batch_u_new = batch_u[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    # batch_y_ctx = batch_y_ctx + torch.randn(batch_y_ctx.shape)*noise_std
    # print(batch_y_ctx.shape)
    batch_y_mean = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    batch_y_std = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    for i in range(len(batch_y_ctx[0,0,:])):
        print(i)
        print(batch_y_ctx[:,:,i:i+1].shape)
        batch_y_mean[:,:,i:i+1], batch_y_std, _, _ = model(batch_y_ctx[:,:,i:i+1], batch_u_ctx, batch_u_new[:,:,:],batch_y_new[:,:,i:i+1],cfg.seq_len_n_in)
print(cfg.seq_len_n_in)

0
torch.Size([320, 16000, 1])
30


In [13]:
batch_y, batch_u= next(iter(test_dl))
# print(batch_u.mean(axis = 1),batch_u.std(axis = 1))
batch_y = batch_y[:,:,[0]]
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)
# batch_whitenoise = batch_whitenoise.to(device)
noise_std = 0.0
with torch.no_grad():
    batch_y_ctx = batch_y[:, :cfg.seq_len_ctx, :]
    batch_u_ctx = batch_u[:, :cfg.seq_len_ctx, :]
    batch_y_new = batch_y[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    batch_u_new = batch_u[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    # batch_y_ctx = batch_y_ctx + torch.randn(batch_y_ctx.shape)*noise_std
    # print(batch_y_ctx.shape)
    batch_y_mean2 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    batch_y_std2 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    for i in range(len(batch_y_ctx[0,0,:])):
        print(i)
        print(batch_y_ctx[:,:,i:i+1].shape)
        batch_y_mean2[:,:,i:i+1], batch_y_std2, _, _ = model2(batch_y_ctx[:,:,i:i+1], batch_u_ctx, batch_u_new[:,:,:],batch_y_new[:,:,i:i+1],cfg.seq_len_n_in)
print(cfg.seq_len_n_in)

0
torch.Size([320, 16000, 1])
30


In [12]:
batch_y, batch_u= next(iter(test_dl))
# print(batch_u.mean(axis = 1),batch_u.std(axis = 1))
batch_y = batch_y[:,:,[0]]
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)
# batch_whitenoise = batch_whitenoise.to(device)
noise_std = 0.0
with torch.no_grad():
    batch_y_ctx = batch_y[:, :cfg.seq_len_ctx, :]
    batch_u_ctx = batch_u[:, :cfg.seq_len_ctx, :]
    batch_y_new = batch_y[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    batch_u_new = batch_u[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    # batch_y_ctx = batch_y_ctx + torch.randn(batch_y_ctx.shape)*noise_std
    # print(batch_y_ctx.shape)
    batch_y_mean3 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    batch_y_std3 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    for i in range(len(batch_y_ctx[0,0,:])):
        print(i)
        print(batch_y_ctx[:,:,i:i+1].shape)
        batch_y_mean3[:,:,i:i+1], batch_y_std3, _, _ = model3(batch_y_ctx[:,:,i:i+1], batch_u_ctx, batch_u_new[:,:,:],batch_y_new[:,:,i:i+1],cfg.seq_len_n_in)
print(cfg.seq_len_n_in)

0
torch.Size([320, 16000, 1])
30


In [15]:
batch_y, batch_u= next(iter(test_dl))
# print(batch_u.mean(axis = 1),batch_u.std(axis = 1))
batch_y = batch_y[:,:,[0]]
batch_y = batch_y.to(device)
batch_u = batch_u.to(device)
# batch_whitenoise = batch_whitenoise.to(device)
noise_std = 0.0
with torch.no_grad():
    batch_y_ctx = batch_y[:, :cfg.seq_len_ctx, :]
    batch_u_ctx = batch_u[:, :cfg.seq_len_ctx, :]
    batch_y_new = batch_y[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    batch_u_new = batch_u[:, cfg.seq_len_ctx+cfg.seq_len_skip:, :]
    # batch_y_ctx = batch_y_ctx + torch.randn(batch_y_ctx.shape)*noise_std
    # print(batch_y_ctx.shape)
    batch_y_mean4 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    batch_y_std4 = torch.zeros([batch_size,cfg.seq_len_new,len(batch_y_ctx[0,0,:])])
    for i in range(len(batch_y_ctx[0,0,:])):
        print(i)
        print(batch_y_ctx[:,:,i:i+1].shape)
        batch_y_mean4[:,:,i:i+1], batch_y_std4, _, _ = model4(batch_y_ctx[:,:,i:i+1], batch_u_ctx, batch_u_new[:,:,:],batch_y_new[:,:,i:i+1],cfg.seq_len_n_in)
print(cfg.seq_len_n_in)

0
torch.Size([320, 16000, 1])
30


In [10]:
batch_y_mean = batch_y_mean[:, :, :].to("cpu").detach().numpy()
batch_y_std = batch_y_std[:,:, :].to("cpu").detach().numpy()
batch_y_mean2 = batch_y_mean2[:, :, :].to("cpu").detach().numpy()
batch_y_std2 = batch_y_std2[:,:, :].to("cpu").detach().numpy()
batch_y_mean3 = batch_y_mean3[:, :, :].to("cpu").detach().numpy()
batch_y_std3 = batch_y_std3[:,:, :].to("cpu").detach().numpy()
batch_y_mean4 = batch_y_mean4[:, :, :].to("cpu").detach().numpy()
batch_y_std4 = batch_y_std4[:,:, :].to("cpu").detach().numpy()
batch_y_new = batch_y_new.to("cpu").detach().numpy()
batch_u_new = batch_u_new.to("cpu").detach().numpy()


In [11]:
skip = 0
rmse = metrics.rmse_test(batch_y_new[:, cfg.seq_len_n_in:, 0], batch_y_mean[:,:,0], time_axis=1)
print(f"rmse over the 16k CSTR, few shot on WH {rmse.mean()}")

rmse over the 16k CSTR, few shot on WH 0.24284391105175018


In [20]:
skip = 0
rmse = metrics.rmse_test(batch_y_new[:, cfg.seq_len_n_in:, 0], batch_y_mean2[:,:,0], time_axis=1)
print(f"rmse over the 16k CSTR, few shot on CSTR {rmse.mean()}")

rmse over the 16k CSTR, few shot on CSTR 0.11476530134677887


In [14]:
skip = 0
rmse = metrics.rmse_test(batch_y_new[:, cfg.seq_len_n_in:, 0], batch_y_mean3[:,:,0], time_axis=1)
print(f"rmse over the 16k CSTR, from scratch on CSTR {rmse.mean()}")

rmse over the 16k CSTR, from scratch on CSTR 0.13456487655639648


In [22]:
skip = 0
rmse = metrics.rmse_test(batch_y_new[:, cfg.seq_len_n_in:, 0], batch_y_mean4[:,:,0], time_axis=1)
print(f"rmse over the 16k CSTR, zeros-hot {rmse.mean()}")

rmse over the 16k CSTR, zeros-hot 1.2493032217025757
