In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, LeakyReLU, Add
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.datasets import cifar10 # You can replace this with your satellite image dataset

# Generator
def build_generator(input_shape):
    inputs = Input(shape=input_shape)

    # Encoder
    x = Conv2D(64, (3, 3), strides=1, padding='same')(inputs)
    x = LeakyReLU(alpha=0.2)(x)

    # Residual blocks
    for _ in range(16):
        x = residual_block(x, 64)

    # Decoder
    x = Conv2D(64, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([inputs, x])
    x = Conv2D(3, (3, 3), strides=1, padding='same', activation='tanh')(x) # Adjust output channels according to image color channels

    return Model(inputs, x)

def residual_block(x, filters):
    y = Conv2D(filters, (3, 3), strides=1, padding='same')(x)
    y = BatchNormalization()(y)
    y = LeakyReLU(alpha=0.2)(y)
    y = Conv2D(filters, (3, 3), strides=1, padding='same')(y)
    y = BatchNormalization()(y)
    y = Add()([x, y])
    return y

# Discriminator
def build_discriminator(input_shape):
    inputs = Input(shape=input_shape)

    x = Conv2D(64, (3, 3), strides=2, padding='same')(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(128, (3, 3), strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(256, (3, 3), strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(512, (3, 3), strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1)(x)

    return Model(inputs, x)

# Loss function
def wasserstein_loss(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)

# Generator input shape (e.g., (256, 256, 3))
input_shape = (256, 256, 3)

# Build and compile the discriminator
discriminator = build_discriminator(input_shape)
discriminator.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0002, beta_1=0.5))

# Build the generator
generator = build_generator(input_shape)

# Build and compile the combined model (GAN)
discriminator.trainable = False
input_image = Input(shape=input_shape)
generated_image = generator(input_image)
validity = discriminator(generated_image)
combined = Model(input_image, validity)
combined.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0002, beta_1=0.5))

# Train the GAN
# Load your satellite image dataset or replace this with your data loading mechanism
(x_train, _), (_, _) = cifar10.load_data()
x_train = x_train / 127.5 - 1.0

epochs = 10000
batch_size = 32

for epoch in range(epochs):
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images = x_train[idx]

    # Train discriminator
    noise = np.random.normal(0, 1, (batch_size, *input_shape))
    fake_images = generator.predict(noise)
    d_loss_real = discriminator.train_on_batch(real_images, -np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch(fake_images, np.ones((batch_size, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train generator
    noise = np.random.normal(0, 1, (batch_size, *input_shape))
    g_loss = combined.train_on_batch(noise, -np.ones((batch_size, 1)))

    # Print progress
    print(f"Epoch {epoch}/{epochs}, D Loss: {d_loss}, G Loss: {g_loss}")

    # Save generated images
    if epoch % 100 == 0:
        save_generated_images(epoch)

# Function to save generated images
def save_generated_images(epoch):
    noise = np.random.normal(0, 1, (5, *input_shape))
    generated_images = generator.predict(noise) * 0.5 + 0.5
    for i, image in enumerate(generated_images):
        plt.imsave(f"generated_images/{epoch}_{i}.png", image)
