# Experiments Gaussian optimal Transport in High Dimensions

## 1. Imports

In [1]:
import gc
import math
import os
import sys
import warnings

sys.path.append("..")

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from scipy.linalg import inv, sqrtm
from scipy.stats import ortho_group
from torch import nn
from tqdm import tqdm

from src import distributions

# from src.icnn import DenseICNN
# from src.tools import compute_l1_norm, ewma
from src.fid_score import calculate_frechet_distance
from src.tools import freeze, unfreeze

torch.random.manual_seed(0xBADBEEF)
warnings.filterwarnings("ignore")
%matplotlib inline

## 2. Config

In [5]:
DIM = 2
assert DIM > 1

OUTPUT_SEED = 0xC0FFEE
L1 = 1e-10
GPU_DEVICE = 0
BATCH_SIZE = 512
EPSILON = 10
N_STEPS = 50
TIME_DIM = 1
T_LR = 3e-4
D_LR = 3e-4
T_ITERS = 10
CONSTANT_TIME = False
USE_POSITIONAL_ENCODING = False
INTEGRAL_SCALE = 1 / (DIM)
T_GRADIENT_MAX_NORM = float("inf")
D_GRADIENT_MAX_NORM = float("inf")
IS_RESNET_GENERATOR = False
PREDICT_SHIFT = True
N_LAST_STEPS_WITHOUT_NOISE = 1

T_N_HIDDEN = 512
T_N_LAYERS = 3

D_N_HIDDEN = 512
D_N_LAYERS = 3

MAX_STEPS = 10000
CONTINUE = -1

In [6]:
EXP_NAME = f"Gaussians_test_EPSILON_{EPSILON}_STEPS_{N_STEPS}_DIM_{DIM}"

config = dict(
    DIM=DIM,
    T_ITERS=T_ITERS,
    D_LR=D_LR,
    T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    N_STEPS=N_STEPS,
    EPSILON=EPSILON,
    CONSTANT_TIME=CONSTANT_TIME,
    USE_POSITIONAL_ENCODING=USE_POSITIONAL_ENCODING,
    TIME_DIM=TIME_DIM,
    INTEGRAL_SCALE=INTEGRAL_SCALE,
    T_GRADIENT_MAX_NORM=T_GRADIENT_MAX_NORM,
    D_GRADIENT_MAX_NORM=D_GRADIENT_MAX_NORM,
    IS_RESNET_GENERATOR=IS_RESNET_GENERATOR,
    T_N_HIDDEN=T_N_HIDDEN,
    T_N_LAYERS=T_N_LAYERS,
    D_N_HIDDEN=D_N_HIDDEN,
    D_N_LAYERS=D_N_LAYERS,
    PREDICT_SHIFT=PREDICT_SHIFT,
)

In [7]:
assert torch.cuda.is_available()
torch.cuda.set_device(GPU_DEVICE)

## 3. Initialize Gaussians

In [8]:
np.random.seed(OUTPUT_SEED)
torch.manual_seed(OUTPUT_SEED)

mu_0 = np.zeros(DIM)
mu_T = np.zeros(DIM)
mu_optimal_plan = np.zeros(2 * DIM)

rotation_Y = ortho_group.rvs(DIM)
weight_Y = rotation_Y @ np.diag(np.exp(np.linspace(np.log(0.5), np.log(2), DIM)))
sigma_Y = weight_Y @ weight_Y.T
Y_sampler = distributions.LinearTransformer(
    distributions.StandartNormalSampler(dim=DIM), weight_Y, bias=None
)

rotation_X = ortho_group.rvs(DIM)
weight_X = rotation_X @ np.diag(np.exp(np.linspace(np.log(0.5), np.log(2), DIM)))
sigma_X = weight_X @ weight_X.T
X_sampler = distributions.LinearTransformer(
    distributions.StandartNormalSampler(dim=DIM), weight_X, bias=None
)

BW = calculate_frechet_distance(np.zeros(DIM), sigma_X, np.zeros(DIM), sigma_Y) / 2
print("True Wasserstein-2 Distance: ", BW)

X = X_sampler.sample(100000).cpu().detach().numpy()
Var_X = np.sum(np.var(X, axis=0))
print("Variance of X:", Var_X)

Y = Y_sampler.sample(100000).cpu().detach().numpy()
Var_Y = np.sum(np.var(Y, axis=0))
print("Variance of Y:", np.sum(Var_Y))

torch.cuda.empty_cache()

True Wasserstein-2 Distance:  1.0355121796910387
Variance of X: 4.2659435
Variance of Y: 4.267611


## 4. Functions for calculating BW-UVP metrics

In [9]:
def symmetrize(X):
    return np.real((X + X.T) / 2)


def get_D_sigma(covariance_0, covariance_T, epsilon):
    shape = covariance_0.shape[0]

    covariance_0_sqrt = symmetrize(sqrtm(covariance_0))
    return symmetrize(
        sqrtm(
            4 * covariance_0_sqrt @ covariance_T @ covariance_0_sqrt
            + (epsilon**2) * np.eye(shape)
        )
    )


def get_C_sigma(covariance_0, D_sigma, epsilon):
    shape = covariance_0.shape[0]

    covariance_0_sqrt = symmetrize(sqrtm(covariance_0))
    covariance_0_sqrt_inv = inv(covariance_0_sqrt)

    return 0.5 * (
        covariance_0_sqrt @ D_sigma @ covariance_0_sqrt_inv - epsilon * np.eye(shape)
    )


def get_mu_t(t, mu_0, mu_T):
    return (1 - t) * mu_0 + t * mu_T


def get_covariance_t(t, covariance_0, covariance_T, C_sigma, epsilon):
    shape = covariance_0.shape[0]

    return (
        ((1 - t) ** 2) * covariance_0
        + (t**2) * covariance_T
        + t * (1 - t) * (C_sigma + C_sigma.T)
        + epsilon * t * (1 - t) * np.eye(shape)
    )


def get_conditional_covariance_t(t, covariance_0, covariance_T, C_sigma, epsilon):
    shape = covariance_0.shape[0]

    covariance_0_inv = inv(covariance_0)

    return (t**2) * (
        covariance_T - C_sigma.T @ covariance_0_inv @ C_sigma
    ) + epsilon * t * (1 - t) * np.eye(shape)


def get_conditional_mu_t(x0, mu_0, mu_T, t, covariance_0, C_sigma, epsilon):
    shape = covariance_0.shape[0]

    covariance_0_inv = inv(covariance_0)

    return (1 - t) * x0 + t * (
        mu_T + C_sigma.T @ covariance_0_inv @ (x0[:, None] - mu_0[:, None])
    )


def get_optimal_plan_covariance(covariance_0, covariance_T, C_sigma):
    size = covariance_0.shape[0]
    optimal_plan_covariance = np.zeros((2 * size, 2 * size))

    optimal_plan_covariance[:size, :size] = covariance_0
    optimal_plan_covariance[size:, size:] = covariance_T

    optimal_plan_covariance[:size, size:] = C_sigma
    optimal_plan_covariance[size:, :size] = C_sigma.T

    return optimal_plan_covariance


def compute_BW_UVP(samples, true_mu, true_covariance):
    samples_covariance = np.cov(samples.T)
    samples_mu = samples.mean(axis=0)
    samples_covariance_sqrt = symmetrize(sqrtm(samples_covariance))

    mu_term = 0.5 * ((true_mu - samples_mu) ** 2).sum()
    covariance_term = (
        0.5 * np.trace(samples_covariance)
        + 0.5 * np.trace(true_covariance)
        - np.trace(
            symmetrize(
                sqrtm(
                    samples_covariance_sqrt @ true_covariance @ samples_covariance_sqrt
                )
            )
        )
    )

    BW = mu_term + covariance_term
    BW_UVP = 100 * (BW / (0.5 * np.trace(true_covariance)))

    return BW_UVP

### Calculation of parameters for BW-UVP metrics

In [10]:
covariance_0 = sigma_X
covariance_T = sigma_Y

D_sigma = get_D_sigma(covariance_0, covariance_T, EPSILON)
C_sigma = get_C_sigma(covariance_0, D_sigma, EPSILON)

mu_t = np.stack(
    [get_mu_t(t, mu_0, mu_T) for t in np.linspace(0, 1, N_STEPS + 1)], axis=0
)

covariance_t = np.stack(
    [
        get_covariance_t(t, covariance_0, covariance_T, C_sigma, EPSILON)
        for t in np.linspace(0, 1, N_STEPS + 1)
    ],
    axis=0,
)

optimal_plan_covariance = get_optimal_plan_covariance(
    covariance_0, covariance_T, C_sigma
)

## 5. Functions for the model training

In [11]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return F.silu(input)


class TimeEmbedding(nn.Module):
    def __init__(self, dim, scale):
        super().__init__()

        self.dim = dim
        self.scale = scale

        inv_freq = torch.exp(
            torch.arange(0, dim, 2, dtype=torch.float32) * (-math.log(10000) / dim)
        )

        self.register_buffer("inv_freq", inv_freq)

    def forward(self, input):
        shape = input.shape

        input = input * self.scale + 1
        sinusoid_in = torch.ger(input.view(-1).float(), self.inv_freq)
        pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1)
        pos_emb = pos_emb.view(*shape, self.dim)

        return pos_emb


def make_net(n_inputs, n_outputs, n_layers=3, n_hiddens=100):
    layers = [nn.Linear(n_inputs, n_hiddens), nn.ReLU()]

    for i in range(n_layers - 1):
        layers.extend([nn.Linear(n_hiddens, n_hiddens), nn.ReLU()])

    layers.append(nn.Linear(n_hiddens, n_outputs))

    return nn.Sequential(*layers)


class SDE(nn.Module):
    def __init__(self, shift_model, epsilon, n_steps, time_dim, is_resnet_generator):
        super().__init__()
        self.shift_model = shift_model
        self.epsilon = epsilon
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps
        self.is_resnet_generator = is_resnet_generator

        self.time = nn.Sequential(
            TimeEmbedding(time_dim, scale=n_steps),
            nn.Linear(time_dim, time_dim),
            Swish(),
            nn.Linear(time_dim, time_dim),
        )

    def forward(self, x0):
        if self.is_resnet_generator:
            trajectory, shifts = self.shift_model(x0)
            times = torch.linspace(0, 1, self.n_steps + 1)

            return trajectory, times, shifts
        else:
            t0 = 0.0
            trajectory = [x0]
            times = [t0]
            shifts = []

            x, t = x0, t0

            for step in range(self.n_steps):
                x, t, shift = self._step(x, t)

                trajectory.append(x)
                times.append(t)
                shifts.append(shift)

            return (
                torch.stack(trajectory, dim=1),
                torch.tensor(times),
                torch.stack(shifts, dim=1),
            )

    def _step(self, x, t):
        if PREDICT_SHIFT:
            shift = self._get_shift(x, t)
            shifted_x = x + shift * torch.tensor(self.delta_t).cuda()
        else:
            shifted_x = self._get_shift(x, t)
            shift = (shifted_x - x) / (torch.tensor(self.delta_t).cuda())
        noise = self._sample_noise(x)

        return shifted_x + noise, t + self.delta_t, shift

    def _get_shift(self, x, t):
        batch_size = x.shape[0]

        if CONSTANT_TIME:
            t = 0.0

        if USE_POSITIONAL_ENCODING:
            t = torch.tensor(t).repeat(batch_size)
            t = t.cuda()
            t = self.time(t)
        else:
            t = torch.tensor(t).repeat(batch_size)[:, None]
            if x.device.type == "cuda":
                t = t.cuda()

        x_t = torch.cat((x, t), dim=-1)

        return self.shift_model(x_t)

    def _sample_noise(self, x):
        noise = math.sqrt(self.epsilon) * math.sqrt(self.delta_t) * torch.randn(x.shape)

        if x.device.type == "cuda":
            noise = noise.cuda()
        return noise

    def set_n_steps(self, n_steps):
        self.n_steps = n_steps
        self.delta_t = 1 / n_steps


def integrate(values, times):
    deltas = times[1:] - times[:-1]
    if values.device.type == "cuda":
        deltas = deltas.cuda()
    return (values * deltas[None, :]).sum(dim=1)


def compute_metrics(X_sampler, T, N_STEPS, mu_t, covariance_t):
    X = X_sampler.sample(100000)

    trajectory = []
    with torch.no_grad():
        for i in range(100000 // BATCH_SIZE + 1):
            trajectory.append(
                T(X[BATCH_SIZE * i : BATCH_SIZE * (i + 1)])[0].cpu().numpy()
            )

    trajectory = np.concatenate(trajectory, axis=0)

    result = []
    for step in range(1, N_STEPS + 1):
        T_X = trajectory[:, step, :]

        T_X_covariance = np.cov(T_X.T)
        T_X_mu = T_X.mean(axis=0)

        true_mu = mu_t[step]
        true_covariance = covariance_t[step]

        T_X_covariance_sqrt = symmetrize(sqrtm(T_X_covariance))

        mu_term = 0.5 * ((true_mu - T_X_mu) ** 2).sum()
        covariance_term = (
            0.5 * np.trace(T_X_covariance)
            + 0.5 * np.trace(true_covariance)
            - np.trace(
                symmetrize(
                    sqrtm(T_X_covariance_sqrt @ true_covariance @ T_X_covariance_sqrt)
                )
            )
        )

        BW = mu_term + covariance_term
        BW_UVP = 100 * (BW / (0.5 * np.trace(true_covariance)))

        result.append(BW_UVP)

    return result


def calculate_noise_norm(dim, n_steps, epsilon):
    n = dim
    dt = 1 / n_steps
    sigma = math.sqrt(dt) * math.sqrt(epsilon)

    return n_steps * (
        math.exp(
            math.log(sigma)
            + math.log(math.sqrt(2))
            + math.lgamma((n + 1) / 2)
            - math.lgamma(n / 2)
        )
    )

## 6. Training

### Network initializing

In [12]:
D = make_net(DIM, 1, n_layers=D_N_LAYERS, n_hiddens=D_N_HIDDEN).cuda()

T = make_net(DIM + TIME_DIM, DIM, n_layers=T_N_LAYERS, n_hiddens=T_N_HIDDEN).cuda()
T = SDE(T, EPSILON, N_STEPS, TIME_DIM, IS_RESNET_GENERATOR).cuda()

print("T params:", np.sum([np.prod(p.shape) for p in T.parameters()]))
print("D params:", np.sum([np.prod(p.shape) for p in D.parameters()]))

T params: 528390
D params: 527361


In [13]:
T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)

if CONTINUE > -1:
    T_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"T_opt_{SEED}_{CONTINUE}.pt"))
    )
    T.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"T_{SEED}_{CONTINUE}.pt")))
    D_opt.load_state_dict(
        torch.load(os.path.join(OUTPUT_PATH, f"D_opt_{SEED}_{CONTINUE}.pt"))
    )
    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f"D_{SEED}_{CONTINUE}.pt")))

In [14]:
wandb.init(name=EXP_NAME, config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: [32m[41mERROR[0m API key must be 40 characters long, yours was 8
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc 

### Train cycle

In [15]:
metrics = []

noise_norm = calculate_noise_norm(DIM, N_STEPS, EPSILON)
wandb.log({"Noise norm": noise_norm}, step=0)

for step in tqdm(range(CONTINUE + 1, MAX_STEPS)):
    unfreeze(T)
    freeze(D)
    for t_iter in range(T_ITERS):
        T_opt.zero_grad()

        X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
        X0.requires_grad_()

        trajectory, times, shifts = T(X0)
        XN = trajectory[:, -1]
        norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
        integral = INTEGRAL_SCALE * integrate(norm, times)

        T_loss = (integral + D(X1) - D(XN)).mean()
        T_loss.backward()
        T_gradient_norm = torch.nn.utils.clip_grad_norm_(
            T.parameters(), max_norm=T_GRADIENT_MAX_NORM
        )
        T_opt.step()

    wandb.log({"T gradient norm": T_gradient_norm.item()}, step=step)
    wandb.log({"Mean norm": torch.sqrt(norm).mean().item()}, step=step)
    wandb.log({"T_loss": T_loss.item()}, step=step)

    del T_loss, X0, X1, XN
    gc.collect()
    torch.cuda.empty_cache()

    freeze(T)
    unfreeze(D)

    D_opt.zero_grad()

    X0, X1 = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
    trajectory, times, shifts = T(X0)
    XN = trajectory[:, -1]
    norm = torch.norm(shifts.flatten(start_dim=2), p=2, dim=-1) ** 2
    integral = INTEGRAL_SCALE * integrate(norm, times)

    D_X1 = D(X1)
    D_XN = D(XN)

    D_loss = (-integral - D_X1 + D_XN).mean()
    D_loss.backward()
    D_gradient_norm = torch.nn.utils.clip_grad_norm_(
        D.parameters(), max_norm=D_GRADIENT_MAX_NORM
    )
    D_opt.step()

    wandb.log({"D gradient norm": D_gradient_norm.item()}, step=step)
    wandb.log({"D_loss": D_loss.item()}, step=step)

    wandb.log({"integral": integral.mean().item()}, step=step)
    wandb.log({"D_X1": D_X1.mean().item()}, step=step)
    wandb.log({"D_XN": D_XN.mean().item()}, step=step)
    del D_loss, X0, X1, XN
    gc.collect()
    torch.cuda.empty_cache()

    if step % 100 == 0:
        metrics.append(compute_metrics(X_sampler, T, N_STEPS, mu_t, covariance_t))

        for i in range(N_STEPS):
            wandb.log({f"BW_UVP_{i}": metrics[-1][i]}, step=step)

        wandb.log({"BW_UVP_mean": np.mean(metrics[-1])}, step=step)
        wandb.log({"BW_UVP_max": max(metrics[-1])}, step=step)

        X = X_sampler.sample(100000)

        trajectory = []
        with torch.no_grad():
            for i in range(100000 // BATCH_SIZE + 1):
                trajectory.append(
                    T(X[BATCH_SIZE * i : BATCH_SIZE * (i + 1)])[0].cpu().numpy()
                )

        trajectory = np.concatenate(trajectory, axis=0)

        TX = trajectory[:, -1]
        X = X.cpu().numpy()
        X_TX = np.concatenate((X, TX), axis=1)

        BW_UVP_target_distr = compute_BW_UVP(TX, mu_T, covariance_T)
        wandb.log({"BW_UVP_target_distr": BW_UVP_target_distr}, step=step)

        BW_UVP_optimal_plan = compute_BW_UVP(
            X_TX, mu_optimal_plan, optimal_plan_covariance
        )
        wandb.log({"BW_UVP_optimal_plan": BW_UVP_optimal_plan}, step=step)

    gc.collect()
    torch.cuda.empty_cache()

 10%|█         | 1008/10000 [18:05<2:41:22,  1.08s/it]


KeyboardInterrupt: 

In [16]:
wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
BW_UVP_0,█▁▁▁▁▁▁▁▁▁▂
BW_UVP_1,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_10,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_11,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_12,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_13,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_14,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_15,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_16,█▁▁▁▁▁▁▁▁▁▁
BW_UVP_17,█▁▁▁▁▁▁▁▁▁▁

0,1
BW_UVP_0,0.0054
BW_UVP_1,0.00854
BW_UVP_10,0.06112
BW_UVP_11,0.06331
BW_UVP_12,0.06465
BW_UVP_13,0.07629
BW_UVP_14,0.0783
BW_UVP_15,0.08824
BW_UVP_16,0.09573
BW_UVP_17,0.10076
