In [3]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU, ReLU, Flatten, Dense, Reshape, Dropout
from tensorflow.keras import Model
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

def preprocess(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1.0  # Normalize to [-1, 1]
    return image


(train_images, _), (test_images, _) = tf.keras.datasets.cifar10.load_data()

train_images = preprocess(train_images)
test_images = preprocess(test_images)


train_images = tf.convert_to_tensor(train_images)

batch_size = 64
train_data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

class DCGANGenerator(Model):
    def __init__(self):
        super(DCGANGenerator, self).__init__()
        self.dense = Dense(4*4*512, use_bias=False)
        self.bn1 = BatchNormalization()
        self.relu = ReLU()

        self.conv1 = Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False)
        self.bn2 = BatchNormalization()

        self.conv2 = Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False)
        self.bn3 = BatchNormalization()

        self.conv3 = Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False)
        self.bn4 = BatchNormalization()

        self.conv4 = Conv2DTranspose(3, 4, strides=1, padding='same', use_bias=False, activation='tanh')

    def call(self, x, training=True):
        x = self.dense(x)
        x = self.bn1(x, training=training)
        x = self.relu(x)

        x = tf.reshape(x, (-1, 4, 4, 512))

        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn3(x, training=training)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn4(x, training=training)
        x = self.relu(x)

        x = self.conv4(x)
        return x

class DCGANDiscriminator(Model):
    def __init__(self):
        super(DCGANDiscriminator, self).__init__()
        self.conv1 = Conv2D(64, 4, strides=2, padding='same')
        self.lrelu1 = LeakyReLU(alpha=0.2)
        
        self.conv2 = Conv2D(128, 4, strides=2, padding='same')
        self.bn1 = BatchNormalization()
        self.lrelu2 = LeakyReLU(alpha=0.2)
        
        self.conv3 = Conv2D(256, 4, strides=2, padding='same')
        self.bn2 = BatchNormalization()
        self.lrelu3 = LeakyReLU(alpha=0.2)
        
        self.conv4 = Conv2D(512, 4, strides=2, padding='same')
        self.bn3 = BatchNormalization()
        self.lrelu4 = LeakyReLU(alpha=0.2)
        
        self.flatten = Flatten()
        self.dropout = Dropout(0.3)
        self.fc = Dense(1)

    def call(self, x, training=True):
        x = self.conv1(x)
        x = self.lrelu1(x)
        
        x = self.conv2(x)
        x = self.bn1(x, training=training)
        x = self.lrelu2(x)
        
        x = self.conv3(x)
        x = self.bn2(x, training=training)
        x = self.lrelu3(x)
        
        x = self.conv4(x)
        x = self.bn3(x, training=training)
        x = self.lrelu4(x)
        
        x = self.flatten(x)
        x = self.dropout(x, training=training)
        x = self.fc(x)
        return x

cross_entropy = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator = DCGANGenerator()
discriminator = DCGANDiscriminator()

generator_optimizer = Adam(2e-4, beta_1=0.5)
discriminator_optimizer = Adam(2e-4, beta_1=0.5)

def calculate_psnr_ssim(real_images, generated_images):
    psnr_values = []
    ssim_values = []
    for real, gen in zip(real_images, generated_images):
        real = (real + 1) / 2  # Denormalize to [0, 1]
        gen = (gen + 1) / 2
        psnr = peak_signal_noise_ratio(real, gen, data_range=1.0)
        ssim = structural_similarity(real, gen, multichannel=True, data_range=1.0)
        psnr_values.append(psnr)
        ssim_values.append(ssim)
    return np.mean(psnr_values), np.mean(ssim_values)

@tf.function
def train_step(images):
    noise_dim = 100
    noise = tf.random.normal([batch_size, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss, generated_images

def train(dataset, epochs):
    psnr_history = []
    ssim_history = []
    
    for epoch in range(epochs):
        gen_losses = []
        disc_losses = []
        for image_batch in dataset:
            gen_loss, disc_loss, generated_images = train_step(image_batch)
            gen_losses.append(gen_loss)
            disc_losses.append(disc_loss)
        
        avg_gen_loss = tf.reduce_mean(gen_losses)
        avg_disc_loss = tf.reduce_mean(disc_losses)
        
        # Calculate PSNR and SSIM
        psnr, ssim = calculate_psnr_ssim(image_batch, generated_images)
        psnr_history.append(psnr)
        ssim_history.append(ssim)

        print(f'Epoch {epoch+1}, Gen Loss: {avg_gen_loss:.4f}, Disc Loss: {avg_disc_loss:.4f}, PSNR: {psnr:.2f}, SSIM: {ssim:.4f}')

    return psnr_history, ssim_history

def plot_images(images, epoch):
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    for img, ax in zip(images[:8], axes.flatten()):
        ax.imshow((img.numpy() + 1) / 2)
        ax.axis('off')
    plt.suptitle(f"Generated Images at Epoch {epoch}")
    plt.savefig(f"generated_images_epoch_{epoch}.png")
    plt.close()

def plot_metrics(psnr_history, ssim_history):
    epochs = range(1, len(psnr_history) + 1)
    
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, psnr_history, 'b-')
    plt.title('PSNR over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('PSNR')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, ssim_history, 'r-')
    plt.title('SSIM over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('SSIM')
    
    plt.tight_layout()
    plt.savefig('psnr_ssim_history.png')
    plt.close()

checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint = tf.train.Checkpoint(generator=generator,
                                 discriminator=discriminator,
                                 generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer)

# Training
epochs = 100
psnr_history, ssim_history = train(train_data, epochs)

# Save model
checkpoint.save(file_prefix=os.path.join(checkpoint_dir, "dcgan"))

# Plot metrics
plot_metrics(psnr_history, ssim_history)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


Exception: URL fetch failure on https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz: None -- [Errno 110] Connection timed out