# Generative Adversarial Networks (GANs) for Tabular Data Generation with PyTorch

In [1]:
import torch
import torch.nn as nn

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
torch.manual_seed(0) # Set seed for reproducibility

<torch._C.Generator at 0x7fb85c227410>

## Create Syntethic Data

In [4]:
# Number of data points
n = 1000

# generate random data for the first column
first_column = torch.rand(n, 1).to(device)

# Create second and third columns based on the relationships
second_column = 2 * first_column
third_column = 2 * second_column

# Combine the columns
data = torch.cat((first_column, second_column, third_column), dim=1)

# Create Models

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 50),
            nn.ReLU(),
            nn.Linear(50, 3)
        )

    def forward(self, x):
        return self.model(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(3, 50),
            nn.ReLU(),
            nn.Linear(50, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Training Loop

In [7]:
# Initialize the models and move them to the device
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.001)

In [8]:
# Training the GAN
num_epochs = 5000
for epoch in range(num_epochs):
    # Train discriminator
    optimizer_d.zero_grad()

    real_data = data
    real_labels = torch.ones(n, 1).to(device)
    outputs = discriminator(real_data)
    d_loss_real = criterion(outputs, real_labels)

    # Generate fake data
    noise = torch.randn(n, 3).to(device)
    fake_data = generator(noise)
    fake_labels = torch.zeros(n, 1).to(device)
    outputs = discriminator(fake_data.detach())
    d_loss_fake = criterion(outputs, fake_labels)

    # Backprop and optimize
    d_loss = d_loss_real + d_loss_fake
    d_loss.backward()
    optimizer_d.step()

    # Train generator
    optimizer_g.zero_grad()
    outputs = discriminator(fake_data)
    g_loss = criterion(outputs, real_labels)
    g_loss.backward()
    optimizer_g.step()

    # Print losses
    if (epoch+1) % 1000 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")

Epoch [1000/5000], d_loss: 1.3752, g_loss: 0.6964
Epoch [2000/5000], d_loss: 1.3858, g_loss: 0.6953
Epoch [3000/5000], d_loss: 1.3306, g_loss: 0.7926
Epoch [4000/5000], d_loss: 1.3851, g_loss: 0.6995
Epoch [5000/5000], d_loss: 1.3839, g_loss: 0.6937


# Inference

In [9]:
# After training, generate some synthetic data
with torch.no_grad():
    test_noise = torch.randn(n, 3).to(device)
    generated_data = generator(test_noise).cpu().numpy()

# Print the first 10 rows of generated data
print("Generated Data (First 10 rows):")
for i in range(10):
    print(generated_data[i])

# To validate if relationships hold:
print("\nValidation (For the first 10 rows):")
for i in range(10):
    print(f"First: {generated_data[i][0]:.4f}, Expected Second: {2*generated_data[i][0]:.4f}, Actual Second: {generated_data[i][1]:.4f}")
    print(f"Second: {generated_data[i][1]:.4f}, Expected Third: {2*generated_data[i][1]:.4f}, Actual Third: {generated_data[i][2]:.4f}\n")

Generated Data (First 10 rows):
[0.85301065 1.6151371  3.277037  ]
[0.28591636 0.5468786  1.1151662 ]
[0.19395384 0.36494565 0.7560951 ]
[0.5969087 1.133051  2.3002656]
[0.71670544 1.360366   2.765015  ]
[0.8928284 1.670333  3.3818266]
[0.4100025 0.7811587 1.5839252]
[0.17398658 0.32405397 0.676839  ]
[0.06351195 0.11708003 0.25542825]
[0.43641403 0.82894945 1.7054884 ]

Validation (For the first 10 rows):
First: 0.8530, Expected Second: 1.7060, Actual Second: 1.6151
Second: 1.6151, Expected Third: 3.2303, Actual Third: 3.2770

First: 0.2859, Expected Second: 0.5718, Actual Second: 0.5469
Second: 0.5469, Expected Third: 1.0938, Actual Third: 1.1152

First: 0.1940, Expected Second: 0.3879, Actual Second: 0.3649
Second: 0.3649, Expected Third: 0.7299, Actual Third: 0.7561

First: 0.5969, Expected Second: 1.1938, Actual Second: 1.1331
Second: 1.1331, Expected Third: 2.2661, Actual Third: 2.3003

First: 0.7167, Expected Second: 1.4334, Actual Second: 1.3604
Second: 1.3604, Expected Third: 