## Gaussian kernel
For ${x_i}$ and ${y_j}$: 
$$\exp\left(-\frac{\sum_{i,j}(x_i-y_j)^2}{bandwidth}\right)$$

## MMD: distances between mean embeddings of features.

Distributions $P$ and $Q$ over a set $X$.
The MMD is defined by a feature map $\phi:X→H$, $H$: reproducing kernel Hilbert space.
$$\text{MMD}(P,Q)=\|\mathbb{E}_{X \sim P}[\phi(X)]−\mathbb{E}_{Y \sim Q}[\phi(Y)]\|$$

Simplest: $X=H=\mathbb{R}^d$ and $\phi(x)=x$.

$$\text{MMD}(P,Q)=\|\mathbb{E}_{X\sim P}[X]−\mathbb{E}_{Y\sim Q}[Y]\|_{\mathbb{R}^𝑑}
=\|\mu_P-\mu_Q\|_{\mathbb{R}^𝑑},$$
Matching distributions like this will match their means.

Stronger:
$X=\mathbb{R}, \phi(x)=(x, x^2)$,
$$\text{MMD}(P,Q)=\sqrt{(\mathbb{E}X-\mathbb{E}Y)^2 +(\mathbb{E}X^2-\mathbb{E}Y^2)^2}$$

Much stronger: $\phi$ maps to a general RKHS --> apply the kernel trick to compute the MMD, including the Gaussian kernel, lead to the MMD being zero if and only the distributions are identical.

Specifically, letting $k(x,y)=\langle \phi(x), \phi(y)\rangle$, you get
$$\text{MMD}(P,Q)=\mathbb{E}_{X, X' \sim P}k(x, x')+\mathbb{E}_{Y, Y'\sim Q}k(Y, Y')
-2\mathbb{E}_{X\sim P, Y\sim Q}k(X,Y)$$
--> can straightforwardly estimate with samples.


An alternative characterization of the MMD:
$$\text{MMD}(P,Q)=\sup_{f\in H:\|f\|_H\leq 1}\|\mathbb{E}_{X\sim P}[f(X)]-\mathbb{E}_{Y\sim Q}[f(Y)]\|.$$

Actually, the implementation is MKMMD as shown in DAN and the kernel_num means the number of guassian kernel. I consider the kernel_mul is a parameter to control the bandwidth. By 'bandwidth /= kernel_mul ** (kernel_num // 2)', we could get the min_bandwidth. Then, you need compute 5(kernel_num) bandwidth for each kernel. You get bandwidth_list = [min_bandwidth, min_bandwidth * kernel_mul, min_bandwidth * kernel_mul^2 ..... ]

It is ok to use one kernel, but MKMMD could achieve better performance as DAN shown.

In [8]:
import torch
import torch.nn as nn


class MMD_loss(nn.Module):
    def __init__(self, kernel_mul = 2.0, kernel_num = 5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)

        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2) 
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)
    
    def forward(self, source, target):
        batch_size = int(source.size()[0])
        kernels = guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX + YY - XY -YX)
        return loss