In [2]:
def compute_pairwise_distance(x, y):
    # Number of samples in each batch
    n = x.size(0)
    m = y.size(0)

    # Squared norms of each row vector, reshaped for broadcasting
    x_norm = (x**2).sum(dim=1).view(n, 1)
    y_norm = (y**2).sum(dim=1).view(1, m)

    # Pairwise squared distances using matrix multiplication for the cross term
    return x_norm + y_norm - 2 * torch.mm(x, y.t())


def rational_quadratic_kernel_matrix(x, y, alphas=None):
    # Default mixture parameters if none are provided
    if alphas is None:
        alphas = [0.2, 1, 5]

    # Compute all pairwise squared distances between x and y
    pairwise_dists = compute_pairwise_distance(x, y)

    # Sum multiple rational-quadratic kernels (mixture over different alphas)
    kernel_sum = 0
    for alpha in alphas:
        kernel_sum += (1 + pairwise_dists / (2 * alpha)) ** (-alpha)

    # Final kernel matrix
    return kernel_sum


In [6]:
# Generator maps a noise vector to a sample in the data space
class Generator(nn.Module):
    def __init__(self, noise_dim=2, output_dim=2, hidden_dim=16):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # First MLP layer: noise -> hidden representation
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            # Second MLP layer: increase modeling capacity
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            # Final projection to the desired output dimension
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, z):
        # Forward pass through the MLP
        return self.model(z)


# Critic maps an input sample to a feature embedding (used for scoring / losses)
class Critic(nn.Module):
    def __init__(self, input_dim=2, feature_dim=16, hidden_dim=32):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            # First MLP layer: input -> hidden representation
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            # Second MLP layer: deeper feature extraction
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            # Output feature vector (not a scalar score here)
            nn.Linear(hidden_dim, feature_dim)
        )

    def forward(self, x):
        # Forward pass through the MLP
        return self.model(x)


In [3]:
def mmd_loss(real_features, fake_features, kernel=rational_quadratic_kernel_matrix, sigmas=None):
    # Kernel matrix for real-real similarities
    K_XX = kernel(real_features, real_features, sigmas)

    # Kernel matrix for fake-fake similarities
    K_YY = kernel(fake_features, fake_features, sigmas)

    # Kernel matrix for real-fake similarities
    K_XY = kernel(real_features, fake_features, sigmas)

    # Batch sizes
    m = real_features.size(0)
    n = fake_features.size(0)

    # Biased MMD estimate (includes diagonal terms)
    loss = K_XX.sum()/(m*m) + K_YY.sum()/(n*n) - 2*K_XY.sum()/(m*n)
    return loss


def gradient_penalty(critic, real_data, fake_data, lambda_gp=1.0):
    # Number of samples in the batch
    batch_size = real_data.size(0)

    # Random mixing weights per sample, expanded to match data shape
    epsilon = torch.rand(batch_size, 1, device=device)
    epsilon = epsilon.expand_as(real_data)

    # Interpolated samples where the gradient constraint is enforced
    interpolated = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated.requires_grad_(True)

    # Critic output on interpolated samples
    crit_interpolated = critic(interpolated)

    # Sum outputs so autograd gives gradients for the whole batch in one call
    crit_sum = crit_interpolated.sum()

    # Gradients of the critic output w.r.t. interpolated inputs
    gradients = grad(
        outputs=crit_sum,
        inputs=interpolated,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Flatten per sample to compute a per-sample norm
    gradients = gradients.view(batch_size, -1)

    # Per-sample L2 norm of the gradients
    grad_norm = gradients.norm(2, dim=1)

    # Penalty term encouraging the gradient norm to stay close to 1
    penalty = lambda_gp * ((grad_norm - 1) ** 2).mean()
    return penalty


In [9]:
# initialize models
noise_dim = 2
G = Generator(noise_dim=noise_dim).to(device)
C = Critic().to(device)

In [None]:
# generate samples from the generator
with torch.no_grad():
    z = sample_noise(500, noise_dim)
    gen_samples = G(z).cpu().numpy()

# plot generated samples vs real data
plt.figure(figsize=(6,5))
plt.scatter(X_real[:, 0].cpu(), X_real[:, 1].cpu(), alpha=0.5, label="Real Data")
plt.scatter(gen_samples[:, 0], gen_samples[:, 1], alpha=0.5, c= 'red', label="Generated Data")
plt.title("Real vs Generated Samples (MMD GAN)")
plt.legend()
plt.axis('off')
plt.show()


In [None]:
import shutil
from IPython.display import FileLink

shutil.make_archive('2d', 'zip', 'samples/2d')

FileLink("2d.zip")

In [None]:
from google.colab import files
files.download('2d.zip')