<a href="https://colab.research.google.com/github/BotaoJin/Code-for-Thesis/blob/main/AD_EnKF_nonlinear_TBP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
!pip install torchdiffeq
from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions.multivariate_normal import MultivariateNormal

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdiffeq
  Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3


In [None]:
class Lorenz63(nn.Module):
  def __init__(self, sigma, beta, rho, x_dim = 3):
      super().__init__()
      self.sigma = nn.Parameter(sigma)
      self.beta = nn.Parameter(beta)
      self.rho = nn.Parameter(rho)
      self.x_dim = 3

  def forward(self, t, u):
    sigma = self.sigma
    beta = self.beta
    rho = self.rho
    out = torch.stack((sigma*(u[...,1]-u[...,0]), rho*u[..., 0]-u[..., 0]*u[..., 2]-u[..., 1], u[..., 0]*u[..., 1]-beta*u[..., 2]), dim = -1)
    return out

In [None]:
x_dim = 3
y_dim = 3
H = torch.eye(3)
cov = torch.eye(3)
gamma = .1
T_ = 100
t1 = torch.tensor([0., .01])

def generate_data(ode_func):
  x0 = torch.tensor([1., 1., 1.])
  X = torch.empty(T_+1, x_dim)
  Y = torch.empty(T_, y_dim)
  X[0] = x0

  for t in range(T_):
    X[t+1] = odeint(ode_func, X[t], t1)[-1]
    X[t+1] = MultivariateNormal(X[t+1], gamma**2*cov).sample()
    Y[t] = MultivariateNormal(X[t+1], gamma**2*cov).sample()

  return X, Y

In [None]:
# generate data
true_sigma = torch.tensor(10.0)
true_beta = torch.tensor(8.0/3.0)
true_rho = torch.tensor(28.0)

true_ode_func = Lorenz63(sigma = true_sigma, beta = true_beta, rho = true_rho)
X, Y = generate_data(true_ode_func)

In [None]:
def EnKF(ode_func, Y, N_ensem, t_span, x0 = torch.tensor([1., 1., 1.])):
  T = Y.shape[-2]
  X = x0.expand((N_ensem, x_dim))
  res = torch.empty(T+1, N_ensem, x_dim)
  res[0] = X
  mean = torch.zeros(x_dim)
  log_likelihood = torch.tensor(0.)

  for t in range(T):
    # Forcast Step
    X = odeint_adjoint(ode_func, X, t_span, method = 'rk4', adjoint_method = 'rk4')[-1]
    X = X + MultivariateNormal(mean.expand((N_ensem, x_dim)), (gamma**2)*cov).sample() # model error for X: dim = (N_ensem, x_dim)
    X_m = X.mean(dim = -2).unsqueeze(-2) # dim = (1, x_dim)
    X_ct = X - X_m

    # Analysis Step
    y_obs_j = Y[t].unsqueeze(-2) # dim = (1, y_dim)
    y_obs_perturb = MultivariateNormal(y_obs_j.expand(N_ensem, y_dim), (gamma**2)*cov).sample()

    C_uu = 1/(N_ensem - 1)*X_ct.transpose(-1, -2)@X_ct # dim = (1, x_dim)
    # In this model, setting H = I
    HX = X
    HX_m = X_m
    HC = C_uu
    HCH_T = HC
    HCH_TR_chol = torch.linalg.cholesky(HCH_T + (gamma**2)*cov)
    d = MultivariateNormal(HX_m.squeeze(-2), scale_tril = HCH_TR_chol)
    log_likelihood += d.log_prob(y_obs_j.squeeze(-2))

    # Update and store X_j^{1:N}
    pre = (y_obs_perturb-HX)@torch.cholesky_inverse(HCH_TR_chol)
    X = X + pre@HC
    res[t+1] = X

  return X, res, log_likelihood

In [None]:
# gradient ascent (TBP, sigma unknown, N_ensem = 3000)
sigma, beta, rho = torch.tensor(.01), true_beta, true_rho
iter_sigma_GD = []
iter_grad_sigma = []
eta_sigma = 1.5e-3
L = 20

for i in range(100):
  x_bar = torch.tensor([1., 1., 1.])
  for j in range(int(T_/L)):
    t_start = j*L
    t_end = np.minimum((j+1)*L, T_)
    y = Y[t_start:t_end]

    sigma1 = sigma.clone().detach().requires_grad_(True)

    ode_func = Lorenz63(sigma1, beta, rho)
    x_bar, res, loglike = EnKF(ode_func, y, N_ensem = 3000, t_span = t1, x0 = x_bar)
    loglike.backward(retain_graph = True)

    grad_sigma = ode_func.sigma.grad
    sigma = sigma + eta_sigma*grad_sigma

    iter_sigma_GD.append(sigma)
    iter_grad_sigma.append(grad_sigma)

In [None]:
# gradient ascent (TBP, beta unknown, N_ensem = 3000)
sigma, beta, rho = true_sigma, torch.tensor(0.), true_rho
iter_beta_GD = []
iter_grad_beta = []
eta_beta = 2e-4
L = 20

for i in range(100):
  x_bar = torch.tensor([1., 1., 1.])
  for j in range(int(T_/L)):
    t_start = j*L
    t_end = np.minimum((j+1)*L, T_)
    y = Y[t_start:t_end]

    beta1 = beta.clone().detach().requires_grad_(True)

    ode_func = Lorenz63(sigma, beta1, rho)
    x_bar, res, loglike = EnKF(ode_func, y, N_ensem = 3000, t_span = t1, x0 = x_bar)
    loglike.backward(retain_graph = True)
    
    grad_beta = ode_func.beta.grad
    beta = beta + eta_beta*grad_beta

    iter_beta_GD.append(beta)
    iter_grad_beta.append(grad_beta)

In [None]:
# Gradient Ascent (TBP, rho unknown, N_ensem = 3000)
sigma, beta, rho = true_sigma, true_beta, torch.tensor(0.)
iter_rho_GD = []
iter_grad_rho = []
eta_rho = 1e-3
L = 20

for i in range(100):
  x_bar = torch.tensor([1., 1., 1.])
  for j in range(int(T_/L)):
    t_start = j*L
    t_end = np.minimum((j+1)*L, T_)
    y = Y[t_start:t_end]

    rho1 = rho.clone().detach().requires_grad_(True)

    ode_func = Lorenz63(sigma, beta, rho1)
    x_bar, res, loglike = EnKF(ode_func, y, N_ensem = 3000, t_span = t1, x0 = x_bar)
    loglike.backward(retain_graph = True)

    grad_rho = ode_func.rho.grad
    rho = rho + eta_rho*grad_rho

    iter_rho_GD.append(rho)
    iter_grad_rho.append(grad_rho)

In [None]:
fig = plt.figure(figsize=(15, 4.5))
n1 = 5*np.arange(100)
n2 = 5*np.arange(100)+1
n3 = 5*np.arange(100)+2
n4 = 5*np.arange(100)+3
n5 = 5*np.arange(100)+4

plt.subplot(1,3,1)
plt.plot(n1, torch.tensor(iter_grad_sigma).detach().numpy()[n1], label = 'grad n1')
plt.plot(n2, torch.tensor(iter_grad_sigma).detach().numpy()[n2], label = 'grad n2')
plt.plot(n3, torch.tensor(iter_grad_sigma).detach().numpy()[n3], label = 'grad n3')
plt.plot(n4, torch.tensor(iter_grad_sigma).detach().numpy()[n4], label = 'grad n4')
plt.plot(n5, torch.tensor(iter_grad_sigma).detach().numpy()[n5], label = 'grad n5')
plt.ylabel('grad of sigma')
plt.xlabel('iterations')
plt.legend()

plt.subplot(1,3,2)
plt.plot(n1, torch.tensor(iter_grad_beta).detach().numpy()[n1], label = 'grad n1')
plt.plot(n2, torch.tensor(iter_grad_beta).detach().numpy()[n2], label = 'grad n2')
plt.plot(n3, torch.tensor(iter_grad_beta).detach().numpy()[n3], label = 'grad n3')
plt.plot(n4, torch.tensor(iter_grad_beta).detach().numpy()[n4], label = 'grad n4')
plt.plot(n5, torch.tensor(iter_grad_beta).detach().numpy()[n5], label = 'grad n5')
plt.ylabel('grad of beta')
plt.xlabel('iterations')
plt.legend()

plt.subplot(1,3,3)
plt.plot(n1, torch.tensor(iter_grad_rho).detach().numpy()[n1], label = 'grad n1')
plt.plot(n2, torch.tensor(iter_grad_rho).detach().numpy()[n2], label = 'grad n2')
plt.plot(n3, torch.tensor(iter_grad_rho).detach().numpy()[n3], label = 'grad n3')
plt.plot(n4, torch.tensor(iter_grad_rho).detach().numpy()[n4], label = 'grad n4')
plt.plot(n5, torch.tensor(iter_grad_rho).detach().numpy()[n5], label = 'grad n5')
plt.ylabel('grad of rho')
plt.xlabel('iterations')
plt.legend()

plt.show()

In [None]:
fig = plt.figure(figsize=(18, 4.5))

plt.subplot(1,3,1)
plt.plot(range(len(iter_sigma_GD)), torch.tensor(iter_sigma_GD).detach().clone(), label = 'iter sigma')
plt.plot(range(len(iter_sigma_GD)), true_sigma * np.ones((len(iter_sigma_GD),)), label = 'True sigma')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.subplot(1,3,2)
plt.plot(range(len(iter_beta_GD)), torch.tensor(iter_beta_GD).detach().clone(), label = 'iter beta')
plt.plot(range(len(iter_beta_GD)), true_beta * np.ones((len(iter_beta_GD),)), label = 'True beta')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.subplot(1,3,3)
plt.plot(range(len(iter_rho_GD)), torch.tensor(iter_rho_GD).detach().clone(), label = 'iter_rho')
plt.plot(range(len(iter_rho_GD)), true_rho * np.ones((len(iter_rho_GD),)), label = 'True rho')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()


plt.show()