**Definition (Wasserstein Distance)**

Let $P$ and $Q$ be two probability distributions and denote by $\Pi(P,Q)$ the set of probability distributions whose marginals are $P$ and $Q.$ We call Wasserstein distance between $P$ and $Q$ the distance
$$
W(P,Q)=\inf_{\gamma\in\Pi(P,Q)}\mathbb{E}_{(x,y)\sim \gamma}\|x-y\|.
$$
Intuitively, $W(P,Q)$ measures the minimal effort required to move mass around to change $P$ into $Q,$ or $Q$ into $P.$

**Theorem (Kantorovich-Rubinstein Duality)**

For all probability distributions $P$ and $Q,$ we have
$$
W(P,Q)=\sup_{{\|f\|\ }_L\leqslant 1} \mathbb{E}_{x\sim P}[f(x)]-\mathbb{E}_{x\sim Q}[f(x)]
$$
where $\|\cdot\|_L$ denotes the Lipschitz seminorm.

**Notations**

- We shall denote by $P_r$ the real distribution of the data.

- We shall denote by $\theta$ the weights of the generator $g_\theta$ and by $P_\theta$ the probability distribution of $g_\theta.$ We shall assume $P_\theta=g_\theta(Z)$ where $Z$ is a fixed distribution on a low-dimensional space.

- We shall denote by $w$ the weights of the discriminator $f_w,$ called the critic in the WGAN framework. If $x$ is a data sample, then $f_w(x)$ is the score given to $x$ by $f_w.$ In the WGAN model, the score is not binary, but real-valued.

**Problem Formulation**

- Our goal is to choose $\theta$ to approximate $\inf_{\theta}W(P_r,P_\theta),$ which, by the **Kantorovich-Rubinstein duality theorem**, is equivalent to approximating
$$
\inf_{\theta}\sup_{{\|f\|\ }_L\leqslant 1} \mathbb{E}_{x\sim P_r}[f(x)]-\mathbb{E}_{z\sim Z}[f\circ g_\theta(z)].
$$

- We thus also need to choose an adequate $f.$ This is the role of the critic $f_w$ with weights $w.$ Indeed, up to a multiplicative constant, we can use the approximation
$$
\sup_{{\|f\|\ }_L\leqslant 1} \mathbb{E}_{x\sim P_r}[f(x)]-\mathbb{E}_{z\sim Z}[f\circ g_\theta(z)]\approx \sup_{-c\leqslant w\leqslant c} \mathbb{E}_{x\sim P_r}[f_w(x)]-\mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)]
$$
where $c$ is a constant such that $-c\leqslant w\leqslant c$ means that all the weights in $w$ are between $-c$ and $c,$ simulating a Lipschtiz constant.

- Hence our problem becomes finding $\theta$ and $w$ that approximate
$$
\inf_{\theta}\sup_{-c\leqslant w\leqslant c} \mathbb{E}_{x\sim P_r}[f_w(x)]-\mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)]=\inf_{\theta}\inf_{-c\leqslant w\leqslant c} \mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)]-\mathbb{E}_{x\sim P_r}[f_w(x)].
$$

- In practice, we shall independantly sample $x_1,\ldots,x_n\sim P_r$ and $z_1,\ldots,z_m\sim Z$ and find $\theta$ and $w$ that approximate
$$
\boxed{\inf_{\theta}\inf_{-c\leqslant w\leqslant c} \sum_{i=1}^mf_w\circ g_\theta(z_i)-\sum_{i=1}^nf_w(x_i).}
$$

**Critic Loss**

Given a fixed $\theta,$ the loss function for the critic is
$$
\boxed{L_\text{disc}(w)=\sum_{i=1}^mf_w\circ g_\theta(z_i)-\sum_{i=1}^nf_w(x_i).}
$$
In other words, we have
$$
\boxed{L_\text{disc}(w)=\text{mean}(\text{fake data scores})-\text{mean}(\text{real data scores}).}
$$

**Generator Loss**

- Assume $w$ is fixed and that $W(P_r,P_\theta)$ is approximated by
$$
W(P_r,P_\theta)\approx \mathbb{E}_{x\sim P_r}[f_w(x)]-\mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)].
$$

- We are thus looking to approximate
$$
\underset{\theta}{\text{arginf}}\ W(P_r,P_\theta)\approx \underset{\theta}{\text{arginf}}\ \mathbb{E}_{x\sim P_r}[f_w(x)]-\mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)].
$$

- Because $\mathbb{E}_{x\sim P_r}[f_w(x)]$ is independant of $\theta,$ the problem is equivalent to solving
$$
\underset{\theta}{\text{arginf}}\ -\mathbb{E}_{z\sim Z}[f_w\circ g_\theta(z)].
$$

- In practice, we independantly sample $z_1,\ldots,z_m\sim Z$ from the real data and solve
$$
\underset{\theta}{\text{arginf}}\ -\sum_{i=1}^mf_w\circ g_\theta(z_i).
$$

- Hence the loss function for the generator is
$$
\boxed{L_\text{gen}(\theta)=-\sum_{i=1}^mf_w\circ g_\theta(z_i).}
$$
In other words, we have
$$
\boxed{L_\text{gen}(\theta)=-\text{mean}(\text{fake data scores}).}
$$

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Generator class
class Generator(nn.Module):

    # Initialize network
    def __init__(self, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(100, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # Feed forward
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

In [None]:
# Critic class
class Critic(nn.Module):

    # Initialize network
    def __init__(self, d_input_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)

    # Feed forward
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return self.fc4(x)  # No sigmoid, WGAN outputs raw scores

In [None]:
# Critic loss function
def critic_loss(real_data_score, fake_data_score):
    return torch.mean(fake_data_score) - torch.mean(real_data_score)

# Generator loss function
def generator_loss(fake_data_score):
    return -torch.mean(fake_data_score)

In [None]:
# Weight clipping function
def clip_weights(model, clip_value):
    for param in model.parameters():
        param.data.clamp_(-clip_value, clip_value)

In [None]:
# WGAN model training
def train_wgan(critic, generator, real_data_loader, num_epochs, n_critic = 5, lr = 0.00005, clip_value = 0.01):

    # Initialize optimizers
    optimizer_c = optim.RMSprop(critic.parameters(), lr = lr)  # Critic optimizer
    optimizer_g = optim.RMSprop(generator.parameters(), lr = lr)  # Generator optimizer

    # Move models to device
    critic.to(device)
    generator.to(device)

    # Loop over number of epochs
    for epoch in range(num_epochs):
        for real_data in real_data_loader:  # Loop over each batch of real data
            real_data = real_data.to(device)  # Move real data to device
            batch_size = real_data.size(0)  # Get size of current batch

            # Train critic n_critic times
            for _ in range(n_critic):

                # Generate fake data
                noise = torch.randn(batch_size, 100, device = device)  # Create random noise
                fake_data = generator(noise).detach()  # Output fake data from generator

                # Compute critic loss
                optimizer_c.zero_grad()  # Clear previous gradients for critic
                loss_c = critic_loss(critic(real_data), critic(fake_data))  # Calculate critic loss
                loss_c.backward()  # Backpropagate loss
                optimizer_c.step()  # Update critic weights

                # Clip critic weights
                clip_weights(critic, clip_value)

            # Train generator
            noise = torch.randn(batch_size, 100, device = device)  # Create random noise
            fake_data = generator(noise)  # Output fake data from generator

            optimizer_g.zero_grad()  # Clear previous gradients for generator
            loss_g = generator_loss(critic(fake_data))  # Calculate generator loss
            loss_g.backward()  # Backpropagate loss
            optimizer_g.step()  # Update generator weights

        # Print loss for each epoch to track training progress
        print(f"Epoch [{epoch} / {num_epochs}], Critic Loss: {loss_c.item()}, Generator Loss: {loss_g.item()}")