In [12]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from categoricalGAN import Generator, Discriminator


In [13]:
# Load the Adult Income dataset
data = fetch_openml(name='adult', version=2, as_frame=True)
df = data.frame

X = df.drop(columns=['class'])
y = df['class']

In [3]:
#  categorical and numeric columns
cat_columns = X.select_dtypes(include='category').columns
num_columns = X.select_dtypes(include='number').columns


In [4]:
# One-hot encode categorical variables
ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
cat_data = ohe.fit_transform(X[cat_columns])

In [5]:
scaler = MinMaxScaler()
num_data = scaler.fit_transform(X[num_columns])

In [7]:
# Assign unique column names for categorical and numeric features
cat_columns_expanded = [f"{col}_{i}" for col, n_vals in zip(cat_columns, ohe.categories_) for i in range(len(n_vals))]
num_columns_expanded = list(num_columns)

# Convert processed data into DataFrames
cat_df = pd.DataFrame(cat_data, columns=cat_columns_expanded)
num_df = pd.DataFrame(num_data, columns=num_columns_expanded)

processed_data = torch.tensor(
    pd.concat([cat_df, num_df], axis=1).values,
    dtype=torch.float32
)


In [8]:

# GAN Training Parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 64  # Latent space dimensionality
data_dim = processed_data.shape[1]  # Number of features
batch_size = 64
epochs = 3
lr = 0.0002

# Create DataLoader
data_loader = DataLoader(TensorDataset(processed_data), batch_size=batch_size, shuffle=True)

In [14]:

# Instantiate models
generator = Generator(input_dim=latent_dim, output_dim=data_dim).to(device)
discriminator = Discriminator(input_dim=data_dim).to(device)

# Define optimizers and loss function
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)
loss_fn = nn.BCELoss()



In [15]:

# Training Loop
for epoch in range(epochs):
    for real_batch, in data_loader:
        real_batch = real_batch.to(device)
        batch_size = real_batch.size(0)

        # Train Discriminator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_data = generator(z)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        d_real = discriminator(real_batch)
        d_fake = discriminator(fake_data.detach())

        d_loss_real = loss_fn(d_real, real_labels)
        d_loss_fake = loss_fn(d_fake, fake_labels)
        d_loss = d_loss_real + d_loss_fake

        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # Train Generator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_data = generator(z)
        g_loss = loss_fn(discriminator(fake_data), real_labels)

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

    # Print loss every 100 epochs
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

Epoch 0, D Loss: 0.2438, G Loss: 2.4731
Epoch 100, D Loss: 0.0000, G Loss: 19.9995


KeyboardInterrupt: 

In [17]:
# Generate synthetic data
z = torch.randn(1000, latent_dim).to(device)
synthetic_data = generator(z).detach().cpu().numpy()

# Decode the generated data
# Split into categorical and numeric
cat_synthetic = synthetic_data[:, :cat_data.shape[1]]
num_synthetic = synthetic_data[:, cat_data.shape[1]:]

# Decode categorical data
cat_synthetic_decoded = ohe.inverse_transform(cat_synthetic)
num_synthetic_decoded = scaler.inverse_transform(num_synthetic)

# Combine results into a DataFrame
synthetic_df = pd.DataFrame(num_synthetic_decoded, columns=num_columns_expanded).join(
    pd.DataFrame(cat_synthetic_decoded, columns=cat_columns)
)

synthetic_df.to_csv("synthetic_census.csv", index=False)

In [21]:

print("Synthetic Data Sample:")
synthetic_df.shape
synthetic_df

Synthetic Data Sample:


Unnamed: 0,age,fnlwgt,education-num,capital-gain,capital-loss,hours-per-week,workclass,education,marital-status,occupation,relationship,race,sex,native-country
0,57.759556,286444.96875,5.181427,88.306305,0.045228,36.817223,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
1,57.630077,288280.28125,5.320496,130.893478,0.058223,37.187393,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
2,59.936989,264911.46875,4.655873,32.699440,0.002211,36.531822,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
3,59.129276,282049.96875,4.911891,103.870430,0.005044,36.767963,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
4,57.530968,286471.84375,5.347731,181.059448,0.039062,37.393661,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,57.996986,298267.71875,5.486256,315.752808,0.056460,37.102989,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
996,57.791504,290608.96875,5.375197,162.772964,0.035775,37.192074,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
997,57.823540,287673.50000,5.331725,157.335922,0.034766,37.586899,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
998,59.026642,271351.84375,4.931422,79.738762,0.009834,37.059826,Private,9th,Divorced,Other-service,Not-in-family,White,Female,United-States
