Skip to content

Commit 7f02357

Browse files
committed
Add script to evaluate W1 distance
1 parent 31d31a5 commit 7f02357

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- jaxlib=0.4.23
1818
- flax=0.8.3
1919
- notebook=7.0.8
20+
- POT=0.9.4
2021
- pip:
2122
- dmff @ git+https://github.com/deepmodeling/DMFF@v1.0.0
2223
- rdkit==2023.3.3

eval/evaluate_wasserstein.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import ot
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from tqdm import trange
5+
import os
6+
7+
if __name__ == '__main__':
8+
two_way_shooting = np.load('./out/baselines/mueller/paths-two-way-shooting.npy', allow_pickle=True)
9+
two_way_shooting = two_way_shooting.squeeze(2)
10+
ours = np.load('./out/toy/mueller_single_gaussian/stochastic_paths.npy')
11+
12+
assert two_way_shooting.shape == ours.shape, f'Shapes do not match: {two_way_shooting.shape} vs {ours.shape}'
13+
14+
savedir = './out/evaluation/mueller/'
15+
os.makedirs(savedir, exist_ok=True)
16+
17+
wasserstein = []
18+
for t in trange(ours.shape[1]):
19+
cur_ground_truth = np.array(two_way_shooting[:, t, :], dtype=np.float64)
20+
cur_ours = np.array(ours[:, t, :], dtype=np.float64)
21+
22+
M = ot.dist(cur_ground_truth, cur_ours, metric='euclidean')
23+
w1 = ot.emd2([], [], M)
24+
wasserstein.append(w1)
25+
26+
wasserstein = np.array(wasserstein)
27+
print('Median Wasserstein:', np.median(wasserstein))
28+
print('Mean Wasserstein:', np.mean(wasserstein))
29+
print('Std Wasserstein:', np.std(wasserstein))
30+
print('Max Wasserstein:', np.max(wasserstein))
31+
print('Min Wasserstein:', np.min(wasserstein))
32+
33+
plt.plot(wasserstein)
34+
plt.xlabel(r'$t$')
35+
plt.ylabel('Wasserstein W1 Distance')
36+
plt.savefig(f'{savedir}/wasserstein.pdf', bbox_inches='tight')
37+
plt.clf()
38+
39+
print()

0 commit comments

Comments
 (0)