# Testing


In [None]:
import tensorflow as tf
import os
from datetime import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image

def denormalize(img):
    """Convert from [-1, 1] to [0, 255] uint8."""
    img = (img + 1.0) * 127.5
    return tf.cast(tf.clip_by_value(img, 0, 255), tf.uint8)

def save_generated_image(image_tensor, epoch, index=0, output_dir="./generated"):
    os.makedirs(output_dir, exist_ok=True)
    image = denormalize(image_tensor)
    image_np = image.numpy()

    # If batch, select one image
    if len(image_np.shape) == 4:
        image_np = image_np[index]  # shape: (H, W, C)

    img_pil = Image.fromarray(image_np)
    img_pil.save(f"{output_dir}/gen_epoch_{epoch}_sample_{index}.png")


# ========== GPU Memory Growth ==========
# gpus = tf.config.list_physical_devices('GPU')
# for gpu in gpus:
#     tf.config.experimental.set_memory_growth(gpu, True)

# ========== Configuration ==========
class Experiment:
    def __init__(self):
        self.IMG_HEIGHT = 256
        self.IMG_WIDTH = 256
        self.INPUT_CHANNELS = 3  # Sentinel-2
        self.OUTPUT_CHANNELS = 3  # Sentinel-3

        self.GEN_LOSS = {"gen_ssim": 50, "gen_l1": 50, "gen_wstein": 100}
        self.DISC_LOSS = {"disc_bce": 1}

        self.output = type('', (), {})()
        self.output.LOGS = f"./logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"

# ========== GAN Model ==========
class ARISGAN:
    def __init__(self, experiment):
        self.exp = experiment
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.configure_optimizers()

    def build_generator(self):
        inputs = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.INPUT_CHANNELS])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = tf.keras.layers.LeakyReLU()(x)

        for _ in range(4):
            x = self.residual_block(x, 64)

        x = tf.keras.layers.Conv2DTranspose(self.exp.OUTPUT_CHANNELS, 4, strides=2, padding='same', activation='tanh')(x)
        return tf.keras.Model(inputs=inputs, outputs=x)

    def build_discriminator(self):
        inp = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.INPUT_CHANNELS])
        tar = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.OUTPUT_CHANNELS])
        x = tf.keras.layers.Concatenate()([inp, tar])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1)(x)
        return tf.keras.Model(inputs=[inp, tar], outputs=x)

    def residual_block(self, x, filters):
        init = x
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        return tf.keras.layers.Add()([init, x])

    def configure_optimizers(self):
        self.g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# ========== Dataset ==========
def load_image(path, channels):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=channels)
    img = tf.image.resize(img, [256, 256])
    img = tf.cast(img, tf.float32) / 127.5 - 1.0  # Normalize to [-1, 1]
    return img

def parse_images(low_path, high_path):
    return load_image(low_path, 3), load_image(high_path, 3)

def load_your_dataset(low_res_dir='./low_res', high_res_dir='./high_res'):
    input_paths = []
    target_paths = []

    for filename in os.listdir(low_res_dir):
        if filename.endswith('.png'):
            parts = filename.split('_')
            base_id = parts[1] + "_" + parts[2].split('.')[0]
            low_path = os.path.join(low_res_dir, filename)
            high_filename = f"PLA4MS_{base_id}_AMS.png"
            high_path = os.path.join(high_res_dir, high_filename)

            if os.path.exists(high_path):
                input_paths.append(low_path)
                target_paths.append(high_path)

    dataset = tf.data.Dataset.from_tensor_slices((input_paths, target_paths))
    dataset = dataset.map(parse_images, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

# ========== Loss + Metrics ==========
def calculate_generator_loss(disc_fake, gen_output, target):
    adv_loss = -tf.reduce_mean(disc_fake)
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gen_output, target, max_val=2.0))
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    return adv_loss * 100 + ssim_loss * 50 + l1_loss * 50

def calculate_discriminator_loss(disc_real, disc_fake):
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(disc_real), logits=disc_real))
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(disc_fake), logits=disc_fake))
    return real_loss + fake_loss

def compute_psnr(target, generated):
    target = (target + 1) / 2
    generated = (generated + 1) / 2
    return tf.image.psnr(target, generated, max_val=1.0)

# ========== Helper ==========
def apply_gradients(loss, variables, optimizer, tape):
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

def save_checkpoint(model, epoch):
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    model.generator.save_weights(f"{checkpoint_dir}/generator_{epoch}.weights.h5")
    model.discriminator.save_weights(f"{checkpoint_dir}/discriminator_{epoch}.weights.h5")

def log_metrics(gen_loss, disc_loss, psnr, epoch):
    print(f"Epoch {epoch}: Gen Loss = {gen_loss:.4f}, Disc Loss = {disc_loss:.4f}, PSNR = {psnr:.2f} dB")

# ========== Training Step ==========
@tf.function
def train_step(input_images, target_images, gan):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gan.generator(input_images, training=True)
        disc_real = gan.discriminator([input_images, target_images], training=True)
        disc_fake = gan.discriminator([input_images, gen_output], training=True)

        gen_loss = calculate_generator_loss(disc_fake, gen_output, target_images)
        disc_loss = calculate_discriminator_loss(disc_real, disc_fake)

    apply_gradients(gen_loss, gan.generator.trainable_variables, gan.g_optimizer, gen_tape)
    apply_gradients(disc_loss, gan.discriminator.trainable_variables, gan.d_optimizer, disc_tape)

    psnr_val = compute_psnr(target_images, gen_output)
    return gen_loss, disc_loss, tf.reduce_mean(psnr_val)

# ========== Training Loop ==========
def train_arisgan(dataset, epochs):
    exp = Experiment()
    gan = ARISGAN(exp)

    print("Generator:")
    gan.generator.summary()
    print("Discriminator:")
    gan.discriminator.summary()

    train_dataset = dataset.shuffle(100).batch(2).prefetch(tf.data.AUTOTUNE)

    for epoch in range(epochs):
        total_psnr = 0
        steps = 0

        for input_images, target_images in train_dataset:
            gen_loss, disc_loss, psnr = train_step(input_images, target_images, gan)
            total_psnr += psnr.numpy()
            steps += 1

        avg_psnr = total_psnr / steps

        if epoch % 2 == 0:
            save_checkpoint(gan, epoch)
        log_metrics(gen_loss, disc_loss, avg_psnr, epoch)

# ========== Run ==========
if __name__ == "__main__":
    train_data = load_your_dataset('/home/nitin/acps/3000/Sentinel', '/home/nitin/acps/3000/Planetscope')
    train_arisgan(train_data, epochs=10)


Generator:


Discriminator:


Epoch 0: Gen Loss = 297.8257, Disc Loss = 0.1429, PSNR = 17.31 dB
Epoch 1: Gen Loss = 240.5256, Disc Loss = 0.1436, PSNR = 17.86 dB
Epoch 2: Gen Loss = 252.1759, Disc Loss = 0.3441, PSNR = 17.64 dB
Epoch 3: Gen Loss = 779.4199, Disc Loss = 0.3794, PSNR = 17.76 dB
Epoch 4: Gen Loss = 793.1877, Disc Loss = 0.0339, PSNR = 17.91 dB
Epoch 5: Gen Loss = -2.5731, Disc Loss = 1.5261, PSNR = 18.01 dB
Epoch 6: Gen Loss = 1165.1361, Disc Loss = 0.0003, PSNR = 18.29 dB
Epoch 7: Gen Loss = 974.3436, Disc Loss = 0.2995, PSNR = 17.84 dB
Epoch 8: Gen Loss = 1066.7953, Disc Loss = 0.9187, PSNR = 18.40 dB
Epoch 9: Gen Loss = 1605.1215, Disc Loss = 0.0010, PSNR = 18.71 dB


In [7]:
import tensorflow as tf
import os
from datetime import datetime
import numpy as np
import tensorflow as tf
#import matplotlib.pyplot as plt
from PIL import Image

def denormalize(img):
    """Convert from [-1, 1] to [0, 255] uint8."""
    img = (img + 1.0) * 127.5
    return tf.cast(tf.clip_by_value(img, 0, 255), tf.uint8)

def save_generated_image(image_tensor, epoch, index=0, output_dir="./generated"):
    os.makedirs(output_dir, exist_ok=True)
    image = denormalize(image_tensor)
    image_np = image.numpy()

    # If batch, select one image
    if len(image_np.shape) == 4:
        image_np = image_np[index]  # shape: (H, W, C)

    img_pil = Image.fromarray(image_np)
    img_pil.save(f"{output_dir}/gen_epoch_{epoch}_sample_{index}.png")


# ========== GPU Memory Growth ==========
# gpus = tf.config.list_physical_devices('GPU')
# for gpu in gpus:
#     tf.config.experimental.set_memory_growth(gpu, True)

# ========== Configuration ==========
class Experiment:
    def __init__(self):
        self.IMG_HEIGHT = 256
        self.IMG_WIDTH = 256
        self.INPUT_CHANNELS = 3  # Sentinel-2
        self.OUTPUT_CHANNELS = 3  # Sentinel-3

        self.GEN_LOSS = {"gen_ssim": 50, "gen_l1": 50, "gen_wstein": 100}
        self.DISC_LOSS = {"disc_bce": 1}

        self.output = type('', (), {})()
        self.output.LOGS = f"./logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"

# ========== GAN Model ==========
class ARISGAN:
    def __init__(self, experiment):
        self.exp = experiment
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.configure_optimizers()

    def build_generator(self):
        inputs = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.INPUT_CHANNELS])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = tf.keras.layers.LeakyReLU()(x)

        for _ in range(4):
            x = self.residual_block(x, 64)

        x = tf.keras.layers.Conv2DTranspose(self.exp.OUTPUT_CHANNELS, 4, strides=2, padding='same', activation='tanh')(x)
        return tf.keras.Model(inputs=inputs, outputs=x)

    def build_discriminator(self):
        inp = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.INPUT_CHANNELS])
        tar = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.OUTPUT_CHANNELS])
        x = tf.keras.layers.Concatenate()([inp, tar])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1)(x)
        return tf.keras.Model(inputs=[inp, tar], outputs=x)

    def residual_block(self, x, filters):
        init = x
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        return tf.keras.layers.Add()([init, x])

    def configure_optimizers(self):
        self.g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# ========== Dataset ==========
def load_image(path, channels):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=channels)
    img = tf.image.resize(img, [256, 256])
    img = tf.cast(img, tf.float32) / 127.5 - 1.0  # Normalize to [-1, 1]
    return img

def parse_images(low_path, high_path):
    return load_image(low_path, 3), load_image(high_path, 3)

def load_your_dataset(low_res_dir='./low_res', high_res_dir='./high_res'):
    input_paths = []
    target_paths = []
    filenames = []  # To store filenames

    for filename in os.listdir(low_res_dir):
        if filename.endswith('.png'):
            parts = filename.split('_')
            base_id = parts[1] + "_" + parts[2].split('.')[0]
            low_path = os.path.join(low_res_dir, filename)
            high_filename = f"PLA4MS_{base_id}_AMS.png"
            high_path = os.path.join(high_res_dir, high_filename)

            if os.path.exists(high_path):
                input_paths.append(low_path)
                target_paths.append(high_path)
                filenames.append(filename)  # Add the filename to the list

    dataset = tf.data.Dataset.from_tensor_slices((input_paths, target_paths, filenames))  # Include filenames
    dataset = dataset.map(parse_images_and_filenames, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

# Update parse function to include filenames
def parse_images_and_filenames(low_path, high_path, filename):
    input_image = load_image(low_path, 3)
    target_image = load_image(high_path, 3)
    return input_image, target_image, filename


# ========== Loss + Metrics ==========
def calculate_generator_loss(disc_fake, gen_output, target):
    adv_loss = -tf.reduce_mean(disc_fake)
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gen_output, target, max_val=2.0))
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    return adv_loss * 100 + ssim_loss * 50 + l1_loss * 50

def calculate_discriminator_loss(disc_real, disc_fake):
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(disc_real), logits=disc_real))
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(disc_fake), logits=disc_fake))
    return real_loss + fake_loss

def compute_psnr(target, generated):
    target = (target + 1) / 2
    generated = (generated + 1) / 2
    return tf.image.psnr(target, generated, max_val=1.0)

# ========== Helper ==========
def apply_gradients(loss, variables, optimizer, tape):
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

def save_checkpoint(model, epoch):
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    model.generator.save_weights(f"{checkpoint_dir}/generator_{epoch}.weights.h5")
    model.discriminator.save_weights(f"{checkpoint_dir}/discriminator_{epoch}.weights.h5")

def log_metrics(gen_loss, disc_loss, psnr, epoch):
    print(f"Epoch {epoch}: Gen Loss = {gen_loss:.4f}, Disc Loss = {disc_loss:.4f}, PSNR = {psnr:.2f} dB")

# ========== Training Step ==========
@tf.function
def train_step(input_images, target_images, gan):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gan.generator(input_images, training=True)
        disc_real = gan.discriminator([input_images, target_images], training=True)
        disc_fake = gan.discriminator([input_images, gen_output], training=True)

        gen_loss = calculate_generator_loss(disc_fake, gen_output, target_images)
        disc_loss = calculate_discriminator_loss(disc_real, disc_fake)

    apply_gradients(gen_loss, gan.generator.trainable_variables, gan.g_optimizer, gen_tape)
    apply_gradients(disc_loss, gan.discriminator.trainable_variables, gan.d_optimizer, disc_tape)

    psnr_val = compute_psnr(target_images, gen_output)
    return gen_loss, disc_loss, tf.reduce_mean(psnr_val)

# ========== Training Loop ==========
def train_arisgan(dataset, epochs):
    exp = Experiment()
    gan = ARISGAN(exp)

    # print("Generator:")
    # gan.generator.summary()
    # print("Discriminator:")
    # gan.discriminator.summary()
    
    train_dataset = dataset.shuffle(100).batch(2).prefetch(tf.data.AUTOTUNE)

    for epoch in range(epochs):
        total_psnr = 0
        steps = 0

        for input_images, target_images, filenames  in train_dataset:
            gen_loss, disc_loss, psnr = train_step(input_images, target_images, gan)
            total_psnr += psnr.numpy()
            steps += 1

        avg_psnr = total_psnr / steps

        if epoch % 2 == 0 and steps == 1500:  # Save output from the first batch only
            print(f"Saving image for epoch {epoch}")
            print(filenames[0])
            gen_output = gan.generator(input_images, training=False)
            save_generated_image(gen_output, epoch, index=0)
        log_metrics(gen_loss, disc_loss, avg_psnr, epoch)

# ========== Run ==========
if __name__ == "__main__":
    train_data = load_your_dataset('/home/nitin/acps/3000/Sentinel', '/home/nitin/acps/3000/Planetscope')
    train_arisgan(train_data, epochs=10)


Saving image for epoch 0
tf.Tensor(b'SENT4MS_102_20231208.png', shape=(), dtype=string)
Epoch 0: Gen Loss = 150.9766, Disc Loss = 3.6308, PSNR = 17.28 dB
Epoch 1: Gen Loss = 300.4625, Disc Loss = 0.0955, PSNR = 18.18 dB
Saving image for epoch 2
tf.Tensor(b'SENT4MS_127_20220424.png', shape=(), dtype=string)
Epoch 2: Gen Loss = 338.5559, Disc Loss = 0.0583, PSNR = 18.04 dB


2025-04-27 11:45:09.496709: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 3: Gen Loss = 713.4654, Disc Loss = 0.3328, PSNR = 17.54 dB
Saving image for epoch 4
tf.Tensor(b'SENT4MS_126_20231210.png', shape=(), dtype=string)
Epoch 4: Gen Loss = 876.0567, Disc Loss = 0.2745, PSNR = 17.54 dB
Epoch 5: Gen Loss = 346.6915, Disc Loss = 2.1179, PSNR = 17.98 dB
Saving image for epoch 6
tf.Tensor(b'SENT4MS_110_20210228.png', shape=(), dtype=string)
Epoch 6: Gen Loss = 1200.2107, Disc Loss = 0.0000, PSNR = 17.80 dB
Epoch 7: Gen Loss = 1201.7007, Disc Loss = 0.0104, PSNR = 17.43 dB
Saving image for epoch 8
tf.Tensor(b'SENT4MS_107_20220514.png', shape=(), dtype=string)
Epoch 8: Gen Loss = 1139.0409, Disc Loss = 0.0008, PSNR = 17.94 dB
Epoch 9: Gen Loss = 128.5554, Disc Loss = 0.6073, PSNR = 18.57 dB


# Upscaling testing

In [6]:
import tensorflow as tf
import os
from datetime import datetime
import numpy as np
from PIL import Image

# ========== Helper Functions ==========
def denormalize(img):
    img = (img + 1.0) * 127.5
    return tf.cast(tf.clip_by_value(img, 0, 255), tf.uint8)

def save_generated_image(image_tensor, epoch, index=0, output_dir="./generated"):
    os.makedirs(output_dir, exist_ok=True)
    image = denormalize(image_tensor)
    image_np = image.numpy()
    if len(image_np.shape) == 4:
        image_np = image_np[index]
    img_pil = Image.fromarray(image_np)
    img_pil.save(f"{output_dir}/gen_epoch_{epoch}_sample_{index}.png")

# ========== Configuration ==========
class Experiment:
    def __init__(self):
        self.IMG_HEIGHT = 128
        self.IMG_WIDTH = 128
        self.OUTPUT_HEIGHT = 512
        self.OUTPUT_WIDTH = 512
        self.INPUT_CHANNELS = 3
        self.OUTPUT_CHANNELS = 3
        self.GEN_LOSS = {"gen_ssim": 50, "gen_l1": 50, "gen_wstein": 100}
        self.DISC_LOSS = {"disc_bce": 1}
        self.output = type('', (), {})()
        self.output.LOGS = f"./logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}"

# ========== GAN Model ==========
class ARISGAN:
    def __init__(self, experiment):
        self.exp = experiment
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        self.configure_optimizers()

    def build_generator(self):
        inputs = tf.keras.Input(shape=[self.exp.IMG_HEIGHT, self.exp.IMG_WIDTH, self.exp.INPUT_CHANNELS])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = tf.keras.layers.LeakyReLU()(x)

        for _ in range(4):
            x = self.residual_block(x, 64)

        # Upsample to 256x256
        x = tf.keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)

        # Upsample to 512x512
        x = tf.keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.ReLU()(x)

        # Final layer for RGB
        x = tf.keras.layers.Conv2D(self.exp.OUTPUT_CHANNELS, 3, padding='same', activation='tanh')(x)
        return tf.keras.Model(inputs=inputs, outputs=x)

    def build_discriminator(self):
        
        inp = tf.keras.Input(shape=[self.exp.OUTPUT_HEIGHT, self.exp.OUTPUT_WIDTH, self.exp.OUTPUT_CHANNELS])
        x = tf.keras.layers.Conv2D(64, 4, strides=2, padding='same')(inp)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)
        x = tf.keras.layers.LeakyReLU()(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1)(x)
        return tf.keras.Model(inputs=inp, outputs=x)



    def residual_block(self, x, filters):
        init = x
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Conv2D(filters, 3, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        return tf.keras.layers.Add()([init, x])

    def configure_optimizers(self):
        self.g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# ========== Dataset ==========
def load_image(path, channels, size):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=channels)
    img = tf.image.resize(img, size)
    img = tf.cast(img, tf.float32) / 127.5 - 1.0
    return img

def parse_images(low_path, high_path):
    input_image = load_image(low_path, 3, [128, 128])     # Low-res input
    target_image = load_image(high_path, 3, [512, 512])   # High-res target
    return input_image, target_image

def load_your_dataset(low_res_dir='./low_res', high_res_dir='./high_res'):
    input_paths = []
    target_paths = []
    for filename in os.listdir(low_res_dir):
        if filename.endswith('.png'):
            parts = filename.split('_')
            base_id = parts[1] + "_" + parts[2].split('.')[0]
            low_path = os.path.join(low_res_dir, filename)
            high_filename = f"PLA4MS_{base_id}_AMS.png"
            high_path = os.path.join(high_res_dir, high_filename)
            if os.path.exists(high_path):
                input_paths.append(low_path)
                target_paths.append(high_path)
    dataset = tf.data.Dataset.from_tensor_slices((input_paths, target_paths))
    dataset = dataset.map(parse_images, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

# ========== Loss + Metrics ==========
def calculate_generator_loss(disc_fake, gen_output, target):
    adv_loss = -tf.reduce_mean(disc_fake)
    ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gen_output, target, max_val=2.0))
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    return adv_loss * 100 + ssim_loss * 50 + l1_loss * 50

def calculate_discriminator_loss(disc_real, disc_fake):
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(disc_real), logits=disc_real))
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(disc_fake), logits=disc_fake))
    return real_loss + fake_loss

def compute_psnr(target, generated):
    target = (target + 1) / 2
    generated = (generated + 1) / 2
    return tf.image.psnr(target, generated, max_val=1.0)

# ========== Training Helpers ==========
def apply_gradients(loss, variables, optimizer, tape):
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

def save_checkpoint(model, epoch):
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    model.generator.save_weights(f"{checkpoint_dir}/generator_{epoch}.weights.h5")
    model.discriminator.save_weights(f"{checkpoint_dir}/discriminator_{epoch}.weights.h5")

def log_metrics(gen_loss, disc_loss, psnr, epoch):
    print(f"Epoch {epoch}: Gen Loss = {gen_loss:.4f}, Disc Loss = {disc_loss:.4f}, PSNR = {psnr:.2f} dB")

@tf.function
def train_step(input_images, target_images, gan):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gan.generator(input_images, training=True)
        disc_real = gan.discriminator(target_images, training=True)
        disc_fake = gan.discriminator(gen_output, training=True)
        gen_loss = calculate_generator_loss(disc_fake, gen_output, target_images)
        disc_loss = calculate_discriminator_loss(disc_real, disc_fake)
    apply_gradients(gen_loss, gan.generator.trainable_variables, gan.g_optimizer, gen_tape)
    apply_gradients(disc_loss, gan.discriminator.trainable_variables, gan.d_optimizer, disc_tape)
    psnr_val = compute_psnr(target_images, gen_output)
    return gen_loss, disc_loss, tf.reduce_mean(psnr_val)

# ========== Training Loop ==========
def train_arisgan(dataset, epochs):
    exp = Experiment()
    gan = ARISGAN(exp)
    print("Generator:")
    gan.generator.summary()
    print("Discriminator:")
    gan.discriminator.summary()

    train_dataset = dataset.shuffle(100).batch(2).prefetch(tf.data.AUTOTUNE)
    for epoch in range(epochs):
        total_psnr = 0
        steps = 0
        for input_images, target_images in train_dataset:
            gen_loss, disc_loss, psnr = train_step(input_images, target_images, gan)
            total_psnr += psnr.numpy()
            steps += 1
        avg_psnr = total_psnr / steps
        if epoch % 2 == 0:
            print(f"Saving image for epoch {epoch}")
            gen_output = gan.generator(input_images, training=False)
            save_generated_image(gen_output, epoch, index=0)
        log_metrics(gen_loss, disc_loss, avg_psnr, epoch)

# ========== Run ==========
if __name__ == "__main__":
    train_data = load_your_dataset('/home/nitin/acps/00deleteit/Sentinel', '/home/nitin/acps/00deleteit/Planetscope')
    train_arisgan(train_data, epochs=10)


Generator:


Discriminator:


ValueError: in user code:

    File "/tmp/ipykernel_3029105/3367183034.py", line 159, in train_step  *
        disc_fake = gan.discriminator(gen_output, training=True)
    File "/home/nitin/miniforge3/envs/tensorflow/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/home/nitin/miniforge3/envs/tensorflow/lib/python3.12/site-packages/keras/src/layers/input_spec.py", line 245, in assert_input_compatibility
        raise ValueError(

    ValueError: Input 0 of layer "functional_7" is incompatible with the layer: expected shape=(None, 512, 512, 3), found shape=(2, 256, 256, 3)
