<a href="https://colab.research.google.com/github/BotaoJin/Code-for-Thesis/blob/main/EM_nonlinear.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal
!pip install torchdiffeq
from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint
import numpy as np
import matplotlib.pyplot as plt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdiffeq
  Downloading torchdiffeq-0.2.3-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.3


In [None]:
class Lorenz63(nn.Module):
  def __init__(self, sigma, beta, rho, x_dim = 3):
      super().__init__()
      self.sigma = nn.Parameter(sigma)
      self.beta = nn.Parameter(beta)
      self.rho = nn.Parameter(rho)
      self.x_dim = 3

  def forward(self, t, u):
    sigma = self.sigma
    beta = self.beta
    rho = self.rho
    out = torch.stack((sigma*(u[...,1]-u[...,0]), rho*u[..., 0]-u[..., 0]*u[..., 2]-u[..., 1], u[..., 0]*u[..., 1]-beta*u[..., 2]), dim = -1)
    return out

In [None]:
true_sigma = torch.tensor(10.0)
true_beta = torch.tensor(8.0/3.0)
true_rho = torch.tensor(28.0)

true_model = Lorenz63(sigma = true_sigma, beta = true_beta, rho = true_rho)

In [None]:
x_dim = 3
y_dim = 3
H = torch.eye(3)
cov = torch.eye(3)
gamma = .1
T_ = 100
t1 = torch.tensor([0., .01])

def generate_data(ode_func):
  x0 = torch.tensor([1., 1., 1.])
  X = torch.empty(T_+1, x_dim)
  Y = torch.empty(T_, y_dim)
  X[0] = x0

  for t in range(T_):
    X[t+1] = odeint(ode_func, X[t], t1)[-1]
    X[t+1] = MultivariateNormal(X[t+1], gamma**2*cov).sample()
    Y[t] = MultivariateNormal(X[t+1], gamma**2*cov).sample()

  return X, Y

In [None]:
# generate data
num_data_set = 5
X = torch.zeros(T_+1, num_data_set, x_dim)
Y = torch.zeros(T_, num_data_set, y_dim)

for i in range(num_data_set):
  X_data, Y_data = generate_data(true_model)
  X[:,i,:], Y[:,i,:] = X_data, Y_data

X, Y = X.mean(-2), Y.mean(-2)

In [None]:
# EM Algorithm
# E-step: sample V from MCMC Method
def Sample_V_MCMC(ode_func, V, Y):
  V_p = V + MultivariateNormal(torch.zeros(x_dim).expand(V.size()), 0.01**2*torch.eye(x_dim)).sample()

  V_new = odeint(ode_func, V[:-1], t1)[-1]
  V_p_new = odeint(ode_func, V_p[:-1], t1)[-1]
  log_FV = -1/(2*gamma**2)*torch.linalg.norm(Y - V[1:])**2 - 1/(2*gamma**2)*torch.linalg.norm(V[1:] - V_new)**2
  log_FV_p = -1/(2*gamma**2)*torch.linalg.norm(Y - V_p[1:])**2 - 1/(2*gamma**2)*torch.linalg.norm(V_p[1:] - V_p_new)**2

  d1 = MultivariateNormal(V, 0.01*2*torch.eye(3))
  d2 = MultivariateNormal(V_p, 0.01*2*torch.eye(3))

  log_acc_prob1 = log_FV_p + torch.sum(d2.log_prob(V))
  log_acc_prob2 = log_FV + torch.sum(d1.log_prob(V_p))
  acc_prob = torch.exp(torch.min(torch.tensor([0., log_acc_prob1 - log_acc_prob2])))

  a = torch.distributions.uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0])).sample()
  if a <= acc_prob: # Accept
    acc = 1
    return acc, V_p.clone().detach().requires_grad_(True)
  else: # Reject
    acc = 0
    return acc, V.clone().detach().requires_grad_(True)

In [None]:
def EnKF(ode_func, Y, N_ensem, t_span, x0 = torch.tensor([1., 1., 1.])):
  T = Y.shape[-2]
  X = x0.expand((N_ensem, x_dim))
  res = torch.empty(T+1, N_ensem, x_dim)
  res[0] = X
  mean = torch.zeros(x_dim)
  log_likelihood = torch.tensor(0.)

  for t in range(T):
    # Forcast Step
    X = odeint_adjoint(ode_func, X, t_span, method = 'rk4', adjoint_method = 'rk4')[-1]
    X = X + MultivariateNormal(mean.expand((N_ensem, x_dim)), (gamma**2)*cov).sample() # model error for X: dim = (N_ensem, x_dim)
    X_m = X.mean(dim = -2).unsqueeze(-2) # dim = (1, x_dim)
    X_ct = X - X_m

    # Analysis Step
    y_obs_j = Y[t].unsqueeze(-2) # dim = (1, y_dim)
    y_obs_perturb = MultivariateNormal(y_obs_j.expand(N_ensem, y_dim), (gamma**2)*cov).sample()

    C_uu = 1/(N_ensem - 1)*X_ct.transpose(-1, -2)@X_ct # dim = (1, x_dim)
    # In this model, setting H = I
    HX = X
    HX_m = X_m
    HC = C_uu
    HCH_T = HC
    HCH_TR_chol = torch.linalg.cholesky(HCH_T + (gamma**2)*cov)
    d = MultivariateNormal(HX_m.squeeze(-2), scale_tril = HCH_TR_chol)
    log_likelihood += d.log_prob(y_obs_j.squeeze(-2))

    # Update and store X_j^{1:N}
    pre = (y_obs_perturb-HX)@torch.cholesky_inverse(HCH_TR_chol)
    X = X + pre@HC
    res[t+1] = X

  return X, res, log_likelihood

In [None]:
def Exp_value(ode_func, res):
  N = res.size(dim = -2)
  T = res.size(dim = 0)
  t1 = torch.tensor([0., .01])
  exp_value = 0

  for t in range(T-1):
    V_hat = odeint_adjoint(ode_func, res[t], t1, method = 'rk4', adjoint_method = 'rk4')[-1]
    V_diff = res[t+1] - V_hat
    exp_value += torch.linalg.norm(V_diff, 'fro')**2

  return 1/(N*gamma**2)*exp_value

In [None]:
# EM algorithm with MCMC (sigma unknown)
sigma = torch.tensor(0.01)
beta = true_beta
rho = true_rho
J = 50
I = 1
eta = 0.05

iter_sigma_EM_MCMC = [sigma]
sample_size = 20
sample_iter = 0
num_acc = 0

for j in range(J):
  ode_func = Lorenz63(sigma = sigma, beta = beta, rho = rho)
  V = torch.empty(T_+1, x_dim)
  V[0] = torch.tensor([1., 1., 1.])
  for t in range(T_):
    V[t+1] = odeint(ode_func, V[t], t1)[-1] + MultivariateNormal(torch.zeros(x_dim), gamma**2*cov).sample()
  res = torch.empty(T_+1, sample_size, x_dim)

  for sz in range(sample_size):
    acc, V = Sample_V_MCMC(ode_func, V, Y)
    res[:,sz,:] = V
    num_acc += acc
    sample_iter += 1

  for i in range(I):
    #torch.autograd.set_detect_anomaly(True)
    sigma1 = sigma.detach().requires_grad_(True)
    ode_fun = Lorenz63(sigma1, beta, rho)
    J = Exp_value(ode_fun, res)
    J.backward(retain_graph = True)
    grad_ = ode_fun.sigma.grad
    sigma = sigma - eta * grad_
    iter_sigma_EM_MCMC.append(sigma)

In [None]:
# EM algorithm with MCMC (beta unknown)
sigma = true_beta
beta = torch.tensor(0.)
rho = true_rho
J = 50
I = 1
eta = 0.01

iter_beta_EM_MCMC = [beta]
sample_size = 20
sample_iter = 0
num_acc = 0

for j in range(J):
  ode_func = Lorenz63(sigma = sigma, beta = beta, rho = rho)
  V = torch.empty(T_+1, x_dim)
  V[0] = torch.tensor([1., 1., 1.])
  for t in range(T_):
    V[t+1] = odeint(ode_func, V[t], t1)[-1] + MultivariateNormal(torch.zeros(x_dim), gamma**2*cov).sample()
  res = torch.empty(T_+1, sample_size, x_dim)

  for sz in range(sample_size):
    acc, V = Sample_V_MCMC(ode_func, V, Y)
    res[:,sz,:] = V
    num_acc += acc
    sample_iter += 1

  for i in range(I):
    #torch.autograd.set_detect_anomaly(True)
    beta1 = beta.detach().requires_grad_(True)
    ode_fun = Lorenz63(sigma, beta1, rho)
    J = Exp_value(ode_fun, res)
    J.backward(retain_graph = True)
    grad_ = ode_fun.beta.grad
    beta = beta - eta * grad_
    iter_beta_EM_MCMC.append(beta)

In [None]:
# EM algorithm with MCMC (rho unknown)
sigma = true_sigma
beta = true_beta
rho = torch.tensor(0.01)
J = 50
I = 1
eta = 0.5

iter_rho_EM_MCMC = [rho]
sample_size = 20
sample_iter = 0
num_acc = 0

for j in range(J):
  ode_func = Lorenz63(sigma = sigma, beta = beta, rho = rho)
  V = torch.empty(T_+1, x_dim)
  V[0] = torch.tensor([1., 1., 1.])
  for t in range(T_):
    V[t+1] = odeint(ode_func, V[t], t1)[-1] + MultivariateNormal(torch.zeros(x_dim), gamma**2*cov).sample()
  res = torch.empty(T_+1, sample_size, x_dim)

  for sz in range(sample_size):
    acc, V = Sample_V_MCMC(ode_func, V, Y)
    res[:,sz,:] = V
    num_acc += acc
    sample_iter += 1

  for i in range(I):
    #torch.autograd.set_detect_anomaly(True)
    rho1 = rho.detach().requires_grad_(True)
    ode_fun = Lorenz63(sigma, beta, rho1)
    J = Exp_value(ode_fun, res)
    J.backward(retain_graph = True)
    grad_ = ode_fun.rho.grad
    rho = rho - eta * grad_
    iter_rho_EM_MCMC.append(rho)

In [None]:
# EM algo for EnKF
sigma, beta, rho = torch.tensor(.01), torch.tensor(0.), torch.tensor(0.)
#sigma, beta, rho = torch.tensor(.01), true_beta, true_rho
#sigma, beta, rho = true_sigma, torch.tensor(0.), true_rho
#sigma, beta, rho = true_sigma, true_beta, torch.tensor(0.)
ode_func = Lorenz63(sigma = sigma, beta = beta, rho = rho)

J = 50
I = 1
eta = 1e-3

iter_sigma_EM_EnKF = [sigma]
iter_beta_EM_EnKF = [beta]
iter_rho_EM_EnKF = [rho]

for j in range(J):
  # E-step: sample V from P(V|Y)
  x, res, loglike = EnKF(ode_func, Y, N_ensem = 20, t_span = t1)

  # M-step: Maximize the Expected Value using Gradient Descent
  for i in range(I):
    sigma1 = sigma.clone().detach().requires_grad_(True)
    beta1 = beta.clone().detach().requires_grad_(True)
    rho1 = rho.clone().detach().requires_grad_(True)

    ode_func = Lorenz63(sigma = sigma1, beta = beta1, rho = rho1)
    J = Exp_value(ode_func, res)
    J.backward(retain_graph = True)

    grad_sigma = ode_func.sigma.grad
    grad_beta = ode_func.beta.grad
    grad_rho = ode_func.rho.grad

    sigma = sigma - 2*eta*grad_sigma
    beta = beta - 0.3*eta*grad_beta
    rho = rho - eta*grad_rho

    iter_sigma_EM_EnKF.append(sigma)
    iter_beta_EM_EnKF.append(beta)
    iter_rho_EM_EnKF.append(rho)

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(18, 4.5))

plt.subplot(1,3,1)
plt.plot(range(len(iter_sigma_EM_MCMC)), torch.tensor(iter_sigma_EM_MCMC).detach().clone(), label = 'MCMC')
plt.plot(range(len(iter_sigma_EM_EnKF)), torch.tensor(iter_sigma_EM_EnKF).detach().clone(), label = 'EnKF')
plt.plot(range(len(iter_sigma_EM_MCMC)), true_sigma * np.ones((len(iter_sigma_EM_MCMC),)), label = 'True sigma')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.subplot(1,3,2)
plt.plot(range(len(iter_beta_EM_MCMC)), torch.tensor(iter_beta_EM_MCMC).detach().clone(), label = 'MCMC')
plt.plot(range(len(iter_beta_EM_EnKF)), torch.tensor(iter_beta_EM_EnKF).detach().clone(), label = 'EnKF')
plt.plot(range(len(iter_beta_EM_MCMC)), true_beta * np.ones((len(iter_beta_EM_MCMC),)), label = 'True beta')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.subplot(1,3,3)
plt.plot(range(len(iter_rho_EM_MCMC)), torch.tensor(iter_rho_EM_MCMC).detach().clone(), label = 'MCMC')
plt.plot(range(len(iter_rho_EM_EnKF)), torch.tensor(iter_rho_EM_EnKF).detach().clone(), label = 'EnKF')
plt.plot(range(len(iter_rho_EM_MCMC)), true_rho * np.ones((len(iter_rho_EM_EnKF),)), label = 'True rho')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.show()