# Debugging GED and Sinkhorn training

In [2]:
%matplotlib inline
import numpy as np
import torch
from torch.nn import MSELoss
from geomloss import SamplesLoss
from matplotlib import pyplot as plt
import sys

sys.path.append("../..")

from experiments.utils.models import create_vector_transform
from manifold_flow.flows import ManifoldFlow
from manifold_flow.training.losses import make_sinkhorn_divergence


## Load models

In [3]:
def make_model(filename):
    outer_transform = create_vector_transform(
        3,
        5,
        linear_transform_type="permutation",
        base_transform_type="rq-coupling",
        context_features=None,
        dropout_probability=0.,
    )
    inner_transform = create_vector_transform(
        2,
        5,
        linear_transform_type="permutation",
        base_transform_type="rq-coupling",
        context_features=None,
        dropout_probability=0.,
    )
    model = ManifoldFlow(
        data_dim=3,
        latent_dim=2,
        outer_transform=outer_transform,
        inner_transform=inner_transform,
        apply_context_to_outer=False,
        pie_epsilon=0.01,
    )
    model.load_state_dict(torch.load(filename, map_location=torch.device("cpu")))
    
    return model

In [4]:
mf_sinkhorn = make_model("../data/models/gamf_2_spherical_gaussian_2_3_0.010_small_largebs.pt")
mf_ged = make_model("../data/models/gamf_2_spherical_gaussian_2_3_0.010_small_ged_largebs.pt")

## Debugging

In [70]:
sinkhorn = SamplesLoss("sinkhorn")

In [123]:
batchsize = 1000
x_true_filename="../data/samples/spherical_gaussian/spherical_gaussian_2_3_0.010_x_train.npy"
all_x_true = np.load(x_true_filename)
all_indices = np.arange(0, len(all_x_true), dtype=np.int)
x_true = torch.tensor(all_x_true[np.random.choice(all_indices, batchsize)], dtype=torch.float)
x_true_shuffled = x_true[torch.randperm(x_true.size(0))]

x_gen = mf_sinkhorn.sample(n=batchsize)
x_gen_shuffled = x_gen[torch.randperm(x_gen.size(0))]

In [124]:
x_true_ = x_true[0:1000]
x_gen_ = x_gen[0:1000]

In [128]:
np.save("x0.npy", x_true_.detach().numpy())
np.save("x1.npy", x_gen_.detach().numpy())

In [125]:
sinkhorn(x_true_, x_gen_)

tensor(-0.0224, grad_fn=<SelectBackward>)

In [126]:
sinkhorn(x_true_, x_true_)

tensor(0.)

In [127]:
sinkhorn(x_gen_, x_gen_)

tensor(0., grad_fn=<SelectBackward>)

In [133]:
import geomloss
geomloss.__version__

'0.2.3'

In [7]:
import numpy as np
import torch
from geomloss import SamplesLoss

x0 = torch.from_numpy(np.load("x0.npy"))
x1 = torch.from_numpy(np.load("x1.npy"))

sinkhorn = SamplesLoss("sinkhorn", scaling=0.99)

print(sinkhorn(x0, x1).item())

0.038145098835229874


## Sinkhorn divergence of different batches on the data set

In [43]:
def calculate_losses(
    model,
    batchsize=1000,
    blur=0.05,
    x_true_filename="../data/samples/spherical_gaussian/spherical_gaussian_2_3_0.010_x_train.npy", 
    tests=25
):
    all_x_true = np.load(x_true_filename)
    all_indices = np.arange(0, len(all_x_true), dtype=np.int)
    sinkhorn = make_sinkhorn_divergence()
    ged = make_generalized_energy_distance()
    
    sinkhorn_losses = []
    ged_losses = []
    for _ in range(tests):
        x_true = torch.tensor(all_x_true[np.random.choice(all_indices, batchsize)], dtype=torch.float)
        x_gen = model.sample(n=batchsize)
        sinkhorn_losses.append(sinkhorn(x_true, x_gen, None).detach().numpy())
        ged_losses.append(ged(x_true, x_gen, None).detach().numpy())
        
    return np.mean(sinkhorn_losses), np.std(sinkhorn_losses), np.mean(ged_losses), np.std(ged_losses)

In [44]:
mfs_s, mfs_s_err, mfs_g, mfs_g_err = calculate_losses(mf_sinkhorn)
mfg_s, mfg_s_err, mfg_g, mfg_g_err = calculate_losses(mf_ged)

In [46]:
print("Sinkhorn loss: Sinkhorn model {} +/- {}".format(mfs_s, mfs_s_err / 25**0.5))
print("               GED model      {} +/- {}".format(mfg_s, mfg_s_err / 25**0.5))
print("GED loss:      Sinkhorn model {} +/- {}".format(mfs_g, mfs_g_err / 25**0.5))
print("               GED model      {} +/- {}".format(mfg_g, mfg_g_err / 25**0.5))

Sinkhorn loss: Sinkhorn model -0.029649963602423668 +/- 0.0028731148689985276
               GED model      -0.01351374015212059 +/- 0.00489116944372654
GED loss:      Sinkhorn model -0.05749441310763359 +/- 0.011745116859674453
               GED model      -0.015488267876207829 +/- 0.018498928844928743


In [47]:
mfs_s_large, mfs_s_large_err, mfs_g_large, mfs_g_large_err = calculate_losses(mf_sinkhorn, 5000, tests=5)
mfg_s_large, mfg_s_large_err, mfg_g_large, mfg_g_large_err = calculate_losses(mf_ged, 5000, tests=5)

In [48]:
print("Sinkhorn loss: Sinkhorn model {} +/- {}".format(mfs_s_large, mfs_s_large_err / 5**0.5))
print("               GED model      {} +/- {}".format(mfg_s_large, mfg_s_large_err / 5**0.5))
print("GED loss:      Sinkhorn model {} +/- {}".format(mfs_g_large, mfs_g_large_err / 5**0.5))
print("               GED model      {} +/- {}".format(mfg_g_large, mfg_g_large_err / 5**0.5))

Sinkhorn loss: Sinkhorn model -0.05015917867422104 +/- 0.002777075756708633
               GED model      -0.05232926458120346 +/- 0.0031953293393659596
GED loss:      Sinkhorn model -0.19188883900642395 +/- 0.009126323078291282
               GED model      -0.17760400474071503 +/- 0.015089282784708826
