In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Generate a synthetic dataset with missing values
np.random.seed(0)
original_data = np.random.randn(100, 1)
missing_mask = np.random.choice([0, 1], size=original_data.shape, p=[0.2, 0.8])
data_with_missing = original_data * missing_mask

# Define the VAE model for data imputation
latent_dim = 2

# Encoder
encoder_inputs = keras.Input(shape=(1,))
x = layers.Dense(32, activation="relu")(encoder_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(lambda args: tf.random.normal(shape=(tf.shape(args[0])[0], latent_dim)) * tf.exp(0.5 * args[1]) + args[0], name="z")([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(32, activation="relu")(latent_inputs)
outputs = layers.Dense(1)(x)
decoder = keras.Model(latent_inputs, outputs, name="decoder")

# VAE
outputs = decoder(encoder(encoder_inputs)[2])
vae = keras.Model(encoder_inputs, outputs, name="vae")

# Define VAE loss function
def vae_loss(encoder_inputs, outputs, z_mean, z_log_var):
    reconstruction_loss_fn = tf.keras.losses.MeanSquaredError()
    reconstruction_loss = reconstruction_loss_fn(encoder_inputs, outputs)
    kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)
    return reconstruction_loss + 0.1 * kl_loss  # Adjust the weight of the KL divergence term

# Custom training step
optimizer = keras.optimizers.Adam()

@tf.function
def train_step(data):
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = encoder(data)
        reconstructed = decoder(z)
        loss = vae_loss(data, reconstructed, z_mean, z_log_var)
    gradients = tape.gradient(loss, vae.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae.trainable_variables))
    return loss

# Training loop
epochs = 100
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(data_with_missing).batch(batch_size)

for epoch in range(epochs):
    for step, batch_data in enumerate(dataset):
        loss = train_step(batch_data)
    print(f"Epoch {epoch + 1}, Loss: {loss.numpy()}")

# Perform data imputation
imputed_data = vae.predict(data_with_missing)
print("Original Data\tData with Missing Values\tImputed Data")
for original, missing, imputed in zip(original_data[:10], data_with_missing[:10], imputed_data[:10]):
    print(f"{original}\t{missing}\t{imputed}")

Epoch 1, Loss: [0.7795422  0.8250585  0.77983063 0.7819766 ]
Epoch 2, Loss: [0.8331166  0.88270205 0.83352643 0.8360044 ]
Epoch 3, Loss: [0.8654649  0.9195336  0.86602896 0.8688874 ]
Epoch 4, Loss: [1.3295064 1.388175  1.3302343 1.3334875]
Epoch 5, Loss: [0.9485084  1.0117984  0.94939595 0.95305884]
Epoch 6, Loss: [1.4464006 1.5141566 1.4474492 1.4515238]
Epoch 7, Loss: [0.6750906  0.74808997 0.6763172  0.68086594]
Epoch 8, Loss: [0.70230484 0.7811089  0.70371675 0.7087934 ]
Epoch 9, Loss: [0.76284933 0.8476394  0.76444227 0.7700832 ]
Epoch 10, Loss: [0.8820605 0.9729386 0.8838269 0.8900554]
Epoch 11, Loss: [0.8268822  0.92397124 0.8288248  0.8356193 ]
Epoch 12, Loss: [0.90702647 1.0099869  0.9091257  0.91645557]
Epoch 13, Loss: [0.82602656 0.9352707  0.8282859  0.836185  ]
Epoch 14, Loss: [0.93692356 1.0520582  0.9393321  0.94777375]
Epoch 15, Loss: [0.40103376 0.52179044 0.40357444 0.41255355]
Epoch 16, Loss: [0.45780236 0.5837267  0.4604535  0.46992877]
Epoch 17, Loss: [0.43725625 0

In [None]:
# Print some results
print("Original Data:")
print(original_data[:10])
print("Data with Missing Values:")
print(data_with_missing[:10])
print("Imputed Data:")
print(imputed_data[:10])

Original Data:
[[ 1.76405235]
 [ 0.40015721]
 [ 0.97873798]
 [ 2.2408932 ]
 [ 1.86755799]
 [-0.97727788]
 [ 0.95008842]
 [-0.15135721]
 [-0.10321885]
 [ 0.4105985 ]]
Data with Missing Values:
[[ 1.76405235]
 [ 0.40015721]
 [ 0.        ]
 [ 2.2408932 ]
 [ 1.86755799]
 [-0.97727788]
 [ 0.95008842]
 [-0.15135721]
 [-0.        ]
 [ 0.4105985 ]]
Imputed Data:
[[ 1.7015392 ]
 [-0.00750341]
 [ 0.01855113]
 [ 1.7848713 ]
 [ 1.4728615 ]
 [-0.8055777 ]
 [ 0.90160435]
 [-0.3438847 ]
 [-0.26320636]
 [ 0.47405857]]
