In [None]:
import math, os, time, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

def seed_everything(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_everything(123)

In [None]:
def generate_data(data: str, batch_size: int = 200, device: str = "cpu") -> torch.Tensor:
    """
    Generate synthetic 2D datasets without rewards.

    Parameters
    ----------
    data : {"rings", "8gaussians", "2spirals", "checkerboard"}
    batch_size : int
    device : str

    Returns
    -------
    X : torch.FloatTensor of shape (batch_size, 2)
    """
    def torch_linspace_exclusive(start, stop, steps, device="cpu"):
        return torch.linspace(start, stop, steps + 1, device=device)[:-1]

    if data == "rings":
        # split into 4 rings
        n4 = n3 = n2 = batch_size // 4
        n1 = batch_size - n4 - n3 - n2

        angle4 = torch_linspace_exclusive(0, 2 * np.pi, n4, device=device)
        angle3 = torch_linspace_exclusive(0, 2 * np.pi, n3, device=device)
        angle2 = torch_linspace_exclusive(0, 2 * np.pi, n2, device=device)
        angle1 = torch_linspace_exclusive(0, 2 * np.pi, n1, device=device)

        circ4 = torch.stack([torch.cos(angle4), torch.sin(angle4)], dim=1)         # r = 1.00
        circ3 = torch.stack([torch.cos(angle3), torch.sin(angle3)], dim=1) * 0.75  # r = 0.75
        circ2 = torch.stack([torch.cos(angle2), torch.sin(angle2)], dim=1) * 0.50  # r = 0.50
        circ1 = torch.stack([torch.cos(angle1), torch.sin(angle1)], dim=1) * 0.25  # r = 0.25

        X = torch.cat([circ4, circ3, circ2, circ1], dim=0) * 3.0
        X = X + torch.randn_like(X) * 0.08  # small Gaussian noise

        perm = torch.randperm(X.size(0), device=device)
        return X[perm].float()

    elif data == "8gaussians":
        scale = 4.0
        centers = torch.tensor([
            [0, 1],
            [-1/np.sqrt(2),  1/np.sqrt(2)],
            [-1, 0],
            [-1/np.sqrt(2), -1/np.sqrt(2)],
            [0, -1],
            [ 1/np.sqrt(2), -1/np.sqrt(2)],
            [1, 0],
            [ 1/np.sqrt(2),  1/np.sqrt(2)]
        ], dtype=torch.float32, device=device) * scale

        idx = torch.randint(0, 8, (batch_size,), device=device)
        X = torch.randn(batch_size, 2, device=device) * 0.5
        X = (X + centers[idx]) / 1.414

        perm = torch.randperm(X.size(0), device=device)
        return X[perm].float()

    elif data == "2spirals":
        half = batch_size // 2
        n = torch.sqrt(torch.rand(half, 1, device=device)) * (3 * np.pi)

        d1x = -torch.cos(n) * n + torch.rand(half, 1, device=device) * 0.5
        d1y =  torch.sin(n) * n + torch.rand(half, 1, device=device) * 0.5
        spiral1 = torch.cat([d1x, d1y], dim=1)

        spiral2 = -spiral1
        X = torch.cat([spiral1, spiral2], dim=0) / 3.0
        X = X + torch.randn_like(X) * 0.1

        perm = torch.randperm(X.size(0), device=device)
        # if batch_size is odd, drop the last extra sample after permuting
        return X[perm][:batch_size].float()

    elif data == "checkerboard":
        # x1 ~ Uniform([-2, 2])
        x1 = torch.rand(batch_size, device=device) * 4 - 2
        # x2 with alternating offset by parity of floor(x1)
        x2_offset = torch.rand(batch_size, device=device) - (
            torch.randint(0, 2, (batch_size,), device=device, dtype=torch.float32) * 2
        )
        x2 = x2_offset + (torch.floor(x1) % 2)

        X = torch.stack([x1, x2], dim=1) * 2

        perm = torch.randperm(X.size(0), device=device)
        return X[perm].float()

    else:
        raise ValueError(f"Unknown dataset type: {data}")

In [None]:
# --- Choose dataset & dataloader ---
data_type = "checkerboard"   # "rings" | "8gaussians" | "2spirals" | "checkerboard"
batch_size = 1024
dataset_size = 50000

X_real = generate_data(data_type, batch_size=dataset_size, device="cpu")
ds = TensorDataset(X_real)  # (N,2)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True,
                num_workers=0, pin_memory=(device.type=="cuda"))

plt.figure(figsize=(4,4))
sub = X_real[:5000].numpy()
plt.scatter(sub[:,0], sub[:,1], s=2, alpha=0.6)
plt.gca().set_aspect('equal', 'box')
plt.xlim(-4.3,4.3); plt.ylim(-4.3,4.3)
plt.title(f"Real samples • {data_type}")
plt.show()

In [None]:
class MLPGenerator(nn.Module):
    def __init__(self, z_dim=16, hidden=(256,256,256), out_dim=2):
        super().__init__()
        layers = []
        in_dim = z_dim
        for h in hidden:
            layers += [nn.Linear(in_dim, h), nn.ReLU(inplace=True)]
            in_dim = h
        layers += [nn.Linear(in_dim, out_dim)]
        self.net = nn.Sequential(*layers)
    def forward(self, z):
        return self.net(z)

class MLPDiscriminator(nn.Module):
    def __init__(self, in_dim=2, hidden=(256,256,256)):
        super().__init__()
        layers = []
        d = in_dim
        for h in hidden:
            layers += [nn.Linear(d, h), nn.LeakyReLU(0.2, inplace=True)]
            d = h
        layers += [nn.Linear(d, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x).squeeze(-1)

def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, a=0.2, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

z_dim = 16
G = MLPGenerator(z_dim=z_dim).to(device)
D = MLPDiscriminator().to(device)
G.apply(weights_init); D.apply(weights_init)

sum_params = lambda m: sum(p.numel() for p in m.parameters())
print(f"G params: {sum_params(G):,} | D params: {sum_params(D):,}")

In [None]:
lr = 2e-4
beta1, beta2 = 0.5, 0.999
epochs = 500
n_critic = 1

opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
bce_logits = nn.BCEWithLogitsLoss()

def sample_noise(n, z_dim, device):
    return torch.randn(n, z_dim, device=device)

g_losses, d_losses = [], []

viz_noise = sample_noise(5000, z_dim, device)

for epoch in range(1, epochs+1):
    G.train(); D.train()
    pbar = tqdm(dl, desc=f"[Epoch {epoch}/{epochs}]", leave=False)
    for real_batch, in pbar:
        real = real_batch.to(device, non_blocking=True)     # (B,2)
        B = real.size(0)

        # ===== 1) Update Discriminator =====
        for _ in range(n_critic):
            z = sample_noise(B, z_dim, device)
            with torch.no_grad():
                fake_detached = G(z)                        # (B,2)

            D_real = D(real)                                # logits
            D_fake = D(fake_detached)

            d_loss = bce_logits(D_real, torch.ones_like(D_real)) + \
                     bce_logits(D_fake, torch.zeros_like(D_fake))

            opt_D.zero_grad(set_to_none=True)
            d_loss.backward()
            opt_D.step()

        # ===== 2) Update Generator (non-saturating) =====
        z = sample_noise(B, z_dim, device)
        fake = G(z)
        D_fake_for_G = D(fake)
        g_loss = bce_logits(D_fake_for_G, torch.ones_like(D_fake_for_G))

        opt_G.zero_grad(set_to_none=True)
        g_loss.backward()
        opt_G.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())
        pbar.set_postfix(d=f"{d_loss.item():.3f}", g=f"{g_loss.item():.3f}")

    # Visualization every few epochs
    if epoch % 20 == 0 or epoch == 1:
        G.eval()
        with torch.no_grad():
            fake_viz = G(viz_noise).detach().cpu().numpy()
        real_viz = X_real[:5000].numpy()

        plt.figure(figsize=(15,4))
        # left: real
        plt.subplot(1,3,1)
        plt.scatter(real_viz[:,0], real_viz[:,1], s=2, alpha=0.6)
        plt.title(f"Real • {data_type}")
        plt.xlim(-5,5); plt.ylim(-5,5); plt.gca().set_aspect('equal','box')

        # middle: discriminator contour (sigmoid of logits)
        plt.subplot(1,3,2)
        grid_n = 200
        xs = np.linspace(-5, 5, grid_n)
        ys = np.linspace(-5, 5, grid_n)
        xx, yy = np.meshgrid(xs, ys)
        grid = np.stack([xx.ravel(), yy.ravel()], axis=1).astype(np.float32)
        grid_t = torch.from_numpy(grid).to(device)
        with torch.no_grad():
            logits = D(grid_t)
            probs = torch.sigmoid(logits).view(grid_n, grid_n).detach().cpu().numpy()
        im = plt.contourf(xs, ys, probs, levels=20, alpha=0.9)
        plt.colorbar(im, fraction=0.046, pad=0.04)
        plt.title("Discriminator Contour")
        plt.xlim(-5,5); plt.ylim(-5,5); plt.gca().set_aspect('equal','box')

        # right: fake
        plt.subplot(1,3,3)
        plt.scatter(fake_viz[:,0], fake_viz[:,1], s=2, alpha=0.6)
        plt.title(f"GAN samples @ epoch {epoch}")
        plt.xlim(-5,5); plt.ylim(-5,5); plt.gca().set_aspect('equal','box')

        plt.tight_layout(); plt.show()


plt.figure(figsize=(6,3))
plt.plot(d_losses, label="D"); plt.plot(g_losses, label="G")
plt.legend(); plt.title("Training losses"); plt.tight_layout(); plt.show()

In [None]:
# G.load_state_dict(torch.load(f"./weight/ganG_{data_type}.pth", map_location=device)); G.eval()

@torch.no_grad()
def gan_sample(n=5000, z_dim=16, device=device):
    G.eval()
    z = torch.randn(n, z_dim, device=device)
    x = G(z).detach().cpu().numpy()
    return x

fake = gan_sample(n=10000, z_dim=z_dim, device=device)
plt.figure(figsize=(4.5,4.5))
plt.scatter(fake[:,0], fake[:,1], s=2, alpha=0.6)
plt.gca().set_aspect('equal','box')
plt.xlim(-5,5); plt.ylim(-5,5)
plt.title("Samples from trained GAN")
plt.show()
