# MDOF PINN - instance prediction

## Problem overview

The example problem we solve here is the 3DOF nonlinear-stiffness oscillator defined in state space:
$$
\dot{\mathbf{z}} = \mathbf{A}\mathbf{z} + \mathbf{A}_n\mathbf{z}_n + \mathbf{H}\mathbf{f}
$$
where,
$$
\mathbf{z} = \left\{ x_1, x_2, ... , x_n, \dot{x}_1, \dot{x}_2, ... , \dot{x}_n \right\}^T, \quad
\mathbf{f} = \left\{ f_1, f_2, ... , f_n \right\}^T
$$
and $\mathbf{z}_n$ is the nonlinear state vector.
$$
\mathbf{A} = \begin{bmatrix} 0 & \mathbf{I} \\ -\mathbf{M}^{-1}\mathbf{K} & -\mathbf{M}^{-1}\mathbf{C} \end{bmatrix}, \quad
\mathbf{A}_n = \begin{bmatrix} 0 \\ -\mathbf{M}^{-1} \mathbf{K}_n \end{bmatrix}, \quad
\mathbf{H} = \begin{bmatrix} 0 \\ \mathbf{M}^{-1} \end{bmatrix}
$$
with the initial conditions
$$
\mathbf{x}(0) = \mathbf{x}_0~~,~~\dot{\mathbf{x}}(0) = \dot{\mathbf{x}}_0
$$

As an example, for a 3DOF system with cubic nonlinearities, fixed at the first degree of freedom:
$$
\mathbf{z}_n = g_n(\mathbf{z}) = \left\{ x_1^3, (x_2-x_1)^3, (x_3-x_2)^3 \right\}^T, \quad
\mathbf{K}_n = \begin{bmatrix} k_{n,1} & -k_{n,2} & 0 \\ 0 & k_{n,2} & -k_{n,3} \\ 0 & 0 & k_{n,3} \end{bmatrix}
$$

In [None]:
# add parent directory to path
import sys
sys.path.append('..')

from mdof_pinn_batch import mdof_pinn_stoch, mdof_dataset, ParamClipper
import dynasim
from mdof_solutions import gen_ndof_cantilever, add_noise

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from scipy.stats import qmc

import string
import findiff

from tqdm import tqdm
from tqdm.auto import tqdm as tqdma

import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [None]:
nt = 512
time = np.linspace(0,30,nt)

F0 = 1.0  # N
n_dof = 2

# set physical parameters
k1 = 10.0
c1 = 0.25
m1 = 1.0
kn_ = np.zeros((n_dof))
kn_[0] = 100.0

# create nonlinearity
cubic_nonlin = dynasim.nonlinearities.exponent_stiffness(kn_, exponent=3, dofs=n_dof)

# instantiate system
system = dynasim.systems.cantilever(m1, c1, k1, dofs=n_dof, nonlinearity=cubic_nonlin)

true_params = {
    'm_' : system.m_,
    'c_' : system.c_,
    'k_' : system.k_,
    'kn_' : kn_
}

# generate excitations
# system.excitations = [
#     dynasim.actuators.rand_phase_ms(
#         freqs = np.array([0.7, 0.85, 1.6, 1.8]),
#         Sx = np.ones(4)
#     ), None]
system.excitations = [
    dynasim.actuators.sine_sweep(
        w_l = 0.7,
        w_u = 4.0,
        F0 = 1.0
    ), None]

x0 = np.array([-2.0, 0.0, 3.0])
v0 = np.array([-2.0, 0.0, 0.0])
z0 = np.concatenate((x0, v0), axis=0)

data = system.simulate(time, z0=None)

t_span = time.reshape(-1,1)
xx, vv = data['x'].T, data['xdot'].T
f = system.f.T
xx_noisy, vv_noisy = add_noise(xx, db=-30, seed=43810), add_noise(vv, db=-30, seed=13927)
f_noisy = add_noise(f, db=-30, seed=1234)

ground_truth = {
    "t" : t_span,
    "x_hat" : xx,
    "v_hat" : vv,
    "f_hat" : f
}

if n_dof > 4:
    sub_rows = n_dof // 4 + int((n_dof%4)!=0)
    sub_cols = 4
else:
    sub_rows = 1
    sub_cols = n_dof

fig, axs = plt.subplots(3*sub_rows,sub_cols,figsize=(8*sub_cols, 8*sub_rows))
p_count = 0
for j in range(sub_rows):
    for i in range(sub_cols):
        axs[j*3,i].plot(time, xx[:,p_count], color="tab:blue", label="Displacement", linewidth=1.0, linestyle='--')
        # axs[j*3,i].grid()
        axs[j*3,i].legend()

        axs[j*3+1,i].plot(time, vv[:,p_count], color="tab:red", label="Velocity", linewidth=1.0, linestyle='--')
        # axs[j*3+1,i].grid()
        axs[j*3+1,i].legend()

        axs[j*3+2,i].plot(time, f[:,p_count], color="tab:gray", label="Forcing", linewidth=1.0, linestyle='--')
        axs[j*3+2,i].legend()
        
        p_count += 1

        if p_count == n_dof:
            break

Normalise and create some plotting functions

In [None]:
batch_size = 1
subsample = 1
dataset = mdof_dataset(xx_noisy, vv_noisy, f_noisy, t_span, subsample)
phases = ['full', 'train', 'val']
full_dataset = torch.utils.data.random_split(dataset, [1.0])
train_size = 1.0; val_size = 0.0
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)#, multiprocessing_context='fork')
# datasets = {
#     'full' : full_dataset,
#     'train' : train_dataset,
#     'val' : val_dataset
# }
# dataloaders = {
#     phase: DataLoader(dataset=datasets[phase], batch_size=batch_size, shuffle=True if phase=='train' else False, num_workers=8) for phase in phases
# }

fig, axs = plt.subplots(3*sub_rows,sub_cols,figsize=(8*sub_cols,8*sub_rows))
p_count = 0
for j in range(sub_rows):
    for i in range(sub_cols):
        axs[j*3,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['x_hat'][:,i]/dataset.alphas['x'], color="tab:blue", linewidth=1.0, linestyle='-')
        # axs[j*3,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['x'][:,i]/dataset.alphas['x']-dataset.data[:, i], color="tab:blue", linewidth=1.0, linestyle='-')
        axs[j*3,i].scatter(dataset.data[:, -1], dataset.data[:, i], color="tab:blue", s=6)
        axs[j*3,i].grid()

        axs[j*3+1,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['v_hat'][:,i]/dataset.alphas['v'], color="tab:red", linewidth=1.0, linestyle='-')
        # axs[j*3+1,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['v'][:,i]/dataset.alphas['v']-dataset.data[:, n_dof + i], color="tab:red", linewidth=1.0, linestyle='-')
        axs[j*3+1,i].scatter(dataset.data[:, -1], dataset.data[:, n_dof + i], color="tab:red", s=6)
        axs[j*3+1,i].grid()

        axs[j*3+2,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['f_hat'][:,i]/dataset.alphas['f'], color="tab:gray", linewidth=1.0, linestyle='-')
        # axs[j*3+2,i].plot(ground_truth['t'][:,0]/dataset.alphas['t'], ground_truth['f'][:,i]/dataset.alphas['f']-dataset.data[:, 2 * n_dof + i], color="tab:gray", linewidth=1.0, linestyle='-')
        axs[j*3+2,i].scatter(dataset.data[:, -1], dataset.data[:, 2 * n_dof + i], color="tab:gray", s=6)
        axs[j*3+2,i].grid()

        p_count += 1
        if p_count == n_dof:
            break

In [None]:
alphabet = list(string.ascii_uppercase)
mosaic_key = ''
alph_count = 0
for j in range(sub_rows):
    mosaic_key += ''.join(alphabet[alph_count:alph_count+sub_cols]) + ';' + ''.join(alphabet[alph_count+sub_cols:alph_count+2*sub_cols]) + ';' + ''.join(alphabet[alph_count+2*sub_cols:alph_count+3*sub_cols]) + ';'
    alph_count += 3*sub_cols
mosaic_key += ''.join([alphabet[alph_count]]*sub_cols)
# print(mosaic_key)

def plot_joint_loss_hist(ax,loss_hist):
    n_epoch = len(loss_hist)
    indices = np.arange(1,n_epoch+1)
    if n_epoch > 20000:
        step = int(np.floor(n_epoch/10000))
        loss_hist = loss_hist[::step,:]
        indices = indices[::step]
    labels = ["L_obs", "L_cc", "L_ode", "L"]
    colors = ["tab:blue", "tab:red", "tab:green", "black"]
    ax.cla()
    for i in range(len(labels)):
        ax.plot(indices, loss_hist[:,i], color=colors[i], label=labels[i])
    ax.set_yscale('symlog')
    ax.legend()

def plot_result(axs_m, ground_truth, data, prediction, alphas):
    for ax in axs_m:
        axs_m[ax].cla()
    axs_top_list = []
    for j in range(sub_rows):
        axs_top_list.append([axs_m[alphabet[3*sub_cols*j+i]] for i in range(sub_cols)])
        axs_top_list.append([axs_m[alphabet[3*sub_cols*j+sub_cols+i]] for i in range(sub_cols)])
        axs_top_list.append([axs_m[alphabet[3*sub_cols*j+2*sub_cols+i]] for i in range(sub_cols)])
    axs_top = np.array(axs_top_list)

    plot_keys = ["x_hat", "v_hat", "f_hat"]
    plot_cols = ["tab:blue", "tab:red", "tab:gray"]
    alpha_keys = ["x", "v", "f"]
    p_count = 0
    for j in range(sub_rows):
        for i in range(sub_cols):
            for n in range(3):
                # axs_top[j*3+n,i].plot(data["t_hat"].detach()*alphas["t"], data[plot_keys[n]][:,p_count].detach()*alphas[alpha_keys[n]], color="tab:olive", linewidth=1, alpha=0.8, label='Training data')
                axs_top[j*3+n,i].plot(ground_truth["t"], ground_truth[plot_keys[n]][:,p_count], color="grey", linewidth=2, alpha=0.5, label="Exact solution")
                axs_top[j*3+n,i].plot(prediction["t_hat"]*alphas["t"], prediction[plot_keys[n]][:,p_count]*alphas[alpha_keys[n]], color=plot_cols[n], linewidth=2, alpha=0.8, linestyle='--', label="Neural network prediction")
                if n < 2:
                    axs_top[j*3+n,i].fill_between((prediction["t_hat"]*alphas["t"]).squeeze(), (prediction[plot_keys[n]][:,p_count]-2*prediction['sigma'])*alphas[alpha_keys[n]], (prediction[plot_keys[n]][:,p_count]+2*prediction['sigma'])*alphas[alpha_keys[n]], alpha=0.25, color="tab:blue", label=r"$2\sigma$ Range")
                else:
                    axs_top[j*3+n,i].fill_between((prediction["t_hat"]*alphas["t"]).squeeze(), (prediction[plot_keys[n]][:,p_count]-2*prediction['sigma_s'])*alphas[alpha_keys[n]], (prediction[plot_keys[n]][:,p_count]+2*prediction['sigma_s']), alpha=0.25, color="tab:blue", label=r"$2\sigma$ Range")
                xL = np.amax(ground_truth["t"])
                yL = np.amax(np.abs(ground_truth[plot_keys[n]][:,p_count]))
                axs_top[j*2+n,i].set_xlim(-0.05*xL, 1.05*xL)
                axs_top[j*2+n,i].set_ylim(-1.1*yL, 1.1*yL)

            p_count += 1
            if p_count == n_dof:
                break

prediction = {
    "t_hat" : time,
    "x_hat" : None,
    "v_hat" : None,
    "F_hat" : None
}

In [None]:
def sort_data(vec2sort: np.ndarray, *data_: tuple[np.ndarray,...]):
    sort_ids = np.argsort(vec2sort)
    sorted_data_ = [None] * len(data_)
    for i, data in enumerate(data_):
        sorted_data_[i] = np.zeros_like(data)
        if len(data.shape) > 1:
            for j in range(data.shape[1]):
                sorted_data_[i][:,j] = data[sort_ids,j].squeeze()
        else:
            sorted_data_[i] = data[sort_ids]
    if len(data_) > 1:
        return tuple(sorted_data_), sort_ids
    else:
        return sorted_data_[0], sort_ids

## PINN

Neural network predicts over the full time domain:
$$
\mathcal{N}_{\mathbf{z}}(t), \qquad 
\mathbf{R} = \partial_t \mathcal{N}_{\mathbf{z}} - \mathbf{A} \mathcal{N}_{\mathbf{z}} - \mathbf{A}_n \mathcal{N}_{\mathbf{z}_n} - \mathbf{H}\mathbf{f}
$$

$$
\mathcal{L}(t;\mathbf{\theta}) := \mathcal{L}_{obs} + \mathcal{L}_{ic} + \Lambda\mathcal{L}_{ode}
$$

$$
\mathcal{L}_{obs} = \langle \hat{\mathbf{z}}^* - \mathcal{N}_{\hat{\mathbf{z}}} \rangle _{\Omega_o}
$$
$$
\mathcal{L}_{obs} = \prod_{i=1}^{N} -\frac{1}{\sqrt{2\pi\sigma^2}}\exp\left(-\frac{1}{2}\frac{||\hat{\mathbf{z}}^* - \mathcal{N}_{\hat{\mathbf{z}}}||^2}{\sigma^2}\right)
$$
in the log space
$$
\mathcal{L}_{obs} = -N\log(\sigma) - \frac{N}{2}\log(2\pi) - \frac{1}{2}\sum_{i=1}^{N} \frac{||\hat{\mathbf{z}}^* - \mathcal{N}_{\hat{\mathbf{z}}}||^2}{\sigma^2}
$$
<!-- $$
\mathcal{L}_{ic} = \sum_{j=1}^{N_{d}}\left[ 
\left\langle \alpha_{\dot{x}}\hat{\dot{x}}_{j,0} - \frac{\alpha_x}{\alpha_t}\partial_{\hat{t}}\mathcal{N}_{\hat{x}_j} \right\rangle ~~ + ~~
\left\langle \alpha_{x}\hat{x}_{j,0} - \alpha_x\mathcal{N}_{\hat{x}_j} \right\rangle ~~ + ~~
\left\langle \alpha_{\dot{x}}\hat{\dot{x}}_{j,0} - \alpha_{\dot{x}}\mathcal{N}_{\hat{\dot{x}}_j} \right\rangle
\right] _{\Omega\in\{t=0\}}
$$ -->
$$
\mathcal{L}_{cc} = \sum_{j=1}^{N_{d}} \left\langle \mathbf{R}[j,:] \right\rangle _{\Omega_p}, \qquad
\mathcal{L}_{ode} = \sum_{j=1}^{N_{d}} \left\langle \mathbf{R}[N_d+j,:] \right\rangle _{\Omega_p}
$$
where,
$$ \mathcal{N}_{\bullet} = \mathcal{N}_{\bullet}(\mathbf{z};\mathbf{\theta}), \qquad 
\partial_{*}\bullet = \frac{\partial\bullet}{\partial *}, \qquad 
\partial^2_{*}\bullet = \frac{\partial^2\bullet}{\partial *^2}, \qquad
\langle\bullet\rangle _{\Omega_{\kappa}} = \frac{1}{N_{\kappa}}\sum_{t\in\Omega_{\kappa}}\left|\left|\bullet\right|\right|^2 $$

ODE loss function comes from including the normalisation of the parameters, then choosing the suitable range to aid optimisation.

$$
\frac{1}{\alpha_t^2} \partial^2_{\hat{t}}\hat{x} + 
\tilde{c}\frac{1}{\alpha_t}\partial_{\hat{t}}\hat{x} + 
\tilde{k} \hat{x} - 
\frac{\alpha_F}{\alpha_x} \hat{F} = 0 
\quad \rightarrow \quad 
\hat{m} \partial^2_{\hat{t}}\hat{x} + 
\hat{c} \partial_{\hat{t}}\hat{x} + 
\hat{k}\hat{x} - \eta\hat{F} = 0
$$
To scale loss function in a physically meaningful way, multiply the loss function by any of the following:
$$
\Lambda = 1, \alpha_t, \alpha_t^2, \alpha_x^, \alpha_F^{-1}
$$

In [None]:
alphas = {
    "c" : 1.0,
    "k" : 1.0,
    "kn" : 1.0,
    "sigma" : 1.0,
    "sigma_s" : 1.0
}
alphas.update(dataset.alphas)

nct = nt  # number of collocation points

torch.manual_seed(123)

pinn_config = {
    "n_input" : 1,
    "n_output" : 2*n_dof,
    "n_hidden" : 16,
    "n_layers" : 4,
    "n_dof" : n_dof,
    "nct" : nct,
    "nonlinearity" : "cubic",
    "phys_params" : {
        "m_" : {
            "type" : "constant",
            "value" : torch.tensor(system.m_, dtype=torch.float32)
        },
        "c_" : {
            "type" : "constant",
            "value" : torch.tensor(system.c_, dtype=torch.float32)
        },
        "k_" : {
            "type" : "constant",
            "value" : torch.tensor(system.k_, dtype=torch.float32)
        },
        "kn_" : {
            "type" : "constant",
            "value" : torch.tensor(kn_, dtype=torch.float32)
        },
        "sigma_" : {
            "type" : "variable",
            "value" : torch.tensor(1.0, dtype=torch.float32)
        },
        "sigma_s_" : {
            "type" : "variable",
            "value" : torch.tensor(1.0, dtype=torch.float32)
        },
    },
    "param_func" : gen_ndof_cantilever,
    "alphas" : alphas,
    "forcing" : f
}

device = torch.device("cpu")
# configure PINN
mdof_model = mdof_pinn_stoch(pinn_config, device)
mdof_model = mdof_model.to(device)

# configure optimiser
learning_rate = 2.5e-3
betas = (0.99,0.999)
optimizer = torch.optim.Adam(mdof_model.parameters(), lr=learning_rate, betas=betas)

clipper = ParamClipper()

fig, axs = plt.subplot_mosaic(
    mosaic_key,
    figsize=(18,16),
    facecolor='w'
)
axs2 = axs.copy()
for key, ax in axs.items():
    axs2[key] = ax.twinx()

print_step = 200
loss_hist=[]
lambds = {
    'obs' : 1.0e-3,
    'ic' : 0.0,
    'ode' : 0.0e-4,
    'cc' : 0.0
}

mdof_model.set_switches(lambds)
# compiled_model = mdof_model.to(device)
# compiled_model = torch.compile(mdof_model, mode="reduce-overhead")
# compiled_model = torch.compile(mdof_model, mode="max-autotune").to(device)

num_obs_samps = len(train_dataset)
num_col_samps = len(train_dataset) * subsample
z_pred = np.zeros((num_col_samps, 2*n_dof))
f_pred = np.zeros((num_col_samps, n_dof))
t_pred = np.zeros((num_col_samps, 1))

epochs = int(2e6)
epoch = 0
progress_bar = tqdm(total=epochs)
phases = ['train', 'val']
while epoch < epochs:
    write_string = ''
    write_string += 'Epoch {}\n'.format(epoch)
    phase_loss = 0.
    losses = [0.0] * 3
    for i, (obs_data, col_data) in enumerate(train_loader):
        # parse data sample
        state_obs = obs_data[..., :2*n_dof].float().to(device).requires_grad_()
        time_obs = obs_data[..., -1].reshape(-1,1).float().to(device).requires_grad_()

        force_col = col_data[..., 2*n_dof:3*n_dof].float().to(device).requires_grad_()
        force_col = force_col.reshape(-1, n_dof)  # unroll collocation data
        time_col = col_data[..., -1].float().to(device).requires_grad_()
        time_col = time_col.reshape(-1, 1)  # unroll collocation data

        optimizer.zero_grad()
        loss, losses_i, _ = mdof_model.loss_func(lambds, time_obs, state_obs, time_col, force_col)
        phase_loss += loss.item()
        losses = [losses[j] + loss_i for j, loss_i in enumerate(losses_i)]
        loss.backward()
        optimizer.step()
    loss_hist.append([loss_it.item() for loss_it in losses] + [phase_loss])
    write_string += '\tLoss {:.4e}\n'.format(phase_loss)

    for phase in phases:
        a = 0
        # phase_loss = 0.
        # losses = [0.0] * 3
        # write_string += '\tPhase {}\n'.format(phase)
        # if phase == 'train':
        #     compiled_model.train()
        # else:
        #     compiled_model.eval()
        # for i, (obs_data, col_data) in enumerate(dataloaders[phase]):
        #     # parse data sample
        #     state_obs = obs_data[..., :2*n_dof].to(device).float().requires_grad_()
        #     time_obs = obs_data[..., -1].reshape(-1,1).to(device).float().requires_grad_()

        #     force_col = col_data[..., 2*n_dof:3*n_dof].to(device).float().requires_grad_()
        #     force_col = force_col.reshape(-1, n_dof)  # unroll collocation data
        #     time_col = col_data[..., -1].to(device).float().requires_grad_()
        #     time_col = time_col.reshape(-1, 1)  # unroll collocation data

        #     if phase == 'train':
        #         optimizer.zero_grad()
        #     loss, losses_i, _ = compiled_model.loss_func(lambds, time_obs, state_obs, time_col, force_col)
        #     phase_loss += loss.item()
        #     losses = [losses[j] + loss_i for j, loss_i in enumerate(losses_i)]
        #     if phase == 'train':
        #         loss.backward()
        #         optimizer.step()
        # if phase == 'train':
        #     loss_hist.append([loss_it.item() for loss_it in losses] + [phase_loss])
        # write_string += '\tLoss {:.4e}\n'.format(phase_loss)
    
    if (epoch+1) % print_step == 0:

        for i, (obs_data, col_data) in enumerate(train_loader):

            inpoint_ = i * batch_size * subsample
            outpoint_ = (i + 1) * batch_size * subsample

            t_col = col_data[..., -1].to(device).float().requires_grad_()
            pred_inputs = t_col.reshape(-1, 1)
            z_pred_, f_pred_ = mdof_model.predict(pred_inputs)
            t_pred[inpoint_:outpoint_] = pred_inputs.detach().cpu().numpy()
            z_pred[inpoint_:outpoint_, :], f_pred[inpoint_:outpoint_, :] = z_pred_.detach().cpu().reshape(-1, 2*n_dof).numpy(), f_pred_.detach().cpu().reshape(-1, n_dof).numpy()
        
        (z_pred, f_pred, t_pred), _ = sort_data(t_pred[:,0], z_pred, f_pred, t_pred)

        prediction['t_hat'] = t_pred
        prediction["x_hat"] = z_pred[:,:n_dof]
        prediction["v_hat"] = z_pred[:,n_dof:]
        prediction["f_hat"] = f_pred
        prediction['sigma'] = mdof_model.sigma_.detach().item()
        prediction['sigma_s'] = mdof_model.sigma_s_.detach().item()

        plot_result(axs, ground_truth, data, prediction, alphas)#, residuals['R_ode'], axs2)

        plot_joint_loss_hist(axs[alphabet[3*n_dof]], np.array(loss_hist))

        display.clear_output(wait=True)
        display.display(plt.gcf())
        write_string += 'sigma_o: {:.4e} ---- sigma_p: {:.4e}\n'.format(mdof_model.sigma_.detach(), mdof_model.sigma_s_.detach())
        write_string += 'c :                       k :                     kn : \n'
        for j in range(n_dof):
            wri_str = '{} : '.format(j+1)
            for param in ['c_','k_','kn_']:
                if pinn_config['phys_params'][param]['type']=='constant':
                    wri_str += '{:.4f} '.format(getattr(mdof_model,param)[j])
                elif pinn_config['phys_params'][param]['type']=='variable':
                    wri_str += '{:.4f} '.format(getattr(mdof_model,param)[j]*alphas[param[:-1]])
                wri_str += '[{:.4f}]       '.format(true_params[param][j])
            wri_str + '\n'
        
        tqdma.write(write_string)
    
    epoch += 1
    progress_bar.update(1)

In [None]:
checkpoint = {
    'config' : pinn_config,
    'data' : data,
    'prediction' : prediction,
    'ground_truth' : ground_truth,
    'alphas' : alphas,
    'epoch' : i,
    'model' : mdof_model.state_dict(),
    'optimizer' : optimizer.state_dict(),
    'loss' : loss_hist,
    'true_params' : true_params,
}
# torch.save(checkpoint,'checkpoints/mdof_stoch_instance_batch.pth')