In [3]:
# install pytorch 
# conda install -c pytorch pytorch

In [4]:
import torch

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

![title](mmd_math_equation.png)

In [45]:
def MMD(x, y, kernel, a):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    
    ## torch.mm performs matrix multiplication
    ## x.t() performs a transpose
    ## xx.diag() contains the square of each row/point in x
    xx, yy, xy = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
    
    ## calculating the first summation of parts A and C
    # unsqueze adds another dimension to the tensor (,n) to (1,n)
    # expand_as duplicates the array to as many rows to match the array given 
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))
    
    # compute the distance matrix: (x - y)^2 = x^2 - 2*x*y + y^2
    dxx = rx.t() - 2 * xx + rx # Used for A in (1)
    dyy = ry.t() - 2 * yy + ry # Used for B in (1)
    dxy = rx.t() - 2 * xy + ry # Used for C in (1)
    
#     XX, YY, XY = (torch.zeros(xx.shape).to(device),
#                   torch.zeros(xx.shape).to(device),
#                   torch.zeros(xx.shape).to(device))
    
    if kernel == "multiscale":
#         bandwidth_range = [0.2, 0.5, 0.9, 1.3]
#         for a in bandwidth_range:
#             XX += a**2 * (a**2 + dxx)**-1
#             YY += a**2 * (a**2 + dyy)**-1
#             XY += a**2 * (a**2 + dxy)**-1
        XX = a**2 * (a**2 + dxx)**-1
        YY = a**2 * (a**2 + dyy)**-1
        XY = a**2 * (a**2 + dxy)**-1
            
    if kernel == "rbf":
#         bandwidth_range = [10, 15, 20, 50]
#         for a in bandwidth_range:
#             XX += torch.exp(-0.5*dxx/a)
#             YY += torch.exp(-0.5*dyy/a)
#             XY += torch.exp(-0.5*dxy/a)
        XX = torch.exp(-0.5*dxx/a)
        YY = torch.exp(-0.5*dyy/a)
        XY = torch.exp(-0.5*dxy/a)

    m = x.size()[0]
    beta = 1/(m*(m-1))
    gamma = 2/(m*m)
    
#     return torch.mean(XX - 2 * XY + YY)
    return beta * (torch.sum(XX) + torch.sum(YY)) - gamma * torch.sum(XY)

In [50]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy.stats import dirichlet 
from torch.distributions.multivariate_normal import MultivariateNormal 


m = 20 # sample size
x_mean = torch.zeros(2)+1
y_mean = torch.zeros(2)
x_cov = 2*torch.eye(2) # IMPORTANT: Covariance matrices must be positive definite
y_cov = 3*torch.eye(2) - 1

px = MultivariateNormal(x_mean, x_cov)
qy = MultivariateNormal(y_mean, y_cov)
x = px.sample([m]).to(device)
y = qy.sample([m]).to(device)

result = MMD(x, y, kernel="multiscale", a = 1.3)

print(f"MMD result of X and Y is {result.item()}")

MMD result of X and Y is 0.1570422649383545


In [34]:
## correct implementation for mmd found here:
## https://discuss.pytorch.org/t/maximum-mean-discrepancy-mmd-and-radial-basis-function-rbf/1875/2

In [36]:
x.size()[0]

20

In [42]:
XX

NameError: name 'XX' is not defined