In [None]:
import os
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import matplotlib.pyplot as plt

# Set paths
base_path = 'drug discovery'
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)

# Load dataset
def load_images(base_path, img_size=(256, 256)):
    images = {}
    for cls in os.listdir(base_path):
        cls_path = os.path.join(base_path, cls)
        if os.path.isdir(cls_path):
            class_images = []
            for img_file in os.listdir(cls_path):
                img_path = os.path.join(cls_path, img_file)
                img = load_img(img_path, target_size=img_size)
                img_array = img_to_array(img) / 255.0  # Normalize to [0, 1]
                class_images.append(img_array)
            images[cls] = np.array(class_images)
    return images

# Prepare data
data = load_images(base_path)
print(f"Loaded images from {len(data)} classes: {list(data.keys())}")

# Define Pix2Pix generator
def build_generator():
    inputs = layers.Input(shape=(256, 256, 3))

    # Encoder
    x = layers.Conv2D(64, kernel_size=4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    # Decoder
    x = layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    outputs = layers.Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')(x)

    return Model(inputs, outputs)

# Define Pix2Pix discriminator
def build_discriminator():
    inputs = layers.Input(shape=(256, 256, 3))
    targets = layers.Input(shape=(256, 256, 3))

    x = layers.Concatenate()([inputs, targets])
    x = layers.Conv2D(64, kernel_size=4, strides=2, padding='same')(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(1, kernel_size=4, strides=1, padding='same')(x)

    return Model([inputs, targets], x)

# Instantiate models
generator = build_generator()
discriminator = build_discriminator()

# Define loss and optimizers
gan_loss = BinaryCrossentropy(from_logits=True)
generator_optimizer = Adam(2e-4, beta_1=0.5)
discriminator_optimizer = Adam(2e-4, beta_1=0.5)

# Training loop
def train_pix2pix(generator, discriminator, data, min_images_per_class=30):
    for cls, images in data.items():
        class_output_dir = os.path.join(output_dir, f'output_{cls}')
        os.makedirs(class_output_dir, exist_ok=True)
        
        total_images = len(images)
        generated_count = 0
        img_index = 0

        while generated_count < min_images_per_class:
            input_image = tf.expand_dims(images[img_index % total_images], axis=0)
            target_image = input_image  # Identity mapping for simplicity

            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                generated_image = generator(input_image, training=True)

                real_output = discriminator([input_image, target_image], training=True)
                fake_output = discriminator([input_image, generated_image], training=True)

                gen_loss = gan_loss(tf.ones_like(fake_output), fake_output)
                disc_loss = (gan_loss(tf.ones_like(real_output), real_output) +
                            gan_loss(tf.zeros_like(fake_output), fake_output))

            gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
            disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

            generator_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

            # Save generated image
            save_generated_image(generator, input_image, class_output_dir, cls, generated_count + 1)
            
            generated_count += 1
            img_index += 1  # Cycle through original images if needed

        print(f"Generated {generated_count} images for class: {cls}")

def save_generated_image(generator, input_image, output_dir, label, index):
    generated_image = generator(input_image, training=False)[0].numpy()
    generated_image = (generated_image * 255).astype(np.uint8)
    plt.imsave(os.path.join(output_dir, f'{label}_{index}.png'), generated_image)

# Train the model
train_pix2pix(generator, discriminator, data, min_images_per_class=30)

Epoch 1/10, Loss: 0.12147603929042816
Epoch 2/10, Loss: 0.05164103294935143
Epoch 3/10, Loss: 0.04808125082860913
Epoch 4/10, Loss: 0.04493849802958338
Epoch 5/10, Loss: 0.04272231381190451
Epoch 6/10, Loss: 0.038610169533313365
Epoch 7/10, Loss: 0.03622487625270559
Epoch 8/10, Loss: 0.0347495397324102
Epoch 9/10, Loss: 0.03442482241805185
Epoch 10/10, Loss: 0.03347644079149815
Generated images saved in 'generated_images' folder.
