
<h1><center id="title">DataLab Cup 3: Reverse Image Caption</center></h1>

<center id="author">Shan-Hung Wu &amp; DataLab<br/>Fall 2025</center>




<h1><center class="subtitle">Text to Image</center></h1>

<h2 id="Platform:-Kaggle">Platform: <a href="https://www.kaggle.com/competitions/2025-datalab-cup-3-reverse-image-caption/overview">Kaggle</a><a class="anchor-link" href="#Platform:-Kaggle">¬∂</a></h2>
<h2 id="Overview">Overview<a class="anchor-link" href="#Overview">¬∂</a></h2>
<p>In this work, we are interested in translating text in the form of single-sentence human-written descriptions directly into image pixels. For example, "<strong>this flower has petals that are yellow and has a ruffled stamen</strong>" and "<strong>this pink and yellow flower has a beautiful yellow center with many stamens</strong>". You have to develop a novel deep architecture and GAN formulation to effectively translate visual concepts from characters to pixels.</p>

<p>More specifically, given a set of texts, your task is to generate reasonable images with size 64x64x3 to illustrate the corresponding texts. Here we use <a href="http://www.robots.ox.ac.uk/~vgg/data/flowers/102/">Oxford-102 flower dataset</a> and its <a href="https://drive.google.com/file/d/0B0ywwgffWnLLcms2WWJQRFNSWXM/view">paired texts</a> as our training dataset.</p>

<img alt="No description has been provided for this image" src="./data/example.png"/>

<ul>
<li>7370 images as training set, where each images is annotated with at most 10 texts.</li>
<li>819 texts for testing. You must generate 1 64x64x3 image for each text.</li>
</ul>




<h2 id="Conditional-GAN">Conditional GAN<a class="anchor-link" href="#Conditional-GAN">¬∂</a></h2>
<p>Given a text, in order to generate the image which can illustrate it, our model must meet several requirements:</p>

<ol>
<li>Our model should have ability to understand and extract the meaning of given texts.<ul>
<li>Use RNN or other language model, such as BERT, ELMo or XLNet, to capture the meaning of text.</li>
</ul>
</li>
<li>Our model should be able to generate image.<ul>
<li>Use GAN to generate high quality image.</li>
</ul>
</li>
<li>GAN-generated image should illustrate the text.<ul>
<li>Use conditional-GAN to generate image conditioned on given text.</li>
</ul>
</li>
</ol>

<p>Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information $y$. We can perform the conditioning by feeding $y$ into both the discriminator and generator as additional input layer.</p>

<img alt="No description has been provided for this image" src="./data/cGAN.png" width="500"/>

<p>There are two motivations for using some extra information in a GAN model:</p>

<ol>
<li>Improve GAN.</li>
<li>Generate targeted image.</li>
</ol>

<p>Additional information that is correlated with the input images, such as class labels, can be used to improve the GAN. This improvement may come in the form of more stable training, faster training, and/or generated images that have better quality.</p>

<img alt="No description has been provided for this image" src="./data/GANCLS.jpg"/>



In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
from tqdm import tqdm

import re
from IPython import display

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
	try:
		# Restrict TensorFlow to only use the first GPU
		tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

		# Currently, memory growth needs to be the same across GPUs
		for gpu in gpus:
			tf.config.experimental.set_memory_growth(gpu, True)
		logical_gpus = tf.config.experimental.list_logical_devices('GPU')
		print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
	except RuntimeError as e:
		# Memory growth must be set before GPUs have been initialized
		print(e)

RANDOM_SEED = 42

# Python random
import random
random.seed(RANDOM_SEED)

# NumPy random
np.random.seed(RANDOM_SEED)

# TensorFlow random
tf.random.set_seed(RANDOM_SEED)

BATCH_SIZE = 128


<h2 id="Preprocess-Text">Preprocess Text<a class="anchor-link" href="#Preprocess-Text">¬∂</a></h2>
<p>Since dealing with raw string is inefficient, we have done some data preprocessing for you:</p>

<ul>
<li>Delete text over <code>MAX_SEQ_LENGTH (20)</code>.</li>
<li>Delete all puntuation in the texts.</li>
<li>Encode each vocabulary in <code>dictionary/vocab.npy</code>.</li>
<li>Represent texts by a sequence of integer IDs.</li>
<li>Replace rare words by <code>&lt;RARE&gt;</code> token to reduce vocabulary size for more efficient training.</li>
<li>Add padding as <code>&lt;PAD&gt;</code> to each text to make sure all of them have equal length to <code>MAX_SEQ_LENGTH (20)</code>.</li>
</ul>

<p>It is worth knowing that there is no necessary to append <code>&lt;ST&gt;</code> and <code>&lt;ED&gt;</code> to each text because we don't need to generate any sequence in this task.</p>

<p>To make sure correctness of encoding of the original text, we can decode sequence vocabulary IDs by looking up the vocabulary dictionary:</p>

<ul>
<li><code>dictionary/word2Id.npy</code> is a numpy array mapping word to id.</li>
<li><code>dictionary/id2Word.npy</code> is a numpy array mapping id back to word.</li>
</ul>



In [None]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))


In [None]:
# This cell previously contained sent2IdList() function
# It has been removed as we now use CLIP tokenizer instead
# The id2word_dict is still available from cell 6 for visualization purposes

print("‚úì Using CLIP tokenizer (sent2IdList removed)")


<h2 id="Dataset">Dataset<a class="anchor-link" href="#Dataset">¬∂</a></h2>
<p>For training, the following files are in dataset folder:</p>

<ul>
<li><code>./dataset/text2ImgData.pkl</code> is a pandas dataframe with attribute 'Captions' and 'ImagePath'.<ul>
<li>'Captions' : A list of text id list contain 1 to 10 captions.</li>
<li>'ImagePath': Image path that store paired image.</li>
</ul>
</li>
<li><code>./102flowers/</code> is the directory containing all training images.</li>
<li><code>./dataset/testData.pkl</code> is a pandas a dataframe with attribute 'ID' and 'Captions', which contains testing data.</li>
</ul>



In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


In [None]:
df.head(5)



<h2 id="Create-Dataset-by-Dataset-API">Create Dataset by Dataset API<a class="anchor-link" href="#Create-Dataset-by-Dataset-API">¬∂</a></h2>



In [None]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import CLIPTokenizer

# Load CLIP Tokenizer
# "openai/clip-vit-base-patch32" is a standard, powerful model
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

def preprocess_text_clip(text, max_length=77):
		encoded = tokenizer(
				text,
				padding='max_length',
				truncation=True,
				max_length=max_length,
				return_tensors='tf'
		)
		return {
				'input_ids': encoded['input_ids'],
				'attention_mask': encoded['attention_mask']
		}

In [None]:
def DiffAugment(x, policy='color,translation,cutout', channels_first=False, params=None):
		"""
		Differentiable augmentation for GANs
		
		Args:
				x: Input images [batch, H, W, C] 
				policy: Comma-separated augmentation policies
				channels_first: If True, expects [batch, C, H, W]
				params: Optional dict of pre-generated augmentation parameters for consistency
		
		Returns:
				Augmented images
		"""
		if policy:
				if not channels_first:
						# TensorFlow format: [batch, H, W, C]
						for p in policy.split(','):
								for f in AUGMENT_FNS[p]:
										x = f(x, params)  # ‚Üê Pass params!
		return x


def rand_brightness(x, params=None):
		"""Random brightness adjustment"""
		if params is not None and 'brightness' in params:
				magnitude = params['brightness']
		else:
				magnitude = tf.random.uniform([], -0.5, 0.5)
		x = x + magnitude
		return x


def rand_saturation(x, params=None):
		"""Random saturation adjustment"""
		if params is not None and 'saturation' in params:
				magnitude = params['saturation']
		else:
				magnitude = tf.random.uniform([], 0.0, 2.0)
		x_mean = tf.reduce_mean(x, axis=-1, keepdims=True)
		x = (x - x_mean) * magnitude + x_mean
		return x


def rand_contrast(x, params=None):
		"""Random contrast adjustment"""
		if params is not None and 'contrast' in params:
				magnitude = params['contrast']
		else:
				magnitude = tf.random.uniform([], 0.5, 1.5)
		x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
		x = (x - x_mean) * magnitude + x_mean
		return x

def rand_translation(x, params=None, ratio=0.125):
		"""Random translation (shift) - Fully vectorized for @tf.function"""
		batch_size = tf.shape(x)[0]
		image_size = tf.shape(x)[1]
		shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
		
		# Random translation amounts for entire batch
		if params is not None and 'translation_x' in params:
				translation_x = params['translation_x']
				translation_y = params['translation_y']
		else:
				translation_x = tf.random.uniform([batch_size], -shift, shift + 1, dtype=tf.int32)
				translation_y = tf.random.uniform([batch_size], -shift, shift + 1, dtype=tf.int32)
		
		def translate_single_image(args):
				"""Translate a single image"""
				img, tx, ty = args
				img = tf.pad(img, [[shift, shift], [shift, shift], [0, 0]], mode='REFLECT')
				img = tf.image.crop_to_bounding_box(img, shift + ty, shift + tx, image_size, image_size)
				return img
		
		# Use tf.map_fn (graph-mode compatible)
		x_translated = tf.map_fn(
				translate_single_image,
				(x, translation_x, translation_y),
				fn_output_signature=tf.TensorSpec(shape=[64, 64, 3], dtype=tf.float32),
				parallel_iterations=BATCH_SIZE
		)
		
		return x_translated


def rand_cutout(x, params=None, ratio=0.2):
		"""
		Random cutout - SIMPLIFIED vectorized version
		
		Instead of complex per-pixel masking, we create rectangular masks
		using broadcasting and boolean operations
		"""
		batch_size = tf.shape(x)[0]
		image_size = tf.shape(x)[1]
		channels = tf.shape(x)[3]
		
		# Cutout size
		cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
		
		# Random offset for cutout location
		if params is not None and 'cutout_x' in params:
				offset_x = params['cutout_x']
				offset_y = params['cutout_y']
		else:
				offset_x = tf.random.uniform([batch_size], 0, image_size - cutout_size + 1, dtype=tf.int32)
				offset_y = tf.random.uniform([batch_size], 0, image_size - cutout_size + 1, dtype=tf.int32)
		
		def cutout_single_image(args):
				"""Apply cutout to single image using simple slicing"""
				img, ox, oy = args
				
				# Create coordinate grids
				height_range = tf.range(image_size)
				width_range = tf.range(image_size)
				
				# Create 2D grids
				yy, xx = tf.meshgrid(height_range, width_range, indexing='ij')
				
				# Create mask: True where we want to KEEP pixels
				mask_y = tf.logical_or(yy < oy, yy >= oy + cutout_size)
				mask_x = tf.logical_or(xx < ox, xx >= ox + cutout_size)
				mask = tf.logical_or(mask_y, mask_x)
				
				# Expand mask to all channels
				mask = tf.expand_dims(mask, axis=-1)  # [H, W, 1]
				mask = tf.tile(mask, [1, 1, channels])  # [H, W, C]
				
				# Apply mask (convert bool to float)
				mask = tf.cast(mask, tf.float32)
				return img * mask
		
		# Use tf.map_fn
		x_cutout = tf.map_fn(
				cutout_single_image,
				(x, offset_x, offset_y),
				fn_output_signature=tf.TensorSpec(shape=[64, 64, 3], dtype=tf.float32),
				parallel_iterations=BATCH_SIZE
		)
		
		return x_cutout


# Augmentation function registry
AUGMENT_FNS = {
		'color': [rand_brightness, rand_saturation, rand_contrast],
		'translation': [rand_translation],
		'cutout': [rand_cutout],
}


print("‚úì DiffAugment functions loaded")
print("  Policies available: color, translation, cutout")

In [None]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3

def training_data_generator(caption_text, image_path):
		"""
		Updated data generator using CLIP tokenization
		
		Args:
				caption_text: Raw text string (not IDs!)
				image_path: Path to image file
		
		Returns:
				img, input_ids, attention_mask
		"""
		# ============= IMAGE PROCESSING (same as before) =============
		img = tf.io.read_file(image_path)
		img = tf.image.decode_image(img, channels=3)
		img = tf.image.convert_image_dtype(img, tf.float32)  # [0, 1]
		img.set_shape([None, None, 3])
		img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
		
		# Normalize to [-1, 1] to match generator's tanh output
		img = (img * 2.0) - 1.0
		img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
		
		# ============= TEXT PROCESSING (NEW: Use CLIP tokenizer) =============
		def tokenize_caption_clip(text):
				"""Python function to tokenize text using CLIP tokenizer"""
				# Convert EagerTensor to bytes, then decode to string
				text = text.numpy().decode('utf-8')
				
				# Tokenize using CLIP
				encoded = tokenizer(
						text,
						padding='max_length',
						truncation=True,
						max_length=77,
						return_tensors='np'  # Use numpy arrays for TF compatibility
				)
				
				return encoded['input_ids'][0], encoded['attention_mask'][0]
		
		# Use tf.py_function to call Python tokenizer
		input_ids, attention_mask = tf.py_function(
				func=tokenize_caption_clip,
				inp=[caption_text],
				Tout=[tf.int32, tf.int32]
		)
		
		# Set shapes explicitly for CLIP
		input_ids.set_shape([77])
		attention_mask.set_shape([77])
		
		return img, input_ids, attention_mask

def dataset_generator(filenames, batch_size, data_generator):
		"""
		Updated dataset generator to work with raw text (decoded from IDs)
		"""
		# Load the training data
		df = pd.read_pickle(filenames)
		captions_ids = df['Captions'].values
		caption_texts = []
		
		# Decode pre-tokenized IDs back to raw text
		for i in range(len(captions_ids)):
				# Randomly choose one caption (list of ID lists)
				chosen_caption_ids = random.choice(captions_ids[i])
				
				# Decode IDs back to text using id2word_dict
				words = []
				for word_id in chosen_caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':  # Skip padding tokens
								words.append(word)
				
				caption_text = ' '.join(words)
				caption_texts.append(caption_text)
		
		image_paths = df['ImagePath'].values
		
		# Verify same length
		assert len(caption_texts) == len(image_paths)
		
		# Create dataset from raw text and image paths
		dataset = tf.data.Dataset.from_tensor_slices((caption_texts, image_paths))
		dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
		dataset = dataset.cache()
		dataset = dataset.shuffle(len(caption_texts)).batch(batch_size, drop_remainder=True)
		dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

		return dataset

In [None]:
dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator)


<h2 id="Conditional-GAN-Model">Conditional GAN Model<a class="anchor-link" href="#Conditional-GAN-Model">¬∂</a></h2>
<p>As mentioned above, there are three models in this task, text encoder, generator and discriminator.</p>

<h2 id="Text-Encoder">Text Encoder<a class="anchor-link" href="#Text-Encoder">¬∂</a></h2>
<p>A RNN encoder that captures the meaning of input text.</p>

<ul>
<li>Input: text, which is a list of ids.</li>
<li>Output: embedding, or hidden representation of input text.</li>
</ul>



In [None]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import TFCLIPTextModel, TFCLIPModel

class ClipTextEncoder(tf.keras.Model):
		def __init__(self, output_dim=512, freeze_clip=True):
				super(ClipTextEncoder, self).__init__()
				# Load Pre-trained CLIP Text Model
				self.clip = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
				
				if freeze_clip:
						self.clip.trainable = False
						
			# REMOVED: Projection, LayerNorm, Dropout to ensure RAW embeddings
	
		def call(self, input_ids, attention_mask, training=False):
			# 1. Get the projected features (Aligned with images, e.g., 512-dim)
			text_embeds = self.clip.get_text_features(
				input_ids=input_ids, 
				attention_mask=attention_mask
			)
			
			# 2. CRITICAL FIX: Manually normalize to get the actual CLIP embedding
			# CLIP uses cosine similarity, so vectors must be unit length.
			text_embeds = tf.math.l2_normalize(text_embeds, axis=1)
			
			return text_embeds
				


<h2 id="Generator">Generator<a class="anchor-link" href="#Generator">¬∂</a></h2>
<p>A image generator which generates the target image illustrating the input text.</p>

<ul>
<li>Input: hidden representation of input text and random noise z with random seed.</li>
<li>Output: target image, which is conditioned on the given text, in size 64x64x3.</li>
</ul>



In [None]:
# Weight initialization as per DCGAN paper
def dcgan_weight_init():
		return tf.keras.initializers.HeNormal()
class Generator(tf.keras.Model):
    def __init__(self, hparas):
        super(Generator, self).__init__()
        self.hparas = hparas
        
        # 1. Initialize Weights (He Normal is better for WGAN-GP)
        init = tf.keras.initializers.HeNormal()
        
        # ---------------------------------------------------------
        # [NEW] Text Conditioning Projection ("The Translator")
        # Maps the 512-dim unit-vector from CLIP to a learned 128-dim space
        # that allows the Generator to "understand" the instruction.
        # ---------------------------------------------------------
        self.text_projection = tf.keras.Sequential([
            tf.keras.layers.Dense(128, kernel_initializer=init),
            tf.keras.layers.LeakyReLU(0.2)
        ])
        

        self.dense = tf.keras.layers.Dense(4 * 4 * 512, use_bias=False, kernel_initializer=init)
        self.bn0 = tf.keras.layers.BatchNormalization()
        
        # 3. Upsample Blocks
        self.up1 = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
        self.conv1 = tf.keras.layers.Conv2D(256, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer=init)
        self.bn1 = tf.keras.layers.BatchNormalization()
        
        self.up2 = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
        self.conv2 = tf.keras.layers.Conv2D(128, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer=init)
        self.bn2 = tf.keras.layers.BatchNormalization()
        
        self.up3 = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
        self.conv3 = tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer=init)
        self.bn3 = tf.keras.layers.BatchNormalization()
        
        # 4. Final Layer
        self.up4 = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
        # Note: use_bias=True here is critical (as discussed before)
        self.conv4 = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, padding='same', use_bias=True, kernel_initializer=init)
        
    def call(self, text_embedding, noise_z, training=True):
        # ---------------------------------------------------------
        # Step 1: Process the Text
        # text_embedding shape: [Batch, 512] (Normalized)
        # text_feat shape:      [Batch, 128] (Unbounded, Learnable)
        # ---------------------------------------------------------
        text_feat = self.text_projection(text_embedding)
        
        # Step 2: Concatenate with Noise
        # noise_z shape: [Batch, 100]
        # x shape:       [Batch, 228]
        x = tf.concat([noise_z, text_feat], axis=1)
        
        # Step 3: Project to 4x4x512
        x = self.dense(x)
        x = self.bn0(x, training=training)
        x = tf.nn.relu(x)
        x = tf.reshape(x, [-1, 4, 4, 512])
        
        # Step 4: Upsampling
        x = self.up1(x)
        x = self.conv1(x)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.up2(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.up3(x)
        x = self.conv3(x)
        x = self.bn3(x, training=training)
        x = tf.nn.relu(x)
        
        # Step 5: Final Generation
        x = self.up4(x)
        output = self.conv4(x)
        output = tf.nn.tanh(output)
        
        return x, output


<h2 id="Discriminator">Discriminator<a class="anchor-link" href="#Discriminator">¬∂</a></h2>
<p>A binary classifier which can discriminate the real and fake image:</p>

<ol>
<li>Real image<ul>
<li>Input: real image and the paired text</li>
<li>Output: a floating number representing the result, which is expected to be 1.</li>
</ul>
</li>
<li>Fake Image<ul>
<li>Input: generated image and paired text</li>
<li>Output: a floating number representing the result, which is expected to be 0.</li>
</ul>
</li>
</ol>



In [None]:
class Critic(tf.keras.Model):
		"""
		Projection Discriminator with Stability Fixes
		"""
		def __init__(self, hparas):
				super(Critic, self).__init__()
				self.hparas = hparas
				init = tf.keras.initializers.HeNormal()
				
				# --- IMAGE PATH ---
				# 64 -> 32
				self.conv1 = tf.keras.layers.Conv2D(64, 4, 2, padding='same', kernel_initializer=init)
				#self.ln1 = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
				
				# 32 -> 16
				self.conv2 = tf.keras.layers.Conv2D(128, 4, 2, padding='same', kernel_initializer=init)
				self.ln2 = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
				
				# 16 -> 8
				self.conv3 = tf.keras.layers.Conv2D(256, 4, 2, padding='same', kernel_initializer=init)
				self.ln3 = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
				
				# 8 -> 4
				self.conv4 = tf.keras.layers.Conv2D(512, 4, 2, padding='same', kernel_initializer=init)
				self.ln4 = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
				
				# --- TEXT PATH ---
				# Project text to match the depth of image features (512 channels)
				self.text_dense = tf.keras.layers.Dense(512, kernel_initializer=tf.keras.initializers.Orthogonal())
				self.text_dense_ln = tf.keras.layers.LayerNormalization(axis=-1)
				
				# --- SCORING ---
				# "Realism" score (f(x))
				self.flatten = tf.keras.layers.Flatten()
				self.disc_realism = tf.keras.layers.Dense(1, kernel_initializer=init)

		def call(self, img, text, training=True):
				# 1. Extract Image Features
				x = tf.nn.leaky_relu(self.conv1(img), alpha=0.2)
				x = tf.nn.leaky_relu(self.ln2(self.conv2(x), training=training), alpha=0.2)
				x = tf.nn.leaky_relu(self.ln3(self.conv3(x), training=training), alpha=0.2)
				x = tf.nn.leaky_relu(self.ln4(self.conv4(x), training=training), alpha=0.2)
				
				# x shape: [Batch, 4, 4, 512]
				
				# 2. Process Text Features (psi(y))
				psi_y = self.text_dense(text) 
				psi_y = self.text_dense_ln(psi_y, training=training)
				psi_y = tf.nn.leaky_relu(psi_y, alpha=0.2)
				
				# 3. PROJECTION SCORE (Alignment)
				# Reshape text to [Batch, 1, 1, 512]
				# psi_y_reshaped = tf.reshape(psi_y, [-1, 1, 1, 512])
				
				# Global Sum Pooling of image features for projection dot product
				# Note: Original used sum([1,2]), optimized uses reduce_sum on phi_x * psi_y.
				# Both are mathematically similar, but let's stick to the stable explicit dot product:
				
				# Get global image vector
				img_vec = tf.reduce_mean(x, axis=[1, 2]) # [Batch, 512]
				
				# Dot product
				alignment_score = tf.reduce_sum(img_vec * psi_y, axis=1, keepdims=True)
				
				# 4. REALISM SCORE
				flat_img = self.flatten(x)
				realism_score = self.disc_realism(flat_img)
				
				# 5. TOTAL SCORE
				return realism_score + alignment_score

In [None]:
hparas = {
		'MAX_SEQ_LENGTH': 77,          # CLIP Standard
		'EMBED_DIM': 512,              # CLIP Output
		'VOCAB_SIZE': 49408,           # CLIP Vocab Size (approx)
		'RNN_HIDDEN_SIZE': 512,        # Target projection size (Updated to 512 for Raw CLIP)
		'Z_DIM': 512,
		'DENSE_DIM': 128,
		'IMAGE_SIZE': [64, 64, 3],
		
		'BATCH_SIZE': BATCH_SIZE,      # Fix: Use Global Variable
		'LR': 2e-4,
		'BETA_1': 0.0,
		'BETA_2': 0.9,
		'N_CRITIC': 5,
		'LAMBDA_GP': 10.0,
		'LAMBDA_MISMATCH': 1.0,
		
		'LR_DECAY_START': 50,
		'LR_DECAY_EVERY': 10,
		'LR_DECAY_FACTOR': 0.95,
		'LR_MIN': 1e-5,
		
		'USE_DIFFAUG': True,
		'DIFFAUG_POLICY': 'translation', # Conservative augmentation
		
		'N_EPOCH': 1000,
		'N_SAMPLE': num_training_sample,
		'PRINT_FREQ': 5
}

print(f"‚úì Hyperparameters updated:")
print(f"  Batch size: {hparas['BATCH_SIZE']}")
print(f"  Learning rate: {hparas['LR']}")
print(f"  N_Critic: {hparas['N_CRITIC']}")
print(f"  Lambda_GP: {hparas['LAMBDA_GP']}")
print(f"  DiffAugment: {hparas['USE_DIFFAUG']} ({hparas['DIFFAUG_POLICY']})")

In [None]:
text_encoder = ClipTextEncoder(output_dim=hparas['RNN_HIDDEN_SIZE'], freeze_clip=True)
generator = Generator(hparas)
critic = Critic(hparas)


<h2 id="Loss-Function-and-Optimization">Loss Function and Optimization<a class="anchor-link" href="#Loss-Function-and-Optimization">¬∂</a></h2>
<p>Although the conditional GAN model is quite complex, the loss function used to optimize the network is relatively simple. Actually, it is simply a binary classification task, thus we use cross entropy as our loss.</p>



In [None]:
def wasserstein_loss_critic(real_scores, fake_scores):
		"""
		Wasserstein loss for critic
		Critic wants to maximize: E[critic(real)] - E[critic(fake)]
		So we minimize: E[critic(fake)] - E[critic(real)]
		"""
		return tf.reduce_mean(fake_scores) - tf.reduce_mean(real_scores)

def mismatch_loss_critic(real_scores, wrong_scores, margin=1.0):
		"""
		Hinge loss for GAN-CLS (mismatched pairs).
		Wants: real_scores > wrong_scores + margin
		Minimizes: max(0, margin + wrong_scores - real_scores)
		"""
		loss = tf.nn.relu(margin + wrong_scores - real_scores)
		return tf.reduce_mean(loss)

def wasserstein_loss_generator(fake_scores):
		"""
		Wasserstein loss for generator
		Generator wants to maximize: E[critic(fake)]
		So we minimize: -E[critic(fake)]
		"""
		return -tf.reduce_mean(fake_scores)

def gradient_penalty(critic, real_images, fake_images, text_embed, batch_size, 
                     diffaug_policy=None, aug_params=None):
    """
    Computes Gradient Penalty with CORRECT DiffAugment application.
    
    CORRECTED LOGIC:
    1. Augment Real and Fake images FIRST.
    2. Interpolate between the AUGMENTED images.
    3. Compute gradients w.r.t the INTERPOLATION.
    """
    
    # 1. Apply DiffAugment to RAW images FIRST
    # We essentially treat the augmented images as our "training data" for this step.
    if diffaug_policy is not None and aug_params is not None:
        real_images_used = DiffAugment(real_images, policy=diffaug_policy, params=aug_params)
        fake_images_used = DiffAugment(fake_images, policy=diffaug_policy, params=aug_params)
    else:
        real_images_used = real_images
        fake_images_used = fake_images

    # 2. Interpolate between the AUGMENTED images
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    interpolated = alpha * real_images_used + (1.0 - alpha) * fake_images_used
    
    with tf.GradientTape() as gp_tape:
        # 3. Watch the INTERPOLATED (Augmented) image
        gp_tape.watch(interpolated)
            
        # 4. Critic Pass (No further augmentation needed)
        interpolated_scores = critic(interpolated, text_embed, training=True)
    
    # 5. Compute Gradients w.r.t INTERPOLATED input
    # Now we are checking the slope of the Critic on the augmented manifold.
    gradients = gp_tape.gradient(interpolated_scores, [interpolated])[0]
    
    # 6. Compute Norm
    gradients_sqr_sum = tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3])
    gradients_norm = tf.sqrt(gradients_sqr_sum + 1e-12)
    
    # 7. Penalty
    return tf.reduce_mean(tf.square(gradients_norm - 1.0))

In [None]:
# WGAN-GP: Use Adam with beta_1=0.0, beta_2=0.9
generator_optimizer = tf.keras.optimizers.Adam(
		learning_rate=hparas['LR'],
		beta_1=hparas['BETA_1'],
		beta_2=hparas['BETA_2']
)

critic_optimizer = tf.keras.optimizers.Adam(
		learning_rate=hparas['LR'],
		beta_1=hparas['BETA_1'],
		beta_2=hparas['BETA_2']
)

In [None]:
checkpoint = tf.train.Checkpoint(
		generator_optimizer=generator_optimizer,
		critic_optimizer=critic_optimizer,
		text_encoder=text_encoder,
		generator=generator,
		critic=critic
)

In [None]:
def calculate_wasserstein_distance(real_scores, fake_scores):
		"""
		Approximation of Wasserstein distance
		Higher is better (critic getting better at distinguishing)
		"""
		return tf.reduce_mean(real_scores) - tf.reduce_mean(fake_scores)

def calculate_gradient_norm(gradients):
		"""Calculate L2 norm of gradients"""
		squared_norms = [tf.reduce_sum(tf.square(g)) for g in gradients if g is not None]
		total_norm = tf.sqrt(tf.reduce_sum(squared_norms))
		return total_norm

In [None]:
@tf.function
def train_step(real_image, input_ids, attention_mask):
    batch_size = tf.shape(real_image)[0]
    
    # --- Helper to generate consistent Augment Params ---
    def get_aug_params(bs):
        if not hparas['USE_DIFFAUG']: return None
        image_size = 64
        shift = tf.cast(image_size * 0.125 + 0.5, tf.int32)
        cutout_size = tf.cast(image_size * 0.5 + 0.5, tf.int32)
        
        return {
            'brightness': tf.random.uniform([], -0.5, 0.5),
            'saturation': tf.random.uniform([], 0.0, 2.0),
            'contrast': tf.random.uniform([], 0.5, 1.5),
            'translation_x': tf.random.uniform([bs], -shift, shift + 1, dtype=tf.int32),
            'translation_y': tf.random.uniform([bs], -shift, shift + 1, dtype=tf.int32),
            'cutout_x': tf.random.uniform([bs], 0, image_size - cutout_size + 1, dtype=tf.int32),
            'cutout_y': tf.random.uniform([bs], 0, image_size - cutout_size + 1, dtype=tf.int32),
        }

    # ============================================================
    # 1. Train Critic
    # ============================================================
    # Initialize variables to ensure they exist after the loop
    c_loss = 0.0
    c_loss_w = 0.0
    c_loss_mismatch = 0.0
    gp = 0.0
    
    for _ in range(hparas['N_CRITIC']):
        noise = tf.random.normal([batch_size, hparas['Z_DIM']])
        aug_params = get_aug_params(batch_size)
        
        with tf.GradientTape() as critic_tape:
            # --- Text Embeddings ---
            text_embed = text_encoder(input_ids, attention_mask, training=False)
            text_embed = tf.math.l2_normalize(text_embed, axis=1)
            text_embed = tf.stop_gradient(text_embed)
            
            wrong_text_embed = tf.roll(text_embed, shift=1, axis=0)
            
            # --- Generate Fake (RAW) ---
            _, fake_image = generator(text_embed, noise, training=True)
            
            # --- Augment ---
            if hparas['USE_DIFFAUG']:
                real_image_aug = DiffAugment(real_image, policy=hparas['DIFFAUG_POLICY'], params=aug_params)
                fake_image_aug = DiffAugment(fake_image, policy=hparas['DIFFAUG_POLICY'], params=aug_params)
            else:
                real_image_aug = real_image
                fake_image_aug = fake_image
            
            # --- Scores ---
            real_scores = critic(real_image_aug, text_embed, training=True)
            fake_scores = critic(fake_image_aug, text_embed, training=True)
            wrong_scores = critic(real_image_aug, wrong_text_embed, training=True)
            
            # --- Losses ---
            c_loss_w = wasserstein_loss_critic(real_scores, fake_scores)
            c_loss_mismatch = mismatch_loss_critic(real_scores, wrong_scores)
            
            # --- Gradient Penalty (Pass RAW images) ---
            gp = gradient_penalty(
                critic, 
                real_image, 
                fake_image, 
                text_embed, 
                batch_size,
                diffaug_policy=hparas['DIFFAUG_POLICY'] if hparas['USE_DIFFAUG'] else None,
                aug_params=aug_params
            )
            
            c_loss = c_loss_w + hparas['LAMBDA_GP'] * gp + hparas['LAMBDA_MISMATCH'] * c_loss_mismatch
        
        # Apply Gradients
        grad_c = critic_tape.gradient(c_loss, critic.trainable_variables)
        critic_optimizer.apply_gradients(zip(grad_c, critic.trainable_variables))
    
    # ============================================================
    # 2. Train Generator
    # ============================================================
    noise = tf.random.normal([batch_size, hparas['Z_DIM']])
    gen_aug_params = get_aug_params(batch_size)
    
    with tf.GradientTape() as gen_tape:
        text_embed = text_encoder(input_ids, attention_mask, training=False)
        text_embed = tf.math.l2_normalize(text_embed, axis=1)
        text_embed = tf.stop_gradient(text_embed)
        
        _, fake_image = generator(text_embed, noise, training=True)
        
        if hparas['USE_DIFFAUG']:
            fake_image_aug = DiffAugment(fake_image, policy=hparas['DIFFAUG_POLICY'], params=gen_aug_params)
        else:
            fake_image_aug = fake_image
            
        fake_scores = critic(fake_image_aug, text_embed, training=True)
        g_loss = wasserstein_loss_generator(fake_scores)

    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    
    # ============================================================
    # 3. Calculate Metrics
    # ============================================================
    wasserstein_dist = calculate_wasserstein_distance(real_scores, fake_scores)
    
    # Calculate Gradient Norms for monitoring stability
    # Filter out None gradients (frozen layers)
    valid_grads_c = [g for g in grad_c if g is not None]
    valid_grads_g = [g for g in grad_g if g is not None]
    
    grad_norm_c = tf.linalg.global_norm(valid_grads_c)
    grad_norm_g = tf.linalg.global_norm(valid_grads_g)
    
    # Return EVERYTHING needed for the train loop logging
    return {
        'g_loss': g_loss,
        'c_loss': c_loss,
        'c_loss_wasserstein': c_loss_w,
        'c_loss_mismatch': c_loss_mismatch,
        'gp': gp,
        'wasserstein_dist': wasserstein_dist,
        'grad_norm_g': grad_norm_g,
        'grad_norm_c': grad_norm_c,
        'lr_g': generator_optimizer.learning_rate,
        'lr_c': critic_optimizer.learning_rate
    }

In [None]:
@tf.function
def test_step(input_ids, attention_mask, noise):
		# Encode text with DistilBERT (no hidden state)
		text_embed = text_encoder(input_ids, attention_mask, training=False)
		_, fake_image = generator(text_embed, noise, training=False)
		return fake_image


<h2 id="Visualiztion">Visualiztion<a class="anchor-link" href="#Visualiztion">¬∂</a></h2>
<p>During training, we can visualize the generated image to evaluate the quality of generator. The followings are some functions helping visualization.</p>



In [None]:
def merge(images, size):
		h, w = images.shape[1], images.shape[2]
		img = np.zeros((h * size[0], w * size[1], 3))
		for idx, image in enumerate(images):
				i = idx % size[1]
				j = idx // size[1]
				img[j*h:j*h+h, i*w:i*w+w, :] = image
		return img

def imsave(images, size, path):
		# getting the pixel values between [0, 1] to save it
		return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
		return imsave(images, size, image_path)



<p>We always use same random seed and same senteces during training, which is more convenient for us to evaluate the quality of generated image.</p>



In [None]:
# Create sample data for visualization during training
# IMPORTANT: All three variables must have the same batch size!

sample_size = BATCH_SIZE  # Current: 32
ni = int(np.ceil(np.sqrt(sample_size)))  # Grid size for visualization

# Create random noise seed
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)

# Define 8 diverse sample sentences
base_sentences = [
		"the flower shown has yellow anther red pistil and bright red petals.",
		"this flower has petals that are yellow, white and purple and has dark lines",
		"the petals on this flower are white with a yellow center",
		"this flower has a lot of small round pink petals.",
		"this flower is orange in color, and has petals that are ruffled and rounded.",
		"the flower has yellow petals and the center of it is brown.",
		"this flower has petals that are blue and white.",
		"these white flowers have petals that start off white in color and end in a white towards the tips."
]

# Repeat sentences to match sample_size (batch size)
sample_sentences = []
for i in range(sample_size):
		sample_sentences.append(base_sentences[i % len(base_sentences)])

# Tokenize with CLIP
sample_encoded = preprocess_text_clip(sample_sentences, max_length=77)
sample_input_ids = sample_encoded['input_ids']
sample_attention_mask = sample_encoded['attention_mask']

# Verify all dimensions match!
print(f"‚úì Sample data created:")
print(f"  Batch size: {sample_size}")
print(f"  Grid size (ni): {ni} √ó {ni} = {ni*ni}")
print(f"  Sample sentences: {len(sample_sentences)} sentences")
print(f"  sample_seed shape: {sample_seed.shape}")
print(f"  sample_input_ids shape: {sample_input_ids.shape}")
print(f"  sample_attention_mask shape: {sample_attention_mask.shape}")

# Check for dimension mismatches
assert len(sample_sentences) == sample_size, f"Mismatch: {len(sample_sentences)} != {sample_size}"
assert sample_seed.shape[0] == sample_size, f"Mismatch: {sample_seed.shape[0]} != {sample_size}"
assert sample_input_ids.shape[0] == sample_size, f"Mismatch: {sample_input_ids.shape[0]} != {sample_size}"
assert sample_attention_mask.shape[0] == sample_size, f"Mismatch: {sample_attention_mask.shape[0]} != {sample_size}"
print("‚úì All dimensions match!")

In [None]:
# Helper functions for managing training runs
import glob
import json

def list_available_runs():
		"""List all available training runs with their details"""
		run_dirs = sorted(glob.glob('runs/*/'))
		
		if not run_dirs:
				print('No training runs found in runs/ directory')
				return []
		
		print('=' * 80)
		print('Available Training Runs:')
		print('=' * 80)
		
		available_runs = []
		for run_dir in run_dirs:
				timestamp = run_dir.split('/')[-2]
				available_runs.append(timestamp)
				
				# Check for checkpoints
				checkpoint_dir = f'{run_dir}checkpoints'
				latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
				has_checkpoint = '‚úì' if latest_ckpt else '‚úó'
				
				# Check for config
				config_path = f'{run_dir}config.json'
				has_config = '‚úì' if os.path.exists(config_path) else '‚úó'
				
				# Count sample images
				sample_count = len(glob.glob(f'{run_dir}samples/*.jpg'))
				
				print(f'{timestamp}  |  Checkpoint: {has_checkpoint}  |  Config: {has_config}  |  Samples: {sample_count}')
				
				if latest_ckpt:
						print(f'  ‚îî‚îÄ Latest checkpoint: {latest_ckpt}')
		
		print('=' * 80)
		return available_runs

def load_run_config(run_timestamp):
		"""Load configuration from a previous run"""
		config_path = f'runs/{run_timestamp}/config.json'
		
		if not os.path.exists(config_path):
				raise FileNotFoundError(f'Config not found: {config_path}')
		
		with open(config_path, 'r') as f:
				config = json.load(f)
		
		print(f'‚úì Loaded config from: {config_path}')
		return config

# List available runs
print('Checking for existing training runs...')
list_available_runs()


<h2 id="Training">Training<a class="anchor-link" href="#Training">¬∂</a></h2>



In [None]:
from datetime import datetime

# ============================================================
# RESUME TRAINING CONFIGURATION
# ============================================================
# Set to None for new run, or specify run timestamp to resume
# Example: RESUME_RUN = '20251116-225453'
RESUME_RUN = None # ‚Üê Change this to resume from specific run

# ============================================================
# RUN DIRECTORY SETUP
# ============================================================
if RESUME_RUN:
		# Resume from existing run
		run_dir = f'runs/{RESUME_RUN}'
		
		# Verify directory exists
		if not os.path.exists(run_dir):
				raise FileNotFoundError(f'Run directory not found: {run_dir}')
		
		# Load existing config
		try:
				prev_config = load_run_config(RESUME_RUN)
				run_timestamp = prev_config.get('run_timestamp', RESUME_RUN)
				print(f'\n‚ü≥ RESUMING training from: {run_dir}')
				print(f'  Original start: {run_timestamp}')
				
				# Warn if hyperparameters might be different
				if 'hyperparameters' in prev_config:
						prev_hparas = prev_config['hyperparameters']
						if prev_hparas.get('BATCH_SIZE') != BATCH_SIZE:
								print(f'  ‚ö† WARNING: Batch size changed ({prev_hparas.get("BATCH_SIZE")} ‚Üí {hparas["BATCH_SIZE"]})')
						if prev_hparas.get('LR') != hparas['LR']:
								print(f'  ‚ö† WARNING: Learning rate changed ({prev_hparas.get("LR")} ‚Üí {hparas["LR"]})')
		except Exception as e:
				print(f'‚ö† Could not load previous config: {e}')
				run_timestamp = RESUME_RUN
		
		# Use existing subdirectories
		checkpoint_dir = f'{run_dir}/checkpoints'
		best_models_dir = f'{run_dir}/best_models'
		samples_dir = f'{run_dir}/samples'
		inference_dir = f'{run_dir}/inference'
		
		# Create directories if they don't exist (shouldn't happen, but safety check)
		os.makedirs(checkpoint_dir, exist_ok=True)
		os.makedirs(best_models_dir, exist_ok=True)
		os.makedirs(samples_dir, exist_ok=True)
		os.makedirs(inference_dir, exist_ok=True)
		
else:
		# Create new run
		run_timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
		run_dir = f'runs/{run_timestamp}'
		
		# All outputs for this run go in subdirectories
		checkpoint_dir = f'{run_dir}/checkpoints'
		best_models_dir = f'{run_dir}/best_models'
		samples_dir = f'{run_dir}/samples'
		inference_dir = f'{run_dir}/inference'
		
		# Create all directories
		os.makedirs(checkpoint_dir, exist_ok=True)
		os.makedirs(best_models_dir, exist_ok=True)
		os.makedirs(samples_dir, exist_ok=True)
		os.makedirs(inference_dir, exist_ok=True)
		
		print(f'‚úì Created NEW run directory: {run_dir}')
		
		# Save hyperparameters
		config_to_save = {
				'run_timestamp': run_timestamp,
				'hyperparameters': hparas,
		}
		with open(f'{run_dir}/config.json', 'w') as f:
				json.dump(config_to_save, f, indent=4)
		print(f'‚úì Saved configuration to: {run_dir}/config.json')


# Display directory structure
print(f'\nRun directory structure:')
print(f'  {run_dir}/')
print(f'  ‚îú‚îÄ‚îÄ checkpoints/  : {checkpoint_dir}')
print(f'  ‚îú‚îÄ‚îÄ best_models/  : {best_models_dir}')
print(f'  ‚îú‚îÄ‚îÄ samples/      : {samples_dir}')
print(f'  ‚îî‚îÄ‚îÄ inference/    : {inference_dir}')

In [None]:
# ============================================================
# RESTORE CHECKPOINT FOR RESUMING TRAINING
# ============================================================
if RESUME_RUN:
		# When resuming, restore the LATEST regular checkpoint (not best model)
		# This ensures training continues from where it left off
		print(f'\nRestoring checkpoint for resuming training...')
		print(f'Looking for latest checkpoint in: {checkpoint_dir}')
		
		latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
		if latest_checkpoint:
				checkpoint.restore(latest_checkpoint).expect_partial()
				ckpt_num = latest_checkpoint.split('-')[-1]
				print(f'‚úì Restored latest checkpoint: {latest_checkpoint}')
				print(f'  Checkpoint number: {ckpt_num}')
				print(f'  Training will continue from this point')
				
				# Also check if best model exists
				best_checkpoint = tf.train.latest_checkpoint(best_models_dir)
				if best_checkpoint:
						print(f'\n‚úì Best model also available at: {best_checkpoint}')
		else:
				print('‚ö† No checkpoint found in the run directory')
				print('  Training will start from epoch 1 (this is unusual for RESUME mode)')
else:
		print('\n‚úì Starting NEW training run - no checkpoint restoration needed')


In [None]:
import subprocess
import sys
import time
def train(dataset, epochs):
		global run_dir, checkpoint_dir, best_models_dir, samples_dir, inference_dir
		
		checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
		best_model_prefix = os.path.join(best_models_dir, "best_ckpt")
		
		log_dir = f'{run_dir}/logs'
		os.makedirs(log_dir, exist_ok=True)
		summary_writer = tf.summary.create_file_writer(log_dir)
	
		print(f"Run tensorboard --logdir {log_dir}")
		print(f'Model: WGAN-GP with {hparas["N_CRITIC"]} critic iterations')
		print(f'DiffAugment: {hparas["USE_DIFFAUG"]} ({hparas.get("DIFFAUG_POLICY", "N/A")})')
		
		steps_per_epoch = int(hparas['N_SAMPLE']/BATCH_SIZE)
		global_step = 0
		

		try:
				# Using sys.executable ensures we use tensorboard from the correct python env
				tensorboard_process = subprocess.Popen([
						sys.executable, "-m", "tensorboard.main", "--logdir", log_dir
				])
				print(f"‚úì TensorBoard launched as a background process (PID: {tensorboard_process.pid}).")
				print("  It might take a few seconds to become available in your browser.")
		except Exception as e:
				print(f"‚ö† Could not start TensorBoard automatically: {e}")
				print(f"  You can start it manually by running: tensorboard --logdir {log_dir}")

		# ========== EARLY STOPPING SETUP ==========
		best_wasserstein_dist = float('inf')
		patience = 300
		patience_counter = 0
		# ==========================================
		
		for epoch in range(hparas['N_EPOCH']):
				# ========== LEARNING RATE DECAY ==========
				if epoch >= hparas['LR_DECAY_START'] and epoch % hparas['LR_DECAY_EVERY'] == 0:
						current_lr_g = generator_optimizer.learning_rate.numpy()
						current_lr_c = critic_optimizer.learning_rate.numpy()
						
						new_lr_g = max(current_lr_g * hparas['LR_DECAY_FACTOR'], hparas['LR_MIN'])
						new_lr_c = max(current_lr_c * hparas['LR_DECAY_FACTOR'], hparas['LR_MIN'])
						
						generator_optimizer.learning_rate.assign(new_lr_g)
						critic_optimizer.learning_rate.assign(new_lr_c)
						
						print(f'  üìâ LR Decay: G={new_lr_g:.2e}, C={new_lr_c:.2e}')
				# ==========================================
				
				g_total_loss = 0
				c_total_loss = 0
				c_total_loss_wasserstein = 0
				gp_total = 0
				wd_total = 0
				start = time.time()
				
				pbar = tqdm(dataset, desc=f'Epoch {epoch+1}/{hparas["N_EPOCH"]}', 
									 total=steps_per_epoch, unit='batch')
				
				for batch_idx, (image, input_ids, attention_mask) in enumerate(pbar):
						metrics = train_step(image, input_ids, attention_mask)
						
						# Accumulate losses
						g_total_loss += metrics['g_loss']
						c_total_loss += metrics['c_loss']
						c_total_loss_wasserstein += metrics['c_loss_wasserstein']
						gp_total += metrics['gp']
						wd_total += metrics['wasserstein_dist']
						
						# Log to TensorBoard
						with summary_writer.as_default():
								tf.summary.scalar('Losses/generator_loss', metrics['g_loss'], step=global_step)
								tf.summary.scalar('Losses/critic_loss_total', metrics['c_loss'], step=global_step)
								tf.summary.scalar('Losses/critic_loss_wasserstein', metrics['c_loss_wasserstein'], step=global_step)
								tf.summary.scalar('Losses/critic_loss_mismatch', metrics['c_loss_mismatch'], step=global_step)
								tf.summary.scalar('Losses/gradient_penalty', metrics['gp'], step=global_step)
								tf.summary.scalar('Metrics/wasserstein_distance', metrics['wasserstein_dist'], step=global_step)
								
								if global_step % 50 == 0:
										tf.summary.scalar('Gradients/generator_gradient_norm', metrics['grad_norm_g'], step=global_step)
										tf.summary.scalar('Gradients/critic_gradient_norm', metrics['grad_norm_c'], step=global_step)
										# ========== LOG LEARNING RATES ==========
										tf.summary.scalar('Training/learning_rate_generator', 
																		generator_optimizer.learning_rate.numpy(), step=global_step)
										tf.summary.scalar('Training/learning_rate_critic', 
																		critic_optimizer.learning_rate.numpy(), step=global_step)
										# ========================================
						
						# Update progress bar
						pbar.set_postfix({
								'G_loss': f'{metrics["g_loss"]:.4f}',
								'C_loss': f'{metrics["c_loss"]:.4f}',
								'W_dist': f'{metrics["wasserstein_dist"]:.4f}'
						})
						
						global_step += 1
				
				pbar.close()
				
				# Print epoch summary
				avg_g_loss = g_total_loss / steps_per_epoch
				avg_c_loss = c_total_loss / steps_per_epoch
				avg_c_loss_w = c_total_loss_wasserstein / steps_per_epoch
				avg_gp = gp_total / steps_per_epoch
				avg_wd = wd_total / steps_per_epoch
				epoch_time = time.time() - start
				
				print(f'Epoch {epoch+1}: G_loss={avg_g_loss:.4f}, C_loss={avg_c_loss:.4f} ' +
							f'(W={avg_c_loss_w:.4f}, GP={avg_gp:.4f}), W_dist={avg_wd:.4f}, Time={epoch_time:.2f}s')
				
				# Log epoch averages
				with summary_writer.as_default():
						tf.summary.scalar('Epoch/generator_loss_avg', avg_g_loss, step=epoch)
						tf.summary.scalar('Epoch/critic_loss_avg', avg_c_loss, step=epoch)
						tf.summary.scalar('Epoch/wasserstein_distance_avg', avg_wd, step=epoch)
				
				# ========== EARLY STOPPING CHECK ==========
				if avg_wd < best_wasserstein_dist:  # Avoid near-zero
						best_wasserstein_dist = avg_wd
						patience_counter = 0
						# Save best model in separate directory
						best_path = checkpoint.save(file_prefix=best_model_prefix)
						print(f'  ‚≠ê Best model saved! W_dist={avg_wd:.4f} ‚Üí {best_path}')
				else:
						patience_counter += 1
						
				if patience_counter >= patience:
						print(f'\n‚ö†Ô∏è Early stopping triggered at epoch {epoch+1}')
						print(f'   Best Wasserstein distance: {best_wasserstein_dist:.4f}')
						print(f'   No improvement for {patience} epochs')
						break
				# ==========================================
				
				# Save checkpoint (more frequently now)
				if (epoch + 1) % 5 == 0:  # Changed from 50
						saved_path = checkpoint.save(file_prefix=checkpoint_prefix)
						print(f'  ‚úì Checkpoint saved: {saved_path}')
				
				# Visualization
				if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
						fake_image = test_step(sample_input_ids, sample_attention_mask, sample_seed)
						save_images(fake_image, [ni, ni], f'{samples_dir}/train_{epoch+1:03d}.jpg')
						
						with summary_writer.as_default():
								display_images = (fake_image + 1.0) / 2.0
								tf.summary.image('Generated_Samples', display_images, step=epoch, max_outputs=16)
						
						print(f'  ‚úì Sample image saved and logged to TensorBoard')
		
		print('\n‚úì Training completed!')
		print(f'All outputs saved to: {run_dir}')
		print(f'Best Wasserstein distance achieved: {best_wasserstein_dist:.4f}')

In [None]:
train(dataset, hparas['N_EPOCH'])



<h1><center class="subtitle">Evaluation</center></h1>

<p><code>dataset/testData.pkl</code> is a pandas dataframe containing testing text with attributes 'ID' and 'Captions'.</p>

<ul>
<li>'ID': text ID used to name generated image.</li>
<li>'Captions': text used as condition to generate image.</li>
</ul>

<p>For each captions, you need to generate <strong>inference_ID.png</strong> to evaluate quality of generated image. You must name the generated image in this format, otherwise we cannot evaluate your images.</p>




<h2 id="Testing-Dataset">Testing Dataset<a class="anchor-link" href="#Testing-Dataset">¬∂</a></h2>
<p>If you change anything during preprocessing of training dataset, you must make sure same operations have be done in testing dataset.</p>



In [None]:
def testing_data_generator(caption_text, index):
		"""
		Updated testing data generator using CLIP tokenization
		
		Args:
				caption_text: Raw text string
				index: Test sample ID
		
		Returns:
				input_ids, attention_mask, index
		"""
		def tokenize_caption_clip(text):
				"""Python function to tokenize text using CLIP tokenizer"""
				# Convert EagerTensor to bytes, then decode to string
				text = text.numpy().decode('utf-8')
				
				# Tokenize using CLIP
				encoded = tokenizer(
						text,
						padding='max_length',
						truncation=True,
						max_length=77,
						return_tensors='np'
				)
				
				return encoded['input_ids'][0], encoded['attention_mask'][0]
		
		# Use tf.py_function to call Python tokenizer
		input_ids, attention_mask = tf.py_function(
				func=tokenize_caption_clip,
				inp=[caption_text],
				Tout=[tf.int32, tf.int32]
		)
		
		# Set shapes explicitly
		input_ids.set_shape([77])
		attention_mask.set_shape([77])
		
		return input_ids, attention_mask, index

def testing_dataset_generator(batch_size, data_generator):
		"""
		Updated testing dataset generator - decodes IDs to raw text
		"""
		data = pd.read_pickle('./dataset/testData.pkl')
		captions_ids = data['Captions'].values
		caption_texts = []
		
		# Decode pre-tokenized IDs back to text
		for i in range(len(captions_ids)):
				chosen_caption_ids = captions_ids[i]
				
				# Decode IDs back to text using id2word_dict
				words = []
				for word_id in chosen_caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':  # Skip padding tokens
								words.append(word)
				
				caption_text = ' '.join(words)
				caption_texts.append(caption_text)
		
		index = data['ID'].values
		index = np.asarray(index)
		
		# Create dataset from raw text
		dataset = tf.data.Dataset.from_tensor_slices((caption_texts, index))
		dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
		dataset = dataset.repeat().batch(batch_size)
		
		return dataset

In [None]:
testing_dataset = testing_dataset_generator(BATCH_SIZE, testing_data_generator)


In [None]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)



<h2 id="Inferece">Inferece<a class="anchor-link" href="#Inferece">¬∂</a></h2>



In [None]:
# Inference directory is already created by the train() function
# No need to create it again here

In [None]:
# Restore BEST MODEL for inference
print(f'Looking for BEST model in: {best_models_dir}')

best_checkpoint = tf.train.latest_checkpoint(best_models_dir)
if best_checkpoint:
		checkpoint.restore(best_checkpoint)
		print(f'‚úì Restored BEST model: {best_checkpoint}')
		print(f'  This is the model with the lowest Wasserstein distance during training')
else:
		print('‚ö† No best model found, trying regular checkpoints...')
		latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
		if latest_checkpoint:
				checkpoint.restore(latest_checkpoint)
				print(f'‚úì Restored latest checkpoint: {latest_checkpoint}')
				print('  ‚ö† WARNING: Using latest checkpoint, not best model')
		else:
				print('‚ö† No checkpoint found at all, using fresh/untrained model')

In [None]:
def inference(dataset):
		"""
		Updated inference function for CLIP
		FIXED: Generate fresh random noise for each batch!
		"""
		sample_size = BATCH_SIZE
		
		step = 0
		start = time.time()
		total_images = 0
		
		# Progress bar for inference
		pbar = tqdm(total=NUM_TEST, desc='Generating images', unit='img')
		
		# Unpack 3 values: input_ids, attention_mask, idx
		for input_ids, attention_mask, idx in dataset:
				if step > EPOCH_TEST:
						break
				
				# CRITICAL FIX: Generate FRESH random noise for each batch
				# This ensures diversity across all 819 test images
				sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
				
				fake_image = test_step(input_ids, attention_mask, sample_seed)
				step += 1
				
				for i in range(BATCH_SIZE):
						plt.imsave(f'{inference_dir}/inference_{idx[i]:04d}.jpg', fake_image[i].numpy()*0.5 + 0.5)
						total_images += 1
						pbar.update(1)
		
		pbar.close()
		print(f'\n‚úì Generated {total_images} images in {time.time()-start:.4f} sec')
		print(f'‚úì Images saved to: {inference_dir}')
		print(f'‚úì Each image generated with unique random noise for better diversity!')

In [None]:
inference(testing_dataset)


In [None]:
# Run evaluation script to generate score.csv
# Note: This must be run from the testing directory because inception_score.py uses relative paths
# Arguments: [inference_dir] [output_csv] [batch_size]
# Batch size must be 1, 2, 3, 7, 9, 21, or 39 to avoid remainder (819 test images)

# Save score.csv inside the run directory
print("running in ", inference_dir, "with", run_dir)
!cd testing && python inception_score.py ../{inference_dir}/ ../{run_dir}/score.csv 39

## Visualize Generated Images

Below we randomly sample 20 images from our generated test results to visually inspect the quality and diversity of the model's outputs.


<h1><center class="subtitle">Demo</center></h1>

<p>We demonstrate the capability of our model (TA80) to generate plausible images of flowers from detailed text descriptions.</p>



In [None]:
# Visualize 20 random generated images with their captions
import glob

# Load test data
data = pd.read_pickle('./dataset/testData.pkl')
test_captions = data['Captions'].values
test_ids = data['ID'].values

# Get all generated images from the current inference directory
image_files = sorted(glob.glob(inference_dir + '/inference_*.jpg'))

if len(image_files) == 0:
		print(f'‚ö† No images found in {inference_dir}')
		print('Please run the inference cell first!')
else:
		# Randomly sample 20 images
		np.random.seed(42)  # For reproducibility
		num_samples = min(20, len(image_files))
		sample_indices = np.random.choice(len(image_files), size=num_samples, replace=False)
		sample_files = [image_files[i] for i in sorted(sample_indices)]

		# Create 4x5 grid
		fig, axes = plt.subplots(4, 5, figsize=(20, 16))
		axes = axes.flatten()

		for idx, img_path in enumerate(sample_files):
				# Extract image ID from filename
				img_id = int(Path(img_path).stem.split('_')[1])
				
				# Find caption
				caption_idx = np.where(test_ids == img_id)[0][0]
				caption_ids = test_captions[caption_idx]
				
				# Decode caption
				caption_text = ''
				for word_id in caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':
								caption_text += word + ' '
				
				# Load and display image
				img = plt.imread(img_path)
				axes[idx].imshow(img)
				axes[idx].set_title(f'ID: {img_id}\n{caption_text[:60]}...', fontsize=8)
				axes[idx].axis('off')

		# Hide unused subplots if less than 20 images
		for idx in range(num_samples, 20):
				axes[idx].axis('off')

		plt.tight_layout()
		plt.suptitle(f'Random Sample of {num_samples} Generated Images', fontsize=16, y=1.002)
		plt.show()

		print(f'\nTotal generated images: {len(image_files)}')
		print(f'Images directory: {inference_dir}')