In [21]:
import torch
class distribution:
    def __init__(self):
        pass
    @staticmethod
    def analyze(X, Y, evaluation = 'mmd'):
        if evaluation == 'mmd':
            return distribution.mmd_rbf(X, Y) 
    
    @staticmethod
    def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        '''
        将源域数据和目标域数据转化为核矩阵
        Params: 
         source: 源域数据，行表示样本数目，列表示样本数据维度
         target: 目标域数据 同source
         kernel_mul: 多核MMD，以bandwidth为中心，两边扩展的基数，比如bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul
         kernel_num: 取不同高斯核的数量
         fix_sigma: 是否固定，如果固定，则为单核MMD
     Return:
      sum(kernel_val): 多个核矩阵之和
        '''
        n_samples = int(source.size()[0])+int(target.size()[0])
        # 求矩阵的行数，即两个域的的样本总数，一般source和target的尺度是一样的，这样便于计算
        total = torch.cat([source, target], dim=0)#将source,target按列方向合并
        
        #将total复制（n+m）份
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        #将total的每一行都复制成（n+m）行，即每个数据都扩展成（n+m）份
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        # total1 - total2 得到的矩阵中坐标（i,j, :）代表total中第i行数据和第j行数据之间的差 
        # sum函数，对第三维进行求和，即平方后再求和，获得高斯核指数部分的分子，是L2范数的平方
        L2_distance_square = ((total0-total1)**2).sum(2) 
        #调整高斯核函数的sigma值
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance_square).item() / (n_samples**2-n_samples)
        # 多核MMD
        #以fix_sigma为中值，以kernel_mul为倍数取kernel_num个bandwidth值（比如fix_sigma为1时，得到[0.25,0.5,1,2,4]
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
#         print(bandwidth_list)
        #高斯核函数的数学表达式
        kernel_val = [torch.exp(-L2_distance_square / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        #得到最终的核矩阵
        return sum(kernel_val)#/len(kernel_val)
    @staticmethod
    def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
        '''
        计算源域数据和目标域数据的MMD距离
        Params: 
         source: 源域数据，行表示样本数目，列表示样本数据维度
         target: 目标域数据 同source
         kernel_mul: 多核MMD，以bandwidth为中心，两边扩展的基数，比如bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul
         kernel_num: 取不同高斯核的数量
         fix_sigma: 是否固定，如果固定，则为单核MMD
     Return:
      loss: MMD loss
        '''
        source = torch.from_numpy(np.array(source).reshape(np.array(source).shape[0], -1))
        target = torch.from_numpy(np.array(target).reshape(np.array(target).shape[0], -1))
        source_num = int(source.size()[0])#一般默认为源域和目标域的batchsize相同
        target_num = int(target.size()[0])
        kernels = distribution.guassian_kernel(source, target,
            kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
        
        #根据式（3）将核矩阵分成4部分
        XX = torch.mean(kernels[:source_num, :source_num])
        YY = torch.mean(kernels[source_num:, source_num:])
        XY = torch.mean(kernels[:source_num, source_num:])
        YX = torch.mean(kernels[source_num:, :source_num])
        
        loss = XX + YY -XY - YX
        return loss#因为一般都是n==m，所以L矩阵一般不加入计算

In [2]:
import numpy as np
a = np.array([[1, 2, 3], [4, 5, 6]])
import torch
torch.from_numpy(a).size()

torch.Size([2, 3])

In [19]:
a = torch.from_numpy(np.array([1,2,3]).reshape(np.array([1,2,3]).shape[0], 1))