In [2]:
import os
import zipfile
import torch
import tensorflow as tf
from torchvision import transforms 
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tensorflow.keras import layers,models


In [3]:
path = 'C:\\Users\\ashmi\\Downloads\\Compressed\\archive_3.zip'

extract='abstract_art_512'
with zipfile.ZipFile(path, 'r') as zip_ref:
    zip_ref.extractall(extract)


transform = transforms.Compose([
    transforms.Resize((28, 28)),  # Resize images to 64x64
    transforms.ToTensor(),         # Convert images to tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize pixel values
])

dataset = ImageFolder(root=extract, transform=transform)
batch_size= 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [4]:
def generator_rex(latent_dim):
    model = tf.keras.Sequential()
    model.add(layers.Dense(7 * 7 * 128, input_dim=latent_dim))
    model.add(layers.Reshape((7, 7, 128)))
    model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', activation='relu'))
    model.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', activation='relu'))
    model.add(layers.Conv2DTranspose(3, (7, 7), activation='sigmoid', padding='same'))
    return model

# Example usage
latent_dim = 100
generator = generator_rex(latent_dim)



#Gen Arch
generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 6272)              633472    
                                                                 
 reshape (Reshape)           (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      262272    
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 64)       131136    
 ranspose)                                                       
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 3)        9411      
 ranspose)                                                       
                                                        

In [9]:


def discriminator_rex(input_shape):
    model = models.Sequential()

    model.add(layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=input_shape))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Flatten())

    model.add(layers.Dense(1, activation='sigmoid'))

    return model

# Create discriminator model
discriminator = discriminator_rex(input_shape)

discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

#Disc
discriminator.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_3 (Conv2D)           (None, 14, 14, 64)        1792      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 dropout_3 (Dropout)         (None, 14, 14, 64)        0         
                                                                 
 conv2d_4 (Conv2D)           (None, 7, 7, 128)         73856     
                                                                 
 batch_normalization_2 (Batc  (None, 7, 7, 128)        512       
 hNormalization)                                                 
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                      

In [8]:
def genzloss(fake_output):
    return tf.keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output)

# Define discriminator loss function
def discloss(real_output, fake_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
    total_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
    return total_loss

# Define optimizers
grex_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
drex_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

# Instantiate generator and discriminator
latent_dim = 100
input_shape = (28,28,3)
generator = generator_rex(latent_dim)
discriminator = discriminator_rex(input_shape)


@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = genzloss(fake_output)
        disc_loss = discloss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    grex_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    drex_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))






In [10]:
num_epochs = 100

# Train the GAN
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        # Reshape images to correct shape
        images = tf.transpose(images, [0, 2, 3, 1])  # Change shape from (batch_size, channels, height, width) to (batch_size, height, width, channels)
        train_step(images)