In [None]:
import tensorflow as tf
from keras import layers
import numpy as np
import os
import matplotlib.pyplot as plt

# Define the generator network
generator = tf.keras.Sequential([
    layers.Dense(256, input_shape=(100,), activation='relu'),
    layers.Dense(512, activation='relu'),
    layers.Dense(28*28*3, activation='tanh'),
    layers.Reshape((28, 28, 3))
])

# Define the discriminator network
discriminator = tf.keras.Sequential([
    layers.Flatten(input_shape=(28, 28, 3)),
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(1, activation='sigmoid')
])

# Define the GAN model
gan_input = tf.keras.Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)

# Compile the discriminator
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

# Compile the GAN
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Load the real images
data_dir = '/dataset_path'
batch_size = 32
img_size = (28, 28)
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    batch_size=batch_size,
    image_size=img_size,
    label_mode=None,
)

# Preprocess the data
dataset = dataset.map(lambda x: x / 255.0)

# Train the GAN
for epoch in range(100):
    print('Epoch:', epoch)
    for batch in dataset:
        # Train the discriminator
        noise = tf.random.normal((batch.shape[0], 100))
        generated_images = generator(noise)
        real_images = batch
        combined_images = tf.concat([generated_images, real_images], axis=0)
        labels = tf.concat([tf.zeros((batch.shape[0], 1)), tf.ones((batch.shape[0], 1))], axis=0)
        discriminator_loss = discriminator.train_on_batch(combined_images, labels)

        # Train the generator
        noise = tf.random.normal((batch.shape[0], 100))
        misleading_labels = tf.ones((batch.shape[0], 1))
        gan_loss = gan.train_on_batch(noise, misleading_labels)

    # Print the loss
    print('Discriminator loss:', discriminator_loss)
    print('Generator loss:', gan_loss)

    # Generate some images for visualization
    noise = tf.random.normal((16, 100))
    generated_images = generator(noise)
    generated_images = (generated_images + 1) / 2.0 # Rescale to [0,1]
    fig, axs = plt.subplots(4, 4)
    for i in range(4):
        for j in range(4):
            axs[i,j].imshow(generated_images[i*4+j])
            axs[i,j].axis('off')
    plt.show()
