In [2]:
from tensorflow.keras.datasets import cifar10
import cv2
import numpy as np

# Load CIFAR-10
(x_train, _), (x_test, _) = cifar10.load_data()

# Convert to grayscale and normalize
x_train_gray = np.array([cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in x_train]) / 255.0
x_test_gray = np.array([cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in x_test]) / 255.0

# Reshape grayscale images
x_train_gray = x_train_gray[..., np.newaxis]
x_test_gray = x_test_gray[..., np.newaxis]

# Normalize RGB images
x_train = x_train / 255.0
x_test = x_test / 255.0

print(f"Training data shape: {x_train_gray.shape}, {x_train.shape}")


2025-04-08 10:46:50.221185: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Training data shape: (50000, 32, 32, 1), (50000, 32, 32, 3)


In [22]:
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation
from tensorflow.keras.models import Model

def build_generator():
    input_img = Input(shape=(32, 32, 1))

    x = Conv2D(64, (3, 3), padding='same')(input_img)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(128, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(256, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(3, (3, 3), padding='same')(x)
    output = Activation('tanh')(x)  # Use tanh to output values in [-1, 1]

    model = Model(input_img, output)
    return model



In [23]:
from tensorflow.keras.layers import Flatten, Dense

def build_discriminator():
    input_img = Input(shape=(32, 32, 3))

    x = Conv2D(64, (3, 3), padding='same')(input_img)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Flatten()(x)

    x = Dense(1, activation='sigmoid')(x)

    model = Model(input_img, x)
    return model


In [24]:
from tensorflow.keras.optimizers import Adam

generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(optimizer=Adam(0.0002), loss='binary_crossentropy', metrics=['accuracy'])
discriminator.trainable = False

gan_input = Input(shape=(32, 32, 1))
generated_img = generator(gan_input)
validity = discriminator(generated_img)

gan = Model(gan_input, validity)
gan.compile(optimizer=Adam(0.0002), loss='binary_crossentropy')


In [27]:
import numpy as np
import matplotlib.pyplot as plt

def train_gan(epochs, batch_size=64, sample_interval=1000):
    for epoch in range(epochs):
        
        # ---------------------
        # Train Discriminator
        # ---------------------
        # Select a random batch of grayscale images
        idx = np.random.randint(0, x_train_gray.shape[0], batch_size)
        gray_imgs = x_train_gray[idx]       # Grayscale images (input)
        color_imgs = x_train[idx]           # Real color images (ground truth)

        # Generate colorized images
        generated_imgs = generator.predict(gray_imgs)

        # Create labels for real (1) and fake (0) images
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))

        # Train the discriminator on real images
        d_loss_real = discriminator.train_on_batch(color_imgs, real_labels)
        # Train the discriminator on generated (fake) images
        d_loss_fake = discriminator.train_on_batch(generated_imgs, fake_labels)
        # Calculate average loss
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        # Train Generator
        # ---------------------
        # Train the generator to produce images that are classified as real
        g_loss = gan.train_on_batch(gray_imgs, real_labels)

        # Print progress every sample interval
        if epoch % sample_interval == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
            save_generated_images(epoch)

# Helper function to save generated images during training
def save_generated_images(epoch, examples=5):
    # Generate a batch of images for testing
    idx = np.random.randint(0, x_test_gray.shape[0], examples)
    gray_imgs = x_test_gray[idx]
    generated_imgs = generator.predict(gray_imgs)

    plt.figure(figsize=(10, 4))
    for i in range(examples):
        # Original grayscale image
        plt.subplot(2, examples, i + 1)
        plt.imshow(gray_imgs[i].reshape(32, 32), cmap='gray')
        plt.axis('off')

        # Colorized (generated) image
        plt.subplot(2, examples, i + 1 + examples)
        plt.imshow((generated_imgs[i] * 127.5 + 127.5).astype(np.uint8))
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(f"../outputs/generated_{epoch}.png")
    plt.close()


In [None]:
# train_gan(epochs=10000, batch_size=64, sample_interval=1000)
train_gan(epochs=500, batch_size=64, sample_interval=100)



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 205ms/step
0 [D loss: 0.7175, acc: 6.34%] [G loss: 0.3406]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 166ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 173ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 173ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 164ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 166ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 197ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 181ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 177ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 229ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 152ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 161ms/step
[1