In this notebook, try to make global model work by augmenting trajectories with parameters values. Don't do encoder as it is an optional compression step which would be needed only if we were trying to work with high-dim data like images!

we also try a corrected `get_batch` procedure - see `get_batch2()` from `/Users/ajivani/ode_demo_2_true_bt` which adapts example from `torchdiffeq` but makes sure that for each "IC" in our batch from a trajectory, the time vector is correctly indexed instead of just starting from zero like they did in the library's `get_batch()` routine which seems to be clearly wrong!

In [2]:
import pandas as pd

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
# import sparselinear as sl

In [5]:
adjoint=True

In [6]:
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [8]:
runModel=True

In [9]:
import numpy as np
import scipy.linalg as la
import scipy.sparse as sparse
import matplotlib.pyplot as plt

import opinf

In [10]:
import re
import os
import time

In [11]:
%matplotlib inline

In [12]:
plt.rc("axes.spines", right=True, top=True)
plt.rc("figure", dpi=300, 
       figsize=(9, 3)
      )
plt.rc("font", family="serif")
plt.rc("legend", edgecolor="none", frameon=True)
plt.style.use("dark_background")

In [13]:
import edge_utils as edut

### Load Data, split into train and test

In [14]:
# we are removing some data where the edge detection is not necessarily super reliable.
sims_to_remove = np.array([33, 39, 63, 73, 113, 128, 131, 142, 193, 218, 253, 264, 273, 312, 313, 324])

In [15]:
ed_2161, sd_2161 = edut.load_edge_data_blobfree(2161)

In [16]:
theta_s_2161, theta_e_2161 = np.linspace(0, 360, 512)[160] + 1.2 * 180 - 360, np.linspace(0, 360, 512)[320] + 1.2 * 180 - 360
print("Range of angles for CR2161: {} {}".format(theta_s_2161, theta_e_2161))

Range of angles for CR2161: -31.279843444227026 81.44031311154595


In [17]:
nTimes, nTheta_2161, nSims_2161 = ed_2161.shape
nTimes, nTheta_2161, nSims_2161

(90, 160, 278)

In [18]:
theta_grid = np.linspace(np.ceil(theta_s_2161), np.ceil(theta_e_2161), nTheta_2161)

In [19]:
sd_modified = np.setdiff1d(sd_2161, sims_to_remove)

In [20]:
len(sd_modified)

262

In [21]:
# extract training and test sets. for now just random.

from numpy.random import Generator, PCG64
rng = Generator(PCG64())

nTrain = int(np.floor(0.8 * len(sd_modified)))
nTest = len(sd_modified) - nTrain

nTrain, nTest

(209, 53)

In [22]:
sd_train = np.sort(np.random.choice(sd_modified, nTrain, replace=False))
sd_test = np.setdiff1d(sd_modified, sd_train)
sd_test

array([ 34,  35,  40,  42,  43,  50,  52,  54,  59,  76,  78,  79,  89,
        91,  99, 100, 101, 130, 135, 140, 141, 152, 155, 158, 161, 172,
       184, 185, 188, 191, 200, 201, 203, 204, 212, 214, 221, 222, 230,
       239, 241, 243, 257, 268, 274, 275, 278, 303, 315, 317, 321, 322,
       329])

In [23]:
sd_test.shape

(53,)

In [24]:
def getRValuesAllSims(edge_data_matrix):
    """
    Return r values for all sims at once so we don't lose time in training processing r values repeatedly
    """
    r_data_matrix = np.zeros(edge_data_matrix.shape)
    nsims = edge_data_matrix.shape[2]
    for i in range(nsims):
        r_vals, theta_vals = edut.getRValues(edge_data_matrix, simIdx=i, minStartIdx=0)
        r_data_matrix[:, :, i] = r_vals

    return r_data_matrix

In [25]:
rd_2161 = getRValuesAllSims(ed_2161)

### Set up for Parametrized and Regular Neural ODEs

In [26]:
# get the param list for CR2161, get background values as well

In [27]:
cme_params = pd.read_csv("./restarts_CR2161.csv")
cme_params.head()

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,restartdir,realization
0,0.58375,22.526622,0.682075,204.725,1,4,4
1,0.52625,6.919483,0.519179,164.225,1,4,4
2,0.72625,5.0875,0.662022,233.975,1,4,4
3,0.36125,18.36077,0.420743,200.675,-1,4,4
4,0.69125,8.12712,0.735101,184.925,1,4,4


In [28]:
bg_params = pd.read_csv("./backgrounds_CR2161.csv")
bg_params

Unnamed: 0,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.5987,547686.9,259077.5
1,0.6876,514248.2,170240.5
2,0.7334,437169.7,242477.4
3,0.801,378833.7,147856.5
4,0.9431,491082.5,205150.8
5,1.0077,349735.7,230084.9
6,1.0404,668772.0,278163.0
7,1.0921,463650.6,299730.0
8,1.2009,337215.4,113682.7
9,1.2292,583776.3,160912.2


In [29]:
unique_bg = cme_params.restartdir.unique()
unique_bg

array([ 4, 12])

In [30]:
bg_params.FactorB0[3]

0.801

In [31]:
bg_params.shape

(30, 3)

In [32]:
# identify which rows correspond to a particular background
cme_param_bg_idx = []
for bg in unique_bg:
    cme_param_bg_idx.append(np.array((cme_params.restartdir == bg).to_list()))

# now append columns to cme_params dataframe to merge

In [33]:
cme_params

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,restartdir,realization
0,0.58375,22.526622,0.682075,204.725,1,4,4
1,0.52625,6.919483,0.519179,164.225,1,4,4
2,0.72625,5.087500,0.662022,233.975,1,4,4
3,0.36125,18.360770,0.420743,200.675,-1,4,4
4,0.69125,8.127120,0.735101,184.925,1,4,4
...,...,...,...,...,...,...,...
295,0.36250,32.665259,0.436133,197.750,-1,12,5
296,0.46250,26.757500,0.443711,224.750,1,12,5
297,0.76250,16.930533,0.874492,215.750,-1,12,5
298,0.61250,21.948929,0.553164,188.750,1,12,5


In [34]:
bg_params_to_insert = np.zeros((cme_params.shape[0], bg_params.shape[1]))
for bg_idx, bg in enumerate(unique_bg):
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 0] = bg_params.FactorB0[bg - 1]
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 1] = bg_params.PoyntingFluxPerBSi[bg - 1]
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 2] = bg_params.LperpTimesSqrtBSi[bg - 1]

cme_params["FactorB0"] = pd.Series(bg_params_to_insert[:, 0])
cme_params["PoyntingFluxPerBSi"] = pd.Series(bg_params_to_insert[:, 1])
cme_params["LperpTimesSqrtBSi"] = pd.Series(bg_params_to_insert[:, 2])
cme_params["SimID"] = pd.Series(np.linspace(31, 330, 300))
sim_idx_successful = sd_2161 - 30 - 1
sim_idx_to_drop = np.setdiff1d(np.linspace(0, 299, 300), sim_idx_successful).astype(int).tolist()

sim_idx_to_drop
cme_params.drop(["restartdir"], axis=1, inplace=True)
cme_params.drop(sim_idx_to_drop, inplace=True)
cme_params.drop(["SimID"], axis=1, inplace=True)
cme_params

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,realization,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.58375,22.526622,0.682075,204.725,1,4,0.8010,378833.7,147856.5
1,0.52625,6.919483,0.519179,164.225,1,4,0.8010,378833.7,147856.5
2,0.72625,5.087500,0.662022,233.975,1,4,0.8010,378833.7,147856.5
3,0.36125,18.360770,0.420743,200.675,-1,4,0.8010,378833.7,147856.5
4,0.69125,8.127120,0.735101,184.925,1,4,0.8010,378833.7,147856.5
...,...,...,...,...,...,...,...,...,...
294,0.73750,15.331483,0.721367,170.750,-1,5,1.3016,304514.4,283042.3
295,0.36250,32.665259,0.436133,197.750,-1,5,1.3016,304514.4,283042.3
296,0.46250,26.757500,0.443711,224.750,1,5,1.3016,304514.4,283042.3
297,0.76250,16.930533,0.874492,215.750,-1,5,1.3016,304514.4,283042.3


In [35]:
cme_params.ApexHeight.min(), cme_params.ApexHeight.max()

(0.280873828, 0.928935547)

In [36]:
0.875 * 0.3, 1.25 * 0.8

(0.2625, 1.0)

In [37]:
cme_params.BStrength.min(), cme_params.BStrength.max()

(5.0875, 42.76293137)

In [38]:
(0.5 / 0.8) * 19.25 * 0.37, (2.0 / 0.3) * 19.25 * 0.37

(4.4515625, 47.483333333333334)

In [39]:
cme_params.OrientationCme.min(), cme_params.OrientationCme.max()

(155.225, 244.775)

In [40]:
200 - 45, 200 + 45

(155, 245)

In [41]:
type(cme_params.min())
# cme_params.max()

pandas.core.series.Series

In [42]:
cme_p_min = cme_params.min()
cme_p_min.BStrength = 4.4516
cme_p_min.Radius = 0.3
cme_p_min.OrientationCme = 155
cme_p_min.iHelicity = -1
cme_p_min.ApexHeight = 0.2625
cme_p_min.realization = 1
cme_p_min.FactorB0 = 0.54
cme_p_min.PoyntingFluxPerBSi = 0.3e6
cme_p_min.LperpTimesSqrtBSi = 0.3e5

cme_p_min

Radius                     0.3000
BStrength                  4.4516
ApexHeight                 0.2625
OrientationCme           155.0000
iHelicity                 -1.0000
realization                1.0000
FactorB0                   0.5400
PoyntingFluxPerBSi    300000.0000
LperpTimesSqrtBSi      30000.0000
dtype: float64

In [43]:
cme_p_max = cme_params.max()
cme_p_max.BStrength = 47.4833
cme_p_max.Radius = 0.8
cme_p_max.OrientationCme = 245
cme_p_max.iHelicity = 1
cme_p_max.ApexHeight = 1
cme_p_max.realization = 12
cme_p_max.FactorB0 = 2.7
cme_p_max.PoyntingFluxPerBSi = 1.1e6
cme_p_max.LperpTimesSqrtBSi = 3e5

cme_p_max

Radius                8.000000e-01
BStrength             4.748330e+01
ApexHeight            1.000000e+00
OrientationCme        2.450000e+02
iHelicity             1.000000e+00
realization           1.200000e+01
FactorB0              2.700000e+00
PoyntingFluxPerBSi    1.100000e+06
LperpTimesSqrtBSi     3.000000e+05
dtype: float64

In [44]:
# define min and max series in terms of actual ranges of parameters?
# cme_min = pd.Series([0.3, 0.54, 0.3e6, 0.3e5, ])

In [45]:
# now rescale each column of `cme_params` to lie between 0 and 1.

# cme_params_norm = (cme_params - cme_params.min()) / (cme_params.max() - cme_params.min())
cme_params_norm = (cme_params - cme_p_min) / (cme_p_max - cme_p_min)
cme_params_norm

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,realization,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.5675,0.420040,0.568916,0.5525,1.0,0.272727,0.120833,0.098542,0.436506
1,0.4525,0.057350,0.348039,0.1025,1.0,0.272727,0.120833,0.098542,0.436506
2,0.8525,0.014777,0.541725,0.8775,1.0,0.272727,0.120833,0.098542,0.436506
3,0.1225,0.323231,0.214567,0.5075,0.0,0.272727,0.120833,0.098542,0.436506
4,0.7825,0.085414,0.640815,0.3325,1.0,0.272727,0.120833,0.098542,0.436506
...,...,...,...,...,...,...,...,...,...
294,0.8750,0.252834,0.622193,0.1750,0.0,0.363636,0.352593,0.005643,0.937194
295,0.1250,0.655648,0.235434,0.4750,0.0,0.363636,0.352593,0.005643,0.937194
296,0.3250,0.518360,0.245710,0.7750,1.0,0.363636,0.352593,0.005643,0.937194
297,0.9250,0.289994,0.829820,0.6750,0.0,0.363636,0.352593,0.005643,0.937194


In [46]:
cme_params_to_augment = cme_params_norm.to_numpy()
cme_params_to_augment.shape

(278, 9)

In [47]:
# cme_params_to_augment = np.expand_dims(cme_params_to_augment, axis=1)

In [48]:
param_dim = cme_params_to_augment.shape[1]
param_dim

9

In [49]:
# now augment dataset with these scaled values.
rd_2161.shape

(90, 160, 278)

In [50]:
data_dim = rd_2161.shape[1]
data_dim

160

In [51]:
input_dim = rd_2161.shape[1]
input_dim, param_dim

(160, 9)

In [52]:
augmented_r = np.zeros((rd_2161.shape[0], data_dim + param_dim, rd_2161.shape[2]))
augmented_r[:, :(data_dim), :] = rd_2161
for iii in range(rd_2161.shape[2]):
    augmented_r[:, (data_dim):, iii] = cme_params_to_augment[iii, :]

In [53]:
augmented_r.shape

(90, 169, 278)

In [54]:
augmented_r[:, :, 0]

array([[ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       [ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       [ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       ...,
       [17.87190238, 17.85375065, 17.844671  , ...,  0.12083333,
         0.09854213,  0.43650556],
       [17.97142902, 18.01082051, 18.0539653 , ...,  0.12083333,
         0.09854213,  0.43650556],
       [18.17501243, 18.21278563, 18.25486527, ...,  0.12083333,
         0.09854213,  0.43650556]])

In [55]:
def getDataForSim(edge_data_matrix, r_data_matrix, sim_data, sid):
    """
    Take in a randomly chosen sim from the training set and return the following:
    y0_train_torch
    y_train_torch
    i.e. IC and data in torch tensor format on Device
    t_train_torch
    and correct sim_index from sim_data
    """
    
    sim_index = np.argwhere(sim_data == sid)[0][0]
    
    r_sim = r_data_matrix[:, :, sim_index]
    
    tMinIdx, tMin, tMaxIdx, tMax = edut.getTMinTMax(edge_data_matrix, simIdx=sim_index)
    
    r_sim_valid = r_sim[tMinIdx:(tMaxIdx+1), :]
    valid_times = np.arange(tMin, tMax + 2, step=2)
    
    tTrainEnd = tMin + np.floor((2/3)*(tMax - tMin))
    
    
    trainEndIdx = np.argmin(np.abs(valid_times - tTrainEnd))
    #     trainEndIdx = np.argwhere(valid_times == tTrainEnd)[0][0]
    
    tTrain = valid_times[:(trainEndIdx + 1)]
    
    tTest = valid_times[(trainEndIdx + 1):]
    
    tTrainScaled = (tTrain - tMin) / (tMax - tMin)
    tTestScaled = (tTest - tMin) / (tMax - tMin)
    
    tAllScaled = (valid_times - tMin) / (tMax - tMin)
    
    y0_train_orig = r_sim_valid[0, :]
    y0_train_torch = torch.from_numpy(np.float32(y0_train_orig))
    y0_train_torch = y0_train_torch.reshape((1, len(y0_train_torch))).to(device)
    
    
    y_train_orig = r_sim_valid[:(trainEndIdx + 1), :]
    y_train_torch = torch.from_numpy(np.expand_dims(np.float32(y_train_orig), axis=1)).to(device)
    
    y_full_torch = torch.from_numpy(np.expand_dims(np.float32(r_sim_valid), axis=1)).to(device)
    
    t_train_torch = torch.tensor(np.float32(tTrainScaled)).to(device)
    t_scaled_torch = torch.tensor(np.float32(tAllScaled)).to(device)
    
    return y0_train_torch, y_train_torch, y_full_torch, t_train_torch, t_scaled_torch, sim_index

In [56]:
def get_batch(torch_train_data, torch_train_time, batch_time=5, batch_size=10):
    s = torch.from_numpy(np.random.choice(np.arange(len(torch_train_time) - batch_time, dtype=np.int64),
                                          batch_size,
                                          replace=False))
    batch_y0 = torch_train_data[s]  # (M, D)
    batch_t = torch.zeros((batch_size, batch_time))
    for i in range(batch_size):
        batch_t[i, :] = torch_train_time[s[i]:(s[i] + batch_time)]
        
    batch_y = torch.stack([torch_train_data[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [57]:
y0t, ytt, yft, ttt, tst, si = getDataForSim(ed_2161, augmented_r, sd_2161, 31)

In [58]:
# ytt.shape

In [59]:
by0, bt, by = get_batch(ytt, ttt)

In [60]:
by.shape

torch.Size([5, 10, 1, 169])

In [61]:
by[0, 0, :, :]

tensor([[ 8.0900,  8.0899,  8.0927,  8.0983,  8.1067,  8.1176,  8.1311,  8.1471,
          8.1653,  8.1859,  8.2085,  8.2333,  8.2600,  8.2885,  8.3189,  8.3508,
          8.3844,  8.4195,  8.4559,  8.4936,  8.5325,  8.5725,  8.6136,  8.6555,
          8.6982,  8.7416,  8.7857,  8.8302,  8.8750,  8.9199,  8.9651,  9.0101,
          9.0550,  9.0998,  9.1443,  9.1884,  9.2321,  9.2753,  9.3179,  9.3600,
          9.4015,  9.4424,  9.4826,  9.5221,  9.5609,  9.5990,  9.6363,  9.6729,
          9.7086,  9.7435,  9.7776,  9.8109,  9.8433,  9.8747,  9.9053,  9.9350,
          9.9639,  9.9920, 10.0192, 10.0454, 10.0708, 10.0955, 10.1195, 10.1425,
         10.1647, 10.1861, 10.2067, 10.2267, 10.2461, 10.2645, 10.2822, 10.2991,
         10.3153, 10.3306, 10.3452, 10.3591, 10.3721, 10.3843, 10.3955, 10.4059,
         10.4153, 10.4239, 10.4317, 10.4400, 10.4473, 10.4537, 10.4591, 10.4635,
         10.4668, 10.4692, 10.4704, 10.4706, 10.4696, 10.4675, 10.4642, 10.4597,
         10.4539, 10.4467, 1

In [62]:
# by.shape

In [63]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net1 = nn.Sequential(
            nn.Linear(169, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 169),
        )
        
        for m in self.net1.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
                
    def forward(self, t, y):
        return self.net1(y)

In [64]:
class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.losses = []
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val
        self.losses.append(self.avg)

In [65]:
ii = 0
func = ODEFunc().to(device)
optimizer = optim.Adam(func.parameters(), lr=1e-3)
end = time.time()

niters=6000
test_freq=5

In [66]:
batch_size = 10
batch_time = 5

In [67]:
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

if runModel:
    # step_loss = []
    for itr in range(1, niters + 1):
        optimizer.zero_grad()
        chosen_sim = np.random.choice(sd_train)
        y0_train_torch, y_train_torch, y_full_torch, t_train_torch, t_scaled_torch, sim_index = getDataForSim(ed_2161,
                                                                                                        augmented_r,
                                                                                                        sd_2161,
                                                                                                        chosen_sim)
        batch_y0, batch_t, batch_y = get_batch(y_train_torch, t_train_torch)
        pred_y = torch.zeros_like(batch_y)

        for i in range(batch_size):
            pred_y[:, i, :, :] = odeint(func, batch_y0[i, :, :], batch_t[i, :]).to(device)

        loss = torch.mean(torch.abs(pred_y - batch_y))
        loss.backward()
        optimizer.step()

        #     step_loss.append(loss.item())

        time_meter.update(time.time() - end)
        loss_meter.update(loss.item())

        if itr % test_freq == 0:
            with torch.no_grad():
                ## more radical defn would be loss on ENTIRE series?
                #             pred_y = odeint(func, y0_train_torch, t_train_torch)
                #             loss = torch.mean(torch.abs(pred_y - y_train_torch))
                pred_y_full_series = odeint(func, y0_train_torch, t_scaled_torch)
                #             print(pred_y_full_series.shape)
                loss_full_series = torch.mean(torch.abs(pred_y_full_series[:, :, :input_dim] 
                                                        - y_full_torch[:, :, :input_dim]))
                print("Iter {:04d} | Total Loss {:.6f} | Sim ID {:03d} ".format(itr,
                                                                                loss_full_series.item(), 
                                                                                chosen_sim))
                ii += 1

        end = time.time()

Iter 0005 | Total Loss 4.977581 | Sim ID 085 
Iter 0010 | Total Loss 7.705052 | Sim ID 148 
Iter 0015 | Total Loss 4.386827 | Sim ID 176 
Iter 0020 | Total Loss 2.280607 | Sim ID 311 
Iter 0025 | Total Loss 6.737478 | Sim ID 127 
Iter 0030 | Total Loss 4.065915 | Sim ID 323 
Iter 0035 | Total Loss 5.184512 | Sim ID 181 
Iter 0040 | Total Loss 2.064832 | Sim ID 112 
Iter 0045 | Total Loss 4.532670 | Sim ID 287 
Iter 0050 | Total Loss 1.621136 | Sim ID 082 
Iter 0055 | Total Loss 4.594798 | Sim ID 145 
Iter 0060 | Total Loss 0.949450 | Sim ID 276 
Iter 0065 | Total Loss 4.223021 | Sim ID 105 
Iter 0070 | Total Loss 0.955569 | Sim ID 234 
Iter 0075 | Total Loss 1.586343 | Sim ID 044 
Iter 0080 | Total Loss 1.792297 | Sim ID 252 
Iter 0085 | Total Loss 0.530773 | Sim ID 074 
Iter 0090 | Total Loss 0.630445 | Sim ID 271 
Iter 0095 | Total Loss 1.540943 | Sim ID 060 
Iter 0100 | Total Loss 0.360589 | Sim ID 319 
Iter 0105 | Total Loss 3.587350 | Sim ID 156 
Iter 0110 | Total Loss 1.923126 | 

KeyboardInterrupt: 

In [None]:
# torch.save(func, "dydt_Ay_global_model_5000_steps.pkl")

In [None]:
# torch.save(optimizer, "dydt_Ay_global_model_optimizer.pkl")

In [None]:
func_saved = torch.load("dydt_Ay_global_model_5000_steps.pkl")

In [None]:
optim_saved = torch.load("dydt_Ay_global_model_optimizer.pkl")

In [None]:
func_saved.eval()

In [None]:
optim_saved.state_dict

### Testing

In [None]:
sd_test

In [None]:
y01, yt1, yf1, tt1, ts1, si1 = getDataForSim(ed_2161, augmented_r, sd_2161, 326)

In [None]:
si1

In [None]:
yf1.shape

In [None]:
with torch.no_grad():
    yt_pred_326 = odeint(func, y01, ts1)

In [None]:
yt_pred_326.shape

In [None]:
yf1_np = yf1.cpu().numpy()[:, 0, :input_dim]
yt_pred_np = yt_pred_326.cpu().numpy()[:, 0, :input_dim]

In [None]:
edut.plotTrainPredData(yf1_np, 
                       yt_pred_np, 
                       ed_2161, 
                       sd_2161, 
                       theta=np.linspace(-31, 82, 160),
                       simIdx=si1, 
                       savefig=False)

In [None]:
edut.plotTrainPredData1Model(yf1_np,
                             yt_pred_np,
                             ed_2161,
                             sd_2161,
                             theta=np.linspace(-31, 82, 160), 
                             simIdx=si1,
                             savefig=False)

In [None]:
y02, yt2, yf2, tt2, ts2, si2 = getDataForSim(ed_2161, augmented_r, sd_2161, 70)

In [None]:
si2

In [None]:
yt2.shape

In [None]:
ts2.shape

In [None]:
with torch.no_grad():
    yt_pred_70 = odeint(func, y02, ts2)

In [None]:
yf2_np = yf2.cpu().numpy()[:, 0, :input_dim]
yt_pred2_np = yt_pred_70.cpu().numpy()[:, 0, :input_dim]

In [None]:
edut.plotTrainPredData(yf2_np, 
                       yt_pred2_np, 
                       ed_2161, 
                       sd_2161, 
                       theta=np.linspace(-31, 82, 160),
                       simIdx=si2, 
                       savefig=False)

In [None]:
edut.plotTrainPredData1Model(yf2_np,
                             yt_pred2_np,
                             ed_2161,
                             sd_2161,
                             theta=np.linspace(-31, 82, 160), 
                             simIdx=si2,
                             savefig=False)

In [None]:
save_test_preds=False

In [None]:
if save_test_preds:
    for ss in sd_test:
        y0_test, yt_test, yf_test, tt_test, ts_test, si_test = getDataForSim(ed_2161, augmented_r, sd_2161, ss)

        with torch.no_grad():
            yt_pred_test = odeint(func, y0_test, ts_test)

        y_test_np = yf_test.cpu().numpy()[:, 0, :input_dim]
        y_pred_np = yt_pred_test.cpu().numpy()[:, 0, :input_dim]

        edut.plotTrainPredData1Model(y_test_np,
                                 y_pred_np,
                                 ed_2161,
                                 sd_2161,
                                 theta=np.linspace(-31, 82, 160), 
                                 simIdx=si_test,
                                 savefig=True,
                                 savedir="./test_data_notebook_05_global_model/")