In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Define generator model
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)


# Define discriminator model
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),  # Match output size to input_dim
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)


# Utility function to introduce missing values
def introduce_missing_values(data, missing_rate=0.2):
    data_with_missing = data.copy()
    mask = np.random.binomial(1, 1 - missing_rate, data.shape)
    data_with_missing[mask == 0] = 0
    return data_with_missing, mask


# Hyperparameters
input_dim = 5  # Number of features
batch_size = 64
epochs = 10000
lr = 1e-3

# Generate sample data
np.random.seed(0)
data = np.random.rand(1000, input_dim)
data_with_missing, mask = introduce_missing_values(data, missing_rate=0.2)

# Convert to torch tensors
data_tensor = torch.FloatTensor(data)
data_missing_tensor = torch.FloatTensor(data_with_missing)
mask_tensor = torch.FloatTensor(mask)

# Initialize models
G = Generator(input_dim)
D = Discriminator(input_dim)

# Optimizers
g_optimizer = optim.Adam(G.parameters(), lr=lr)
d_optimizer = optim.Adam(D.parameters(), lr=lr)

# Loss functions
bce_loss = nn.BCELoss()

# Training loop
for epoch in range(epochs):
    # Generate imputations
    G_sample = G(data_missing_tensor)

    # Combine the generated and known data
    data_hat = data_missing_tensor * mask_tensor + G_sample * (1 - mask_tensor)

    # Discriminator forward pass
    D_prob = D(data_hat)
    real_data_prob = D(data_tensor)

    # Loss for discriminator
    D_loss = bce_loss(D_prob, mask_tensor) + bce_loss(real_data_prob, torch.ones_like(real_data_prob))

    # Optimize discriminator
    d_optimizer.zero_grad()
    D_loss.backward()
    d_optimizer.step()

    # Generator forward pass
    G_sample = G(data_missing_tensor)
    data_hat = data_missing_tensor * mask_tensor + G_sample * (1 - mask_tensor)

    # Discriminator output with imputed data
    D_prob = D(data_hat)

    # Loss for generator
    G_loss = bce_loss(D_prob, torch.ones_like(D_prob))

    # Optimize generator
    g_optimizer.zero_grad()
    G_loss.backward()
    g_optimizer.step()

    if epoch % 1000 == 0:
        print(f'Epoch [{epoch}/{epochs}] - D Loss: {D_loss.item()}, G Loss: {G_loss.item()}')

# Impute missing values with the trained generator
with torch.no_grad():
    imputed_data = G(data_missing_tensor).numpy()
    imputed_data = data_with_missing * mask + imputed_data * (1 - mask)

print("Original Data:")
print(data[:5])
print("Data with Missing Values:")
print(data_with_missing[:5])
print("Imputed Data:")
print(imputed_data[:5])

Epoch [0/10000] - D Loss: 1.3789055347442627, G Loss: 0.6722905039787292
Epoch [1000/10000] - D Loss: 0.6118221879005432, G Loss: 0.10541795939207077
Epoch [2000/10000] - D Loss: 0.6141912341117859, G Loss: 0.10770110040903091
Epoch [3000/10000] - D Loss: 0.5924171209335327, G Loss: 0.10522453486919403
Epoch [4000/10000] - D Loss: 0.5922989249229431, G Loss: 0.1070709377527237
Epoch [5000/10000] - D Loss: 0.5815091133117676, G Loss: 0.10766290128231049
Epoch [6000/10000] - D Loss: 0.5929068922996521, G Loss: 0.10830793529748917
Epoch [7000/10000] - D Loss: 0.5765272378921509, G Loss: 0.10705052316188812
Epoch [8000/10000] - D Loss: 0.579983115196228, G Loss: 0.10592798888683319
Epoch [9000/10000] - D Loss: 0.5512576103210449, G Loss: 0.10675715655088425
Original Data:
[[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ]
 [0.64589411 0.43758721 0.891773   0.96366276 0.38344152]
 [0.79172504 0.52889492 0.56804456 0.92559664 0.07103606]
 [0.0871293  0.0202184  0.83261985 0.77815675 0