In this notebook, try to make global model work by augmenting trajectories with parameters values. Don't do encoder as it is an optional compression step which would be needed only if we were trying to work with high-dim data like images!

we also try a corrected `get_batch` procedure - see `get_batch2()` from `/Users/ajivani/ode_demo_2_true_bt` which adapts example from `torchdiffeq` but makes sure that for each "IC" in our batch from a trajectory, the time vector is correctly indexed instead of just starting from zero like they did in the library's `get_batch()` routine which seems to be clearly wrong!

we also do a "testing" run - where we take only 3 or 4 sims in turn, and run through multiple batches of each one by one to see how the network adapts or if it adapts to learn a general parametrized version of them in the first place. We can even track error across all 3 simulations in our log.

In [1]:
import pandas as pd

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
# import sparselinear as sl

In [4]:
adjoint=True

In [5]:
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [7]:
runModel=False
runTestCase=True

In [8]:
import numpy as np
import scipy.linalg as la
import scipy.sparse as sparse
import matplotlib.pyplot as plt

import opinf

In [9]:
import re
import os
import time

In [10]:
%matplotlib inline

In [11]:
plt.rc("axes.spines", right=True, top=True)
plt.rc("figure", dpi=300, 
       figsize=(9, 3)
      )
plt.rc("font", family="serif")
plt.rc("legend", edgecolor="none", frameon=True)
plt.style.use("dark_background")

In [12]:
import edge_utils as edut

### Load Data, split into train and test

In [13]:
# we are removing some data where the edge detection is not necessarily super reliable.
sims_to_remove = np.array([33, 39, 63, 73, 113, 128, 131, 142, 193, 218, 253, 264, 273, 312, 313, 324])

In [14]:
ed_2161, sd_2161 = edut.load_edge_data_blobfree(2161)

In [15]:
theta_s_2161, theta_e_2161 = np.linspace(0, 360, 512)[160] + 1.2 * 180 - 360, np.linspace(0, 360, 512)[320] + 1.2 * 180 - 360
print("Range of angles for CR2161: {} {}".format(theta_s_2161, theta_e_2161))

Range of angles for CR2161: -31.279843444227026 81.44031311154595


In [16]:
nTimes, nTheta_2161, nSims_2161 = ed_2161.shape
nTimes, nTheta_2161, nSims_2161

(90, 160, 278)

In [17]:
theta_grid = np.linspace(np.ceil(theta_s_2161), np.ceil(theta_e_2161), nTheta_2161)

In [18]:
sd_modified = np.setdiff1d(sd_2161, sims_to_remove)

In [19]:
len(sd_modified)

262

In [20]:
# extract training and test sets. for now just random.

from numpy.random import Generator, PCG64
rng = Generator(PCG64())

nTrain = int(np.floor(0.8 * len(sd_modified)))
nTest = len(sd_modified) - nTrain

nTrain, nTest

(209, 53)

In [21]:
sd_train = np.sort(np.random.choice(sd_modified, nTrain, replace=False))
sd_test = np.setdiff1d(sd_modified, sd_train)
sd_test

array([ 40,  41,  47,  49,  57,  60,  65,  70,  82,  83,  85,  86,  87,
        91,  92,  93,  96, 110, 129, 138, 143, 153, 154, 190, 198, 203,
       206, 208, 227, 229, 231, 233, 237, 243, 247, 249, 257, 260, 261,
       266, 267, 269, 275, 277, 298, 299, 300, 301, 305, 316, 325, 327,
       328])

In [22]:
sd_test.shape

(53,)

In [23]:
def getRValuesAllSims(edge_data_matrix):
    """
    Return r values for all sims at once so we don't lose time in training processing r values repeatedly
    """
    r_data_matrix = np.zeros(edge_data_matrix.shape)
    nsims = edge_data_matrix.shape[2]
    for i in range(nsims):
        r_vals, theta_vals = edut.getRValues(edge_data_matrix, simIdx=i, minStartIdx=0)
        r_data_matrix[:, :, i] = r_vals

    return r_data_matrix

In [24]:
rd_2161 = getRValuesAllSims(ed_2161)

### Set up for Parametrized and Regular Neural ODEs

In [25]:
# get the param list for CR2161, get background values as well

In [26]:
cme_params = pd.read_csv("./restarts_CR2161.csv")
cme_params.head()

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,restartdir,realization
0,0.58375,22.526622,0.682075,204.725,1,4,4
1,0.52625,6.919483,0.519179,164.225,1,4,4
2,0.72625,5.0875,0.662022,233.975,1,4,4
3,0.36125,18.36077,0.420743,200.675,-1,4,4
4,0.69125,8.12712,0.735101,184.925,1,4,4


In [27]:
bg_params = pd.read_csv("./backgrounds_CR2161.csv")
bg_params

Unnamed: 0,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.5987,547686.9,259077.5
1,0.6876,514248.2,170240.5
2,0.7334,437169.7,242477.4
3,0.801,378833.7,147856.5
4,0.9431,491082.5,205150.8
5,1.0077,349735.7,230084.9
6,1.0404,668772.0,278163.0
7,1.0921,463650.6,299730.0
8,1.2009,337215.4,113682.7
9,1.2292,583776.3,160912.2


In [28]:
unique_bg = cme_params.restartdir.unique()
unique_bg

array([ 4, 12])

In [29]:
bg_params.FactorB0[3]

0.801

In [30]:
bg_params.shape

(30, 3)

In [31]:
# identify which rows correspond to a particular background
cme_param_bg_idx = []
for bg in unique_bg:
    cme_param_bg_idx.append(np.array((cme_params.restartdir == bg).to_list()))

# now append columns to cme_params dataframe to merge

In [32]:
cme_params

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,restartdir,realization
0,0.58375,22.526622,0.682075,204.725,1,4,4
1,0.52625,6.919483,0.519179,164.225,1,4,4
2,0.72625,5.087500,0.662022,233.975,1,4,4
3,0.36125,18.360770,0.420743,200.675,-1,4,4
4,0.69125,8.127120,0.735101,184.925,1,4,4
...,...,...,...,...,...,...,...
295,0.36250,32.665259,0.436133,197.750,-1,12,5
296,0.46250,26.757500,0.443711,224.750,1,12,5
297,0.76250,16.930533,0.874492,215.750,-1,12,5
298,0.61250,21.948929,0.553164,188.750,1,12,5


In [33]:
bg_params_to_insert = np.zeros((cme_params.shape[0], bg_params.shape[1]))
for bg_idx, bg in enumerate(unique_bg):
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 0] = bg_params.FactorB0[bg - 1]
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 1] = bg_params.PoyntingFluxPerBSi[bg - 1]
    bg_params_to_insert[cme_param_bg_idx[bg_idx], 2] = bg_params.LperpTimesSqrtBSi[bg - 1]

cme_params["FactorB0"] = pd.Series(bg_params_to_insert[:, 0])
cme_params["PoyntingFluxPerBSi"] = pd.Series(bg_params_to_insert[:, 1])
cme_params["LperpTimesSqrtBSi"] = pd.Series(bg_params_to_insert[:, 2])
cme_params["SimID"] = pd.Series(np.linspace(31, 330, 300))
sim_idx_successful = sd_2161 - 30 - 1
sim_idx_to_drop = np.setdiff1d(np.linspace(0, 299, 300), sim_idx_successful).astype(int).tolist()

sim_idx_to_drop
cme_params.drop(["restartdir"], axis=1, inplace=True)
cme_params.drop(sim_idx_to_drop, inplace=True)
cme_params.drop(["SimID"], axis=1, inplace=True)
cme_params

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,realization,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.58375,22.526622,0.682075,204.725,1,4,0.8010,378833.7,147856.5
1,0.52625,6.919483,0.519179,164.225,1,4,0.8010,378833.7,147856.5
2,0.72625,5.087500,0.662022,233.975,1,4,0.8010,378833.7,147856.5
3,0.36125,18.360770,0.420743,200.675,-1,4,0.8010,378833.7,147856.5
4,0.69125,8.127120,0.735101,184.925,1,4,0.8010,378833.7,147856.5
...,...,...,...,...,...,...,...,...,...
294,0.73750,15.331483,0.721367,170.750,-1,5,1.3016,304514.4,283042.3
295,0.36250,32.665259,0.436133,197.750,-1,5,1.3016,304514.4,283042.3
296,0.46250,26.757500,0.443711,224.750,1,5,1.3016,304514.4,283042.3
297,0.76250,16.930533,0.874492,215.750,-1,5,1.3016,304514.4,283042.3


In [34]:
cme_params.ApexHeight.min(), cme_params.ApexHeight.max()

(0.280873828, 0.928935547)

In [35]:
0.875 * 0.3, 1.25 * 0.8

(0.2625, 1.0)

In [36]:
cme_params.BStrength.min(), cme_params.BStrength.max()

(5.0875, 42.76293137)

In [37]:
(0.5 / 0.8) * 19.25 * 0.37, (2.0 / 0.3) * 19.25 * 0.37

(4.4515625, 47.483333333333334)

In [38]:
cme_params.OrientationCme.min(), cme_params.OrientationCme.max()

(155.225, 244.775)

In [39]:
200 - 45, 200 + 45

(155, 245)

In [40]:
type(cme_params.min())
# cme_params.max()

pandas.core.series.Series

In [41]:
cme_p_min = cme_params.min()
cme_p_min.BStrength = 4.4516
cme_p_min.Radius = 0.3
cme_p_min.OrientationCme = 155
cme_p_min.iHelicity = -1
cme_p_min.ApexHeight = 0.2625
cme_p_min.realization = 1
cme_p_min.FactorB0 = 0.54
cme_p_min.PoyntingFluxPerBSi = 0.3e6
cme_p_min.LperpTimesSqrtBSi = 0.3e5

cme_p_min

Radius                     0.3000
BStrength                  4.4516
ApexHeight                 0.2625
OrientationCme           155.0000
iHelicity                 -1.0000
realization                1.0000
FactorB0                   0.5400
PoyntingFluxPerBSi    300000.0000
LperpTimesSqrtBSi      30000.0000
dtype: float64

In [42]:
cme_p_max = cme_params.max()
cme_p_max.BStrength = 47.4833
cme_p_max.Radius = 0.8
cme_p_max.OrientationCme = 245
cme_p_max.iHelicity = 1
cme_p_max.ApexHeight = 1
cme_p_max.realization = 12
cme_p_max.FactorB0 = 2.7
cme_p_max.PoyntingFluxPerBSi = 1.1e6
cme_p_max.LperpTimesSqrtBSi = 3e5

cme_p_max

Radius                8.000000e-01
BStrength             4.748330e+01
ApexHeight            1.000000e+00
OrientationCme        2.450000e+02
iHelicity             1.000000e+00
realization           1.200000e+01
FactorB0              2.700000e+00
PoyntingFluxPerBSi    1.100000e+06
LperpTimesSqrtBSi     3.000000e+05
dtype: float64

In [43]:
# define min and max series in terms of actual ranges of parameters?
# cme_min = pd.Series([0.3, 0.54, 0.3e6, 0.3e5, ])

In [44]:
# now rescale each column of `cme_params` to lie between 0 and 1.

# cme_params_norm = (cme_params - cme_params.min()) / (cme_params.max() - cme_params.min())
cme_params_norm = (cme_params - cme_p_min) / (cme_p_max - cme_p_min)
cme_params_norm

Unnamed: 0,Radius,BStrength,ApexHeight,OrientationCme,iHelicity,realization,FactorB0,PoyntingFluxPerBSi,LperpTimesSqrtBSi
0,0.5675,0.420040,0.568916,0.5525,1.0,0.272727,0.120833,0.098542,0.436506
1,0.4525,0.057350,0.348039,0.1025,1.0,0.272727,0.120833,0.098542,0.436506
2,0.8525,0.014777,0.541725,0.8775,1.0,0.272727,0.120833,0.098542,0.436506
3,0.1225,0.323231,0.214567,0.5075,0.0,0.272727,0.120833,0.098542,0.436506
4,0.7825,0.085414,0.640815,0.3325,1.0,0.272727,0.120833,0.098542,0.436506
...,...,...,...,...,...,...,...,...,...
294,0.8750,0.252834,0.622193,0.1750,0.0,0.363636,0.352593,0.005643,0.937194
295,0.1250,0.655648,0.235434,0.4750,0.0,0.363636,0.352593,0.005643,0.937194
296,0.3250,0.518360,0.245710,0.7750,1.0,0.363636,0.352593,0.005643,0.937194
297,0.9250,0.289994,0.829820,0.6750,0.0,0.363636,0.352593,0.005643,0.937194


In [45]:
cme_params_to_augment = cme_params_norm.to_numpy()
cme_params_to_augment.shape

(278, 9)

In [46]:
# cme_params_to_augment = np.expand_dims(cme_params_to_augment, axis=1)

In [47]:
param_dim = cme_params_to_augment.shape[1]
param_dim

9

In [48]:
# now augment dataset with these scaled values.
rd_2161.shape

(90, 160, 278)

In [49]:
data_dim = rd_2161.shape[1]
data_dim

160

In [50]:
input_dim = rd_2161.shape[1]
input_dim, param_dim

(160, 9)

In [51]:
augmented_r = np.zeros((rd_2161.shape[0], data_dim + param_dim, rd_2161.shape[2]))
augmented_r[:, :(data_dim), :] = rd_2161
for iii in range(rd_2161.shape[2]):
    augmented_r[:, (data_dim):, iii] = cme_params_to_augment[iii, :]

In [52]:
augmented_r.shape

(90, 169, 278)

In [53]:
augmented_r[:, :, 0]

array([[ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       [ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       [ 3.99369242,  3.99004235,  3.98640029, ...,  0.12083333,
         0.09854213,  0.43650556],
       ...,
       [17.87190238, 17.85375065, 17.844671  , ...,  0.12083333,
         0.09854213,  0.43650556],
       [17.97142902, 18.01082051, 18.0539653 , ...,  0.12083333,
         0.09854213,  0.43650556],
       [18.17501243, 18.21278563, 18.25486527, ...,  0.12083333,
         0.09854213,  0.43650556]])

In [54]:
def getDataForSim(edge_data_matrix, r_data_matrix, sim_data, sid):
    """
    Take in a randomly chosen sim from the training set and return the following:
    y0_train_torch
    y_train_torch
    i.e. IC and data in torch tensor format on Device
    t_train_torch
    and correct sim_index from sim_data
    """
    
    sim_index = np.argwhere(sim_data == sid)[0][0]
    
    r_sim = r_data_matrix[:, :, sim_index]
    
    tMinIdx, tMin, tMaxIdx, tMax = edut.getTMinTMax(edge_data_matrix, simIdx=sim_index)
    
    r_sim_valid = r_sim[tMinIdx:(tMaxIdx+1), :]
    valid_times = np.arange(tMin, tMax + 2, step=2)
    
    tTrainEnd = tMin + np.floor((2/3)*(tMax - tMin))
    
    
    trainEndIdx = np.argmin(np.abs(valid_times - tTrainEnd))
    #     trainEndIdx = np.argwhere(valid_times == tTrainEnd)[0][0]
    
    tTrain = valid_times[:(trainEndIdx + 1)]
    
    tTest = valid_times[(trainEndIdx + 1):]
    
    tTrainScaled = (tTrain - tMin) / (tMax - tMin)
    tTestScaled = (tTest - tMin) / (tMax - tMin)
    
    tAllScaled = (valid_times - tMin) / (tMax - tMin)
    
    y0_train_orig = r_sim_valid[0, :]
    y0_train_torch = torch.from_numpy(np.float32(y0_train_orig))
    y0_train_torch = y0_train_torch.reshape((1, len(y0_train_torch))).to(device)
    
    
    y_train_orig = r_sim_valid[:(trainEndIdx + 1), :]
    y_train_torch = torch.from_numpy(np.expand_dims(np.float32(y_train_orig), axis=1)).to(device)
    
    y_full_torch = torch.from_numpy(np.expand_dims(np.float32(r_sim_valid), axis=1)).to(device)
    
    t_train_torch = torch.tensor(np.float32(tTrainScaled)).to(device)
    t_scaled_torch = torch.tensor(np.float32(tAllScaled)).to(device)
    
    return y0_train_torch, y_train_torch, y_full_torch, t_train_torch, t_scaled_torch, sim_index

In [55]:
def get_batch(torch_train_data, torch_train_time, batch_time=5, batch_size=10):
    s = torch.from_numpy(np.random.choice(np.arange(len(torch_train_time) - batch_time, dtype=np.int64),
                                          batch_size,
                                          replace=False))
    batch_y0 = torch_train_data[s]  # (M, D)
    batch_t = torch.zeros((batch_size, batch_time))
    for i in range(batch_size):
        batch_t[i, :] = torch_train_time[s[i]:(s[i] + batch_time)]
        
    batch_y = torch.stack([torch_train_data[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [56]:
y0t, ytt, yft, ttt, tst, si = getDataForSim(ed_2161, augmented_r, sd_2161, 31)

In [57]:
# ytt.shape

In [58]:
by0, bt, by = get_batch(ytt, ttt)

In [59]:
by.shape

torch.Size([5, 10, 1, 169])

In [60]:
by[0, 0, :, :]

tensor([[ 9.4231,  9.4450,  9.4696,  9.4967,  9.5263,  9.5583,  9.5925,  9.6288,
          9.6673,  9.7078,  9.7501,  9.7943,  9.8402,  9.8877,  9.9367,  9.9872,
         10.0390, 10.0921, 10.1463, 10.2016, 10.2579, 10.3150, 10.3730, 10.4316,
         10.4909, 10.5507, 10.6103, 10.6702, 10.7304, 10.7907, 10.8509, 10.9110,
         10.9710, 11.0308, 11.0902, 11.1493, 11.2079, 11.2660, 11.3237, 11.3807,
         11.4371, 11.4928, 11.5479, 11.6022, 11.6557, 11.7083, 11.7601, 11.8110,
         11.8609, 11.9099, 11.9580, 12.0050, 12.0510, 12.0960, 12.1401, 12.1830,
         12.2248, 12.2656, 12.3053, 12.3437, 12.3810, 12.4170, 12.4516, 12.4849,
         12.5169, 12.5475, 12.5767, 12.6045, 12.6309, 12.6561, 12.6797, 12.7020,
         12.7229, 12.7425, 12.7607, 12.7777, 12.7935, 12.8081, 12.8212, 12.8330,
         12.8435, 12.8527, 12.8607, 12.8676, 12.8733, 12.8780, 12.8818, 12.8846,
         12.8865, 12.8877, 12.8880, 12.8875, 12.8862, 12.8840, 12.8809, 12.8770,
         12.8721, 12.8664, 1

In [61]:
# by.shape

In [62]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net1 = nn.Sequential(
            nn.Linear(169, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 169),
        )
        
        for m in self.net1.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)
                
    def forward(self, t, y):
        return self.net1(y)

In [63]:
class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.losses = []
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val
        self.losses.append(self.avg)

In [64]:
ii = 0
func = ODEFunc().to(device)
optimizer = optim.Adam(func.parameters(), lr=1e-3)
end = time.time()

niters=6000
test_freq=5

In [65]:
batch_size = 10
batch_time = 5

In [66]:
time_meter = RunningAverageMeter(0.97)
loss_meter = RunningAverageMeter(0.97)

if runModel:
    # step_loss = []
    for itr in range(1, niters + 1):
        optimizer.zero_grad()
        chosen_sim = np.random.choice(sd_train)
        y0_train_torch, y_train_torch, y_full_torch, t_train_torch, t_scaled_torch, sim_index = getDataForSim(ed_2161,
                                                                                                        augmented_r,
                                                                                                        sd_2161,
                                                                                                        chosen_sim)
        batch_y0, batch_t, batch_y = get_batch(y_train_torch, t_train_torch)
        pred_y = torch.zeros_like(batch_y)

        for i in range(batch_size):
            pred_y[:, i, :, :] = odeint(func, batch_y0[i, :, :], batch_t[i, :]).to(device)

        loss = torch.mean(torch.abs(pred_y - batch_y))
        loss.backward()
        optimizer.step()

        #     step_loss.append(loss.item())

        time_meter.update(time.time() - end)
        loss_meter.update(loss.item())

        if itr % test_freq == 0:
            with torch.no_grad():
                ## more radical defn would be loss on ENTIRE series?
                #             pred_y = odeint(func, y0_train_torch, t_train_torch)
                #             loss = torch.mean(torch.abs(pred_y - y_train_torch))
                pred_y_full_series = odeint(func, y0_train_torch, t_scaled_torch)
                #             print(pred_y_full_series.shape)
                loss_full_series = torch.mean(torch.abs(pred_y_full_series[:, :, :input_dim] 
                                                        - y_full_torch[:, :, :input_dim]))
                print("Iter {:04d} | Total Loss {:.6f} | Sim ID {:03d} ".format(itr,
                                                                                loss_full_series.item(), 
                                                                                chosen_sim))
                ii += 1

        end = time.time()

In [67]:
# torch.save(func, "dydt_Ay_global_model_5000_steps.pkl")

In [68]:
# torch.save(optimizer, "dydt_Ay_global_model_optimizer.pkl")

In [69]:
# func_saved = torch.load("dydt_Ay_global_model_5000_steps.pkl")

In [70]:
# optim_saved = torch.load("dydt_Ay_global_model_optimizer.pkl")

In [71]:
# func_saved.eval()

In [72]:
# optim_saved.state_dict

### Testing for full model

In [73]:
sd_test

array([ 40,  41,  47,  49,  57,  60,  65,  70,  82,  83,  85,  86,  87,
        91,  92,  93,  96, 110, 129, 138, 143, 153, 154, 190, 198, 203,
       206, 208, 227, 229, 231, 233, 237, 243, 247, 249, 257, 260, 261,
       266, 267, 269, 275, 277, 298, 299, 300, 301, 305, 316, 325, 327,
       328])

In [74]:
# y01, yt1, yf1, tt1, ts1, si1 = getDataForSim(ed_2161, augmented_r, sd_2161, 326)

In [75]:
# si1

In [76]:
# yf1.shape

In [77]:
# with torch.no_grad():
#     yt_pred_326 = odeint(func, y01, ts1)

In [78]:
# yt_pred_326.shape

In [79]:
# yf1_np = yf1.cpu().numpy()[:, 0, :input_dim]
# yt_pred_np = yt_pred_326.cpu().numpy()[:, 0, :input_dim]

In [80]:
# edut.plotTrainPredData(yf1_np, 
#                        yt_pred_np, 
#                        ed_2161, 
#                        sd_2161, 
#                        theta=np.linspace(-31, 82, 160),
#                        simIdx=si1, 
#                        savefig=False)

In [81]:
# edut.plotTrainPredData1Model(yf1_np,
#                              yt_pred_np,
#                              ed_2161,
#                              sd_2161,
#                              theta=np.linspace(-31, 82, 160), 
#                              simIdx=si1,
#                              savefig=False)

In [82]:
# y02, yt2, yf2, tt2, ts2, si2 = getDataForSim(ed_2161, augmented_r, sd_2161, 70)

In [83]:
# si2

In [84]:
# yt2.shape

In [85]:
# ts2.shape

In [86]:
# with torch.no_grad():
#     yt_pred_70 = odeint(func, y02, ts2)

In [87]:
# yf2_np = yf2.cpu().numpy()[:, 0, :input_dim]
# yt_pred2_np = yt_pred_70.cpu().numpy()[:, 0, :input_dim]

In [88]:
# edut.plotTrainPredData(yf2_np, 
#                        yt_pred2_np, 
#                        ed_2161, 
#                        sd_2161, 
#                        theta=np.linspace(-31, 82, 160),
#                        simIdx=si2, 
#                        savefig=False)

In [89]:
# edut.plotTrainPredData1Model(yf2_np,
#                              yt_pred2_np,
#                              ed_2161,
#                              sd_2161,
#                              theta=np.linspace(-31, 82, 160), 
#                              simIdx=si2,
#                              savefig=False)

In [90]:
save_test_preds=False

In [91]:
if save_test_preds:
    for ss in sd_test:
        y0_test, yt_test, yf_test, tt_test, ts_test, si_test = getDataForSim(ed_2161, augmented_r, sd_2161, ss)

        with torch.no_grad():
            yt_pred_test = odeint(func, y0_test, ts_test)

        y_test_np = yf_test.cpu().numpy()[:, 0, :input_dim]
        y_pred_np = yt_pred_test.cpu().numpy()[:, 0, :input_dim]

        edut.plotTrainPredData1Model(y_test_np,
                                 y_pred_np,
                                 ed_2161,
                                 sd_2161,
                                 theta=np.linspace(-31, 82, 160), 
                                 simIdx=si_test,
                                 savefig=True,
                                 savedir="./test_data_notebook_05_global_model/")

### Test Case with limited Sims

31, 32, 82?

In [92]:
test_case_sims = np.array([31, 32, 82])

In [93]:
y0_train_torch_31, y_train_torch_31, y_full_torch_31, t_train_torch_31, t_scaled_torch_31, sim_index_31 = getDataForSim(ed_2161,augmented_r,sd_2161,31)

y0_train_torch_32, y_train_torch_32, y_full_torch_32, t_train_torch_32, t_scaled_torch_32, sim_index_32 = getDataForSim(ed_2161,augmented_r,sd_2161,32)

y0_train_torch_82, y_train_torch_82, y_full_torch_82, t_train_torch_82, t_scaled_torch_82, sim_index_82 = getDataForSim(ed_2161,augmented_r,sd_2161,82)

In [100]:
y_full_torch_82.shape

torch.Size([68, 1, 169])

In [94]:
y_full_torch_82.cpu().numpy().shape

(68, 1, 169)

In [95]:
np.expand_dims(np.linspace(-31, 82, 160), axis=1).T.shape

(1, 160)

In [96]:
time_meter_test = RunningAverageMeter(0.97)
loss_meter_test = RunningAverageMeter(0.97)

In [97]:
sim_iters = 400 # we will train on each individual sim for 400 iterations, and see how the adaptation is working
                # when we move to the next sim. Save test plots for all sims when the 400 iters are complete!

In [98]:
testCaseSaveDir = "./test_case_adaptation_3sims"

In [99]:
if runTestCase:
    # step_loss = []
    for each_sim in test_case_sims:
        y0_train_torch, y_train_torch, y_full_torch, t_train_torch, t_scaled_torch, sim_index = getDataForSim(ed_2161,augmented_r,sd_2161,each_sim)
        for itr in range(1, sim_iters + 1):
            optimizer.zero_grad()
            #             chosen_sim = np.random.choice(sd_train)
            batch_y0, batch_t, batch_y = get_batch(y_train_torch, t_train_torch)
            pred_y = torch.zeros_like(batch_y)

            for i in range(batch_size):
                pred_y[:, i, :, :] = odeint(func, batch_y0[i, :, :], batch_t[i, :]).to(device)

            loss = torch.mean(torch.abs(pred_y - batch_y))
            loss.backward()
            optimizer.step()

            #     step_loss.append(loss.item())

            time_meter.update(time.time() - end)
            loss_meter.update(loss.item())

            #             if itr % test_freq == 0:
            with torch.no_grad():

                pred_y_full_series_31 = odeint(func, y0_train_torch_31, t_scaled_torch_31)
                pred_y_full_series_32 = odeint(func, y0_train_torch_32, t_scaled_torch_32)
                pred_y_full_series_82 = odeint(func, y0_train_torch_82, t_scaled_torch_82)
                
                loss1 = torch.mean(torch.abs(pred_y_full_series_31[:, :, :input_dim] - y_full_torch_31[:, :, :input_dim]))
                loss2 = torch.mean(torch.abs(pred_y_full_series_32[:, :, :input_dim] - y_full_torch_32[:, :, :input_dim]))
                loss3 = torch.mean(torch.abs(pred_y_full_series_82[:, :, :input_dim] - y_full_torch_82[:, :, :input_dim]))
                
                print("Iter {:04d} | Total Loss sim 31 sim 32 sim 82 {:.6f} {:.6f} {:.6f} ".format(itr,
                                                                                loss1.item(), 
                                                                                loss2.item(),
                                                                                loss3.item()
                                                                                ))
                if itr == sim_iters:
                    edut.plotTrainPredData1Model(y_full_torch_31.cpu().numpy()[:, 0, :input_dim],
                                 pred_y_full_series_31[:, 0, :input_dim],
                                 ed_2161,
                                 sd_2161,
                                 theta=np.linspace(-31, 82, 160), 
                                 simIdx=sim_index_31,
                                 savefig=True,
                                 savedir=testCaseSaveDir)
                    
                    edut.plotTrainPredData1Model(y_full_torch_32.cpu().numpy()[:, 0, :input_dim],
                                 pred_y_full_series_32[:, 0, :input_dim],
                                 ed_2161,
                                 sd_2161,
                                 theta=np.linspace(-31, 82, 160), 
                                 simIdx=sim_index_32,
                                 savefig=True,
                                 savedir=testCaseSaveDir)
                    
                    edut.plotTrainPredData1Model(y_full_torch_82.cpu().numpy()[:, 0, :input_dim],
                                 pred_y_full_series_82[:, 0, :input_dim],
                                 ed_2161,
                                 sd_2161,
                                 theta=np.linspace(-31, 82, 160), 
                                 simIdx=sim_index_82,
                                 savefig=True,
                                 savedir=testCaseSaveDir)

            end = time.time()

Iter 0001 | Total Loss sim 31 sim 32 sim 82 7.991735 2.368146 4.079865 
Iter 0002 | Total Loss sim 31 sim 32 sim 82 7.943471 2.320563 4.030716 
Iter 0003 | Total Loss sim 31 sim 32 sim 82 7.897148 2.275946 3.983633 
Iter 0004 | Total Loss sim 31 sim 32 sim 82 7.851968 2.233128 3.938531 
Iter 0005 | Total Loss sim 31 sim 32 sim 82 7.808712 2.191646 3.895600 
Iter 0006 | Total Loss sim 31 sim 32 sim 82 7.766994 2.151397 3.853577 
Iter 0007 | Total Loss sim 31 sim 32 sim 82 7.725617 2.111818 3.812057 
Iter 0008 | Total Loss sim 31 sim 32 sim 82 7.683292 2.072041 3.769743 
Iter 0009 | Total Loss sim 31 sim 32 sim 82 7.639496 2.031129 3.726472 
Iter 0010 | Total Loss sim 31 sim 32 sim 82 7.595599 1.990217 3.682751 
Iter 0011 | Total Loss sim 31 sim 32 sim 82 7.551115 1.948665 3.638293 
Iter 0012 | Total Loss sim 31 sim 32 sim 82 7.505651 1.906226 3.592832 
Iter 0013 | Total Loss sim 31 sim 32 sim 82 7.459363 1.863190 3.546540 
Iter 0014 | Total Loss sim 31 sim 32 sim 82 7.412303 1.819639 3.

Iter 0115 | Total Loss sim 31 sim 32 sim 82 1.971312 3.663073 1.947110 
Iter 0116 | Total Loss sim 31 sim 32 sim 82 1.923324 3.711409 1.995434 
Iter 0117 | Total Loss sim 31 sim 32 sim 82 1.875617 3.759492 2.043505 
Iter 0118 | Total Loss sim 31 sim 32 sim 82 1.828382 3.807082 2.091083 
Iter 0119 | Total Loss sim 31 sim 32 sim 82 1.781247 3.854689 2.138681 
Iter 0120 | Total Loss sim 31 sim 32 sim 82 1.735194 3.901264 2.185248 
Iter 0121 | Total Loss sim 31 sim 32 sim 82 1.689526 3.947502 2.231477 
Iter 0122 | Total Loss sim 31 sim 32 sim 82 1.644722 3.992830 2.276798 
Iter 0123 | Total Loss sim 31 sim 32 sim 82 1.600158 4.038020 2.321981 
Iter 0124 | Total Loss sim 31 sim 32 sim 82 1.555943 4.082773 2.366729 
Iter 0125 | Total Loss sim 31 sim 32 sim 82 1.512186 4.127039 2.410990 
Iter 0126 | Total Loss sim 31 sim 32 sim 82 1.468697 4.170946 2.454893 
Iter 0127 | Total Loss sim 31 sim 32 sim 82 1.425451 4.214592 2.498534 
Iter 0128 | Total Loss sim 31 sim 32 sim 82 1.382534 4.257959 2.

Iter 0229 | Total Loss sim 31 sim 32 sim 82 0.130484 5.601471 3.885381 
Iter 0230 | Total Loss sim 31 sim 32 sim 82 0.130045 5.602615 3.886525 
Iter 0231 | Total Loss sim 31 sim 32 sim 82 0.130317 5.601636 3.885546 
Iter 0232 | Total Loss sim 31 sim 32 sim 82 0.130474 5.600943 3.884854 
Iter 0233 | Total Loss sim 31 sim 32 sim 82 0.130584 5.599872 3.883783 
Iter 0234 | Total Loss sim 31 sim 32 sim 82 0.130724 5.598076 3.881986 
Iter 0235 | Total Loss sim 31 sim 32 sim 82 0.130982 5.596594 3.880504 
Iter 0236 | Total Loss sim 31 sim 32 sim 82 0.131674 5.593521 3.877431 
Iter 0237 | Total Loss sim 31 sim 32 sim 82 0.131901 5.591934 3.875843 
Iter 0238 | Total Loss sim 31 sim 32 sim 82 0.132100 5.590533 3.874442 
Iter 0239 | Total Loss sim 31 sim 32 sim 82 0.132466 5.588900 3.872808 
Iter 0240 | Total Loss sim 31 sim 32 sim 82 0.132172 5.589200 3.873109 
Iter 0241 | Total Loss sim 31 sim 32 sim 82 0.131783 5.589887 3.873795 
Iter 0242 | Total Loss sim 31 sim 32 sim 82 0.130590 5.593013 3.

Iter 0343 | Total Loss sim 31 sim 32 sim 82 0.129034 5.599046 3.882952 
Iter 0344 | Total Loss sim 31 sim 32 sim 82 0.130294 5.596498 3.880404 
Iter 0345 | Total Loss sim 31 sim 32 sim 82 0.131218 5.595346 3.879253 
Iter 0346 | Total Loss sim 31 sim 32 sim 82 0.132423 5.593097 3.877003 
Iter 0347 | Total Loss sim 31 sim 32 sim 82 0.133515 5.591214 3.875122 
Iter 0348 | Total Loss sim 31 sim 32 sim 82 0.134784 5.588063 3.871970 
Iter 0349 | Total Loss sim 31 sim 32 sim 82 0.136561 5.583786 3.867694 
Iter 0350 | Total Loss sim 31 sim 32 sim 82 0.137710 5.580819 3.864727 
Iter 0351 | Total Loss sim 31 sim 32 sim 82 0.139162 5.577162 3.861071 
Iter 0352 | Total Loss sim 31 sim 32 sim 82 0.141135 5.572346 3.856256 
Iter 0353 | Total Loss sim 31 sim 32 sim 82 0.141746 5.570694 3.854605 
Iter 0354 | Total Loss sim 31 sim 32 sim 82 0.142433 5.568964 3.852874 
Iter 0355 | Total Loss sim 31 sim 32 sim 82 0.141969 5.570020 3.853932 
Iter 0356 | Total Loss sim 31 sim 32 sim 82 0.140473 5.572813 3.

Iter 0056 | Total Loss sim 31 sim 32 sim 82 4.616851 1.153054 0.805450 
Iter 0057 | Total Loss sim 31 sim 32 sim 82 4.774428 1.019901 0.895074 
Iter 0058 | Total Loss sim 31 sim 32 sim 82 4.930349 0.891147 1.029503 
Iter 0059 | Total Loss sim 31 sim 32 sim 82 5.073948 0.780389 1.166254 
Iter 0060 | Total Loss sim 31 sim 32 sim 82 5.216002 0.674693 1.305814 
Iter 0061 | Total Loss sim 31 sim 32 sim 82 5.357314 0.579488 1.445771 
Iter 0062 | Total Loss sim 31 sim 32 sim 82 5.498801 0.507568 1.586529 
Iter 0063 | Total Loss sim 31 sim 32 sim 82 5.640064 0.472284 1.726768 
Iter 0064 | Total Loss sim 31 sim 32 sim 82 5.735385 0.464186 1.822758 
Iter 0065 | Total Loss sim 31 sim 32 sim 82 5.775557 0.454415 1.863187 
Iter 0066 | Total Loss sim 31 sim 32 sim 82 5.782638 0.436750 1.870260 
Iter 0067 | Total Loss sim 31 sim 32 sim 82 5.770497 0.412398 1.858106 
Iter 0068 | Total Loss sim 31 sim 32 sim 82 5.752674 0.385894 1.840273 
Iter 0069 | Total Loss sim 31 sim 32 sim 82 5.734885 0.359267 1.

Iter 0170 | Total Loss sim 31 sim 32 sim 82 5.703447 0.156873 1.790692 
Iter 0171 | Total Loss sim 31 sim 32 sim 82 5.700142 0.154716 1.787388 
Iter 0172 | Total Loss sim 31 sim 32 sim 82 5.696590 0.152727 1.783838 
Iter 0173 | Total Loss sim 31 sim 32 sim 82 5.694280 0.151453 1.781530 
Iter 0174 | Total Loss sim 31 sim 32 sim 82 5.691929 0.150453 1.779180 
Iter 0175 | Total Loss sim 31 sim 32 sim 82 5.688728 0.149546 1.775981 
Iter 0176 | Total Loss sim 31 sim 32 sim 82 5.685021 0.148578 1.772276 
Iter 0177 | Total Loss sim 31 sim 32 sim 82 5.683523 0.147940 1.770779 
Iter 0178 | Total Loss sim 31 sim 32 sim 82 5.683233 0.147157 1.770491 
Iter 0179 | Total Loss sim 31 sim 32 sim 82 5.683555 0.146438 1.770813 
Iter 0180 | Total Loss sim 31 sim 32 sim 82 5.684481 0.145928 1.771741 
Iter 0181 | Total Loss sim 31 sim 32 sim 82 5.683424 0.145342 1.770687 
Iter 0182 | Total Loss sim 31 sim 32 sim 82 5.682410 0.145151 1.769675 
Iter 0183 | Total Loss sim 31 sim 32 sim 82 5.681509 0.145043 1.

Iter 0284 | Total Loss sim 31 sim 32 sim 82 5.685022 0.153077 1.772260 
Iter 0285 | Total Loss sim 31 sim 32 sim 82 5.690001 0.152935 1.777238 
Iter 0286 | Total Loss sim 31 sim 32 sim 82 5.694265 0.152808 1.781502 
Iter 0287 | Total Loss sim 31 sim 32 sim 82 5.698155 0.152766 1.785393 
Iter 0288 | Total Loss sim 31 sim 32 sim 82 5.701324 0.153012 1.788561 
Iter 0289 | Total Loss sim 31 sim 32 sim 82 5.703718 0.153018 1.790957 
Iter 0290 | Total Loss sim 31 sim 32 sim 82 5.703290 0.152391 1.790531 
Iter 0291 | Total Loss sim 31 sim 32 sim 82 5.701565 0.151624 1.788808 
Iter 0292 | Total Loss sim 31 sim 32 sim 82 5.700186 0.151486 1.787433 
Iter 0293 | Total Loss sim 31 sim 32 sim 82 5.699418 0.151816 1.786667 
Iter 0294 | Total Loss sim 31 sim 32 sim 82 5.697819 0.152342 1.785071 
Iter 0295 | Total Loss sim 31 sim 32 sim 82 5.696431 0.153337 1.783686 
Iter 0296 | Total Loss sim 31 sim 32 sim 82 5.696091 0.154446 1.783348 
Iter 0297 | Total Loss sim 31 sim 32 sim 82 5.696608 0.155843 1.

Iter 0398 | Total Loss sim 31 sim 32 sim 82 5.693548 0.148557 1.780811 
Iter 0399 | Total Loss sim 31 sim 32 sim 82 5.697878 0.149341 1.785137 
Iter 0400 | Total Loss sim 31 sim 32 sim 82 5.702185 0.150591 1.789440 
Saved image for Sim 031
Saved image for Sim 032
Saved image for Sim 082
Iter 0001 | Total Loss sim 31 sim 32 sim 82 5.694820 0.148434 1.782078 
Iter 0002 | Total Loss sim 31 sim 32 sim 82 5.676953 0.145659 1.764219 
Iter 0003 | Total Loss sim 31 sim 32 sim 82 5.649508 0.148155 1.736788 
Iter 0004 | Total Loss sim 31 sim 32 sim 82 5.613467 0.160890 1.700768 
Iter 0005 | Total Loss sim 31 sim 32 sim 82 5.570153 0.184111 1.657479 
Iter 0006 | Total Loss sim 31 sim 32 sim 82 5.521385 0.215922 1.608741 
Iter 0007 | Total Loss sim 31 sim 32 sim 82 5.469225 0.253720 1.556617 
Iter 0008 | Total Loss sim 31 sim 32 sim 82 5.414633 0.296050 1.502062 
Iter 0009 | Total Loss sim 31 sim 32 sim 82 5.357544 0.342595 1.445017 
Iter 0010 | Total Loss sim 31 sim 32 sim 82 5.297568 0.393212 1.

Iter 0111 | Total Loss sim 31 sim 32 sim 82 3.888179 1.742327 0.085041 
Iter 0112 | Total Loss sim 31 sim 32 sim 82 3.886242 1.744264 0.086305 
Iter 0113 | Total Loss sim 31 sim 32 sim 82 3.885092 1.745416 0.087680 
Iter 0114 | Total Loss sim 31 sim 32 sim 82 3.883594 1.746915 0.088610 
Iter 0115 | Total Loss sim 31 sim 32 sim 82 3.882703 1.747808 0.089705 
Iter 0116 | Total Loss sim 31 sim 32 sim 82 3.884849 1.745661 0.091210 
Iter 0117 | Total Loss sim 31 sim 32 sim 82 3.885813 1.744697 0.093383 
Iter 0118 | Total Loss sim 31 sim 32 sim 82 3.886710 1.743799 0.095810 
Iter 0119 | Total Loss sim 31 sim 32 sim 82 3.887153 1.743356 0.097372 
Iter 0120 | Total Loss sim 31 sim 32 sim 82 3.888268 1.742241 0.099114 
Iter 0121 | Total Loss sim 31 sim 32 sim 82 3.889084 1.741427 0.100633 
Iter 0122 | Total Loss sim 31 sim 32 sim 82 3.890745 1.739767 0.102049 
Iter 0123 | Total Loss sim 31 sim 32 sim 82 3.891754 1.738760 0.103023 
Iter 0124 | Total Loss sim 31 sim 32 sim 82 3.892252 1.738264 0.

Iter 0225 | Total Loss sim 31 sim 32 sim 82 3.910974 1.719549 0.094400 
Iter 0226 | Total Loss sim 31 sim 32 sim 82 3.907974 1.722553 0.094847 
Iter 0227 | Total Loss sim 31 sim 32 sim 82 3.906042 1.724489 0.095972 
Iter 0228 | Total Loss sim 31 sim 32 sim 82 3.903499 1.727035 0.096314 
Iter 0229 | Total Loss sim 31 sim 32 sim 82 3.899625 1.730914 0.096862 
Iter 0230 | Total Loss sim 31 sim 32 sim 82 3.895198 1.735343 0.096972 
Iter 0231 | Total Loss sim 31 sim 32 sim 82 3.892093 1.738452 0.097437 
Iter 0232 | Total Loss sim 31 sim 32 sim 82 3.889731 1.740815 0.097628 
Iter 0233 | Total Loss sim 31 sim 32 sim 82 3.888159 1.742386 0.097696 
Iter 0234 | Total Loss sim 31 sim 32 sim 82 3.886501 1.744042 0.097993 
Iter 0235 | Total Loss sim 31 sim 32 sim 82 3.886344 1.744195 0.097600 
Iter 0236 | Total Loss sim 31 sim 32 sim 82 3.887137 1.743400 0.096537 
Iter 0237 | Total Loss sim 31 sim 32 sim 82 3.887199 1.743335 0.095107 
Iter 0238 | Total Loss sim 31 sim 32 sim 82 3.886932 1.743599 0.

Iter 0339 | Total Loss sim 31 sim 32 sim 82 3.894839 1.735680 0.109159 
Iter 0340 | Total Loss sim 31 sim 32 sim 82 3.891579 1.738941 0.106894 
Iter 0341 | Total Loss sim 31 sim 32 sim 82 3.887405 1.743119 0.104711 
Iter 0342 | Total Loss sim 31 sim 32 sim 82 3.883728 1.746801 0.102907 
Iter 0343 | Total Loss sim 31 sim 32 sim 82 3.879306 1.751229 0.101451 
Iter 0344 | Total Loss sim 31 sim 32 sim 82 3.872936 1.757606 0.100422 
Iter 0345 | Total Loss sim 31 sim 32 sim 82 3.867098 1.763453 0.100546 
Iter 0346 | Total Loss sim 31 sim 32 sim 82 3.861417 1.769142 0.100997 
Iter 0347 | Total Loss sim 31 sim 32 sim 82 3.857580 1.772985 0.101857 
Iter 0348 | Total Loss sim 31 sim 32 sim 82 3.853691 1.776875 0.103157 
Iter 0349 | Total Loss sim 31 sim 32 sim 82 3.850956 1.779610 0.104191 
Iter 0350 | Total Loss sim 31 sim 32 sim 82 3.850828 1.779738 0.103942 
Iter 0351 | Total Loss sim 31 sim 32 sim 82 3.850874 1.779693 0.104024 
Iter 0352 | Total Loss sim 31 sim 32 sim 82 3.852044 1.778521 0.