In [3]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Conv2DTranspose, Conv2D, Flatten, LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import PIL
from PIL import Image

# Define paths
train_data_dir = 'grayscale_dataset/train'
test_data_dir = 'grayscale_dataset/test'
output_dir_infected = 'Generated_Images/infected'
output_dir_notinfected = 'Generated_Images/notinfected'

# Create output directories if they don't exist
os.makedirs(output_dir_infected, exist_ok=True)
os.makedirs(output_dir_notinfected, exist_ok=True)

# Parameters
img_rows, img_cols = 64, 64
channels = 1  # Grayscale images have only one channel
img_shape = (img_rows, img_cols, channels)
latent_dim = 100
batch_size = 32  # Adjust batch size

# Generator model
def build_generator():
    model = Sequential()
    model.add(Dense(128 * 16 * 16, activation="relu", input_dim=latent_dim))
    model.add(Reshape((16, 16, 128)))
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(channels, kernel_size=5, padding="same", activation="tanh"))
    return model

# Detector model
def build_detector():
    model = Sequential([
        Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(img_rows, img_cols, channels)),
        Conv2D(64, kernel_size=(3, 3), activation='relu'),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    return model

# Combined GAN and detector model
def build_gan(generator, detector):
    detector.trainable = False
    model = Sequential([
        generator,
        detector
    ])
    return model

# Load and preprocess data
datagen = ImageDataGenerator(rescale=1./255)
train_generator = datagen.flow_from_directory(train_data_dir, target_size=(img_rows, img_cols), color_mode='grayscale', batch_size=batch_size, class_mode=None)

# Build generator, detector, and GAN
generator = build_generator()
detector = build_detector()
gan = build_gan(generator, detector)

# Compile models
generator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
detector.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=Adam())


Found 1924 images belonging to 2 classes.


In [4]:
# Train the combined GAN and detector
num_epochs = 100  # Define the number of epochs
total_batches = len(train_generator)
for epoch in range(num_epochs):  # Number of epochs for training
    real_losses = []
    fake_losses = []
    gan_losses = []
    for i, batch in enumerate(train_generator):
        real_images = batch  # Get the batch

        # Adjust batch size for the last batch if necessary
        current_batch_size = real_images.shape[0]

        # Generate fake images
        noise = np.random.normal(0, 1, (current_batch_size, latent_dim))
        fake_images = generator.predict(noise, batch_size=current_batch_size)

        # Train the detector on real images
        real_labels = np.ones((current_batch_size, 1))  # Use the actual batch size of real_images
        real_loss = detector.train_on_batch(real_images, real_labels)
        real_losses.append(real_loss)

        # Train the detector on fake images
        fake_labels = np.zeros((current_batch_size, 1))
        fake_loss = detector.train_on_batch(fake_images, fake_labels)
        fake_losses.append(fake_loss)

        # Train the generator via the GAN
        noise = np.random.normal(0, 1, (current_batch_size, latent_dim))
        gan_labels = np.ones((current_batch_size, 1))
        gan_loss = gan.train_on_batch(noise, gan_labels)
        gan_losses.append(gan_loss)

        # Break the loop if all batches have been processed
        if i == total_batches - 1:
            break

    # Print losses after each epoch
    avg_real_loss = np.mean(real_losses)
    avg_fake_loss = np.mean(fake_losses)
    avg_gan_loss = np.mean(gan_losses)
    print(f'Epoch: {epoch}, Average Detector Loss (Real): {avg_real_loss}, Average Detector Loss (Fake): {avg_fake_loss}, Average GAN Loss: {avg_gan_loss}')

    # Save one generated image after each epoch
    noise = np.random.normal(0, 1, (1, latent_dim))
    generated_image = generator.predict(noise)[0]
    generated_image_normalized = np.clip(0.5 * generated_image + 0.5, 0, 1)  # Normalize to [0, 1]
    # Determine filename based on detector prediction
    if detector.predict(np.expand_dims(generated_image_normalized, axis=0)) > 0.5:
        filename = f"{output_dir_infected}/generated_image_epoch_{epoch}.png"
    else:
        filename = f"{output_dir_notinfected}/generated_image_epoch_{epoch}.png"

    # Save the generated image
    # Example using PIL library:
    

    # Denormalize the generated image
    generated_image_denormalized = (generated_image_normalized * 255).astype(np.uint8)

    # Create a PIL image
    image_to_save = Image.fromarray(generated_image_denormalized.squeeze(), mode='L')  # Squeeze to remove channel dimension

    # Save the image
    image_to_save.save(filename)

Epoch: 0, Average Detector Loss (Real): 0.47578430371206315, Average Detector Loss (Fake): 0.645172376613148, Average GAN Loss: 0.37735630645126594
Epoch: 1, Average Detector Loss (Real): 0.4775470998443541, Average Detector Loss (Fake): 1.0694704690917594, Average GAN Loss: 0.12651272069235317
Epoch: 2, Average Detector Loss (Real): 0.4758239373809002, Average Detector Loss (Fake): 1.2349795200785652, Average GAN Loss: 0.0887436274378026
Epoch: 3, Average Detector Loss (Real): 0.47939176539905737, Average Detector Loss (Fake): 1.3260392810477586, Average GAN Loss: 0.07318457985510592
Epoch: 4, Average Detector Loss (Real): 0.4792814279188875, Average Detector Loss (Fake): 1.3813475878512274, Average GAN Loss: 0.065247152244947
Epoch: 5, Average Detector Loss (Real): 0.4793337187806114, Average Detector Loss (Fake): 1.4273086551760064, Average GAN Loss: 0.059323991541979745
Epoch: 6, Average Detector Loss (Real): 0.4757433499469132, Average Detector Loss (Fake): 1.476165943458432, Aver