-
Notifications
You must be signed in to change notification settings - Fork 1
/
kernel.py
60 lines (50 loc) · 1.76 KB
/
kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch as th
from torch.nn.functional import relu
EPSILON = 1E-9
"Inspired by the implementation in https://github.com/DanielTrosten/mvc"
def kernel_from_distance_matrix(dist, rel_sigma, min_sigma=EPSILON):
"""
Compute a Gaussian kernel matrix from a distance matrix.
:param dist: Disatance matrix
:type dist: th.Tensor
:param rel_sigma: Multiplication factor for the sigma hyperparameter
:type rel_sigma: float
:param min_sigma: Minimum value for sigma. For numerical stability.
:type min_sigma: float
:return: Kernel matrix
:rtype: th.Tensor
"""
# `dist` can sometimes contain negative values due to floating point errors, so just set these to zero.
dist = relu(dist)
sigma2 = rel_sigma * th.median(dist)
# Disable gradient for sigma
sigma2 = sigma2.detach()
sigma2 = th.where(sigma2 < min_sigma, sigma2.new_tensor(min_sigma), sigma2)
k = th.exp(- dist / (2 * sigma2))
return k
def vector_kernel(x, rel_sigma=0.15):
"""
Compute a kernel matrix from the rows of a matrix.
:param x: Input matrix
:type x: th.Tensor
:param rel_sigma: Multiplication factor for the sigma hyperparameter
:type rel_sigma: float
:return: Kernel matrix
:rtype: th.Tensor
"""
return kernel_from_distance_matrix(cdist(x, x), rel_sigma)
def cdist(X, Y):
"""
Pairwise distance between rows of X and rows of Y.
:param X: First input matrix
:type X: th.Tensor
:param Y: Second input matrix
:type Y: th.Tensor
:return: Matrix containing pairwise distances between rows of X and rows of Y
:rtype: th.Tensor
"""
xyT = X @ th.t(Y)
x2 = th.sum(X**2, dim=1, keepdim=True)
y2 = th.sum(Y**2, dim=1, keepdim=True)
d = x2 - 2 * xyT + th.t(y2)
return d