In [17]:
# Code from Dr. Toone's ecplispse gan 2
# Outputs are from my personal server instead of the incredibly slow google colab
# Images from my trip to arkansas for the 2024 eclipse
# !git clone https://gitea.jptechtips.com/JP-Garcia/eclipse_data.git data

In [41]:
import os
import numpy as np
import matplotlib.pyplot as plt
#np.object = np.object_ #trying to fix import errors, ended up installing numpy 1.21
#np.bool = np.bool_
#np.int = np.int_
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, BatchNormalization, LeakyReLU
from tensorflow.keras.mixed_precision import experimental as mixed_precision

In [42]:
def load_images_as_dataset(directory, size=(32,32), batch_size=32):
    global total_num_images
    global images
    images = []
    for filename in os.listdir(directory):
        try:
            img_path = os.path.join(directory, filename)
            img = load_img(img_path, target_size=size)
            img = img_to_array(img)
            img = (img - 127.5) / 127.5  # Normalize the images to [-1, 1]
            # print("append", filename)
            images.append(img)
        except OSError: 
            print("exception noted:", filename)
    total_num_images = len(images)
    images = np.array(images)    
    # dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(len(images)).batch(batch_size, drop_remainder=True) # original dataset assignment
    dataset = tf.data.Dataset.from_tensor_slices(images)
    dataset = dataset.shuffle(len(images)).batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    print("dataset loaded!")
    return dataset

def make_generator_model():
    model = Sequential([
        Dense(8*8*256, use_bias=False, input_shape=(100,)),
        BatchNormalization(),
        LeakyReLU(),
        Reshape((8, 8, 256)),
        Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        BatchNormalization(),
        LeakyReLU(),
        Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        BatchNormalization(),
        LeakyReLU(),
        Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def make_discriminator_model():
    model = Sequential([
        Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[32, 32, 3]),
        LeakyReLU(),
        Flatten(),
        Dense(1)
    ])
    return model


def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def train_step(images):
    # Ensure the input has the correct shape
    images = tf.reshape(images, (-1, 32, 32, 3))  # -1 is used to automatically calculate the needed batch size

    noise = tf.random.normal([len(images), 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs, batch_size):
    global total_num_images
    for epoch in range(epochs):
        print("\\/ "*20)
        print(f"Epoch {epoch} starting ...")
        print("Batch", end=' ')
        for batchi, image_batch in enumerate(dataset):
            if (batchi % 100 == 0) and (batchi !=0): # just some fancy formatting
                print("\nBatch", end=' ')
            if batchi % 10 == 0:
                print(f"{batchi}, ", end='')
            train_step(image_batch)
        print("complete!")
        dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(len(images)).batch(batch_size, drop_remainder=True)
        if epoch % 5 == 0:
            show_gan(10, epoch)


def show_gan(num_images, cnt, dir="generated_images/"):
    # Generate images from the noise vector
    noise = tf.random.normal([num_images, 100])
    generated_images = generator(noise, training=False)

    # Adjusting the pixel values to display them properly
    generated_images = (generated_images + 1) / 2  # rescale from [-1, 1] to [0, 1]
    generated_images = generated_images.numpy()  # convert to numpy array if not already

    # Create a plot to display the images
    fig, axes = plt.subplots(1, 10, figsize=(20, 2))
    for i, img in enumerate(generated_images):
        axes[i].imshow(img)
        axes[i].axis('off')  # Turn off axis labels
    fig.savefig(f"{dir}gen_{cnt : 04d}.png", dpi=300)
    plt.close(fig)

In [26]:
generator = make_generator_model()
discriminator = make_discriminator_model()

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [34]:
# Directory containing your images
directory = "data"
batch_size = 32 # default=32
total_num_images = 0 # will be updated by load func
dataset = load_images_as_dataset(directory, batch_size=batch_size)
print(dataset, total_num_images)

exception noted: readme.txt
exception noted: .git
dataset loaded!


In [38]:
%%capture output
train(dataset, 200, batch_size)  # Train for 200 epochs

In [39]:
output.show()

\/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ 
Epoch 0 starting ...
Batch 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 
Batch 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 
Batch 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 
Batch 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 
Batch 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 
Batch 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 
Batch 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 
Batch 700, 710, 720, 730, 740, 750, 760, 770, 780, 790, 
Batch 800, 810, 820, 830, 840, 850, 860, 870, 880, 890, 
Batch 900, 910, 920, 930, 940, 950, 960, 970, 980, 990, 
Batch 1000, 1010, 1020, 1030, 1040, 1050, 1060, complete!
\/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ \/ 
Epoch 1 starting ...
Batch 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 
Batch 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 
Batch 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 
Batch 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 
Ba

In [43]:
show_gan(9, 199, "")