
<h1><center id="title">DataLab Cup 3: Reverse Image Caption</center></h1>

<center id="author">Shan-Hung Wu &amp; DataLab<br/>Fall 2025</center>




<h1><center class="subtitle">Text to Image</center></h1>

<h2 id="Platform:-Kaggle">Platform: <a href="https://www.kaggle.com/competitions/2025-datalab-cup-3-reverse-image-caption/overview">Kaggle</a><a class="anchor-link" href="#Platform:-Kaggle">¶</a></h2>
<h2 id="Overview">Overview<a class="anchor-link" href="#Overview">¶</a></h2>
<p>In this work, we are interested in translating text in the form of single-sentence human-written descriptions directly into image pixels. For example, "<strong>this flower has petals that are yellow and has a ruffled stamen</strong>" and "<strong>this pink and yellow flower has a beautiful yellow center with many stamens</strong>". You have to develop a novel deep architecture and GAN formulation to effectively translate visual concepts from characters to pixels.</p>

<p>More specifically, given a set of texts, your task is to generate reasonable images with size 64x64x3 to illustrate the corresponding texts. Here we use <a href="http://www.robots.ox.ac.uk/~vgg/data/flowers/102/">Oxford-102 flower dataset</a> and its <a href="https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view">paired texts</a> as our training dataset.</p>

<img alt="No description has been provided for this image" src="./data/example.png"/>

<ul>
<li>7370 images as training set, where each images is annotated with at most 10 texts.</li>
<li>819 texts for testing. You must generate 1 64x64x3 image for each text.</li>
</ul>




<h2 id="Conditional-GAN">Conditional GAN<a class="anchor-link" href="#Conditional-GAN">¶</a></h2>
<p>Given a text, in order to generate the image which can illustrate it, our model must meet several requirements:</p>

<ol>
<li>Our model should have ability to understand and extract the meaning of given texts.<ul>
<li>Use RNN or other language model, such as BERT, ELMo or XLNet, to capture the meaning of text.</li>
</ul>
</li>
<li>Our model should be able to generate image.<ul>
<li>Use GAN to generate high quality image.</li>
</ul>
</li>
<li>GAN-generated image should illustrate the text.<ul>
<li>Use conditional-GAN to generate image conditioned on given text.</li>
</ul>
</li>
</ol>

<p>Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information $y$. We can perform the conditioning by feeding $y$ into both the discriminator and generator as additional input layer.</p>

<img alt="No description has been provided for this image" src="./data/cGAN.png" width="500"/>

<p>There are two motivations for using some extra information in a GAN model:</p>

<ol>
<li>Improve GAN.</li>
<li>Generate targeted image.</li>
</ol>

<p>Additional information that is correlated with the input images, such as class labels, can be used to improve the GAN. This improvement may come in the form of more stable training, faster training, and/or generated images that have better quality.</p>

<img alt="No description has been provided for this image" src="./data/GANCLS.jpg"/>



In [29]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
from tqdm import tqdm

import re
from IPython import display

In [30]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)



<h2 id="Preprocess-Text">Preprocess Text<a class="anchor-link" href="#Preprocess-Text">¶</a></h2>
<p>Since dealing with raw string is inefficient, we have done some data preprocessing for you:</p>

<ul>
<li>Delete text over <code>MAX_SEQ_LENGTH (20)</code>.</li>
<li>Delete all puntuation in the texts.</li>
<li>Encode each vocabulary in <code>dictionary/vocab.npy</code>.</li>
<li>Represent texts by a sequence of integer IDs.</li>
<li>Replace rare words by <code>&lt;RARE&gt;</code> token to reduce vocabulary size for more efficient training.</li>
<li>Add padding as <code>&lt;PAD&gt;</code> to each text to make sure all of them have equal length to <code>MAX_SEQ_LENGTH (20)</code>.</li>
</ul>

<p>It is worth knowing that there is no necessary to append <code>&lt;ST&gt;</code> and <code>&lt;ED&gt;</code> to each text because we don't need to generate any sequence in this task.</p>

<p>To make sure correctness of encoding of the original text, we can decode sequence vocabulary IDs by looking up the vocabulary dictionary:</p>

<ul>
<li><code>dictionary/word2Id.npy</code> is a numpy array mapping word to id.</li>
<li><code>dictionary/id2Word.npy</code> is a numpy array mapping id back to word.</li>
</ul>



In [31]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))


there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [32]:
# This cell previously contained sent2IdList() function
# It has been removed as we now use DistilBERT tokenizer instead
# The id2word_dict is still available from cell 6 for visualization purposes

print("✓ Using DistilBERT tokenizer (sent2IdList removed)")

✓ Using DistilBERT tokenizer (sent2IdList removed)



<h2 id="Dataset">Dataset<a class="anchor-link" href="#Dataset">¶</a></h2>
<p>For training, the following files are in dataset folder:</p>

<ul>
<li><code>./dataset/text2ImgData.pkl</code> is a pandas dataframe with attribute 'Captions' and 'ImagePath'.<ul>
<li>'Captions' : A list of text id list contain 1 to 10 captions.</li>
<li>'ImagePath': Image path that store paired image.</li>
</ul>
</li>
<li><code>./102flowers/</code> is the directory containing all training images.</li>
<li><code>./dataset/testData.pkl</code> is a pandas a dataframe with attribute 'ID' and 'Captions', which contains testing data.</li>
</ul>



In [33]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


There are 7370 image in training data


In [34]:
df.head(5)


Unnamed: 0_level_0,Captions,ImagePath
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
6734,"[[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11,...",./102flowers/image_06734.jpg
6736,"[[4, 1, 5, 12, 2, 3, 11, 31, 28, 68, 106, 132,...",./102flowers/image_06736.jpg
6737,"[[9, 2, 27, 4, 1, 6, 14, 7, 12, 19, 5427, 5427...",./102flowers/image_06737.jpg
6738,"[[9, 1, 5, 8, 54, 16, 38, 7, 12, 116, 325, 3, ...",./102flowers/image_06738.jpg
6739,"[[4, 12, 1, 5, 29, 11, 19, 7, 26, 70, 5427, 54...",./102flowers/image_06739.jpg


In [35]:
# Data Augmentation Configuration
# Define this BEFORE training_data_generator to avoid reference issues
aug_config = {
    'enabled': True,                      # Master switch for augmentation
    'random_flip_horizontal': True,       # Flowers can be mirrored
    'random_flip_vertical': False,        # Flowers typically grow upward
    'random_rotation': True,              # Any rotation is valid for flowers
    'random_brightness': 0.15,            # Lighting variations (max delta)
    'random_contrast': (0.9, 1.1),        # Subtle contrast changes (lower, upper)
    'random_saturation': (0.9, 1.1),      # Color intensity (lower, upper)
    'random_hue': 0.05,                   # Small color shifts (max delta)
}

print('Data Augmentation:', 'ENABLED' if aug_config['enabled'] else 'DISABLED')
if aug_config['enabled']:
    enabled_augs = [k for k, v in aug_config.items() if k != 'enabled' and v]
    print(f'Active augmentations: {", ".join(enabled_augs)}')

Data Augmentation: ENABLED
Active augmentations: random_flip_horizontal, random_rotation, random_brightness, random_contrast, random_saturation, random_hue



<h2 id="Create-Dataset-by-Dataset-API">Create Dataset by Dataset API<a class="anchor-link" href="#Create-Dataset-by-Dataset-API">¶</a></h2>



In [36]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def preprocess_text_distilbert(text, max_length=64):

    encoded = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='tf'
    )


    return {
        'input_ids': encoded['input_ids'],
        'attention_mask': encoded['attention_mask']
    }

In [37]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3

def training_data_generator(caption_text, image_path):
    """
    Updated data generator using DistilBERT tokenization
    
    Args:
        caption_text: Raw text string (not IDs!)
        image_path: Path to image file
    
    Returns:
        img, input_ids, attention_mask
    """
    # ============= IMAGE PROCESSING (same as before) =============
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0, 1]
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    
    # Data augmentation (only applied during training)
    if aug_config['enabled']:
        if aug_config['random_flip_horizontal']:
            img = tf.image.random_flip_left_right(img)
        
        if aug_config['random_flip_vertical']:
            img = tf.image.random_flip_up_down(img)
        
        if aug_config['random_rotation']:
            img = tf.image.rot90(img, k=tf.random.uniform([], 0, 4, dtype=tf.int32))
        
        if aug_config['random_brightness']:
            img = tf.image.random_brightness(img, aug_config['random_brightness'])
        
        if aug_config['random_contrast']:
            img = tf.image.random_contrast(img, 
                                          aug_config['random_contrast'][0], 
                                          aug_config['random_contrast'][1])
        
        if aug_config['random_saturation']:
            img = tf.image.random_saturation(img, 
                                            aug_config['random_saturation'][0], 
                                            aug_config['random_saturation'][1])
        
        if aug_config['random_hue']:
            img = tf.image.random_hue(img, aug_config['random_hue'])
        
        img = tf.clip_by_value(img, 0.0, 1.0)
    
    # Normalize to [-1, 1] to match generator's tanh output
    img = (img * 2.0) - 1.0
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    
    # ============= TEXT PROCESSING (NEW: Use DistilBERT tokenizer) =============
    def tokenize_caption(text):
        """Python function to tokenize text using DistilBERT tokenizer"""
        # Convert EagerTensor to bytes, then decode to string
        text = text.numpy().decode('utf-8')
        
        # Tokenize using DistilBERT
        encoded = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=64,
            return_tensors='np'  # Use numpy arrays for TF compatibility
        )
        
        return encoded['input_ids'][0], encoded['attention_mask'][0]
    
    # Use tf.py_function to call Python tokenizer
    input_ids, attention_mask = tf.py_function(
        func=tokenize_caption,
        inp=[caption_text],
        Tout=[tf.int32, tf.int32]
    )
    
    # Set shapes explicitly
    input_ids.set_shape([64])
    attention_mask.set_shape([64])
    
    return img, input_ids, attention_mask

def dataset_generator(filenames, batch_size, data_generator):
    """
    Updated dataset generator to work with raw text (decoded from IDs)
    """
    # Load the training data
    df = pd.read_pickle(filenames)
    captions_ids = df['Captions'].values
    caption_texts = []
    
    # Decode pre-tokenized IDs back to raw text
    for i in range(len(captions_ids)):
        # Randomly choose one caption (list of ID lists)
        chosen_caption_ids = random.choice(captions_ids[i])
        
        # Decode IDs back to text using id2word_dict
        words = []
        for word_id in chosen_caption_ids:
            word = id2word_dict[str(word_id)]
            if word != '<PAD>':  # Skip padding tokens
                words.append(word)
        
        caption_text = ' '.join(words)
        caption_texts.append(caption_text)
    
    image_paths = df['ImagePath'].values
    
    # Verify same length
    assert len(caption_texts) == len(image_paths)
    
    # Create dataset from raw text and image paths
    dataset = tf.data.Dataset.from_tensor_slices((caption_texts, image_paths))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(len(caption_texts)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

In [38]:
BATCH_SIZE = 16
dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator)


<h2 id="Conditional-GAN-Model">Conditional GAN Model<a class="anchor-link" href="#Conditional-GAN-Model">¶</a></h2>
<p>As mentioned above, there are three models in this task, text encoder, generator and discriminator.</p>

<h2 id="Text-Encoder">Text Encoder<a class="anchor-link" href="#Text-Encoder">¶</a></h2>
<p>A RNN encoder that captures the meaning of input text.</p>

<ul>
<li>Input: text, which is a list of ids.</li>
<li>Output: embedding, or hidden representation of input text.</li>
</ul>



In [39]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import TFDistilBertModel

class DistillBertEncoder(tf.keras.Model):
    def __init__(self, output_dim=128, freeze_bert=True):
        super(DistillBertEncoder, self).__init__()
        
        self.distilbert = TFDistilBertModel.from_pretrained('distilbert-base-uncased')

        if(freeze_bert):
            self.distilbert.trainable = False

        self.projection = tf.keras.layers.Dense(output_dim, activation='relu')
    
        self.dropout = tf.keras.layers.Dropout(0.1)

    def call(self, input_ids, attention_mask, training=False):
        outputs = self.distilbert(input_ids, attention_mask=attention_mask, training=training)

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        
        cls_embedding = self.dropout(cls_embedding, training=training)

        text_features = self.projection(cls_embedding)

        return text_features


<h2 id="Generator">Generator<a class="anchor-link" href="#Generator">¶</a></h2>
<p>A image generator which generates the target image illustrating the input text.</p>

<ul>
<li>Input: hidden representation of input text and random noise z with random seed.</li>
<li>Output: target image, which is conditioned on the given text, in size 64x64x3.</li>
</ul>



In [40]:
# Weight initialization as per DCGAN paper
def dcgan_weight_init():
    """Returns weight initializer for DCGAN: Normal(mean=0, stddev=0.02)"""
    return tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)


class Generator(tf.keras.Model):
    """
    DCGAN Generator for 64x64 images
    Uses transposed convolutions to progressively upsample from noise+text
    """
    def __init__(self, hparas):
        super(Generator, self).__init__()
        self.hparas = hparas
        
        # Weight initializer
        init = dcgan_weight_init()
        
        # Project and reshape
        # Input: [batch, z_dim + text_embed_dim] (e.g., 512 + 128 = 640)
        self.dense = tf.keras.layers.Dense(
            4 * 4 * 1024,  # Will reshape to [batch, 4, 4, 1024]
            use_bias=False,
            kernel_initializer=init
        )
        self.bn0 = tf.keras.layers.BatchNormalization()
        
        # Transposed convolutions for upsampling
        # 4x4 -> 8x8
        self.conv1 = tf.keras.layers.Conv2DTranspose(
            512, kernel_size=4, strides=2, padding='same',
            use_bias=False, kernel_initializer=init
        )
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        # 8x8 -> 16x16
        self.conv2 = tf.keras.layers.Conv2DTranspose(
            256, kernel_size=4, strides=2, padding='same',
            use_bias=False, kernel_initializer=init
        )
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        # 16x16 -> 32x32
        self.conv3 = tf.keras.layers.Conv2DTranspose(
            128, kernel_size=4, strides=2, padding='same',
            use_bias=False, kernel_initializer=init
        )
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        # 32x32 -> 64x64 (final output)
        self.conv4 = tf.keras.layers.Conv2DTranspose(
            3, kernel_size=4, strides=2, padding='same',
            use_bias=False, kernel_initializer=init
        )
        # No batch norm on output layer
        
    def call(self, text, noise_z, training=True):
        # Concatenate noise and text embeddings
        # text shape: [batch, text_embed_dim] (e.g., [16, 128])
        # noise_z shape: [batch, z_dim] (e.g., [16, 512])
        x = tf.concat([noise_z, text], axis=1)  # [batch, 640]
        
        # Project and reshape
        x = self.dense(x)  # [batch, 4*4*1024]
        x = self.bn0(x, training=training)
        x = tf.nn.relu(x)
        x = tf.reshape(x, [-1, 4, 4, 1024])  # [batch, 4, 4, 1024]
        
        # Upsample: 4x4 -> 8x8
        x = self.conv1(x)  # [batch, 8, 8, 512]
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        
        # Upsample: 8x8 -> 16x16
        x = self.conv2(x)  # [batch, 16, 16, 256]
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        
        # Upsample: 16x16 -> 32x32
        x = self.conv3(x)  # [batch, 32, 32, 128]
        x = self.bn3(x, training=training)
        x = tf.nn.relu(x)
        
        # Upsample: 32x32 -> 64x64 (final)
        x = self.conv4(x)  # [batch, 64, 64, 3]
        output = tf.nn.tanh(x)  # Output in range [-1, 1]
        
        # Return both for compatibility with existing training code
        return x, output  # logits, output


<h2 id="Discriminator">Discriminator<a class="anchor-link" href="#Discriminator">¶</a></h2>
<p>A binary classifier which can discriminate the real and fake image:</p>

<ol>
<li>Real image<ul>
<li>Input: real image and the paired text</li>
<li>Output: a floating number representing the result, which is expected to be 1.</li>
</ul>
</li>
<li>Fake Image<ul>
<li>Input: generated image and paired text</li>
<li>Output: a floating number representing the result, which is expected to be 0.</li>
</ul>
</li>
</ol>



In [41]:
class Critic(tf.keras.Model):
    """
    WGAN-GP Critic for 64x64 images
    Key differences from DCGAN Discriminator:
    1. NO batch normalization (causes issues with gradient penalty)
    2. NO sigmoid activation (outputs raw scores)
    3. Uses LeakyReLU throughout
    """
    def __init__(self, hparas):
        super(Critic, self).__init__()
        self.hparas = hparas
        
        # Weight initializer
        init = dcgan_weight_init()
        
        # Strided convolutions for downsampling
        # 64x64 -> 32x32 (NO batch norm on first layer)
        self.conv1 = tf.keras.layers.Conv2D(
            64, kernel_size=4, strides=2, padding='same',
            kernel_initializer=init
        )
        
        # 32x32 -> 16x16 (NO batch norm!)
        self.conv2 = tf.keras.layers.Conv2D(
            128, kernel_size=4, strides=2, padding='same',
            kernel_initializer=init
        )
        
        # 16x16 -> 8x8 (NO batch norm!)
        self.conv3 = tf.keras.layers.Conv2D(
            256, kernel_size=4, strides=2, padding='same',
            kernel_initializer=init
        )
        
        # 8x8 -> 4x4 (NO batch norm!)
        self.conv4 = tf.keras.layers.Conv2D(
            512, kernel_size=4, strides=2, padding='same',
            kernel_initializer=init
        )
        
        # Text conditioning layers
        self.text_dense = tf.keras.layers.Dense(
            512, kernel_initializer=init
        )
        
        # Final output layer
        self.flatten = tf.keras.layers.Flatten()
        self.final = tf.keras.layers.Dense(1, kernel_initializer=init)
        
    def call(self, img, text, training=True):
        # Image path: 64x64x3 -> 4x4x512
        x = self.conv1(img)  # [batch, 32, 32, 64]
        x = tf.nn.leaky_relu(x, alpha=0.2)
        
        x = self.conv2(x)  # [batch, 16, 16, 128]
        x = tf.nn.leaky_relu(x, alpha=0.2)
        
        x = self.conv3(x)  # [batch, 8, 8, 256]
        x = tf.nn.leaky_relu(x, alpha=0.2)
        
        x = self.conv4(x)  # [batch, 4, 4, 512]
        x = tf.nn.leaky_relu(x, alpha=0.2)
        
        # Flatten image features
        x = self.flatten(x)  # [batch, 8192]
        
        # Process text
        text_features = self.text_dense(text)  # [batch, 512]
        text_features = tf.nn.leaky_relu(text_features, alpha=0.2)
        
        # Concatenate image and text features
        combined = tf.concat([x, text_features], axis=1)  # [batch, 8704]
        
        # Final output - RAW SCORES (no sigmoid!)
        output = self.final(combined)  # [batch, 1]
        
        return output  # Return only scores, not probabilities

In [42]:
hparas = {
    'MAX_SEQ_LENGTH': 20,
    'EMBED_DIM': 256,
    'VOCAB_SIZE': len(word2Id_dict),
    'RNN_HIDDEN_SIZE': 128,
    'Z_DIM': 512,
    'DENSE_DIM': 128,
    'IMAGE_SIZE': [64, 64, 3],
    'BATCH_SIZE': BATCH_SIZE,
    'LR': 2e-4,
    'BETA_1': 0.0,        # WGAN-GP: use 0.0 instead of 0.5
    'BETA_2': 0.9,        # WGAN-GP: use 0.9
    'N_CRITIC': 5,        # NEW: critic iterations per generator iteration
    'LAMBDA_GP': 10.0,    # NEW: gradient penalty weight
    'N_EPOCH': 100,
    'N_SAMPLE': num_training_sample,
    'PRINT_FREQ': 20
}

In [43]:
text_encoder = DistillBertEncoder(output_dim=hparas['RNN_HIDDEN_SIZE'], freeze_bert=True)
generator = Generator(hparas)
critic = Critic(hparas)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.



<h2 id="Loss-Function-and-Optimization">Loss Function and Optimization<a class="anchor-link" href="#Loss-Function-and-Optimization">¶</a></h2>
<p>Although the conditional GAN model is quite complex, the loss function used to optimize the network is relatively simple. Actually, it is simply a binary classification task, thus we use cross entropy as our loss.</p>



In [44]:
def wasserstein_loss_critic(real_scores, fake_scores):
    """
    Wasserstein loss for critic
    Critic wants to maximize: E[critic(real)] - E[critic(fake)]
    So we minimize: E[critic(fake)] - E[critic(real)]
    """
    return tf.reduce_mean(fake_scores) - tf.reduce_mean(real_scores)

def wasserstein_loss_generator(fake_scores):
    """
    Wasserstein loss for generator
    Generator wants to maximize: E[critic(fake)]
    So we minimize: -E[critic(fake)]
    """
    return -tf.reduce_mean(fake_scores)

def gradient_penalty(critic, real_images, fake_images, text_embed, batch_size):
    """
    Gradient penalty for WGAN-GP
    
    Computes ||∇_x critic(x)||₂ for interpolated images x
    Penalty = λ * mean((||gradient|| - 1)²)
    """
    # Random weight for interpolation
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    
    # Interpolated images: x_hat = alpha * real + (1 - alpha) * fake
    interpolated = alpha * real_images + (1.0 - alpha) * fake_images
    
    # Compute critic scores on interpolated images
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        interpolated_scores = critic(interpolated, text_embed, training=True)
    
    # Compute gradients of scores w.r.t. interpolated images
    gradients = gp_tape.gradient(interpolated_scores, [interpolated])[0]
    
    # Compute L2 norm of gradients for each sample
    # gradients shape: [batch, 64, 64, 3]
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    
    # Gradient penalty: mean((||gradient|| - 1)²)
    gradient_penalty = tf.reduce_mean(tf.square(gradients_norm - 1.0))
    
    return gradient_penalty

In [45]:
# WGAN-GP: Use Adam with beta_1=0.0, beta_2=0.9
generator_optimizer = tf.keras.optimizers.Adam(
    learning_rate=hparas['LR'],
    beta_1=hparas['BETA_1'],
    beta_2=hparas['BETA_2']
)

critic_optimizer = tf.keras.optimizers.Adam(
    learning_rate=hparas['LR'],
    beta_1=hparas['BETA_1'],
    beta_2=hparas['BETA_2']
)



In [46]:
checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    critic_optimizer=critic_optimizer,
    text_encoder=text_encoder,
    generator=generator,
    critic=critic
)

In [47]:
def calculate_wasserstein_distance(real_scores, fake_scores):
    """
    Approximation of Wasserstein distance
    Higher is better (critic getting better at distinguishing)
    """
    return tf.reduce_mean(real_scores) - tf.reduce_mean(fake_scores)

def calculate_gradient_norm(gradients):
    """Calculate L2 norm of gradients"""
    squared_norms = [tf.reduce_sum(tf.square(g)) for g in gradients if g is not None]
    total_norm = tf.sqrt(tf.reduce_sum(squared_norms))
    return total_norm

In [48]:
@tf.function
def train_step(real_image, input_ids, attention_mask):
    """
    WGAN-GP training step with n_critic iterations
    """
    batch_size = tf.shape(real_image)[0]
    
    # Encode text once (used for both critic and generator)
    text_embed = text_encoder(input_ids, attention_mask, training=True)
    
    # ============================================================
    # Train Critic (multiple iterations)
    # ============================================================
    for _ in range(hparas['N_CRITIC']):
        noise = tf.random.normal([batch_size, hparas['Z_DIM']], mean=0.0, stddev=1.0)
        
        with tf.GradientTape() as critic_tape:
            # Generate fake images
            _, fake_image = generator(text_embed, noise, training=True)
            
            # Get critic scores
            real_scores = critic(real_image, text_embed, training=True)
            fake_scores = critic(fake_image, text_embed, training=True)
            
            # Wasserstein loss
            c_loss_wasserstein = wasserstein_loss_critic(real_scores, fake_scores)
            
            # Gradient penalty
            gp = gradient_penalty(critic, real_image, fake_image, text_embed, batch_size)
            
            # Total critic loss
            c_loss = c_loss_wasserstein + hparas['LAMBDA_GP'] * gp
        
        # Update critic
        grad_c = critic_tape.gradient(c_loss, critic.trainable_variables)
        critic_optimizer.apply_gradients(zip(grad_c, critic.trainable_variables))
    
    # ============================================================
    # Train Generator (once per n_critic iterations)
    # ============================================================
    noise = tf.random.normal([batch_size, hparas['Z_DIM']], mean=0.0, stddev=1.0)
    
    with tf.GradientTape() as gen_tape:
        _, fake_image = generator(text_embed, noise, training=True)
        fake_scores = critic(fake_image, text_embed, training=True)
        g_loss = wasserstein_loss_generator(fake_scores)
    
    # Update generator
    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    
    # Calculate metrics
    wasserstein_dist = calculate_wasserstein_distance(real_scores, fake_scores)
    grad_norm_g = calculate_gradient_norm(grad_g)
    grad_norm_c = calculate_gradient_norm(grad_c)
    
    return {
        'g_loss': g_loss,
        'c_loss': c_loss,
        'c_loss_wasserstein': c_loss_wasserstein,
        'gp': gp,
        'wasserstein_dist': wasserstein_dist,
        'grad_norm_g': grad_norm_g,
        'grad_norm_c': grad_norm_c
    }

In [49]:
@tf.function
def test_step(input_ids, attention_mask, noise):
    # Encode text with DistilBERT (no hidden state)
    text_embed = text_encoder(input_ids, attention_mask, training=False)
    _, fake_image = generator(text_embed, noise, training=False)
    return fake_image


<h2 id="Visualiztion">Visualiztion<a class="anchor-link" href="#Visualiztion">¶</a></h2>
<p>During training, we can visualize the generated image to evaluate the quality of generator. The followings are some functions helping visualization.</p>



In [50]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img

def imsave(images, size, path):
    # getting the pixel values between [0, 1] to save it
    return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
    return imsave(images, size, image_path)



<p>We always use same random seed and same senteces during training, which is more convenient for us to evaluate the quality of generated image.</p>



In [51]:
ni = int(np.ceil(np.sqrt(hparas['BATCH_SIZE'])))
sample_size = hparas['BATCH_SIZE']
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
# Fix: 8 sentences × 2 repetitions = 16 total (matching sample_size and sample_seed batch dimension)
sample_sentences = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/(2*ni)) + \
                   ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/(2*ni)) + \
                   ["the petals on this flower are white with a yellow center"] * int(sample_size/(2*ni)) + \
                   ["this flower has a lot of small round pink petals."] * int(sample_size/(2*ni)) + \
                   ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/(2*ni)) + \
                   ["the flower has yellow petals and the center of it is brown."] * int(sample_size/(2*ni)) + \
                   ["this flower has petals that are blue and white."] * int(sample_size/(2*ni)) +\
                   ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/(2*ni))

# Tokenize with DistilBERT (no more sent2IdList!)
sample_encoded = preprocess_text_distilbert(sample_sentences, max_length=64)
sample_input_ids = sample_encoded['input_ids']
sample_attention_mask = sample_encoded['attention_mask']

print(f"Sample sentences tokenized: {len(sample_sentences)} sentences")
print(f"Input IDs shape: {sample_input_ids.shape}")
print(f"Attention mask shape: {sample_attention_mask.shape}")

Sample sentences tokenized: 16 sentences
Input IDs shape: (16, 64)
Attention mask shape: (16, 64)


In [52]:
# Helper functions for managing training runs
import glob
import json

def list_available_runs():
    """List all available training runs with their details"""
    run_dirs = sorted(glob.glob('runs/*/'))
    
    if not run_dirs:
        print('No training runs found in runs/ directory')
        return []
    
    print('=' * 80)
    print('Available Training Runs:')
    print('=' * 80)
    
    available_runs = []
    for run_dir in run_dirs:
        timestamp = run_dir.split('/')[-2]
        available_runs.append(timestamp)
        
        # Check for checkpoints
        checkpoint_dir = f'{run_dir}checkpoints'
        latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        has_checkpoint = '✓' if latest_ckpt else '✗'
        
        # Check for config
        config_path = f'{run_dir}config.json'
        has_config = '✓' if os.path.exists(config_path) else '✗'
        
        # Count sample images
        sample_count = len(glob.glob(f'{run_dir}samples/*.jpg'))
        
        print(f'{timestamp}  |  Checkpoint: {has_checkpoint}  |  Config: {has_config}  |  Samples: {sample_count}')
        
        if latest_ckpt:
            print(f'  └─ Latest checkpoint: {latest_ckpt}')
    
    print('=' * 80)
    return available_runs

def load_run_config(run_timestamp):
    """Load configuration from a previous run"""
    config_path = f'runs/{run_timestamp}/config.json'
    
    if not os.path.exists(config_path):
        raise FileNotFoundError(f'Config not found: {config_path}')
    
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print(f'✓ Loaded config from: {config_path}')
    return config

# List available runs
print('Checking for existing training runs...')
list_available_runs()

Checking for existing training runs...
Available Training Runs:
20251116-204432  |  Checkpoint: ✗  |  Config: ✓  |  Samples: 0
20251116-204819  |  Checkpoint: ✗  |  Config: ✓  |  Samples: 0
20251116-225453  |  Checkpoint: ✗  |  Config: ✓  |  Samples: 0
20251116-230026  |  Checkpoint: ✗  |  Config: ✓  |  Samples: 0
20251116-230343  |  Checkpoint: ✗  |  Config: ✓  |  Samples: 0
20251117-012045  |  Checkpoint: ✗  |  Config: ✗  |  Samples: 0
20251117-013543  |  Checkpoint: ✗  |  Config: ✗  |  Samples: 0


['20251116-204432',
 '20251116-204819',
 '20251116-225453',
 '20251116-230026',
 '20251116-230343',
 '20251117-012045',
 '20251117-013543']


<h2 id="Training">Training<a class="anchor-link" href="#Training">¶</a></h2>



In [53]:
from datetime import datetime

# ============================================================
# RESUME TRAINING CONFIGURATION
# ============================================================
# Set to None for new run, or specify run timestamp to resume
# Example: RESUME_RUN = '20251116-225453'
RESUME_RUN = None  # ← Change this to resume from specific run

# ============================================================
# RUN DIRECTORY SETUP
# ============================================================
if RESUME_RUN:
    # Resume from existing run
    run_dir = f'runs/{RESUME_RUN}'
    
    # Verify directory exists
    if not os.path.exists(run_dir):
        raise FileNotFoundError(f'Run directory not found: {run_dir}')
    
    # Load existing config
    try:
        prev_config = load_run_config(RESUME_RUN)
        run_timestamp = prev_config.get('run_timestamp', RESUME_RUN)
        print(f'\n⟳ RESUMING training from: {run_dir}')
        print(f'  Original start: {run_timestamp}')
        
        # Warn if hyperparameters might be different
        if 'hyperparameters' in prev_config:
            prev_hparas = prev_config['hyperparameters']
            if prev_hparas.get('BATCH_SIZE') != hparas['BATCH_SIZE']:
                print(f'  ⚠ WARNING: Batch size changed ({prev_hparas.get("BATCH_SIZE")} → {hparas["BATCH_SIZE"]})')
            if prev_hparas.get('LR') != hparas['LR']:
                print(f'  ⚠ WARNING: Learning rate changed ({prev_hparas.get("LR")} → {hparas["LR"]})')
    except Exception as e:
        print(f'⚠ Could not load previous config: {e}')
        run_timestamp = RESUME_RUN
    
    # Use existing subdirectories
    checkpoint_dir = f'{run_dir}/checkpoints'
    samples_dir = f'{run_dir}/samples'
    inference_dir = f'{run_dir}/inference'
    
    # Create directories if they don't exist (shouldn't happen, but safety check)
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(inference_dir, exist_ok=True)
    
else:
    # Create new run
    run_timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    run_dir = f'runs/{run_timestamp}'
    
    # All outputs for this run go in subdirectories
    checkpoint_dir = f'{run_dir}/checkpoints'
    samples_dir = f'{run_dir}/samples'
    inference_dir = f'{run_dir}/inference'
    
    # Create all directories
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(inference_dir, exist_ok=True)
    
    print(f'✓ Created NEW run directory: {run_dir}')

# Display directory structure
print(f'\nRun directory structure:')
print(f'  {run_dir}/')
print(f'  ├── checkpoints/ : {checkpoint_dir}')
print(f'  ├── samples/     : {samples_dir}')
print(f'  └── inference/   : {inference_dir}')

✓ Created NEW run directory: runs/20251117-134622

Run directory structure:
  runs/20251117-134622/
  ├── checkpoints/ : runs/20251117-134622/checkpoints
  ├── samples/     : runs/20251117-134622/samples
  └── inference/   : runs/20251117-134622/inference


In [54]:
def train(dataset, epochs):
    global run_dir, checkpoint_dir, samples_dir, inference_dir
    
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    
    log_dir = f'{run_dir}/logs'
    os.makedirs(log_dir, exist_ok=True)
    summary_writer = tf.summary.create_file_writer(log_dir)
    
    print(f'Training run: {run_dir}')
    print(f'TensorBoard logs: {log_dir}')
    print(f'Model: WGAN-GP with {hparas["N_CRITIC"]} critic iterations')
    
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    global_step = 0
    
    for epoch in range(hparas['N_EPOCH']):
        g_total_loss = 0
        c_total_loss = 0
        c_total_loss_wasserstein = 0
        gp_total = 0
        wd_total = 0
        start = time.time()
        
        pbar = tqdm(dataset, desc=f'Epoch {epoch+1}/{hparas["N_EPOCH"]}', 
                   total=steps_per_epoch, unit='batch')
        
        for batch_idx, (image, input_ids, attention_mask) in enumerate(pbar):
            metrics = train_step(image, input_ids, attention_mask)
            
            # Accumulate losses
            g_total_loss += metrics['g_loss']
            c_total_loss += metrics['c_loss']
            c_total_loss_wasserstein += metrics['c_loss_wasserstein']
            gp_total += metrics['gp']
            wd_total += metrics['wasserstein_dist']
            
            # Log to TensorBoard every batch
            with summary_writer.as_default():
                tf.summary.scalar('Losses/generator_loss', metrics['g_loss'], step=global_step)
                tf.summary.scalar('Losses/critic_loss_total', metrics['c_loss'], step=global_step)
                tf.summary.scalar('Losses/critic_loss_wasserstein', metrics['c_loss_wasserstein'], step=global_step)
                tf.summary.scalar('Losses/gradient_penalty', metrics['gp'], step=global_step)
                tf.summary.scalar('Metrics/wasserstein_distance', metrics['wasserstein_dist'], step=global_step)
                
                if global_step % 50 == 0:
                    tf.summary.scalar('Gradients/generator_gradient_norm', metrics['grad_norm_g'], step=global_step)
                    tf.summary.scalar('Gradients/critic_gradient_norm', metrics['grad_norm_c'], step=global_step)
            
            # Update progress bar
            pbar.set_postfix({
                'G_loss': f'{metrics["g_loss"]:.4f}',
                'C_loss': f'{metrics["c_loss"]:.4f}',
                'W_dist': f'{metrics["wasserstein_dist"]:.4f}'
            })
            
            global_step += 1
        
        pbar.close()
        
        # Print epoch summary
        avg_g_loss = g_total_loss / steps_per_epoch
        avg_c_loss = c_total_loss / steps_per_epoch
        avg_c_loss_w = c_total_loss_wasserstein / steps_per_epoch
        avg_gp = gp_total / steps_per_epoch
        avg_wd = wd_total / steps_per_epoch
        epoch_time = time.time() - start
        
        print(f'Epoch {epoch+1}: G_loss={avg_g_loss:.4f}, C_loss={avg_c_loss:.4f} ' +
              f'(W={avg_c_loss_w:.4f}, GP={avg_gp:.4f}), W_dist={avg_wd:.4f}, Time={epoch_time:.2f}s')
        
        # Log epoch averages
        with summary_writer.as_default():
            tf.summary.scalar('Epoch/generator_loss_avg', avg_g_loss, step=epoch)
            tf.summary.scalar('Epoch/critic_loss_avg', avg_c_loss, step=epoch)
            tf.summary.scalar('Epoch/wasserstein_distance_avg', avg_wd, step=epoch)
        
        # Save checkpoint
        if (epoch + 1) % 50 == 0:
            saved_path = checkpoint.save(file_prefix=checkpoint_prefix)
            print(f'  ✓ Checkpoint saved: {saved_path}')
        
        # Visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            fake_image = test_step(sample_input_ids, sample_attention_mask, sample_seed)
            save_images(fake_image, [ni, ni], f'{samples_dir}/train_{epoch+1:03d}.jpg')
            
            with summary_writer.as_default():
                display_images = (fake_image + 1.0) / 2.0
                tf.summary.image('Generated_Samples', display_images, step=epoch, max_outputs=16)
            
            print(f'  ✓ Sample image saved and logged to TensorBoard')
    
    print('\n✓ Training completed!')
    print(f'All outputs saved to: {run_dir}')

In [55]:
train(dataset, hparas['N_EPOCH'])


Training run: runs/20251117-134622
TensorBoard logs: runs/20251117-134622/logs
Model: WGAN-GP with 5 critic iterations


Epoch 1/100:   1%|          | 5/460 [00:19<29:32,  3.90s/batch, G_loss=-16.4033, C_loss=-87.5314, W_dist=79.2971]


KeyboardInterrupt: 


<h1><center class="subtitle">Evaluation</center></h1>

<p><code>dataset/testData.pkl</code> is a pandas dataframe containing testing text with attributes 'ID' and 'Captions'.</p>

<ul>
<li>'ID': text ID used to name generated image.</li>
<li>'Captions': text used as condition to generate image.</li>
</ul>

<p>For each captions, you need to generate <strong>inference_ID.png</strong> to evaluate quality of generated image. You must name the generated image in this format, otherwise we cannot evaluate your images.</p>




<h2 id="Testing-Dataset">Testing Dataset<a class="anchor-link" href="#Testing-Dataset">¶</a></h2>
<p>If you change anything during preprocessing of training dataset, you must make sure same operations have be done in testing dataset.</p>



In [None]:
def testing_data_generator(caption_text, index):
    """
    Updated testing data generator using DistilBERT tokenization
    
    Args:
        caption_text: Raw text string
        index: Test sample ID
    
    Returns:
        input_ids, attention_mask, index
    """
    def tokenize_caption(text):
        """Python function to tokenize text using DistilBERT tokenizer"""
        # Convert EagerTensor to bytes, then decode to string
        text = text.numpy().decode('utf-8')
        
        # Tokenize using DistilBERT
        encoded = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=64,
            return_tensors='np'
        )
        
        return encoded['input_ids'][0], encoded['attention_mask'][0]
    
    # Use tf.py_function to call Python tokenizer
    input_ids, attention_mask = tf.py_function(
        func=tokenize_caption,
        inp=[caption_text],
        Tout=[tf.int32, tf.int32]
    )
    
    # Set shapes explicitly
    input_ids.set_shape([64])
    attention_mask.set_shape([64])
    
    return input_ids, attention_mask, index

def testing_dataset_generator(batch_size, data_generator):
    """
    Updated testing dataset generator - decodes IDs to raw text
    """
    data = pd.read_pickle('./dataset/testData.pkl')
    captions_ids = data['Captions'].values
    caption_texts = []
    
    # Decode pre-tokenized IDs back to text
    for i in range(len(captions_ids)):
        chosen_caption_ids = captions_ids[i]
        
        # Decode IDs back to text using id2word_dict
        words = []
        for word_id in chosen_caption_ids:
            word = id2word_dict[str(word_id)]
            if word != '<PAD>':  # Skip padding tokens
                words.append(word)
        
        caption_text = ' '.join(words)
        caption_texts.append(caption_text)
    
    index = data['ID'].values
    index = np.asarray(index)
    
    # Create dataset from raw text
    dataset = tf.data.Dataset.from_tensor_slices((caption_texts, index))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [None]:
testing_dataset = testing_dataset_generator(hparas['BATCH_SIZE'], testing_data_generator)


In [None]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])



<h2 id="Inferece">Inferece<a class="anchor-link" href="#Inferece">¶</a></h2>



In [None]:
# Inference directory is already created by the train() function
# No need to create it again here

In [None]:
# Restore checkpoint from the current run directory
print(f'Looking for checkpoints in: {checkpoint_dir}')

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    checkpoint.restore(latest_checkpoint)
    # Extract checkpoint number from path (e.g., 'ckpt-2' -> 2)
    ckpt_num = latest_checkpoint.split('-')[-1]
    print(f'✓ Restored checkpoint: {latest_checkpoint}')
    print(f'  Checkpoint number: {ckpt_num}')
    
    # Try to infer which epoch this is (checkpoints saved every 50 epochs by default)
    # This is an estimate based on the training code
    estimated_epoch = int(ckpt_num) * 50
    print(f'  Estimated epoch: {estimated_epoch}')
else:
    print('⚠ No checkpoint found, using fresh/untrained model')
    print('  Training will start from epoch 1')

In [None]:
def inference(dataset):
    """
    Updated inference function for DistilBERT
    """
    # No hidden state needed for DistilBERT
    sample_size = hparas['BATCH_SIZE']
    sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
    
    step = 0
    start = time.time()
    total_images = 0
    
    # Progress bar for inference
    pbar = tqdm(total=NUM_TEST, desc='Generating images', unit='img')
    
    # Unpack 3 values: input_ids, attention_mask, idx
    for input_ids, attention_mask, idx in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(input_ids, attention_mask, sample_seed)
        step += 1
        
        for i in range(hparas['BATCH_SIZE']):
            plt.imsave(f'{inference_dir}/inference_{idx[i]:04d}.jpg', fake_image[i].numpy()*0.5 + 0.5)
            total_images += 1
            pbar.update(1)
    
    pbar.close()
    print(f'\n✓ Generated {total_images} images in {time.time()-start:.4f} sec')
    print(f'✓ Images saved to: {inference_dir}')

In [None]:
inference(testing_dataset)


In [None]:
# Run evaluation script to generate score.csv
# Note: This must be run from the testing directory because inception_score.py uses relative paths
# Arguments: [inference_dir] [output_csv] [batch_size]
# Batch size must be 1, 2, 3, 7, 9, 21, or 39 to avoid remainder (819 test images)

# Save score.csv inside the run directory
!cd testing && python inception_score.py ../{inference_dir}/ ../{run_dir}/score.csv 39

## Visualize Generated Images

Below we randomly sample 20 images from our generated test results to visually inspect the quality and diversity of the model's outputs.


<h1><center class="subtitle">Demo</center></h1>

<p>We demonstrate the capability of our model (TA80) to generate plausible images of flowers from detailed text descriptions.</p>



In [None]:
# Visualize 20 random generated images with their captions
import glob

# Load test data
data = pd.read_pickle('./dataset/testData.pkl')
test_captions = data['Captions'].values
test_ids = data['ID'].values

# Get all generated images from the current inference directory
image_files = sorted(glob.glob(inference_dir + '/inference_*.jpg'))

if len(image_files) == 0:
    print(f'⚠ No images found in {inference_dir}')
    print('Please run the inference cell first!')
else:
    # Randomly sample 20 images
    np.random.seed(42)  # For reproducibility
    num_samples = min(20, len(image_files))
    sample_indices = np.random.choice(len(image_files), size=num_samples, replace=False)
    sample_files = [image_files[i] for i in sorted(sample_indices)]

    # Create 4x5 grid
    fig, axes = plt.subplots(4, 5, figsize=(20, 16))
    axes = axes.flatten()

    for idx, img_path in enumerate(sample_files):
        # Extract image ID from filename
        img_id = int(Path(img_path).stem.split('_')[1])
        
        # Find caption
        caption_idx = np.where(test_ids == img_id)[0][0]
        caption_ids = test_captions[caption_idx]
        
        # Decode caption
        caption_text = ''
        for word_id in caption_ids:
            word = id2word_dict[str(word_id)]
            if word != '<PAD>':
                caption_text += word + ' '
        
        # Load and display image
        img = plt.imread(img_path)
        axes[idx].imshow(img)
        axes[idx].set_title(f'ID: {img_id}\n{caption_text[:60]}...', fontsize=8)
        axes[idx].axis('off')

    # Hide unused subplots if less than 20 images
    for idx in range(num_samples, 20):
        axes[idx].axis('off')

    plt.tight_layout()
    plt.suptitle(f'Random Sample of {num_samples} Generated Images', fontsize=16, y=1.002)
    plt.show()

    print(f'\nTotal generated images: {len(image_files)}')
    print(f'Images directory: {inference_dir}')