|
| 1 | +import ot |
| 2 | +import numpy as np |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +from tqdm import trange |
| 5 | +import os |
| 6 | + |
| 7 | +if __name__ == '__main__': |
| 8 | + two_way_shooting = np.load('./out/baselines/mueller/paths-two-way-shooting.npy', allow_pickle=True) |
| 9 | + two_way_shooting = two_way_shooting.squeeze(2) |
| 10 | + ours = np.load('./out/toy/mueller_single_gaussian/stochastic_paths.npy') |
| 11 | + |
| 12 | + assert two_way_shooting.shape == ours.shape, f'Shapes do not match: {two_way_shooting.shape} vs {ours.shape}' |
| 13 | + |
| 14 | + savedir = './out/evaluation/mueller/' |
| 15 | + os.makedirs(savedir, exist_ok=True) |
| 16 | + |
| 17 | + wasserstein = [] |
| 18 | + for t in trange(ours.shape[1]): |
| 19 | + cur_ground_truth = np.array(two_way_shooting[:, t, :], dtype=np.float64) |
| 20 | + cur_ours = np.array(ours[:, t, :], dtype=np.float64) |
| 21 | + |
| 22 | + M = ot.dist(cur_ground_truth, cur_ours, metric='euclidean') |
| 23 | + w1 = ot.emd2([], [], M) |
| 24 | + wasserstein.append(w1) |
| 25 | + |
| 26 | + wasserstein = np.array(wasserstein) |
| 27 | + print('Median Wasserstein:', np.median(wasserstein)) |
| 28 | + print('Mean Wasserstein:', np.mean(wasserstein)) |
| 29 | + print('Std Wasserstein:', np.std(wasserstein)) |
| 30 | + print('Max Wasserstein:', np.max(wasserstein)) |
| 31 | + print('Min Wasserstein:', np.min(wasserstein)) |
| 32 | + |
| 33 | + plt.plot(wasserstein) |
| 34 | + plt.xlabel(r'$t$') |
| 35 | + plt.ylabel('Wasserstein W1 Distance') |
| 36 | + plt.savefig(f'{savedir}/wasserstein.pdf', bbox_inches='tight') |
| 37 | + plt.clf() |
| 38 | + |
| 39 | + print() |
0 commit comments