<a href="https://colab.research.google.com/github/BotaoJin/Code-for-Thesis/blob/main/EM_linear.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.distributions.multivariate_normal import MultivariateNormal

In [None]:
T_ = 100
sigma = 0.1
gamma = 0.1
x_dim = 2
y_dim = 2
mean = torch.zeros(x_dim)
cov = torch.eye(x_dim)
eps = 0.01

In [None]:
def KFilter(theta):
    x0 = torch.tensor([1., 2.])# initialization of X0
    X = torch.zeros(T_, x_dim)
    Y = torch.zeros(T_, y_dim)
    A_theta = torch.diag(torch.tensor(theta))

    for t in range(T_):
        zeta = MultivariateNormal(mean, (sigma**2)*cov)
        eta = MultivariateNormal(mean, (gamma**2)*cov)
        if t == 0:
            X[t,:] = x0@A_theta + zeta.sample()
        else:
            x = X[t-1,:]
            X[t,:] = x@A_theta + zeta.sample()
            
        Y[t,:] = X[t,:] + eta.sample()
        
    return X, Y

In [None]:
# generating data
num_data_set = 1
X = torch.zeros(T_, num_data_set, x_dim)
Y = torch.zeros(T_, num_data_set, y_dim)
for i in range(num_data_set):
  X_data, Y_data = KFilter([.9, .8])
  X[:, i, :], Y[:, i, :] = X_data, Y_data

X, Y = X.mean(dim = -2), Y. mean(dim = -2)

In [None]:
# EM Algorithm
# E-step: sample V from MCMC Method
def Sample_V_MCMC(theta, V, Y):
  V_p = V + MultivariateNormal(torch.zeros(x_dim).expand(V.size()), 0.01**2*torch.eye(x_dim)).sample()

  log_FV = -1/(2*gamma**2)*torch.linalg.norm(Y - V[1:])**2 - 1/(2*sigma**2)*torch.linalg.norm(V[1:] - V[:-1]*theta)**2
  log_FV_p = -1/(2*gamma**2)*torch.linalg.norm(Y - V_p[1:])**2 - 1/(2*sigma**2)*torch.linalg.norm(V_p[1:] - V_p[:-1]*theta)**2

  d1 = MultivariateNormal(V, 0.01*2*torch.eye(2))
  d2 = MultivariateNormal(V_p, 0.01*2*torch.eye(2))

  log_acc_prob1 = log_FV_p + torch.sum(d2.log_prob(V))
  log_acc_prob2 = log_FV + torch.sum(d1.log_prob(V_p))
  acc_prob = torch.exp(torch.min(torch.tensor([0., log_acc_prob1 - log_acc_prob2])))

  a = torch.distributions.uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0])).sample()
  if a <= acc_prob: # Accept
    acc = 1
    return acc, V_p
  else: # Reject
    acc = 0
    return acc, V

In [None]:
# EM Algorithm
# E-step: sample V from P(V|Y, theta) with EnKF Method
from torch.distributions.multivariate_normal import MultivariateNormal

def Sample_V_EnKF(theta, Y, N_ensem, x0 = torch.tensor([1., 2.])):
  T = Y.shape[-2]
  X = x0.expand((N_ensem, x_dim))
  res = torch.empty(T+1, N_ensem, x_dim)
  res[0] = X

  for j in range(T):
    # Forcast Step
    X = X * theta
    #X = X @ A_theta
    X = X + MultivariateNormal(mean.expand(N_ensem, x_dim), (sigma**2)*cov).sample() # model error for X: dim = (N_ensem, x_dim)
    X_m = X.mean(dim = -2).unsqueeze(-2) # dim = (1, x_dim)
    X_ct = X - X_m
        
    # Analysis Step: for $A_{\theta}$ is a linear operator
    y_obs_j = Y[j].unsqueeze(-2) # dim = (1, y_dim)
    y_obs_perturb = MultivariateNormal(y_obs_j.expand(N_ensem, y_dim), (gamma**2)*cov).sample()
        
    C_uu = 1/(N_ensem - 1)*X_ct.transpose(-1, -2)@X_ct # dim = (1, x_dim)
    # In this model, setting H = I
    HX = X
    HX_m = X_m
    HC = C_uu
    HCH_T = HC
    HCH_TR_chol = torch.linalg.cholesky(HCH_T + (gamma**2)*cov)
    
    # Update and store X_j^{1:N}
    pre = (y_obs_perturb-HX)@torch.cholesky_inverse(HCH_TR_chol)
    X = X + pre@HC
    res[j+1] = X
    
  return res

In [None]:
# EM Algorithm
# M-step for calculating expected value
def Exp_value(theta, res):
  T = res.size(dim = 0)
  N = res.size(dim = -2)
  exp_val = 0

  for t in range(T-1):
    Y_p = res[t+1] - res[t]*theta
    exp_val += torch.linalg.norm(Y_p, 'fro')**2

  return 1./(N * gamma**2)*exp_val

In [None]:
# EM algorithm with MCMC
L = 600
I = 1
theta = torch.tensor([0., 0.])
eta = 3e-4
iter_theta1_EM_MCMC = [theta[0]]
iter_theta2_EM_MCMC = [theta[1]]
sample_size = 100
sample_iter = 0
num_acc = 0

for l in range(L):
  V = torch.empty(T_+1, x_dim)
  V[0] = torch.tensor([1., 2.])
  res = torch.empty(T_+1, sample_size, x_dim)
  for t in range(T_):
    V[t+1] = V[t]*theta + MultivariateNormal(torch.zeros(x_dim), sigma**2*cov).sample()

  for sz in range(sample_size):
    acc, V = Sample_V_MCMC(theta, V, Y)
    res[:,sz,:] = V
    sample_iter += 1
    num_acc += acc

  for i in range(I):
    theta1 = theta.clone().detach().requires_grad_(True)
    J = Exp_value(theta1, res)
    J.backward()
    grad_ = theta1.grad
    theta = theta - eta*torch.tensor([3., 1.])*grad_
    #print(J, grad_)
    iter_theta1_EM_MCMC.append(theta[0])
    iter_theta2_EM_MCMC.append(theta[1])

In [None]:
# EM algorithm with EnKF Method
L = 600
N_ensem = 100
I = 1
theta = torch.tensor([0., 0.])
eta = 1e-4
alpha = 0.5
iter_theta1_EM_EnKF = [theta[0]]
iter_theta2_EM_EnKF = [theta[1]]

for l in range(L):
  # E_step: Sample V
  res = Sample_V_EnKF(theta, Y, N_ensem)
    
  # M_step:
  for i in range(I):
    theta1 = theta.clone().detach().requires_grad_(True)
    J = Exp_value(theta1, res)
    J.backward()
    grad_ = theta1.grad
    #print(grad_)
    theta = theta - eta*grad_
    
  iter_theta1_EM_EnKF.append(theta[0])
  iter_theta2_EM_EnKF.append(theta[1])

In [None]:
fig = plt.figure(figsize=(12, 4.5))

plt.subplot(1,2,1)
plt.plot(range(len(iter_theta1_EM_MCMC)), torch.tensor(iter_theta1_EM_MCMC).detach().numpy(), label = 'MCMC')
plt.plot(range(len(iter_theta1_EM_EnKF)), torch.tensor(iter_theta1_EM_EnKF).detach().numpy(), label = 'EnKF')
plt.plot(range(len(iter_theta1_EM_MCMC)), .9*np.ones((len(iter_theta1_EM_MCMC),)), label = 'True value')
plt.xlabel('iterations')
plt.ylabel('values of theta1')
plt.legend()

plt.subplot(1,2,2)
plt.plot(range(len(iter_theta2_EM_MCMC)), torch.tensor(iter_theta2_EM_MCMC).detach().numpy(), label = 'MCMC')
plt.plot(range(len(iter_theta2_EM_EnKF)), torch.tensor(iter_theta2_EM_EnKF).detach().numpy(), label = 'EnKF')
plt.plot(range(len(iter_theta2_EM_MCMC)), .8*np.ones((len(iter_theta2_EM_MCMC),)), label = 'True value')
plt.xlabel('iterations')
plt.ylabel('values of theta2')
plt.legend()

plt.show()