In [6]:
import tensorflow as tf
print(tf.__version__)


2.20.0


In [7]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:04_Central_Daylight_Time_2018
Cuda compilation tools, release 10.0, V10.0.130


In [8]:
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(gpus))
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print("Memory growth set for GPU:", gpu)


Num GPUs Available:  0


In [9]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Memory growth enabled for GPU")
    except RuntimeError as e:
        print(e)


In [10]:
from tensorflow.keras.layers import Layer

class PixelShuffle(Layer):
    def __init__(self, scale, **kwargs):
        super(PixelShuffle, self).__init__(**kwargs)
        self.scale = scale

    def call(self, x):
        return tf.nn.depth_to_space(x, self.scale)

    def get_config(self):
        config = super(PixelShuffle, self).get_config()
        config.update({"scale": self.scale})
        return config


In [11]:
### MODEL
import os
import glob
import random
import time
import math
import numpy as np
from PIL import Image

import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import VGG19

# ---------------------------
# Mixed precision for speed
# ---------------------------
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

# ---------------------------
# Residual Block
# ---------------------------
def ResidualBlock(x_input, filters=64, kernel_size=3):
    x = layers.Conv2D(filters, kernel_size, padding='same')(x_input)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU(shared_axes=[1,2])(x)
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x_input, x])
    return x

# ---------------------------
# Generator
# ---------------------------
def build_generator(num_res_blocks=16, upscaling_factor=4):
    inputs = layers.Input(shape=(None, None, 3))
    x = layers.Conv2D(64, 9, padding='same')(inputs)
    x_ = layers.PReLU(shared_axes=[1,2])(x)
    
    x_res = x_
    for _ in range(num_res_blocks):
        x_res = ResidualBlock(x_res, 64)
    
    x = layers.Conv2D(64, 3, padding='same')(x_res)
    x = layers.BatchNormalization()(x)
    x = layers.Add()([x_, x])
    
    num_upsample = int(math.log2(upscaling_factor))
    x_up = x
    for _ in range(num_upsample):
        x_up = layers.Conv2D(256, 3, padding='same')(x_up)
        x_up = PixelShuffle(2)(x_up)
        x_up = layers.PReLU(shared_axes=[1,2])(x_up)
    
    outputs = layers.Conv2D(3, 9, padding='same', activation='tanh')(x_up)
    return Model(inputs, outputs, name='Generator')

# ---------------------------
# Discriminator
# ---------------------------
def build_discriminator(input_shape=(None, None, 3)):
    def disc_block(x, filters, stride):
        x = layers.Conv2D(filters, 3, strides=stride, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(0.2)(x)
        return x
    
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, 3, strides=1, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)
    
    x = disc_block(x, 64, 2)
    x = disc_block(x, 128, 1)
    x = disc_block(x, 128, 2)
    x = disc_block(x, 256, 1)
    x = disc_block(x, 256, 2)
    x = disc_block(x, 512, 1)
    x = disc_block(x, 512, 2)
    
    x = layers.GlobalAveragePooling2D()(x)    
    x = layers.Dense(1024)(x)
    x = layers.LeakyReLU(0.2)(x)
    outputs = layers.Dense(1)(x)  # logits
    return Model(inputs, outputs, name='Discriminator')

# ---------------------------
# VGG19 feature extractor
# ---------------------------
def build_vgg19_feature_extractor():
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=(None, None, 3))
    model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output)
    model.trainable = False
    return model


In [12]:
import os
import time
import math
import glob
import random
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import VGG19
from tensorflow.image import psnr, ssim
from tqdm import tqdm 

# ---------------------------
# IMAGE UTILITIES
# ---------------------------
def load_image(path):
    img = Image.open(path).convert('RGB')
    return np.array(img)

def preprocess(img):
    img = img / 127.5 - 1.0  # scale [-1,1]
    return img.astype(np.float32)

def deprocess(img):
    img = ((img + 1.0) * 127.5).clip(0,255).astype(np.uint8)
    return img

def make_pairs(lr_dir, hr_dir):
    lr_files = sorted(glob.glob(os.path.join(lr_dir, '*')))
    pairs = []
    for lr_path in lr_files:
        hr_path = os.path.join(hr_dir, os.path.basename(lr_path))
        if os.path.exists(hr_path):
            pairs.append((lr_path, hr_path))
    return pairs

# ---------------------------
# DATASET LOADER
# ---------------------------
def dataset_from_pairs(pairs, batch_size):
    if len(pairs) == 0:
        raise ValueError("No image pairs found!")

    def generator():
        for lr_path, hr_path in pairs:
            lr = preprocess(load_image(lr_path))
            hr = preprocess(load_image(hr_path))
            yield lr, hr

    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
        )
    )

    dataset = dataset.shuffle(buffer_size=max(1, len(pairs))) \
                     .batch(batch_size) \
                     .prefetch(tf.data.AUTOTUNE)
    return dataset

# ---------------------------
# LOSSES
# ---------------------------
mse = tf.keras.losses.MeanSquaredError()
mae = tf.keras.losses.MeanAbsoluteError()
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def content_loss(hr, sr):
    return mae(hr, sr)

def adversarial_loss(real_logits, fake_logits):
    d_loss = bce(tf.ones_like(real_logits), real_logits) + \
             bce(tf.zeros_like(fake_logits), fake_logits)
    g_loss = bce(tf.ones_like(fake_logits), fake_logits)
    return g_loss, d_loss

def perceptual_loss(vgg, hr, sr):
    hr_vgg = (hr + 1.0) * 127.5
    sr_vgg = (sr + 1.0) * 127.5
    hr_feat = vgg(hr_vgg)
    sr_feat = vgg(sr_vgg)
    return mse(hr_feat, sr_feat)

# ---------------------------
# TRAINING LOOP
# ---------------------------
def train(
    train_lr_dir='C:\\Users\\bhatt\\Machine Learning\\SolaRess\\new_dataset\\training\\low_res',
    train_hr_dir='C:\\Users\\bhatt\\Machine Learning\\SolaRess\\new_dataset\\training\\high_res',
    val_lr_dir='C:\\Users\\bhatt\\Machine Learning\\SolaRess\\new_dataset\\validation\\low_res',
    val_hr_dir='C:\\Users\\bhatt\\Machine Learning\\SolaRess\\new_dataset\\validation\\high_res',
    epochs=1,
    batch_size=2,
    lr_g=1e-4,
    lr_d=1e-4,
    checkpoint_dir='checkpoints',
    sample_dir='samples',
    upscaling_factor=4
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(sample_dir, exist_ok=True)

    # Build models
    G = build_generator(upscaling_factor=upscaling_factor)
    D = build_discriminator()
    VGG = build_vgg19_feature_extractor()

    # Optimizers
    opt_g = tf.keras.optimizers.Adam(lr_g, beta_1=0.9)
    opt_d = tf.keras.optimizers.Adam(lr_d, beta_1=0.9)

    # Load dataset
    train_pairs = make_pairs(train_lr_dir, train_hr_dir)
    val_pairs = make_pairs(val_lr_dir, val_hr_dir)
    train_dataset = dataset_from_pairs(train_pairs, batch_size)

    steps_per_epoch = len(train_pairs) // batch_size

    for epoch in range(1, epochs+1):
        start_time = time.time()
        running_g_loss = 0.0
        running_d_loss = 0.0

        # Training progress bar
        train_iter = tqdm(train_dataset, desc=f"Epoch {epoch}/{epochs} [Training]", unit="batch")
        for lr_batch, hr_batch in train_iter:
            # Train Discriminator
            with tf.GradientTape() as tape_d:
                sr_batch = G(lr_batch, training=True)
                real_logits = D(hr_batch, training=True)
                fake_logits = D(sr_batch, training=True)
                g_adv, d_loss_val = adversarial_loss(real_logits, fake_logits)

            grads_d = tape_d.gradient(d_loss_val, D.trainable_variables)
            opt_d.apply_gradients(zip(grads_d, D.trainable_variables))

            # Train Generator
            with tf.GradientTape() as tape_g:
                sr_batch = G(lr_batch, training=True)
                fake_logits = D(sr_batch, training=True)
                pixel_loss = content_loss(hr_batch, sr_batch)
                perc_loss = perceptual_loss(VGG, hr_batch, sr_batch)
                adv_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(
                    tf.ones_like(fake_logits), fake_logits
                )
                g_loss_val = pixel_loss + 1e-2*perc_loss + 1e-3*adv_loss

            grads_g = tape_g.gradient(g_loss_val, G.trainable_variables)
            opt_g.apply_gradients(zip(grads_g, G.trainable_variables))

            # Update running losses
            running_g_loss += g_loss_val.numpy()
            running_d_loss += d_loss_val.numpy()

            # Update tqdm postfix
            train_iter.set_postfix({
                "G_loss": running_g_loss/(train_iter.n+1),
                "D_loss": running_d_loss/(train_iter.n+1)
            })

        # Validation progress bar (small sample)
        val_iter = tqdm(val_pairs[:10], desc=f"Epoch {epoch}/{epochs} [Validation]", unit="img")
        psnr_list, ssim_list = [], []
        for i, (lr_path, hr_path) in enumerate(val_iter):
            lr_img = preprocess(load_image(lr_path))[np.newaxis, ...]
            hr_img = load_image(hr_path)
            sr_img = G(lr_img, training=False).numpy()[0]
            sr_img_np = deprocess(sr_img)

            if sr_img_np.shape != hr_img.shape:
                sr_img_np = np.array(
                    Image.fromarray(sr_img_np).resize((hr_img.shape[1], hr_img.shape[0]), Image.BICUBIC)
                )

            psnr_list.append(tf.image.psnr(sr_img_np, hr_img, max_val=255).numpy())
            ssim_list.append(tf.image.ssim(sr_img_np, hr_img, max_val=255).numpy())

            # Save first 3 validation samples
            if i < 3:
                Image.fromarray(sr_img_np).save(os.path.join(sample_dir, f'epoch{epoch}_sample{i}.png'))

            val_iter.set_postfix({
                "PSNR": np.mean(psnr_list),
                "SSIM": np.mean(ssim_list)
            })

        # Epoch summary
        print(f"\nEpoch {epoch} complete! Time: {time.time()-start_time:.1f}s | "
            f"Train G_loss={running_g_loss/len(train_dataset):.4f}, "
            f"D_loss={running_d_loss/len(train_dataset):.4f} | "
            f"Validation PSNR={np.mean(psnr_list):.2f}, SSIM={np.mean(ssim_list):.4f}")

        # Save checkpoints
        G.save(os.path.join(checkpoint_dir, f'G_epoch{epoch}.h5'))
        D.save(os.path.join(checkpoint_dir, f'D_epoch{epoch}.h5'))

    print("Training Complete!")

In [None]:
train()




Epoch 1/1 [Training]: 0batch [00:00, ?batch/s]