In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout
from tensorflow.keras.models import Sequential, load_model
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [5]:
def build_generator(latent_dim):
    model = Sequential()
    model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, kernel_size=7, activation="sigmoid", padding="same"))
    return model

In [6]:
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

In [7]:
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

In [8]:
latent_dim = 100
img_shape = (28, 28, 1)

# Build and compile the discriminator
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Build the generator
generator = build_generator(latent_dim)

# Keep the discriminator's parameters constant during generator training
discriminator.trainable = False

# Build and compile the GAN
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer='adam')

In [9]:
def train_gan(gan, generator, discriminator, epochs, batch_size, latent_dim):
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = X_train / 255.0
    X_train = np.expand_dims(X_train, axis=-1)
    
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, valid)

        if epoch % 100 == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")

    # Save the generator model after training
    generator.save('generator_model.h5')

In [None]:
epochs = 5000
batch_size = 32

train_gan(gan, generator, discriminator, epochs, batch_size, latent_dim)

def load_generator(model_path):
    return load_model(model_path)

def generate_shape(generator, latent_dim, shape):
    noise = np.random.normal(0, 1, (1, latent_dim))
    generated_image = generator.predict(noise)[0]
    generated_image = (generated_image * 255).astype(np.uint8)
    if shape == "circle":
        cv2.circle(generated_image, (generated_image.shape[1]//2, generated_image.shape[0]//2), generated_image.shape[0]//4, 255, -1)
    elif shape == "square":
        cv2.rectangle(generated_image, (generated_image.shape[1]//4, generated_image.shape[0]//4), (3*generated_image.shape[1]//4, 3*generated_image.shape[0]//4), 255, -1)
    elif shape == "triangle":
        points = np.array([
            [generated_image.shape[1]//2, generated_image.shape[0]//4], 
            [generated_image.shape[1]//4, 3*generated_image.shape[0]//4], 
            [3*generated_image.shape[1]//4, 3*generated_image.shape[0]//4]
        ], dtype=np.int32)
        cv2.fillPoly(generated_image, [points], 255)
    else:
        raise ValueError("Shape not recognized. Supported shapes: 'circle', 'square', 'triangle'.")
    
    return generated_image

In [15]:
def display_generated_shape(generator, latent_dim, shape):
    img = generate_shape(generator, latent_dim, shape)
    plt.imshow(img, cmap='gray')
    plt.title(shape.capitalize())
    plt.axis('off')
    plt.show()

In [None]:
# Load the trained generator model
generator = load_generator('generator_model.h5')

# Example usage:
shape_to_draw = "square"  # Change this to "square" or "triangle" as needed
display_generated_shape(generator, latent_dim, shape_to_draw)