<a href="https://colab.research.google.com/github/GHes31415/DeepBSDE/blob/master/grad_desc_wass_bary_gauss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Here I'm going to try the followin loss funciton for finding the barycenter of Gaussian distributions in the case $n = 2$


In [None]:
n = 2
sig_1 = torch.tensor([[4,2],[2,3]],dtype= torch.float32)#.to(device)
sig_2 = torch.tensor([[3,-1],[-1,5]],dtype= torch.float32)#.to(device)
cov_mats = [sig_1,sig_2]
k = torch.tensor(2)#.to(device)

In [None]:
# Check for commutativity
torch.norm(sig_1@sig_2-sig_2@sig_1)

tensor(4.2426)

In [None]:
def sqrt_cov(sig: torch.tensor):
  '''
  sig: nxn SPD matrix
  '''
  D,U = torch.linalg.eig(sig)
  return torch.matmul(torch.matmul(U,torch.diag(D**0.5)),torch.linalg.inv(U)).type(torch.float32)

In [None]:
def loss(theta,cov_mats,k):
  '''
  theta:    torch.tensor size n^2,
            current estimate.
  cov_mats: list of size k of tensors with shape nxn,
            covariance matrices.
  k:        torch.tensor size 1,
            number of covariance matrices.
  '''
  theta_m = theta.reshape((n,n))#.to(device)
  sig_theta = torch.matmul(theta_m,theta_m)
  sig_theta_sqrt = sqrt_cov(sig_theta)
  sig_theta_sqrt_inv = torch.linalg.inv(sig_theta_sqrt)
  sum_loss = torch.tensor(0)#.to(device)
  for sig in cov_mats:
    first_mat = torch.matmul(torch.matmul(sig_theta_sqrt,sig),sig_theta_sqrt)
    sqrt_mat = sqrt_cov(first_mat)
    second_mat = torch.matmul(sig_theta_sqrt_inv,sqrt_mat)
    sum_loss = sum_loss + torch.norm(sig_theta_sqrt-second_mat)
  return sum_loss/k

In [None]:
def distance_to_covs(theta,cov_mats,k):
  '''
  theta:    torch.tensor size n^2,
            current estimate.
  cov_mats: list of size k of tensors with shape nxn,
            covariance matrices.
  k:        torch.tensor size 1,
            number of covariance matrices.
  '''
  theta_m = theta.reshape((n,n))#.to(device)
  sig_theta = torch.matmul(theta_m,theta_m)
  sig_theta_sqrt = sqrt_cov(sig_theta)
  sig_theta_sqrt_inv = torch.linalg.inv(sig_theta_sqrt)
  distances = []
  for sig in cov_mats:
    first_mat = torch.matmul(torch.matmul(sig_theta_sqrt,sig),sig_theta_sqrt)
    sqrt_mat = sqrt_cov(first_mat)
    second_mat = torch.matmul(sig_theta_sqrt_inv,sqrt_mat)
    distances.append(torch.norm(sig_theta_sqrt-second_mat))
  return distances


In [None]:
lr = torch.tensor(0.1)#.to(device)
n_iters = 100000
theta = torch.rand(n**2,requires_grad=True)#.to(device)
optimizer = torch.optim.Adam([theta],lr = 0.0001)

In [None]:

for epoch in range(n_iters):

  # compute loss and backpropagate
  l = loss(theta,cov_mats,k)
  # if l<10**(-3):
  #   continue
  l.backward()
  d_theta = theta.grad
  # print(d_theta)
  # update theta gradient descent
  # with torch.no_grad():
  #   theta -= lr*d_theta
  optimizer.step()
  # zero graidents
  theta.grad.zero_()


  if epoch%1000 == 0:
    print(f'epoch = {epoch}, loss = {l:.8f}')
    print(theta)




epoch = 0, loss = 1.84617901
tensor([0.4191, 0.8683, 0.5330, 0.4013], requires_grad=True)
epoch = 1000, loss = 1.68737435
tensor([0.3300, 0.9705, 0.6306, 0.3109], requires_grad=True)
epoch = 2000, loss = 1.54614902
tensor([0.2632, 1.0737, 0.7261, 0.2410], requires_grad=True)
epoch = 3000, loss = 1.41580331
tensor([0.2227, 1.1757, 0.8214, 0.1955], requires_grad=True)
epoch = 4000, loss = 1.29207897
tensor([0.2152, 1.2762, 0.9168, 0.1745], requires_grad=True)
epoch = 5000, loss = 1.17218232
tensor([0.2649, 1.3753, 1.0121, 0.1309], requires_grad=True)
epoch = 6000, loss = 1.04323173
tensor([ 0.4262,  1.4720,  1.1061, -0.0318], requires_grad=True)
epoch = 7000, loss = 0.91775823
tensor([ 0.5705,  1.5647,  1.1967, -0.1742], requires_grad=True)
epoch = 8000, loss = 0.81344509
tensor([ 0.6894,  1.6521,  1.2824, -0.2866], requires_grad=True)
epoch = 9000, loss = 0.73887587
tensor([ 0.7896,  1.7312,  1.3607, -0.3763], requires_grad=True)
epoch = 10000, loss = 0.69671923
tensor([ 0.8682,  1.7981

In [None]:
theta

tensor([-0.3450,  1.7701,  1.8119,  0.7172], requires_grad=True)

In [None]:
distances = distance_to_covs(theta,cov_mats,k)

In [None]:
distances

[tensor(0.6196, grad_fn=<LinalgVectorNormBackward0>),
 tensor(0.6875, grad_fn=<LinalgVectorNormBackward0>)]