In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models


In [None]:
def residual_block(x, filters, kernel_size=3, stride=1):
    shortcut = x
    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.add([x, shortcut])
    return x


In [None]:
def build_generator():
    inputs = layers.Input(shape=(512, 512, 3))
    x = layers.Conv2D(64, 7, padding='same')(inputs)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    # Downsample
    x = layers.Conv2D(128, 3, strides=2, padding='same')(x)
    x = layers.Conv2D(256, 3, strides=2, padding='same')(x)

    # Residual blocks
    for _ in range(6):
        x = residual_block(x, 256)

    # Upsample
    x = layers.Conv2DTranspose(128, 3, strides=2, padding='same')(x)
    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same')(x)
    
    # Output layer
    x = layers.Conv2D(3, 7, activation='tanh', padding='same')(x)
    return models.Model(inputs, x)


In [None]:
def build_discriminator():
    inputs = layers.Input(shape=(512, 1024, 3))
    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(512, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return models.Model(inputs, x)


In [None]:
generator = build_generator()
discriminator = build_discriminator()

# Compile discriminator
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# GAN model
discriminator.trainable = False
gan_input = layers.Input(shape=(512, 512, 3))
generated_image = generator(gan_input)
gan_output = discriminator(tf.concat([gan_input, generated_image], axis=2))
gan = models.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')


In [None]:
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, [512, 512])
    image = (image - 127.5) / 127.5  # Normalize the image to [-1, 1]
    return image

def prepare_datasets(brightfield_paths, fluorescent_paths, train_size=0.7, val_size=0.2):
    # Split datasets
    total_images = len(brightfield_paths)
    train_end = int(train_size * total_images)
    val_end = int((train_size + val_size) * total_images)

    train_bf = brightfield_paths[:train_end]
    val_bf = brightfield_paths[train_end:val_end]
    test_bf = brightfield_paths[val_end:]

    train_fl = fluorescent_paths[:train_end]
    val_fl = fluorescent_paths[train_end:val_end]
    test_fl = fluorescent_paths[val_end:]

    # Load and batch datasets
    train_ds = tf.data.Dataset.from_tensor_slices((train_bf, train_fl))
    val_ds = tf.data.Dataset.from_tensor_slices((val_bf, val_fl))
    test_ds = tf.data.Dataset.from_tensor_slices((test_bf, test_fl))

    # Apply preprocessing and batching
    train_ds = train_ds.map(lambda x, y: (load_and_preprocess_image(x), load_and_preprocess_image(y))).batch(1)
    val_ds = val_ds.map(lambda x, y: (load_and_preprocess_image(x), load_and_preprocess_image(y))).batch(1)
    test_ds = test_ds.map(lambda x, y: (load_and_preprocess_image(x), load_and_preprocess_image(y))).batch(1)

    return train_ds, val_ds, test_ds


In [None]:
import matplotlib.pyplot as plt

def train(gan, generator, discriminator, train_dataset, val_dataset, epochs=50):
    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        # Training
        for train_x, train_y in train_dataset:
            # Training step code here

        # Validation
        for val_x, val_y in val_dataset:
            # Validation step code here

        # Record the average losses for plotting
        history['train_loss'].append(np.mean(train_losses))
        history['val_loss'].append(np.mean(val_losses))

        # Visualization of the loss trend every epoch
        plt.figure(figsize=(10, 5))
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.title('Loss Trend')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend()
        plt.grid(True)
        plt.show()

        # Optionally save model every few epochs
        if (epoch + 1) % 10 == 0:
            generator.save(f'generator_epoch_{epoch+1}.h5')
            discriminator.save(f'discriminator_epoch_{epoch+1}.h5')

    return history


In [None]:
train_dataset = load_data('/path/to/brightfield', '/path/to/fluorescent')

In [None]:
# Assuming dataset is already loaded and prepared
train(gan, generator, discriminator, train_dataset, val_dataset, epochs=50)