<a href="https://colab.research.google.com/github/BotaoJin/Code-for-Thesis/blob/main/AD_EnKF_linear_TBP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import math
from torch.distributions.multivariate_normal import MultivariateNormal
import matplotlib.pyplot as plt

In [None]:
T_ = 100
sigma = 0.1
gamma = 0.1
x_dim = 2
y_dim = 2
mean = torch.zeros(x_dim)
cov = torch.eye(x_dim)
eps = 0.01

In [None]:
def KFilter(theta):
    x0 = torch.tensor([1., 2.])# initialization of X0
    X = torch.zeros(T_, x_dim)
    Y = torch.zeros(T_, y_dim)
    A_theta = torch.diag(torch.tensor(theta))

    for t in range(T_):
        zeta = MultivariateNormal(mean, (sigma**2)*cov)
        eta = MultivariateNormal(mean, (gamma**2)*cov)
        if t == 0:
            X[t,:] = x0@A_theta + zeta.sample()
        else:
            x = X[t-1,:]
            X[t,:] = x@A_theta + zeta.sample()
            
        Y[t,:] = X[t,:] + eta.sample()
        
    return X, Y

In [None]:
# generating data
num_data_set = 2
X = torch.zeros(T_, num_data_set, x_dim)
Y = torch.zeros(T_, num_data_set, y_dim)
for i in range(num_data_set):
  X_data, Y_data = KFilter([.9, .8])
  X[:, i, :], Y[:, i, :] = X_data, Y_data

X, Y = X.mean(dim = -2), Y.mean(dim = -2)

In [None]:
device = torch.device("cpu")

def EnKF_log_likelihood(theta, Y, N_ensem, x0 = torch.tensor([1.,2.])):
    # Compute log_likelihood of theta
    # Y: the observation from time 1 to time T, 2*T matrix, with each column represents the state at time t
    # theta: the variable in log_likelihood
    # N_ensem: number of particles
    # x0: initialization
    
    log_likelihood = torch.tensor(0., device = device)
    T = Y.shape[-2]
    X = x0.expand((N_ensem, x_dim))

    
    for j in range(T):
        # Forcast Step
        X = X * theta
        #X = X @ A_theta
        X = X + MultivariateNormal(mean.expand(N_ensem, x_dim), (sigma**2)*cov).sample() # model error for X: dim = (N_ensem, x_dim)
        X_m = X.mean(dim = -2).unsqueeze(-2) # dim = (1, x_dim)
        X_ct = X - X_m
        
        # Analysis Step: for $A_{\theta}$ is a linear operator
        y_obs_j = Y[j].unsqueeze(-2) # dim = (1, y_dim)
        y_obs_perturb = MultivariateNormal(y_obs_j.expand(N_ensem, y_dim), (gamma**2)*cov).sample()
        
        C_uu = 1/(N_ensem - 1)*X_ct.transpose(-1, -2)@X_ct # dim = (1, x_dim)
        # In this model, setting H = I
        HX = X
        HX_m = X_m
        HC = C_uu
        HCH_T = HC
        HCH_TR_chol = torch.linalg.cholesky(HCH_T + (gamma**2)*cov)
        d = MultivariateNormal(HX_m.squeeze(-2), scale_tril = HCH_TR_chol)
        log_likelihood += d.log_prob(y_obs_j.squeeze(-2))
        
        # Update X
        pre = (y_obs_perturb-HX)@torch.cholesky_inverse(HCH_TR_chol)
        X = X + pre@HC
    
    return X, log_likelihood

In [None]:
def true_log_likelihood(Y, theta, m0 = torch.tensor([1., 2.]), C0 = torch.tensor([[1., 0.],[0., 1.]])):
  # calculate the true log likelihood for theta

  sum_log_likelihood = 0
  J = Y.shape[-2]
  Id = torch.tensor([[1., 0.], [0., 1]])
  m = m0
  C = C0

  for j in range(J):
    # prediction part
    m_hat = theta * m # calculate \hat{m_{j+1}}
    C_hat = theta*C*theta.t() + (gamma**2)*cov

    # Analysis part
    dj = Y[j]-m_hat
    S = C_hat + (gamma**2)*cov
    S_inv = torch.linalg.inv(S)
    K = C_hat@S_inv
    m = m_hat + dj@K.t()
    C = (Id - K)@C_hat

    norm_error = (Y[j]-m_hat)@S_inv@(Y[j]-m_hat).t()
    sum_log_likelihood += 2 * torch.log(torch.tensor(2*math.pi))
    sum_log_likelihood += norm_error
    sum_log_likelihood += torch.log(torch.det(S))

  sum_log_likelihood *= (-1/2)

  return m, C, sum_log_likelihood

In [None]:
eta = 5e-4 # learning rate
theta = torch.tensor([0., 0.]) # initial value of theta
diff = 1
n_iterations = 100
# In the model, given that T = 100
Length = 20 # subsequence length is 20
iter_theta1_TBP = []
iter_theta2_TBP = []
grad_theta1_TBP = []
grad_theta2_TBP = []
log_like_theta_TBP = []
true_log_like_TBP = []

for k in range(n_iterations):
  x0 = m0 = torch.tensor([1., 2.])
  C0 = torch.tensor([[1., 0.],[0., 1.]])
  x = x0
  m = m0
  C = C0
  for j in range(int(T_/Length)):
    theta_k = theta.clone().detach().requires_grad_(True)
    t0 = j*Length
    t1 = np.minimum((j+1)*Length, T_)
    y = Y[t0:t1]

    x, L = EnKF_log_likelihood(theta_k, y, N_ensem = 3000, x0 = x) # calculate likelihood, and update x_t
    log_like_theta_TBP.append(L)
    L.backward(retain_graph=True)
    grad_log_likelihood = theta_k.grad
    grad_theta1_TBP.append(grad_log_likelihood[0])
    grad_theta2_TBP.append(grad_log_likelihood[1])

    m, C, true_L = true_log_likelihood(y, theta_k, m0 = m, C0 = C)
    true_log_like_TBP.append(true_L)

    theta = theta_k + eta * torch.tensor([1.5, 1.])* grad_log_likelihood
    iter_theta1_TBP.append(theta[0])
    iter_theta2_TBP.append(theta[1])

In [None]:
fig = plt.figure(figsize=(18, 4.5))
n1 = 5*np.arange(100)
n2 = 5*np.arange(100)+1
n3 = 5*np.arange(100)+2
n4 = 5*np.arange(100)+3
n5 = 5*np.arange(100)+4

plt.subplot(1,3,1)
plt.plot(n1, torch.tensor(log_like_theta_TBP).detach().numpy()[n1], label = 'EnKF log-like(n1)')
plt.plot(n2, torch.tensor(log_like_theta_TBP).detach().numpy()[n2], label = 'EnKF log-like(n2)')
plt.plot(n3, torch.tensor(log_like_theta_TBP).detach().numpy()[n3], label = 'EnKF log-like(n3)')
plt.plot(n4, torch.tensor(log_like_theta_TBP).detach().numpy()[n4], label = 'EnKF log-like(n4)')
plt.plot(n5, torch.tensor(log_like_theta_TBP).detach().numpy()[n5], label = 'EnKF log-like(n5)')
plt.plot(n1, torch.tensor(true_log_like_TBP).detach().numpy()[n1], label = 'EnKF log-like(n1)')
plt.plot(n2, torch.tensor(true_log_like_TBP).detach().numpy()[n2], label = 'EnKF log-like(n2)')
plt.plot(n3, torch.tensor(true_log_like_TBP).detach().numpy()[n3], label = 'EnKF log-like(n3)')
plt.plot(n4, torch.tensor(true_log_like_TBP).detach().numpy()[n4], label = 'EnKF log-like(n4)')
plt.plot(n5, torch.tensor(true_log_like_TBP).detach().numpy()[n5], label = 'EnKF log-like(n5)')
plt.ylabel('log likelihood')
plt.xlabel('iterations')
plt.legend()

plt.subplot(1,3,2)
plt.plot(n1, torch.tensor(grad_theta1_TBP).detach().numpy()[n1], label = 'grad n1')
plt.plot(n2, torch.tensor(grad_theta1_TBP).detach().numpy()[n2], label = 'grad n2')
plt.plot(n3, torch.tensor(grad_theta1_TBP).detach().numpy()[n3], label = 'grad n3')
plt.plot(n4, torch.tensor(grad_theta1_TBP).detach().numpy()[n4], label = 'grad n4')
plt.plot(n5, torch.tensor(grad_theta1_TBP).detach().numpy()[n5], label = 'grad n5')
plt.ylabel('grad of theta1')
plt.xlabel('iterations')
plt.legend()

plt.subplot(1,3,3)
plt.plot(n1, torch.tensor(grad_theta2_TBP).detach().numpy()[n1], label = 'grad n1')
plt.plot(n2, torch.tensor(grad_theta2_TBP).detach().numpy()[n2], label = 'grad n2')
plt.plot(n3, torch.tensor(grad_theta2_TBP).detach().numpy()[n3], label = 'grad n3')
plt.plot(n4, torch.tensor(grad_theta2_TBP).detach().numpy()[n4], label = 'grad n4')
plt.plot(n5, torch.tensor(grad_theta2_TBP).detach().numpy()[n5], label = 'grad n5')
plt.ylabel('grad of theta2')
plt.xlabel('iterations')
plt.legend()

plt.show()

In [None]:
fig = plt.figure(figsize=(12, 4.5))

plt.subplot(1,2,1)
plt.plot(range(len(iter_theta1_TBP)), torch.tensor(iter_theta1_TBP).clone().numpy(), label = 'values of theta1')
plt.plot(range(len(iter_theta1_TBP)), .9*np.ones((len(iter_theta1_TBP),)), label = 'true value of theta1')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.subplot(1,2,2)
plt.plot(range(len(iter_theta2_TBP)), torch.tensor(iter_theta2_TBP).clone().numpy(), label = 'values of theta2')
plt.plot(range(len(iter_theta2_TBP)), .8*np.ones((len(iter_theta2_TBP),)), label = 'true value of theta2')
plt.xlabel('iterations')
plt.ylabel('values')
plt.legend()

plt.show()