In [63]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Reshape, Conv2D, ReLU, LeakyReLU, UpSampling2D, BatchNormalization, Concatenate, Input, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError

import os
import cv2
import numpy as np
from skimage import color

In [60]:
train_dir = r'C:\Users\Asus\Desktop\Others\Datasets\imagenet\train'
val_dir = r'C:\Users\Asus\Desktop\Others\Datasets\imagenet\val'
output_dir = r'C:\Users\Asus\Desktop\Others\Datasets\imagenet\preproccessed'

os.makedirs(output_dir, exist_ok=True)

In [3]:
def load_and_preprocess_image(path, target_size=(256, 256)):
    # Load image
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize with aspect ratio preservation
    h, w = img.shape[:2]
    scale = min(target_size[0]/h, target_size[1]/w)
    new_h, new_w = int(h * scale), int(w * scale)
    img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
    
    # Pad to target size
    delta_h = target_size[0] - new_h
    delta_w = target_size[1] - new_w
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)
    img_padded = cv2.copyMakeBorder(img_resized, top, bottom, left, right, 
                                   cv2.BORDER_REFLECT)
    
    # Convert to LAB color space
    img_lab = color.rgb2lab(img_padded)
    
    # Normalize LAB values
    img_lab = img_lab.astype('float32')
    img_lab[..., 0] = img_lab[..., 0] / 100.0  # L channel [0,100] -> [0,1]
    img_lab[..., 1:] = (img_lab[..., 1:] + 128) / 255.0  # ab channels [-128,127] -> [0,1]
    
    return img_lab

In [None]:
def process_dataset(directory, output_path, batch_size=32, target_size=(256, 256)):
    image_paths = [os.path.join(directory, f) for f in os.listdir(directory) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    # Create output directories
    os.makedirs(os.path.join(output_path, 'L'), exist_ok=True)
    os.makedirs(os.path.join(output_path, 'ab'), exist_ok=True)
    
    # Create processing log
    log_file = os.path.join(output_path, 'processing_log.txt')
    processed_indices = set()
    
    # Load existing log if resuming
    if os.path.exists(log_file):
        with open(log_file, 'r') as f:
            processed_indices = {int(line.strip()) for line in f}
    
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []
        valid_indices = []
        
        # Load and preprocess batch
        for j, path in enumerate(batch_paths):
            idx = i + j
            if idx in processed_indices:
                continue
                
            try:
                img_lab = load_and_preprocess_image(path, target_size)
                batch_images.append(img_lab)
                valid_indices.append(idx)
            except Exception as e:
                print(f"Skipping {path}: {str(e)}")
                continue
        
        if not batch_images:
            continue
            
        try:
            batch_images = np.array(batch_images)
            
            # Split into L and ab channels
            L = batch_images[..., 0]
            ab = batch_images[..., 1:]
            
            # Save batch with individual error handling
            for j, idx in enumerate(valid_indices):
                try:
                    # Verify array contents before saving
                    if not np.all(np.isfinite(L[j])):
                        raise ValueError("L channel contains invalid values")
                    if not np.all(np.isfinite(ab[j])):
                        raise ValueError("ab channels contain invalid values")
                        
                    # Save with protocol=4 for better compatibility
                    np.save(os.path.join(output_path, 'L', f'{idx}.npy'), L[j], allow_pickle=False)
                    np.save(os.path.join(output_path, 'ab', f'{idx}.npy'), ab[j], allow_pickle=False)
                    
                    # Mark as processed
                    with open(log_file, 'a') as f:
                        f.write(f"{idx}\n")
                except Exception as e:
                    print(f"Failed to save image {idx}: {str(e)}")
                    # Clean up any partial files
                    for fname in [f'{idx}.npy']:
                        for subdir in ['L', 'ab']:
                            file_path = os.path.join(output_path, subdir, fname)
                            if os.path.exists(file_path):
                                os.remove(file_path)
                    continue
                
            print(f'Processed batch {i//batch_size + 1}/{(len(image_paths)//batch_size + 1)}')
            
        except Exception as batch_error:
            print(f"Batch processing failed: {str(batch_error)}")
            continue

In [10]:
# Process training data
print("Processing training data...")
process_dataset(train_dir, os.path.join(output_dir, 'train'))

# Process validation data
print("\nProcessing validation data...")
process_dataset(val_dir, os.path.join(output_dir, 'val'))

Processing training data...
Processed batch 1006/1407
Processed batch 1007/1407
Processed batch 1008/1407
Processed batch 1009/1407
Processed batch 1010/1407
Processed batch 1011/1407
Processed batch 1012/1407
Processed batch 1013/1407
Processed batch 1014/1407
Processed batch 1015/1407
Processed batch 1016/1407
Processed batch 1017/1407
Processed batch 1018/1407
Processed batch 1019/1407
Processed batch 1021/1407
Processed batch 1022/1407
Processed batch 1023/1407
Processed batch 1024/1407
Processed batch 1025/1407
Processed batch 1026/1407
Processed batch 1027/1407
Processed batch 1028/1407
Processed batch 1029/1407
Processed batch 1030/1407
Processed batch 1032/1407
Processed batch 1033/1407
Processed batch 1034/1407
Processed batch 1035/1407
Processed batch 1036/1407
Processed batch 1037/1407
Processed batch 1038/1407
Processed batch 1040/1407
Processed batch 1041/1407
Processed batch 1042/1407
Processed batch 1043/1407
Processed batch 1044/1407
Processed batch 1045/1407
Processed 

In [54]:
def data_generator(data_dir, batch_size=32):
    L_dir = os.path.join(data_dir, 'L')
    ab_dir = os.path.join(data_dir, 'ab')
    
    L_files = sorted([os.path.join(L_dir, f) for f in os.listdir(L_dir)])
    ab_files = sorted([os.path.join(ab_dir, f) for f in os.listdir(ab_dir)])
    
    while True:
        for i in range(0, len(L_files), batch_size):
            batch_L = []
            batch_ab = []
            
            for L_path, ab_path in zip(L_files[i:i+batch_size], ab_files[i:i+batch_size]):
                batch_L.append(np.load(L_path))
                batch_ab.append(np.load(ab_path))
            
            yield np.array(batch_L), np.array(batch_ab)

In [52]:
def enhanced_generator():
    inputs = Input(shape=(256, 256))
    
    x = Reshape((256, 256, 1))(inputs)
    
    # Encoder Path
    d1 = Conv2D(64, 4, strides=2, padding='same')(x)
    d1 = LeakyReLU(0.2)(d1)
    
    d2 = Conv2D(128, 4, strides=2, padding='same')(d1)
    d2 = BatchNormalization()(d2)
    d2 = LeakyReLU(0.2)(d2)
    
    # Bottleneck
    bottleneck = Conv2D(256, 3, padding='same')(d2)
    bottleneck = Conv2D(256, 3, padding='same')(bottleneck)
    
    # Decoder Path with Skip Connections
    u1 = UpSampling2D(2)(bottleneck)  
    u1 = Concatenate()([u1, d1])      # Skip connection from d1
    u1 = Conv2D(128, 3, padding='same')(u1)
    u1 = ReLU()(u1)
    
    u2 = UpSampling2D(2)(u1)      
    u2 = Concatenate()([u2, x])       # Skip connection from reshaped input
    u2 = Conv2D(64, 3, padding='same')(u2)
    u2 = ReLU()(u2)
    
    outputs = Conv2D(2, 3, activation='tanh', padding='same')(u2)
    
    return Model(inputs, outputs, name='EnhancedGenerator')

In [27]:
def build_discriminator(input_shape=(256, 256, 3)):
   
    inputs = Input(shape=input_shape)
    
    x = Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, 4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, 4, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(512, 4, strides=1, padding='same')(x)  
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(1, 4, strides=1, padding='same')(x)
    outputs = Activation('sigmoid')(x)
    
    return Model(inputs, outputs, name='ColorizationDiscriminator')

In [57]:
def verify_gan_components():
    generator = enhanced_generator()
    discriminator = build_discriminator()
   
    print("="*50)
    print("Generator Verification")
    print("="*50)
    
    dummy_l = np.random.random((1, 256, 256, 1))
    generated_ab = generator.predict(dummy_l)
    print(f"\nGenerator output shape: {generated_ab.shape}")
    
    print("\n" + "="*50)
    print("Discriminator Verification")
    print("="*50)
    
    real_image = np.concatenate([dummy_l, np.random.uniform(-1, 1, (1, 256, 256, 2))], axis=-1)
    real_pred = discriminator.predict(real_image)
    print(f"\nReal prediction shape: {real_pred.shape}")
    
    print("\n" + "="*50)
    print("GAN Integration Check")
    print("="*50)
    
    discriminator.trainable = False
    input_l = Input(shape=(256, 256, 1))
    gan_ab = generator(input_l)
    merged = Concatenate(axis=-1)([input_l, gan_ab]) 
    gan_output = discriminator(merged)
    gan = Model(input_l, [gan_ab, gan_output])
    
    gan_ab_pred, gan_disc_pred = gan.predict(dummy_l)
    print(f"\nGAN outputs: {gan_ab_pred.shape}, {gan_disc_pred.shape}")

verify_gan_components()



Generator Verification
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 248ms/step

Generator output shape: (1, 256, 256, 2)

Discriminator Verification
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 125ms/step

Real prediction shape: (1, 32, 32, 1)

GAN Integration Check
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 343ms/step

GAN outputs: (1, 256, 256, 2), (1, 32, 32, 1)


In [65]:
generator = enhanced_generator()
discriminator = build_discriminator(input_shape=(256, 256, 3))

# Optimizers
gen_optimizer = Adam(learning_rate=1e-4, beta_1=0.5)
disc_optimizer = Adam(learning_rate=1e-5, beta_1=0.5)

# Loss functions
bce_loss = BinaryCrossentropy()
l1_loss = MeanAbsoluteError()
lambda_l1 = 100  # L1 loss weight

# Data generator setup
data_dir = r"C:\Users\Asus\Desktop\Others\Datasets\imagenet\preproccessed\train"  # Update with your directory
batch_size = 16
train_gen = data_generator(data_dir, batch_size=batch_size)

# Calculate steps per epoch
num_train_files = len(os.listdir(os.path.join(data_dir, 'L')))
steps_per_epoch = (num_train_files + batch_size - 1) // batch_size

@tf.function
def train_step(real_L, real_ab):
    with tf.GradientTape(persistent=True) as tape:
        # Generate fake ab channels
        fake_ab = generator(real_L, training=True)
        
        # Create Lab images (real and fake)
        real_L_expanded = tf.expand_dims(real_L, axis=-1)  # (bs, 256,256,1)
        real_Lab = tf.concat([real_L_expanded, real_ab], axis=-1)
        fake_Lab = tf.concat([real_L_expanded, fake_ab], axis=-1)
        
        # Discriminator predictions
        pred_real = discriminator(real_Lab, training=True)
        pred_fake = discriminator(fake_Lab, training=True)
        
        # Discriminator loss
        disc_loss_real = bce_loss(tf.ones_like(pred_real), pred_real)
        disc_loss_fake = bce_loss(tf.zeros_like(pred_fake), pred_fake)
        disc_total_loss = (disc_loss_real + disc_loss_fake) * 0.5
        
        # Generator losses
        gen_adv_loss = bce_loss(tf.ones_like(pred_fake), pred_fake)
        gen_l1 = l1_loss(real_ab, fake_ab) * lambda_l1
        gen_total_loss = gen_adv_loss + gen_l1

    # Update discriminator
    disc_grads = tape.gradient(disc_total_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))
    
    # Update generator
    gen_grads = tape.gradient(gen_total_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    
    return disc_total_loss, gen_total_loss, gen_adv_loss, gen_l1

# Training loop
epochs = 100
for epoch in range(epochs):
    for step in range(steps_per_epoch):
        real_L, real_ab = next(train_gen)
        d_loss, g_loss, g_adv, g_l1 = train_step(real_L, real_ab)
        
        # Log progress every 50 steps
        if step % 50 == 0:
            print(f"Epoch {epoch+1}/{epochs} | Step {step}/{steps_per_epoch}")
            print(f"D Loss: {d_loss:.4f} | G Loss: {g_loss:.4f} [Adv: {g_adv:.4f}, L1: {g_l1:.4f}]")
    
    # Save model and sample images every 10 epochs
    if (epoch + 1) % 10 == 0:
        generator.save(f"generator_epoch_{epoch+1}.h5")
        print(f"Saved checkpoint at epoch {epoch+1}")

# Save final model
generator.save("generator_final.h5")



Epoch 1/100 | Step 0/2813
D Loss: 0.7787 | G Loss: 38.8656 [Adv: 0.7027, L1: 38.1628]


KeyboardInterrupt: 