In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
D = 3  # 2-sphere
K = 1  # 1-dimensional manifold
L = 5  # Uniform[-L, L] distribution
B = 4_096  # sampling batch size

In [None]:
source_dist = torch.distributions.uniform.Uniform(-L, L)
target_dist = torch.distributions.normal.Normal(0, 1)

Let's visualize the target distribution samples $\sim \mathcal{U}(S^{D - 1})$.

In [None]:
target_test = torch.nn.functional.normalize(target_dist.sample((B, D)), p=2).numpy()

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(target_test[:, 0], target_test[:, 1], target_test[:, 2], c='blue', alpha=0.2, s=10)
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
ax.set_zlim([-1, 1])
plt.show()

## $\phi$ : Vanilla Generative Adversarial Framework

In [None]:
class Phi(torch.nn.Module):
    def __init__(self, omega=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.omega = omega
        self.fc1 = torch.nn.Linear(K, 256)
        self.fc2 = torch.nn.Linear(256, 256)
        self.fc3 = torch.nn.Linear(256, 256)
        self.fc4 = torch.nn.Linear(256, D)

    def forward(self, X):
        X = self.fc1(X)
        X = torch.sin(30 * X)
        X = self.fc2(X)
        X = torch.sin(self.omega * X)
        X = self.fc3(X)
        X = torch.sin(self.omega * X)
        X = self.fc4(X)
        return X

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, omega=1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.omega = omega
        self.fc1 = torch.nn.Linear(D, 256)
        self.fc2 = torch.nn.Linear(256, 256)
        self.fc3 = torch.nn.Linear(256, 256)
        self.fc4 = torch.nn.Linear(256, 1)

    def forward(self, X):
        X = self.fc1(X)
        X = torch.sin(30 * X)
        X = self.fc2(X)
        X = torch.sin(self.omega * X)
        X = self.fc3(X)
        X = torch.sin(self.omega * X)
        X = self.fc4(X)
        return X

In [None]:
LR = 1e-4
EPOCHS = 250
D_ITERATIONS = 10
G_ITERATIONS = 5

In [None]:
generator = Phi().to('cuda')
critic = Discriminator().to('cuda')

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=LR)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=LR)
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
source_samples = source_dist.sample((B, K)).to('cuda')
target_samples = torch.nn.functional.normalize(target_dist.sample((B, D)), p=2).to('cuda')

one_labels = torch.ones(B, 1).to('cuda')
zero_labels = torch.zeros(B, 1).to('cuda')

In [None]:
for epoch in range(EPOCHS):
    for _ in range(D_ITERATIONS):
        target_fwd = critic(target_samples)
        target_err = criterion(target_fwd, one_labels)
        
        phi = generator(source_samples)
        source_fwd = critic(phi)
        source_err = criterion(source_fwd, zero_labels)

        critic_loss = target_err + source_err
        critic_loss.backward()
        critic_optimizer.step()
        critic_optimizer.zero_grad()
    for _ in range(G_ITERATIONS):
        phi = generator(source_samples)
        source_fwd = critic(phi)
        gen_loss = criterion(source_fwd, one_labels)
        
        gen_loss.backward()
        gen_optimizer.step()
        gen_optimizer.zero_grad()

    print(f"Epoch {epoch + 1} | Critic Loss: {critic_loss.item():.5f} | Phi Loss: {gen_loss.item():.5f}")

The vanilla GAN for $\phi : \mathcal{U}[-L, L] \rightarrow \mathcal{U}(S^{d-1})$ is nowhere near accurate, even for a trivial problem like $d = 3$.

In [None]:
data = source_dist.sample((10000, K)).to('cuda')
manifold = generator(data).cpu().detach()

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(manifold.numpy()[:, 0], manifold.numpy()[:, 1], manifold.numpy()[:, 2], c='blue', alpha=0.1, s=10)
plt.show()

## $\phi$ : Sliced Wasserstein (SW)

### Discriminator-less SW-Generator

In the paper *Generative Modeling using the Sliced Wasserstein Distance* by Despande *et al.*, the authors propose a Sliced Wasserstein generative model that does not use an adversarial training framework. Instead, the generator network $G$ directly optimizes the Sliced Wasserstein distance between $\mathcal{P}_d$ and $G_{\theta}(\mathcal{P}_z)$ (Section 3).

In [None]:
class SlicedWasserstein(torch.nn.Module):
    def __init__(self, slices: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.slices = slices

    def forward(self, source, target):
        L = 0
        F_proj = torch.nn.functional.normalize(torch.rand(self.slices, source.shape[1]), p=2)
        D_proj = torch.nn.functional.normalize(torch.rand(self.slices, target.shape[1]), p=2)
        F_sigma = torch.sort(torch.inner(source, F_proj), dim=0).values
        D_sigma = torch.sort(torch.inner(target, D_proj), dim=0).values
        increment = torch.square(torch.norm(D_sigma - F_sigma, p=2, dim=0)) / target.shape[0]
        L += torch.sum(increment)
        return L

### SW-Generative Adversarial Framework