In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pennylane as qml
import numpy as np
import matplotlib.pyplot as plt

### Model Definitions

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hid_size) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hid_size),
            nn.ReLU(),
            nn.Linear(hid_size, hid_size),
            nn.ReLU(),
            nn.Linear(hid_size, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        return self.layers(x)
    
class Generator(nn.Module):
    def __init__(self, input_size, hid_size, out_size) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hid_size),
            nn.ReLU(),
            nn.Linear(hid_size, hid_size),
            nn.ReLU(),
            nn.Linear(hid_size, out_size),
        )
    def forward(self, x):
        return self.layers(x)

### Dataset Generation

In [None]:
N_points = 2000 # From paper

np.random.seed(0)

x = np.random.uniform(low=0.25, high=0.75, size=N_points)
y = np.random.uniform(low=0.25, high=0.75, size=N_points)

# Plot the data
plt.figure(figsize=(4, 4))
plt.scatter(x, y, alpha=1, s=10)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.xlabel('x')
plt.ylabel('y')
plt.title("Central Square Distribution (2000 points)")
plt.gca().set_aspect('equal', adjustable='box')  # Make the plot square
plt.legend()
plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader

class XYDistribution(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return torch.stack((self.x[idx], self.y[idx]))

### Training

In [None]:
noise_size = 4

net_D = Discriminator(2, 4)
net_G = Generator(noise_size, 4, 2)

batch_size = 40

dataset = XYDistribution(x, y)
dataloader = DataLoader(dataset, batch_size=batch_size)

criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = 0.5 + torch.randn(batch_size, noise_size)*2

real_label = 1.
fake_label = 0.
epochs = 1000
lr = 1e-4

optimizerD = optim.Adam(net_D.parameters(), lr=lr)
optimizerG = optim.Adam(net_G.parameters(), lr=lr)
    

In [None]:
d_error = []
g_error = []
fk_prog = []

for _ in range(epochs):
    for batch in dataloader:
        # First Update Discriminator with batch of Real Data
        net_D.zero_grad()
        labels_real = torch.ones((batch_size, 1))
        outputs = net_D(batch)
        loss_d_real = criterion(outputs, labels_real)
        loss_d_real.backward()

        # Update with "Fake" Generator Data
        rand_inps = 0.5 + torch.randn((batch_size, noise_size)) * 2
        gen_outp = net_G(rand_inps)
        labels_fk = torch.zeros((batch_size, 1))
        outputs = net_D(gen_outp)
        loss_d_fake = criterion(outputs, labels_fk)
        loss_d_fake.backward()

        loss_d = loss_d_fake + loss_d_real
        optimizerD.step()

        # Update the Generator Network to Maximize Discriminator Error
        net_G.zero_grad()
        # Do D forward pass again on newly updated network
        labels_real_gen = torch.ones((batch_size, 1))
        gen_outp_2 = net_G(rand_inps)
        outputs = net_D(gen_outp_2)
        loss_g = criterion(outputs, labels_real_gen)
        loss_g.backward()

        optimizerG.step()

        d_error.append(loss_d.item())
        g_error.append(loss_g.item())
    print("Discriminator Loss: ", loss_d.item(), "Generator Loss", loss_g.item())

    with torch.no_grad():
        fake = net_G(fixed_noise)
        # print("GENERATOR PROGRESS: ", fake)
        fk_prog.append(fake)

In [None]:
gen_output = net_G(0.5+torch.randn((50000, noise_size))*2)

x_test = gen_output[:, 0].detach().numpy()
y_test = gen_output[:, 1].detach().numpy()


# Plot test data
def plot_xy(x_pts, y_pts):
    plt.figure(figsize=(4, 4))
    plt.scatter(x_test, y_test, alpha=1, s=10)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title("Test Generator Output")
    plt.gca().set_aspect('equal', adjustable='box')  # Make the plot square
    plt.legend()
    plt.show()

plot_xy(x_test, y_test)