In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch3d.transforms import quaternion_multiply, quaternion_to_matrix

In [2]:
## Wasserstein of two gaussians
class WassersteinGaussian(nn.Module):
    def __init__(self):
        super(WassersteinGaussian, self).__init__()

    def forward(self, loc1, scale1, rot_matrix1, loc2, scale2, rot_matrix2):
        """
        compute the Wasserstein distance between two Gaussians
        loc1, loc2: Bx3
        scale1, scale2: Bx3
        rot_matrix1, rot_matrix2: Bx3x3
        """
        
        loc_diff2 = torch.sum((loc1 - loc2)**2, dim=-1)

        ## Wasserstein distance Tr(C1 + C2 - 2(C1^0.5 * C2 * C1^0.5)^0.5)

        cov1_sqrt_diag = torch.sqrt(scale1).diag_embed() # Bx3x3

        cov2 = torch.bmm(rot_matrix2, torch.bmm(torch.diag_embed(scale2), rot_matrix2.transpose(1, 2))) # covariance matrix Bx3x3
        cov2_R1 = torch.bmm(rot_matrix1.transpose(1, 2), cov2).matmul(rot_matrix1) # Bx3x3
        # E = cv1^0.5*cv2*cv1^0.5

        E = torch.bmm(torch.bmm(cov1_sqrt_diag, cov2_R1), cov1_sqrt_diag) # Bx3x3

        E = (E + E.transpose(1, 2))/2
        E_eign = torch.linalg.eigvalsh(E)


        E_sqrt_trace = (E_eign.pow(2).pow(1/4)).sum(dim=-1)

        CovWasserstein = scale1.sum(dim=-1) + scale2.sum(dim=-1) - 2*E_sqrt_trace
        
        CovWasserstein = torch.clamp(CovWasserstein, min=0) # numerical stability for small negative values

        return torch.sqrt(loc_diff2 + CovWasserstein)

In [3]:
B = 6 # batch size
loc = torch.randn(B, 3) # location Bx3
rot = torch.randn(B, 4) # quaternion Bx4
rot = F.normalize(rot, p=2, dim=1) # normalize quaternion
scale = torch.randn(B, 3) # scale Bx3
scale = torch.exp(scale) # make sure scale is positive

# convert quaternion to rotation matrix
rot_matrix = quaternion_to_matrix(rot) # rotation matrix Bx3x3
cov = torch.bmm(rot_matrix, torch.bmm(torch.diag_embed(scale), rot_matrix.transpose(1, 2))) # covariance matrix Bx3x3



wasserstein = WassersteinGaussian()
wasserstein(loc[:3], scale[:3], rot_matrix[:3], loc[3:], scale[3:], rot_matrix[3:])

tensor([2.9402, 2.8001, 2.2903])

#### 以下为 Junli 的代码

In [4]:
w_distance = wasserstein(loc[:3], scale[:3], rot_matrix[:3], loc[:3], scale[:3], rot_matrix[:3])
# print("loc1:", loc[:3])
# print("loc2:", loc[3:])
print("w_distance:", w_distance)

w_distance: tensor([0.0014, 0.0000, 0.0010])


In [5]:
# 一样的loc 和 scale，一样的rot_matrix
rot_matrix_diff = rot_matrix[:3] + 0.001 * torch.randn(3, 3, 3)
print("rot_matrix:", rot_matrix[:3])
print("rot_matrix_diff:", rot_matrix_diff)
w_distance = wasserstein(loc[:3], scale[:3], rot_matrix[:3], loc[:3], scale[:3], rot_matrix_diff[:3])
# print("loc1:", loc[:3])
# print("loc2:", loc[:3])
# print("rot_matrix1:", rot_matrix[:3])
# print("rot_matrix2:", rot_matrix[3:])
print("w_distance:", w_distance)



rot_matrix: tensor([[[ 0.6218,  0.3789,  0.6854],
         [ 0.6160,  0.3038, -0.7268],
         [-0.4836,  0.8742, -0.0445]],

        [[ 0.0127,  0.7638,  0.6454],
         [ 0.9186, -0.2639,  0.2943],
         [ 0.3950,  0.5891, -0.7049]],

        [[-0.2254, -0.2956,  0.9283],
         [ 0.8773, -0.4760,  0.0615],
         [ 0.4237,  0.8283,  0.3666]]])
rot_matrix_diff: tensor([[[ 0.6213,  0.3766,  0.6867],
         [ 0.6166,  0.3043, -0.7273],
         [-0.4854,  0.8741, -0.0460]],

        [[ 0.0134,  0.7649,  0.6452],
         [ 0.9181, -0.2651,  0.2948],
         [ 0.3952,  0.5887, -0.7048]],

        [[-0.2265, -0.2958,  0.9279],
         [ 0.8795, -0.4752,  0.0600],
         [ 0.4248,  0.8278,  0.3665]]])
w_distance: tensor([0.0000, 0.0221, 0.0000])


In [6]:
from einops import einsum

In [7]:
## Wasserstein of two gaussians
class WassersteinExp(nn.Module):
    def __init__(self):
        super(WassersteinExp, self).__init__()

    def forward(self, loc, scale1, rot_matrix1, velocity, velocity_cov):
        """
        compute the Wasserstein Exponential of X from A
        loc: Bx3
        scale1: Bx3
        rot_matrix1: Bx3x3
        velocity: Bx3
        velocity_cov: Bx3x3 
        """
        new_loc = loc + velocity

        # new_cov = exp_A(X)
        C_ij = rot_matrix1.transpose(1, 2).bmm(velocity_cov).bmm(rot_matrix1)

       
        E_ij = scale1.unsqueeze(-1) + scale1.unsqueeze(-2) # Bx3x3
        E_ij = C_ij/(E_ij+1e-8) # Bx3x3

        gamma = torch.bmm(rot_matrix1, torch.bmm(E_ij, rot_matrix1.transpose(1, 2)))

        cov = torch.bmm(rot_matrix1, torch.bmm(torch.diag_embed(scale1), rot_matrix1.transpose(1, 2))) # covariance matrix Bx3x3

        new_cov = cov + velocity_cov + gamma.bmm(cov).bmm(gamma.transpose(1, 2))

        return new_loc, new_cov





        
## Gaussian Merge
class GaussianMerge(nn.Module):
    def __init__(self):
        super(GaussianMerge, self).__init__()

    def forward(self, loc1, scale1, rot_matrix1, loc2, scale2, rot_matrix2):
        """
        merge two Gaussians
        loc1, loc2: Bx3
        scale1, scale2: Bx3
        rot_matrix1, rot_matrix2: Bx3x3
        """
        
        cov1 = torch.bmm(rot_matrix1, torch.bmm(torch.diag_embed(scale1), rot_matrix1.transpose(1, 2))) # covariance matrix Bx3x3
        cov2 = torch.bmm(rot_matrix2, torch.bmm(torch.diag_embed(scale2), rot_matrix2.transpose(1, 2)))

        K = cov1.matmul((cov1 + cov2).inverse())
        loc_new = loc1.unsqueeze(1) + (loc2.unsqueeze(1) - loc1.unsqueeze(1)).bmm(K.transpose(1, 2))
        loc_new = loc_new.squeeze(1)
        cov_new = cov1 + K.matmul(cov1)

        return loc_new, cov_new




        


In [8]:
B = 6 # batch size
loc = torch.randn(B, 3) # location Bx3
rot = torch.randn(B, 4) # quaternion Bx4
rot = F.normalize(rot, p=2, dim=1) # normalize quaternion
scale = torch.randn(B, 3) # scale Bx3

# convert quaternion to rotation matrix
rot_matrix = quaternion_to_matrix(rot) # rotation matrix Bx3x3
cov = torch.bmm(rot_matrix, torch.bmm(torch.diag_embed(scale), rot_matrix.transpose(1, 2))) # covariance matrix Bx3x3

velocity = torch.randn(B, 3) # velocity Bx3
velocity_cov = torch.randn(B, 3, 3) # velocity covariance Bx3x3
velocity_cov = velocity_cov.transpose(-1, -2) + velocity_cov # make sure it is symmetric

wasserstein_exp = WassersteinExp()
gaussian_merge = GaussianMerge()

In [9]:
wasserstein_exp(loc, scale, rot_matrix, velocity, velocity_cov)



(tensor([[ 0.2802, -1.2681, -1.7694],
         [-0.2062, -0.1605,  0.5394],
         [-0.4850,  0.6407,  1.0364],
         [-0.4175,  2.8751,  0.0866],
         [-0.4916, -3.3968, -3.4293],
         [ 1.0917,  0.5907,  0.6577]]),
 tensor([[[ -0.1895,   2.7206,  -2.1079],
          [  2.7206,  -7.0439,   0.6326],
          [ -2.1079,   0.6326,   3.8252]],
 
         [[ 14.1999,   3.6922,   7.2304],
          [  3.6922,  -2.2090,   1.4793],
          [  7.2304,   1.4793,   1.1822]],
 
         [[  2.8330,   5.0632,  -2.2236],
          [  5.0632,   4.5879,  -3.8400],
          [ -2.2236,  -3.8400,  -3.8684]],
 
         [[ 91.6318,  24.4132, -45.2467],
          [ 24.4132,   6.4690, -12.0736],
          [-45.2467, -12.0736,  21.9829]],
 
         [[ -0.1989,  -0.4289,  -0.6315],
          [ -0.4289,  -0.4784,  -0.2688],
          [ -0.6315,  -0.2688,   0.4564]],
 
         [[ -3.7768,  -0.3736,   0.5976],
          [ -0.3736,  -1.6217,   0.8700],
          [  0.5976,   0.8700,  -3.7201]]

In [10]:
loc1 = torch.randn(B, 3) # location Bx3
rot1 = torch.randn(B, 4) # quaternion Bx4
rot1 = F.normalize(rot1, p=2, dim=1) # normalize quaternion
scale1 = torch.randn(B, 3) # scale Bx3

# convert quaternion to rotation matrix
rot_matrix1 = quaternion_to_matrix(rot1) # rotation matrix Bx3x3

loc2 = torch.randn(B, 3) # location Bx3
rot2 = torch.randn(B, 4) # quaternion Bx4
rot2 = F.normalize(rot2, p=2, dim=1) # normalize quaternion
scale2 = torch.randn(B, 3) # scale Bx3

# convert quaternion to rotation matrix
rot_matrix2 = quaternion_to_matrix(rot2) # rotation matrix Bx3x3

gaussian_merge(loc1, scale1, rot_matrix1, loc2, scale2, rot_matrix2)

(tensor([[ -1.3708,   1.9326,   2.3334],
         [  1.3164,  -0.6882,   1.3025],
         [ 22.7209, -12.4017, -41.7697],
         [ -2.5780,  -1.7663,  -0.9087],
         [  0.4079,   1.0293,   0.6521],
         [  1.9538,  -1.9284,   0.9281]]),
 tensor([[[ -1.0256,  -1.0545,  -0.7546],
          [ -1.0545,   0.6748,   0.7623],
          [ -0.7546,   0.7623,  -0.4141]],
 
         [[ -2.8576,   0.1648,  -1.8476],
          [  0.1648,  -1.3947,   0.0564],
          [ -1.8476,   0.0564,  -5.0289]],
 
         [[  8.7650,  -7.0034, -20.6040],
          [ -7.0034,   3.3736,  11.6706],
          [-20.6040,  11.6706,  40.9157]],
 
         [[ -9.7553,  -5.1141,  -1.7618],
          [ -5.1141,  -4.0525,  -1.1451],
          [ -1.7618,  -1.1451,  -0.2236]],
 
         [[ -0.8428,  -0.4747,   0.3794],
          [ -0.4747,  -0.5156,   0.1640],
          [  0.3794,   0.1640,   0.1407]],
 
         [[  0.0899,   0.5777,  -0.1069],
          [  0.5777,  -0.1246,  -1.2848],
          [ -0.1069,  -

In [11]:
## Wasserstein of two gaussians
class 

SyntaxError: invalid syntax (2731323126.py, line 2)