forked from austinyi/NetOTC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fgw_dist.m
63 lines (57 loc) · 1.54 KB
/
fgw_dist.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
%%
% Fused Gromov-Wasserstein distance.
%
function [FGW,pi] = fgw_dist(M, C1, C2, mu1, mu2, q, alpha)
% Define some helper functions.
function loss = fgw_loss(pi)
loss = (1-alpha)*sum(sum(M.^q.*pi));
m = size(M,1);
n = size(M,2);
for i=1:m
for j=1:n
for k=1:m
for l=1:n
loss = loss + 2*alpha*abs(C1(i,k)-C2(j,l))^q*pi(i,j)*pi(k,l);
end
end
end
end
end
function grad = fgw_grad(pi)
grad = (1-alpha)*(M.^q);
m = size(M,1);
n = size(M,2);
for i=1:m
for j=1:n
for k=1:m
for l=1:n
grad(i,j) = grad(i,j) + 2*alpha*abs(C1(i,k)-C2(j,l))^q*pi(k,l);
end
end
end
end
end
% Initialize coupling
pi = mu1 .* mu2';
m = size(pi,1);
n = size(pi,2);
% Run algorithm
n_iter = 100;
for iter=1:n_iter
% Compute gradient
G = fgw_grad(pi);
% Solve OT problem with cost G
[pi_new, ~] = computeot_lp(G', mu1, mu2');
pi_new = reshape(pi_new',n,m)';
% Line search
fun = @(tau) (fgw_loss((1-tau)*pi+tau*pi_new));
%tau = fminbnd(fun,0,1);
tau_vec = 0:0.1:1;
[~,tau_idx] = min(arrayfun(fun, tau_vec));
tau = tau_vec(tau_idx);
% Store updated coupling
pi = (1-tau)*pi + tau*pi_new;
end
% Store result
FGW = fgw_loss(pi);
end