# Tutorial for 3DVar and 4DVar calculation

### import `deepda` package and example function `forwardModel_r`

In [1]:
import torch
import deepda
from examples.forwardModel import forwardModel_r
from math import ceil

device = "cpu"

### 3DVar Example
#### preparing required parameters

In [2]:
def H(x: torch.Tensor):
    return x

In [3]:
B = torch.eye(3, device=device)
R = torch.eye(3, device=device)
y = torch.tensor([10., 20., 30.], device=device)
xb = torch.zeros_like(y, device=device)

#### chain call supported

In [4]:
# apply_3DVar(H, B, R, xb, y, learning_rate=2, max_iterations=300)
deepda.CaseBuilder().set_background_covariance_matrix(
    B
).set_observation_covariance_matrix(R).set_observations(
    y
).set_background_state(
    xb
).set_algorithm(
    deepda.Algorithms.Var3D
).set_device(
    deepda.Device.CPU
).set_observation_model(
    H
).set_learning_rate(
    2
).set_max_iterations(
    300
).execute()

Iterations: 0, J: 1400.0, Norm of J gradient: 74.83314514160156
Iterations: 1, J: 1184.0, Norm of J gradient: 62.22539520263672
Iterations: 2, J: 1017.4765014648438, Norm of J gradient: 50.39654541015625
Iterations: 3, J: 898.08154296875, Norm of J gradient: 39.807682037353516
Iterations: 4, J: 818.659912109375, Norm of J gradient: 30.810375213623047
Iterations: 5, J: 768.9251708984375, Norm of J gradient: 23.481937408447266
Iterations: 6, J: 740.5262451171875, Norm of J gradient: 18.005840301513672
Iterations: 7, J: 728.07568359375, Norm of J gradient: 14.986849784851074
Iterations: 8, J: 727.2283935546875, Norm of J gradient: 14.758960723876953
Iterations: 9, J: 733.8892211914062, Norm of J gradient: 16.46552276611328
Iterations: 10, J: 744.373046875, Norm of J gradient: 18.841018676757812
Iterations: 11, J: 755.5263061523438, Norm of J gradient: 21.076297760009766
Iterations: 12, J: 764.700927734375, Norm of J gradient: 22.75099754333496
Iterations: 13, J: 769.8867797851562, Norm of

{'assimilated_background_state': tensor([ 5.0000, 10.0000, 15.0000]),
 'intermediate_results': {'J': [1400.0,
   1184.0,
   1017.4765014648438,
   898.08154296875,
   818.659912109375,
   768.9251708984375,
   740.5262451171875,
   728.07568359375,
   727.2283935546875,
   733.8892211914062,
   744.373046875,
   755.5263061523438,
   764.700927734375,
   769.8867797851562,
   770.0601806640625,
   765.42431640625,
   757.24609375,
   747.353759765625,
   737.554443359375,
   729.1776123046875,
   722.83984375,
   718.4612426757812,
   715.5013427734375,
   713.30810546875,
   711.4209594726562,
   709.7034912109375,
   708.290771484375,
   707.4234619140625,
   707.2655029296875,
   707.782470703125,
   708.7230224609375,
   709.6989135742188,
   710.3231201171875,
   710.3370971679688,
   709.6754150390625,
   708.4522094726562,
   706.8919677734375,
   705.2422485351562,
   703.7081298828125,
   702.4252319335938,
   701.469482421875,
   700.8795776367188,
   700.6685791015625,
   70

### 4DVar Example

In [5]:
def forwardModel_wrap(x0, time, rayleigh, prandtl, b):
    return forwardModel_r(x0.ravel(), time, rayleigh, prandtl, b).T.unsqueeze_(1)

In [6]:
# We define the control parameters here
rayleigh = 35
prandtl = 10.
b = 8./3.
# rayleigh = 0.
# prandtl = 0.
# b = 0.
# initial condition for the true reference trajectory
x0 = torch.tensor([0., 1., 2.], device=device)

# integration time parameter
dt = 1.e-3      # This is time step size
T = 1.         # Total integration time, can be as short as 10 to speed things up
n_steps = ceil(T / dt)
time = torch.linspace(0., T, n_steps + 1, device=device)  # array of discrete times

# numerical integration given initial conditions and control parameters
xt = forwardModel_wrap(x0, time, rayleigh, prandtl, b)

In [7]:
sigobs = 2.  # standard deviation of the observation noise
# How often do we observe the true state?
dtobs = 0.5  # time between observations
gap = int(dtobs / dt)  # number of time steps between each observation
time_obs = torch.cat([torch.tensor([0]), time[gap::gap]])
# Generate vector of observations
R = 1e-8 * torch.diag(torch.tile(torch.tensor(sigobs**2, device=device), (x0.size(0),)))
sqrt_s = torch.sqrt(R)
# y = Hxt
y = H(xt[gap::gap, :])
# compute observation error
# noise = 0.125 * torch.randn(size=y.shape, device=device) @ sqrt_s
# y = Hxt + epsilon
# y = y + noise
y = torch.cat([x0.reshape((1, 1, -1)), y])

#### example of setting all parameters as once

In [8]:
xb = torch.zeros((3,), device=device)
# apply_4DVar(time_obs, gap, forwardModel_wrap, H, B, R, xb, y, learning_rate=7.5e-3, max_iterations=1000, args=(rayleigh, prandtl, b))
params_dict = {
    "algorithm": deepda.Algorithms.Var4D,
    "observation_model": H,
    "background_covariance_matrix": B,
    "observation_covariance_matrix": R,
    "background_state": xb,
    "observations": y,
    "forward_model": forwardModel_wrap,
    "observation_time_steps": time_obs,
    "gap": gap,
    "learning_rate": 7.5e-3,
    "args": (rayleigh, prandtl, b),
}
deepda.CaseBuilder().set_all_parameters(params_dict).execute()

Iterations: 0, J: 79945850880.0, Norm of J gradient: 626040539971584.0
Iterations: 1, J: 62386728960.0, Norm of J gradient: 1012593197056.0
Iterations: 2, J: 56148619264.0, Norm of J gradient: 790117351424.0
Iterations: 3, J: 52210466816.0, Norm of J gradient: 668724690944.0
Iterations: 4, J: 49430224896.0, Norm of J gradient: 583518781440.0
Iterations: 5, J: 47358296064.0, Norm of J gradient: 517317885952.0
Iterations: 6, J: 45762953216.0, Norm of J gradient: 463468036096.0
Iterations: 7, J: 44506681344.0, Norm of J gradient: 418592718848.0
Iterations: 8, J: 43500474368.0, Norm of J gradient: 380636200960.0
Iterations: 9, J: 42683551744.0, Norm of J gradient: 348207218688.0
Iterations: 10, J: 42012975104.0, Norm of J gradient: 320286523392.0
Iterations: 11, J: 41457139712.0, Norm of J gradient: 296093253632.0
Iterations: 12, J: 40992686080.0, Norm of J gradient: 275023593472.0
Iterations: 13, J: 40601661440.0, Norm of J gradient: 256586645504.0
Iterations: 14, J: 40270303232.0, Norm o

{'assimilated_background_state': tensor([-0.0559, -0.0558,  0.1543]),
 'intermediate_results': {'Jb': [125000000.0,
   124629216.0,
   124613952.0,
   124946080.0,
   125463656.0,
   126080072.0,
   126754192.0,
   127464568.0,
   128198752.0,
   128948824.0,
   129709264.0,
   130476024.0,
   131246024.0,
   132016776.0,
   132786240.0,
   133552752.0,
   134314944.0,
   135071632.0,
   135821840.0,
   136564768.0,
   137299776.0,
   138026320.0,
   138744000.0,
   139452496.0,
   140151648.0,
   140841280.0,
   141521328.0,
   142191824.0,
   142852864.0,
   143504496.0,
   144146960.0,
   144780416.0,
   145405088.0,
   146021232.0,
   146629136.0,
   147229120.0,
   147821456.0,
   148406496.0,
   148984576.0,
   149556000.0,
   150121120.0,
   150680240.0,
   151233744.0,
   151781920.0,
   152325072.0,
   152863552.0,
   153397648.0,
   153927632.0,
   154453824.0,
   154976464.0,
   155495808.0,
   156012144.0,
   156525696.0,
   157036720.0,
   157545344.0,
   158051856.0,
   1