In [1]:
import os
import zipfile
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers, optimizers
from keras.callbacks import TensorBoard
import time
from PIL import Image
import pickle

In [2]:
def load_dataset(filenames_file_path, embeddings_file_path, image_size, dataset_dir, class_info_file_path=None):
    """
    Load COCO dataset
    """
    # Load filenames
    with open(filenames_file_path, 'rb') as f:
        filenames = pickle.load(f)
    
    # Load class IDs if available
    if class_info_file_path is not None and os.path.exists(class_info_file_path):
        with open(class_info_file_path, 'rb') as f:
            class_ids = pickle.load(f)
    else:
        # If no class info provided, assign a default class ID of 0
        class_ids = [0] * len(filenames)
    
    # Load embeddings
    with open(embeddings_file_path, 'rb') as f:
        all_embeddings = pickle.load(f)
        # Convert to NumPy array if it's a list
        if isinstance(all_embeddings, list):
            all_embeddings = np.array(all_embeddings)
    
    # Verify that the number of embeddings matches the number of filenames
    if len(all_embeddings) != len(filenames):
        raise ValueError(f"Number of embeddings ({len(all_embeddings)}) does not match number of filenames ({len(filenames)})")
    
    print(f"Number of images with embeddings: {len(all_embeddings)}")
    print(f"Shape of embeddings for first image: {all_embeddings[0].shape}")
    
    X, y, embeddings = [], [], []
    
    # Determine if we're working with train or validation data
    # Use the coco2014 directory instead of coco
    if "train" in dataset_dir:
        images_dir = "./data/coco2014/train2014"
    else:
        images_dir = "./data/coco2014/val2014"
    
    for index, filename in enumerate(filenames):
        try:
            # Get just the basename (the filename without the path)
            base_filename = os.path.basename(filename)
            
            # Construct the correct image path
            img_name = os.path.join(images_dir, base_filename)
            
            # Load and resize the image
            img = Image.open(img_name).convert('RGB')
            img = img.resize(image_size, Image.LANCZOS if hasattr(Image, 'LANCZOS') else Image.ANTIALIAS)
            img = np.array(img) / 127.5 - 1.0  # Normalize to [-1, 1]
            
            # Get embedding for this image
            all_embeddings1 = all_embeddings[index]  # Shape: (10, 1, 1024)
            
            # Randomly select one embedding from the available ones
            embedding_ix = random.randint(0, all_embeddings1.shape[0] - 1)
            embedding = all_embeddings1[embedding_ix].squeeze()  # From (1, 1024) to (1024,)
            
            X.append(img)
            y.append(class_ids[index])
            embeddings.append(embedding)
            
        except Exception as e:
            print(f"Error processing image {base_filename}: {e}")
    
    X = np.array(X)
    y = np.array(y)
    embeddings = np.array(embeddings)
    
    return X, y, embeddings

def load_filenames(file_path):
    """Loads filenames from a pickle file."""
    with open(file_path, "rb") as f:
        filenames = pickle.load(f)
    return filenames

def load_class_ids(file_path):
    """Loads class IDs from a pickle file."""
    with open(file_path, "rb") as f:
        class_ids = pickle.load(f)
    return class_ids

def load_bounding_boxes(cub_dataset_dir):
    """
    Placeholder: In real code, you'd load bounding boxes from .txt or .pickle.
    Returns a dict of {filename: (x, y, width, height)} or similar.
    """
    # For simplicity, return an empty dict or random boxes
    return {}

def load_embeddings(embeddings_file_path):
    """Loads embeddings from a pickle file."""
    with open(embeddings_file_path, "rb") as f:
        all_embeddings = pickle.load(f)
    return all_embeddings

def get_img(img_path, bounding_box, image_size):
    """
    Loads and returns a resized image. 
    bounding_box is ignored here, but you could crop the image accordingly if needed.
    """
    img = Image.open(img_path).convert("RGB")
    img = img.resize(image_size, Image.ANTIALIAS)
    return img



In [3]:
# 1) Conditioning Augmentation (CA) Model

#    Takes an embedding and learns mu/logvar, then reparameterizes to produce c

def build_ca_model(embedding_dim=1024, condition_dim=128):
    """
    CA model that takes text embedding of size embedding_dim and
    outputs a (condition_dim)-dim vector c after reparameterization.
    """
    embedding_input = layers.Input(shape=(embedding_dim,))
    x = layers.Dense(256, activation="relu")(embedding_input)
    mu = layers.Dense(condition_dim)(x)
    logvar = layers.Dense(condition_dim)(x)

    # We'll output mu and logvar; the reparameterization trick can be done outside
    model = keras.Model(inputs=embedding_input, outputs=[mu, logvar])
    return model


In [4]:
#------------------------------------------------------------------------------
# 2) KL Divergence Loss for the CA Model
#------------------------------------------------------------------------------

def KL_loss(y_true, y_pred):
    """
    The second output of the adversarial model is (mu, logvar),
    but we typically compute KL inside the CA pipeline. 
    Here is a placeholder that expects y_pred = [mu, logvar] concatenated 
    or another custom approach. For simplicity, let's do a naive version.
    """
    # Suppose we packed mu and logvar along the last dimension
    # i.e. y_pred.shape == (batch_size, condition_dim*2)
    half = y_pred.shape[-1] // 2
    mu = y_pred[:, :half]
    logvar = y_pred[:, half:]
    # KL
    kld = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)
    return tf.reduce_mean(kld)

In [5]:
#------------------------------------------------------------------------------
# 3) Embedding Compressor (optional)
#    Sometimes used to reduce embedding dim (e.g., 1024 -> 128)
#------------------------------------------------------------------------------

def build_embedding_compressor_model(embedding_dim=1024, condition_dim=128):
    """
    Simple FC to reduce large embedding_dim -> condition_dim
    """
    embedding_input = layers.Input(shape=(embedding_dim,))
    x = layers.Dense(condition_dim, activation="relu")(embedding_input)
    model = keras.Model(inputs=embedding_input, outputs=x)
    return model

In [6]:
#------------------------------------------------------------------------------
# 4) Stage 1 Generator
#    Takes random noise z + condition c, outputs a 64x64 image
#------------------------------------------------------------------------------

def build_stage1_generator(z_dim=100, condition_dim=128):
    input_layer = layers.Input(shape=(z_dim + condition_dim,))
    
    x = layers.Dense(4*4*256, use_bias=False)(input_layer)
    x = layers.Reshape((4, 4, 256))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding="same", use_bias=False)(x)  # (8, 8, 128)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding="same", use_bias=False)(x)   # (16, 16, 64)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(32, (4,4), strides=(2,2), padding="same", use_bias=False)(x)   # (32, 32, 32)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(3, (4,4), strides=(2,2), padding="same", use_bias=False)(x)    # (64, 64, 3)
    x = layers.Activation("tanh")(x)

    model = keras.Model(inputs=input_layer, outputs=x)
    return model


In [7]:
#------------------------------------------------------------------------------
# 5) Stage 1 Discriminator
#    Takes 64x64 image + condition embedding, outputs real/fake
#------------------------------------------------------------------------------

def build_stage1_discriminator(condition_dim):
    """
    Build Stage 1 Discriminator - PatchGAN with global averaging
    """
    # Image input
    input_img = layers.Input(shape=(64, 64, 3))
    
    # First conv block
    x_img = layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same")(input_img)  # (32, 32, 64)
    x_img = layers.LeakyReLU(0.2)(x_img)
    
    # Second conv block
    x_img = layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same")(x_img)  # (16, 16, 128)
    x_img = layers.BatchNormalization()(x_img)
    x_img = layers.LeakyReLU(0.2)(x_img)
    
    # Third conv block
    x_img = layers.Conv2D(256, (4, 4), strides=(2, 2), padding="same")(x_img)  # (8, 8, 256)
    x_img = layers.BatchNormalization()(x_img)
    x_img = layers.LeakyReLU(0.2)(x_img)
    
    # Condition input
    input_cond = layers.Input(shape=(condition_dim,))
    
    # Process condition and reshape it to match the spatial dimensions of x_img (8x8)
    cond = layers.Dense(condition_dim)(input_cond)
    cond = layers.Reshape((1, 1, condition_dim))(cond)
    cond = layers.UpSampling2D(size=(8, 8))(cond)  # (8, 8, condition_dim)
    
    # Concatenate condition map with image
    x = layers.Concatenate(axis=-1)([x_img, cond])  # (8, 8, 256 + condition_dim)
    
    x = layers.Conv2D(256, (3,3), strides=(1,1), padding="same")(x)  # (8, 8, 256)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    
    # Output layer
    x = layers.Conv2D(1, (4,4), strides=(1,1), padding="same")(x)  # (8, 8, 1)
    x = layers.Activation("sigmoid")(x)  # (8, 8, 1)
    
    # Average the patch predictions to get a single scalar
    # x = layers.GlobalAveragePooling2D()(x)  # (1,)
    
    # Create model
    model = keras.Model([input_img, input_cond], x, name="stage1_discriminator")
    return model


In [8]:
#------------------------------------------------------------------------------
# 6) Adversarial Model
#    Wires up Generator + Discriminator for generator training
#------------------------------------------------------------------------------

def build_adversarial_model(gen_model, dis_model, ca_model, z_dim, condition_dim):
    """
    Build the adversarial model for generator training
    """
    # Freeze discriminator weights during generator training
    dis_model.trainable = False
    
    # Inputs
    z_input = layers.Input(shape=(z_dim,))
    embedding_input = layers.Input(shape=(1024,))
    
    # CA model -> mu, logvar
    mu, logvar = ca_model(embedding_input)
    
    # Sample c ~ N(mu, sigma)
    # Define the output shape for the Lambda layer
    def random_normal_output_shape(input_shape):
        return input_shape  # epsilon should have the same shape as mu
    
    epsilon = layers.Lambda(
        lambda x: K.random_normal(shape=K.shape(x)),
        output_shape=random_normal_output_shape
    )(mu)
    
    # Compute c using the reparameterization trick
    c = layers.Lambda(
        lambda inputs: inputs[0] + K.exp(inputs[1] / 2) * inputs[2],
        output_shape=(condition_dim,)  # c has shape (batch_size, condition_dim)
    )([mu, logvar, epsilon])
    
    # Combine z and c
    z_c = layers.Concatenate()([z_input, c])
    
    # Generate image
    fake_img = gen_model(z_c)
    
    # Get discriminator output
    dis_output = dis_model([fake_img, c])
    
    # The model outputs both the discriminator output and the CA output (mu, logvar)
    model = keras.Model(
        [z_input, embedding_input], [dis_output, layers.Concatenate()([mu, logvar])]
    )
    
    return model


In [9]:
if __name__ == "__main__":

    # 1) Setup & Hyperparameters

    data_dir = "./data/coco/"
    train_dir = os.path.join(data_dir, "train")
    test_dir = os.path.join(data_dir, "val")  # Using val as test

    # Example: extract train2014 if train_dir is empty
    if not os.listdir(train_dir):
        with zipfile.ZipFile(os.path.join(data_dir, "train2014"), "r") as zip_ref:
            zip_ref.extractall(train_dir)

    # Create results directory if it doesn't exist
    if not os.path.exists("results"):
        os.makedirs("results")

    image_size = (64, 64)
    batch_size = 32
    z_dim = 100
    epochs = 20
    embedding_dim = 1024
    condition_dim = 128


    # File paths for training
    embeddings_file_path_train = os.path.join(train_dir, "char-CNN-RNN-embeddings.pickle")
    filenames_file_path_train = os.path.join(train_dir, "filenames.pickle")
    # class_info_file_path_train = os.path.join(train_dir, "class_info.pickle")

    # Similarly for test data if needed
    embeddings_file_path_test = os.path.join(test_dir, "char-CNN-RNN-embeddings.pickle")
    filenames_file_path_test = os.path.join(test_dir, "filenames.pickle")

    # Optimizers
    dis_optimizer = optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)
    gen_optimizer = optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)


    # 2) Load Training Data

    print("Loading training data...")
    X_train, y_train, embeddings_train = load_dataset(
        filenames_file_path=filenames_file_path_train,
        embeddings_file_path=embeddings_file_path_train,
        image_size=image_size,
        dataset_dir=train_dir,
        # class_info_file_path=class_info_file_path_train,
    )

    print("Loading test data...")
    X_test, y_test, embeddings_test = load_dataset(
        filenames_file_path=filenames_file_path_test,
        embeddings_file_path=embeddings_file_path_test,
        image_size=image_size,
        dataset_dir=test_dir,
        # class_info_file_path=class_info_file_path_test,
    )


    # 3) Build & Compile Models

    print("Building models...")


    ca_model = build_ca_model(embedding_dim=embedding_dim, condition_dim=condition_dim)

    embedding_compressor_model = build_embedding_compressor_model(
        embedding_dim=embedding_dim, condition_dim=condition_dim
    )


    stage1_gen = build_stage1_generator(z_dim=z_dim, condition_dim=condition_dim)
    
    stage1_dis = build_stage1_discriminator(condition_dim=condition_dim)

    # Compile Discriminator
    stage1_dis.compile(
        loss="binary_crossentropy", optimizer=dis_optimizer, metrics=["accuracy"]
    )

    # Adversarial model (for training Generator)
    adversarial_model = build_adversarial_model(
        gen_model=stage1_gen,
        dis_model=stage1_dis,
        ca_model=ca_model,
        z_dim=z_dim,
        condition_dim=condition_dim,
    )
    # We'll use the custom KL_loss for the second output
    adversarial_model.compile(
        loss=["binary_crossentropy", KL_loss],
        loss_weights=[1.0, 2.0],
        optimizer=gen_optimizer,
    )


    # 4) Training Loop

   # Update label shapes for PatchGAN
real_labels = np.ones((batch_size, 8, 8, 1), dtype=np.float32) * 0.9  # Shape: (32, 8, 8, 1)
fake_labels = np.zeros((batch_size, 8, 8, 1), dtype=np.float32)        # Shape: (32, 8, 8, 1)

# Weâ€™ll assume embeddings_train.shape[0] == X_train.shape[0]
num_batches = X_train.shape[0] // batch_size

for epoch in range(epochs):
    print(f"================== Epoch {epoch+1}/{epochs} ==================")
    np.random.shuffle(indices := np.arange(X_train.shape[0]))

    for batch_i in range(num_batches):
        # ---------------------------
        # 4.1) Get real batch
        # ---------------------------
        batch_indices = indices[batch_i * batch_size : (batch_i + 1) * batch_size]
        real_imgs = X_train[batch_indices]
        real_embeddings = embeddings_train[batch_indices]

        # ---------------------------
        # 4.2) Sample random noise
        # ---------------------------
        z_noise = np.random.normal(0, 1, (batch_size, z_dim))

        # ---------------------------
        # 4.3) Generate fake images
        # ---------------------------
        mu, logvar = ca_model.predict_on_batch(real_embeddings)
        epsilon = np.random.normal(0, 1, (batch_size, condition_dim))
        c = mu + np.exp(logvar / 2) * epsilon  # reparameterize

        gen_input = np.concatenate([z_noise, c], axis=1)
        fake_imgs = stage1_gen.predict_on_batch(gen_input)

        # ---------------------------
        # 4.4) Train Discriminator
        # ---------------------------
        # 4.4.1) Train on real
        d_loss_real = stage1_dis.train_on_batch([real_imgs, c], real_labels)

        # 4.4.2) Train on fake
        d_loss_fake = stage1_dis.train_on_batch([fake_imgs, c], fake_labels)

        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------------
        # 4.5) Train Generator
        # ---------------------------
        valid_y = np.ones((batch_size, 8, 8, 1), dtype=np.float32)  # Update for generator training
        dummy_kl = np.zeros((batch_size, condition_dim * 2), dtype=np.float32)

        g_loss = adversarial_model.train_on_batch(
            [z_noise, real_embeddings], [valid_y, dummy_kl]
        )

        # Print every few batches
        if batch_i % 50 == 0:
            print(
                f"Batch {batch_i}/{num_batches} | D loss: {d_loss[0]:.4f} | G loss: {g_loss[0]:.4f} (KL: {g_loss[1]:.4f})"
            )

    # ---------------------------
    # 4.6) Save Weights Periodically
    # ---------------------------
    model_save_dir = "model_weights/stage1"
    os.makedirs(model_save_dir, exist_ok=True)


    if (epoch + 1) % 2 == 0:
        stage1_gen.save_weights(os.path.join(model_save_dir, f"stage1_gen_epoch_{epoch+1}.weights.h5"))
        stage1_dis.save_weights(os.path.join(model_save_dir, f"stage1_dis_epoch_{epoch+1}.weights.h5"))

# 5) Final Save
stage1_gen.save_weights(os.path.join(model_save_dir, "stage1_gen_final.weights.h5"))
stage1_dis.save_weights(os.path.join(model_save_dir, "stage1_dis_final.weights.h5"))
print("Training complete!")

Loading training data...
Number of images with embeddings: 999
Shape of embeddings for first image: (10, 1, 1024)
Loading test data...
Number of images with embeddings: 200
Shape of embeddings for first image: (10, 1, 1024)
Building models...




Batch 0/31 | D loss: 0.6991 | G loss: 0.8577 (KL: 0.6475)
Batch 0/31 | D loss: 0.7277 | G loss: 0.6813 (KL: 0.5916)
Batch 0/31 | D loss: 0.7321 | G loss: 0.6317 (KL: 0.5723)
Batch 0/31 | D loss: 0.7383 | G loss: 0.6059 (KL: 0.5612)
Batch 0/31 | D loss: 0.7449 | G loss: 0.5893 (KL: 0.5534)
Batch 0/31 | D loss: 0.7515 | G loss: 0.5783 (KL: 0.5480)
Batch 0/31 | D loss: 0.7580 | G loss: 0.5697 (KL: 0.5435)
Batch 0/31 | D loss: 0.7637 | G loss: 0.5622 (KL: 0.5390)
Batch 0/31 | D loss: 0.7687 | G loss: 0.5556 (KL: 0.5347)
Batch 0/31 | D loss: 0.7731 | G loss: 0.5514 (KL: 0.5323)
Batch 0/31 | D loss: 0.7772 | G loss: 0.5474 (KL: 0.5298)
Batch 0/31 | D loss: 0.7809 | G loss: 0.5443 (KL: 0.5279)
Batch 0/31 | D loss: 0.7845 | G loss: 0.5414 (KL: 0.5261)
Batch 0/31 | D loss: 0.7875 | G loss: 0.5392 (KL: 0.5247)
Batch 0/31 | D loss: 0.7904 | G loss: 0.5370 (KL: 0.5232)
Batch 0/31 | D loss: 0.7927 | G loss: 0.5347 (KL: 0.5216)
Batch 0/31 | D loss: 0.7948 | G loss: 0.5326 (KL: 0.5201)
Batch 0/31 | D