
<h1><center id="title">DataLab Cup 3: Reverse Image Caption</center></h1>

<center id="author">Shan-Hung Wu &amp; DataLab<br/>Fall 2025</center>



In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
from tqdm import tqdm

import re
from IPython import display

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
	try:
		# Restrict TensorFlow to only use the first GPU
		tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

		# Currently, memory growth needs to be the same across GPUs
		for gpu in gpus:
			tf.config.experimental.set_memory_growth(gpu, True)
		logical_gpus = tf.config.experimental.list_logical_devices('GPU')
		print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
	except RuntimeError as e:
		# Memory growth must be set before GPUs have been initialized
		print(e)

RANDOM_SEED = 42

# Python random
import random
random.seed(RANDOM_SEED)

# NumPy random
np.random.seed(RANDOM_SEED)

# TensorFlow random
tf.random.set_seed(RANDOM_SEED)

BATCH_SIZE = 128


<h2 id="Preprocess-Text">Preprocess Text<a class="anchor-link" href="#Preprocess-Text">Â¶</a></h2>
<p>Since dealing with raw string is inefficient, we have done some data preprocessing for you:</p>

<ul>
<li>Delete text over <code>MAX_SEQ_LENGTH (20)</code>.</li>
<li>Delete all puntuation in the texts.</li>
<li>Encode each vocabulary in <code>dictionary/vocab.npy</code>.</li>
<li>Represent texts by a sequence of integer IDs.</li>
<li>Replace rare words by <code>&lt;RARE&gt;</code> token to reduce vocabulary size for more efficient training.</li>
<li>Add padding as <code>&lt;PAD&gt;</code> to each text to make sure all of them have equal length to <code>MAX_SEQ_LENGTH (20)</code>.</li>
</ul>

<p>It is worth knowing that there is no necessary to append <code>&lt;ST&gt;</code> and <code>&lt;ED&gt;</code> to each text because we don't need to generate any sequence in this task.</p>

<p>To make sure correctness of encoding of the original text, we can decode sequence vocabulary IDs by looking up the vocabulary dictionary:</p>

<ul>
<li><code>dictionary/word2Id.npy</code> is a numpy array mapping word to id.</li>
<li><code>dictionary/id2Word.npy</code> is a numpy array mapping id back to word.</li>
</ul>



In [3]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))


there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [4]:
print("âœ“ Using CLIP tokenizer (sent2IdList removed)")

âœ“ Using CLIP tokenizer (sent2IdList removed)



<h2 id="Dataset">Dataset<a class="anchor-link" href="#Dataset">Â¶</a></h2>
<p>For training, the following files are in dataset folder:</p>

<ul>
<li><code>./dataset/text2ImgData.pkl</code> is a pandas dataframe with attribute 'Captions' and 'ImagePath'.<ul>
<li>'Captions' : A list of text id list contain 1 to 10 captions.</li>
<li>'ImagePath': Image path that store paired image.</li>
</ul>
</li>
<li><code>./102flowers/</code> is the directory containing all training images.</li>
<li><code>./dataset/testData.pkl</code> is a pandas a dataframe with attribute 'ID' and 'Captions', which contains testing data.</li>
</ul>



In [5]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


There are 7370 image in training data


In [6]:
df.head(5)


Unnamed: 0_level_0,Captions,ImagePath
ID,Unnamed: 1_level_1,Unnamed: 2_level_1
6734,"[[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11,...",./102flowers/image_06734.jpg
6736,"[[4, 1, 5, 12, 2, 3, 11, 31, 28, 68, 106, 132,...",./102flowers/image_06736.jpg
6737,"[[9, 2, 27, 4, 1, 6, 14, 7, 12, 19, 5427, 5427...",./102flowers/image_06737.jpg
6738,"[[9, 1, 5, 8, 54, 16, 38, 7, 12, 116, 325, 3, ...",./102flowers/image_06738.jpg
6739,"[[4, 12, 1, 5, 29, 11, 19, 7, 26, 70, 5427, 54...",./102flowers/image_06739.jpg



<h2 id="Create-Dataset-by-Dataset-API">Create Dataset by Dataset API<a class="anchor-link" href="#Create-Dataset-by-Dataset-API">Â¶</a></h2>



In [7]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import CLIPTokenizer

# Load CLIP Tokenizer
# "openai/clip-vit-base-patch32" is a standard, powerful model
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

def preprocess_text_clip(text, max_length=77):
		encoded = tokenizer(
				text,
				padding='max_length',
				truncation=True,
				max_length=max_length,
				return_tensors='tf'
		)
		return {
				'input_ids': encoded['input_ids'],
				'attention_mask': encoded['attention_mask']
		}

In [8]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3

def training_data_generator(caption_text, image_path):
		"""
		Updated data generator using CLIP tokenization
		
		Args:
				caption_text: Raw text string (not IDs!)
				image_path: Path to image file
		
		Returns:
				img, input_ids, attention_mask
		"""
		# ============= IMAGE PROCESSING (same as before) =============
		img = tf.io.read_file(image_path)
		img = tf.image.decode_image(img, channels=3)
		img = tf.image.convert_image_dtype(img, tf.float32)  # [0, 1]
		img.set_shape([None, None, 3])
		img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
		
		# Normalize to [-1, 1] to match generator's tanh output
		img = (img * 2.0) - 1.0
		img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
		
		# ============= TEXT PROCESSING (NEW: Use CLIP tokenizer) =============
		def tokenize_caption_clip(text):
				"""Python function to tokenize text using CLIP tokenizer"""
				# Convert EagerTensor to bytes, then decode to string
				text = text.numpy().decode('utf-8')
				
				# Tokenize using CLIP
				encoded = tokenizer(
						text,
						padding='max_length',
						truncation=True,
						max_length=77,
						return_tensors='np'  # Use numpy arrays for TF compatibility
				)
				
				return encoded['input_ids'][0], encoded['attention_mask'][0]
		
		# Use tf.py_function to call Python tokenizer
		input_ids, attention_mask = tf.py_function(
				func=tokenize_caption_clip,
				inp=[caption_text],
				Tout=[tf.int32, tf.int32]
		)
		
		# Set shapes explicitly for CLIP
		input_ids.set_shape([77])
		attention_mask.set_shape([77])
		
		return img, input_ids, attention_mask

def dataset_generator(filenames, batch_size, data_generator):
		"""
		Updated dataset generator to work with raw text (decoded from IDs)
		"""
		# Load the training data
		df = pd.read_pickle(filenames)
		captions_ids = df['Captions'].values
		caption_texts = []
		
		# Decode pre-tokenized IDs back to raw text
		for i in range(len(captions_ids)):
				# Randomly choose one caption (list of ID lists)
				chosen_caption_ids = random.choice(captions_ids[i])
				
				# Decode IDs back to text using id2word_dict
				words = []
				for word_id in chosen_caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':  # Skip padding tokens
								words.append(word)
				
				caption_text = ' '.join(words)
				caption_texts.append(caption_text)
		
		image_paths = df['ImagePath'].values
		
		# Verify same length
		assert len(caption_texts) == len(image_paths)
		
		# Create dataset from raw text and image paths
		dataset = tf.data.Dataset.from_tensor_slices((caption_texts, image_paths))
		dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
		dataset = dataset.cache()
		dataset = dataset.shuffle(len(caption_texts)).batch(batch_size, drop_remainder=True)
		dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

		return dataset

In [9]:
dataset = dataset_generator(data_path + '/text2ImgData.pkl', BATCH_SIZE, training_data_generator)


<h2 id="Conditional-GAN-Model">Conditional GAN Model<a class="anchor-link" href="#Conditional-GAN-Model">Â¶</a></h2>
<p>As mentioned above, there are three models in this task, text encoder, generator and discriminator.</p>

<h2 id="Text-Encoder">Text Encoder<a class="anchor-link" href="#Text-Encoder">Â¶</a></h2>
<p>A RNN encoder that captures the meaning of input text.</p>

<ul>
<li>Input: text, which is a list of ids.</li>
<li>Output: embedding, or hidden representation of input text.</li>
</ul>



In [10]:
# IMPORTANT: Import TensorFlow FIRST before transformers
import tensorflow as tf
from transformers import TFCLIPTextModel, TFCLIPModel

class ClipTextEncoder(tf.keras.Model):
		def __init__(self, output_dim=512, freeze_clip=True):
				super(ClipTextEncoder, self).__init__()
				# Load Pre-trained CLIP Text Model
				self.clip = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
				
				if freeze_clip:
						self.clip.trainable = False
						
			# REMOVED: Projection, LayerNorm, Dropout to ensure RAW embeddings
	
		def call(self, input_ids, attention_mask, training=False):
			# 1. Get the projected features (Aligned with images, e.g., 512-dim)
			text_embeds = self.clip.get_text_features(
				input_ids=input_ids, 
				attention_mask=attention_mask
			)
			
			# 2. CRITICAL FIX: Manually normalize to get the actual CLIP embedding
			# CLIP uses cosine similarity, so vectors must be unit length.
			text_embeds = tf.math.l2_normalize(text_embeds, axis=1)
			
			return text_embeds
				

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, initializers

# ==============================================================================
# 1. HELPER LAYERS & BLOCKS
# ==============================================================================

class Affine(layers.Layer):
    """
    PyTorch: class Affine(nn.Module)
    """
    def __init__(self, cond_dim, num_features):
        super(Affine, self).__init__()
        
        # PyTorch: self.fc_gamma = nn.Sequential(...)
        # Linear -> ReLU -> Linear
        self.fc_gamma = tf.keras.Sequential([
            layers.Dense(num_features, activation='relu'),
            layers.Dense(num_features, kernel_initializer='zeros', bias_initializer='ones')
        ])

        # PyTorch: self.fc_beta = nn.Sequential(...)
        # Linear -> ReLU -> Linear
        self.fc_beta = tf.keras.Sequential([
            layers.Dense(num_features, activation='relu'),
            layers.Dense(num_features, kernel_initializer='zeros', bias_initializer='zeros')
        ])

    def call(self, inputs):
        # x: [B, H, W, C], y: [B, Cond_Dim]
        x, y = inputs 

        # --- Gamma & Beta Calculation ---
        weight = self.fc_gamma(y)
        bias = self.fc_beta(y)

        # --- Reshape for Broadcasting ---
        # PyTorch: weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        # TF (NHWC): We need [B, 1, 1, C] to broadcast over H, W
        weight = tf.reshape(weight, [-1, 1, 1, x.shape[-1]])
        bias = tf.reshape(bias, [-1, 1, 1, x.shape[-1]])

        return weight * x + bias


class DFBLK(layers.Layer):
    """
    PyTorch: class DFBLK(nn.Module)
    Structure: Affine -> ReLU -> Affine -> ReLU
    """
    def __init__(self, cond_dim, in_ch):
        super(DFBLK, self).__init__()
        self.affine0 = Affine(cond_dim, in_ch)
        self.affine1 = Affine(cond_dim, in_ch)
        self.act = layers.LeakyReLU(0.2)

    def call(self, inputs):
        x, y = inputs
        
        h = self.affine0([x, y])
        h = self.act(h)
        
        h = self.affine1([h, y])
        h = self.act(h)
        
        return h


class G_Block(layers.Layer):
    """
    PyTorch: class G_Block(nn.Module)
    """
    def __init__(self, cond_dim, in_ch, out_ch, upsample=True):
        super(G_Block, self).__init__()
        self.upsample = upsample
        self.learnable_sc = (in_ch != out_ch)
        
        # PyTorch: self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.c1 = layers.Conv2D(out_ch, 3, strides=1, padding='same')
        
        # PyTorch: self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.c2 = layers.Conv2D(out_ch, 3, strides=1, padding='same')
        
        self.fuse1 = DFBLK(cond_dim, in_ch)
        self.fuse2 = DFBLK(cond_dim, out_ch)
        
        if self.learnable_sc:
            # PyTorch: self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
            self.c_sc = layers.Conv2D(out_ch, 1, strides=1, padding='valid')

        self.upsample_layer = layers.UpSampling2D(size=(2, 2))

    def call(self, inputs):
        x, y = inputs
        
        # --- Upsample ---
        if self.upsample:
            x = self.upsample_layer(x)

        # --- Shortcut Path ---
        h_sc = x
        if self.learnable_sc:
            h_sc = self.c_sc(h_sc)

        # --- Residual Path ---
        # 1. fuse1 (DFBLK)
        h_res = self.fuse1([x, y])
        # 2. c1 (Conv)
        h_res = self.c1(h_res)
        # 3. fuse2 (DFBLK)
        h_res = self.fuse2([h_res, y])
        # 4. c2 (Conv)
        h_res = self.c2(h_res)

        return h_sc + h_res


class D_Block(layers.Layer):
    """
    PyTorch: class D_Block(nn.Module)
    """
    def __init__(self, fin, fout, downsample=True):
        super(D_Block, self).__init__()
        self.downsample = downsample
        self.learned_shortcut = (fin != fout)
        
        # PyTorch: self.conv_r = nn.Sequential(...)
        # 1. Conv 4x4, stride 2, pad 1 (PyTorch) -> Downsamples by 2
        self.conv_r_1 = layers.Conv2D(fout, 4, strides=2, padding='same', use_bias=False)
        self.act_1 = layers.LeakyReLU(0.2)
        
        # 2. Conv 3x3, stride 1, pad 1
        self.conv_r_2 = layers.Conv2D(fout, 3, strides=1, padding='same', use_bias=False)
        self.act_2 = layers.LeakyReLU(0.2)
        
        # PyTorch: self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0)
        self.conv_s = layers.Conv2D(fout, 1, strides=1, padding='valid')
        
        # PyTorch: self.gamma = nn.Parameter(torch.zeros(1))
        # FIX: Initialize gamma to 0.1 to prevent "dead" discriminator at start.
        # In TF, 0.0 initialization often causes the variable to get stuck, 
        # leading to the 4x4 block artifacts you observed.
        self.gamma = tf.Variable(0.1, trainable=True, dtype=tf.float32)
        
        # PyTorch: F.avg_pool2d(x, 2)
        self.avg_pool = layers.AveragePooling2D(pool_size=2, strides=2)

    def call(self, inputs):
        x = inputs
        
        # --- Residual Path (conv_r) ---
        res = self.conv_r_1(x)
        res = self.act_1(res)
        res = self.conv_r_2(res)
        res = self.act_2(res)
        
        # --- Shortcut Path ---
        # PyTorch: if self.learned_shortcut: x = self.conv_s(x)
        # PyTorch: if self.downsample: x = F.avg_pool2d(x, 2)
        
        if self.learned_shortcut:
            x = self.conv_s(x)
            
        if self.downsample:
            x = self.avg_pool(x)
            
        return x + self.gamma * res


# ==============================================================================
# 2. MAIN NETWORKS (NetG, NetD, NetC)
# ==============================================================================

class NetG(Model):
    """
    PyTorch: class NetG(nn.Module)
    """
    def __init__(self, ngf=32, nz=100, cond_dim=256, imsize=64, ch_size=3):
        super(NetG, self).__init__()
        self.ngf = ngf
        self.nz = nz
        
        # PyTorch: self.fc = nn.Linear(nz, ngf*8*4*4)
        self.fc = layers.Dense(ngf * 8 * 4 * 4)
        
        # PyTorch: get_G_in_out_chs(ngf, 64) -> [(8,8), (8,4), (4,2), (2,1)]
        # Block 1: 4x4 -> 8x8 (8*ngf -> 8*ngf)
        self.block1 = G_Block(cond_dim + nz, ngf * 8, ngf * 8, upsample=True)
        # Block 2: 8x8 -> 16x16 (8*ngf -> 4*ngf)
        self.block2 = G_Block(cond_dim + nz, ngf * 8, ngf * 4, upsample=True)
        # Block 3: 16x16 -> 32x32 (4*ngf -> 2*ngf)
        self.block3 = G_Block(cond_dim + nz, ngf * 4, ngf * 2, upsample=True)
        # Block 4: 32x32 -> 64x64 (2*ngf -> 1*ngf)
        self.block4 = G_Block(cond_dim + nz, ngf * 2, ngf * 1, upsample=True)
        
        # PyTorch: self.to_rgb = nn.Sequential(...)
        self.to_rgb_act = layers.LeakyReLU(0.2)
        self.to_rgb_conv = layers.Conv2D(ch_size, 3, strides=1, padding='same')
        self.to_rgb_out = layers.Activation('tanh')

    def call(self, inputs):
        # noise: [B, nz], c: [B, cond_dim]
        noise, c = inputs
        
        # PyTorch: out = self.fc(noise)
        out = self.fc(noise)
        
        # PyTorch: out = out.view(noise.size(0), 8*self.ngf, 4, 4)
        out = tf.reshape(out, [-1, 4, 4, self.ngf * 8])
        
        # PyTorch: cond = torch.cat((noise, c), dim=1)
        cond = tf.concat([noise, c], axis=1)
        
        # PyTorch: loop over GBlocks
        out = self.block1([out, cond])
        out = self.block2([out, cond])
        out = self.block3([out, cond])
        out = self.block4([out, cond])
        
        # PyTorch: out = self.to_rgb(out)
        out = self.to_rgb_act(out)
        out = self.to_rgb_conv(out)
        out = self.to_rgb_out(out)
        
        return out


class NetD(Model):
    """
    PyTorch: class NetD(nn.Module)
    """
    def __init__(self, ndf=64, imsize=64, ch_size=3):
        super(NetD, self).__init__()
        
        # PyTorch: self.conv_img = nn.Conv2d(ch_size, ndf, 3, 1, 1)
        self.conv_img = layers.Conv2D(ndf, 3, strides=1, padding='same')
        
        # PyTorch: get_D_in_out_chs(ndf, 64) -> [(1,2), (2,4), (4,8), (8,8)]
        # Block 1: 64x64 -> 32x32 (ndf -> 2*ndf)
        self.block1 = D_Block(ndf, ndf * 2, downsample=True)
        # Block 2: 32x32 -> 16x16 (2*ndf -> 4*ndf)
        self.block2 = D_Block(ndf * 2, ndf * 4, downsample=True)
        # Block 3: 16x16 -> 8x8 (4*ndf -> 8*ndf)
        self.block3 = D_Block(ndf * 4, ndf * 8, downsample=True)
        # Block 4: 8x8 -> 4x4 (8*ndf -> 8*ndf)
        self.block4 = D_Block(ndf * 8, ndf * 8, downsample=True)

    def call(self, inputs):
        x = inputs
        
        out = self.conv_img(x)
        
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        
        return out


class NetC(Model):
    """
    PyTorch: class NetC(nn.Module)
    """
    def __init__(self, ndf=64, cond_dim=256):
        super(NetC, self).__init__()
        self.cond_dim = cond_dim
        
        # PyTorch: nn.Conv2d(ndf*8+cond_dim, ndf*2, 3, 1, 1, bias=False)
        # Input channels: 8*ndf (from NetD) + cond_dim
        self.joint_conv_1 = layers.Conv2D(ndf * 2, 3, strides=1, padding='same', use_bias=False)
        self.act = layers.LeakyReLU(0.2)
        self.joint_conv_2 = layers.Conv2D(1, 4, strides=1, padding='valid', use_bias=False)

    def call(self, inputs):
        # out: [B, 4, 4, 8*ndf], y: [B, cond_dim]
        out, y = inputs
        
        # PyTorch: y = y.view(-1, self.cond_dim, 1, 1)
        # PyTorch: y = y.repeat(1, 1, 4, 4)
        y = tf.reshape(y, [-1, 1, 1, self.cond_dim])
        y = tf.tile(y, [1, 4, 4, 1])
        
        # PyTorch: h_c_code = torch.cat((out, y), 1)
        h_c_code = tf.concat([out, y], axis=-1)
        
        # PyTorch: out = self.joint_conv(h_c_code)
        out = self.joint_conv_1(h_c_code)
        out = self.act(out)
        out = self.joint_conv_2(out)
        
        # Output is [B, 1, 1, 1]
        return tf.reshape(out, [-1, 1])

In [12]:
import tensorflow as tf

def matching_aware_gradient_penalty(netD, netC, real_images, text_embeddings, p=6.0):
    """
    Calculates the Matching Aware Gradient Penalty (MA-GP).
    
    Args:
        netD: The Discriminator model (outputs features).
        netC: The Compressor/Classifier model (takes features + text, outputs score).
        real_images: Batch of real images [B, 64, 64, 3].
        text_embeddings: Batch of matching text embeddings [B, Cond_Dim].
        p: The power to raise the gradient norm to. Official DF-GAN uses p=6.
        
    Returns:
        The gradient penalty scalar (averaged over batch).
    """
    # 1. Watch BOTH real_images and text_embeddings
    # Official DF-GAN penalizes gradients w.r.t both modalities
    with tf.GradientTape() as tape:
        tape.watch(real_images)
        tape.watch(text_embeddings)
        
        # 2. Forward pass through Discriminator and NetC
        # Get features from image
        features = netD(real_images, training=True)
        
        # Get scalar score from features + matching text
        # NetC inputs are [features, text_embeddings]
        pred_real = netC([features, text_embeddings], training=True)
        
    # 3. Calculate gradients of the prediction w.r.t BOTH inputs
    grads = tape.gradient(pred_real, [real_images, text_embeddings])
    grad_img = grads[0]
    grad_text = grads[1]
    
    # 4. Flatten and Concatenate gradients
    # Flatten: [B, -1]
    grad_img_flat = tf.reshape(grad_img, [tf.shape(grad_img)[0], -1])
    grad_text_flat = tf.reshape(grad_text, [tf.shape(grad_text)[0], -1])
    
    # Concatenate: [B, dim_img + dim_text]
    grad_all = tf.concat([grad_img_flat, grad_text_flat], axis=1)
    
    # 5. Calculate L2 norm of the combined gradients
    grad_norms = tf.norm(grad_all, axis=1)
    
    # 6. Calculate Penalty: 2.0 * E[||grad||^p]
    # Official implementation includes a factor of 2.0
    penalty = 2.0 * tf.reduce_mean(tf.pow(grad_norms, p))
    
    return penalty

def discriminator_hinge_loss(real_score, fake_score, wrong_score=None):
    """
    Hinge Loss for Discriminator.
    L_D = E[max(0, 1 - D(real, text))] + E[max(0, 1 + D(fake, text))] 
          + (Optional) E[max(0, 1 + D(real, mismatch_text))]
    """
    # Real Image + Matching Text: Should be > 1
    real_loss = tf.reduce_mean(tf.nn.relu(1.0 - real_score))
    
    # Fake Image + Matching Text: Should be < -1
    fake_loss = tf.reduce_mean(tf.nn.relu(1.0 + fake_score))
    
    total_loss = real_loss + fake_loss
    
    # (Optional) Real Image + Mismatched Text: Should be < -1
    if wrong_score is not None:
        wrong_loss = tf.reduce_mean(tf.nn.relu(1.0 + wrong_score))
        total_loss += wrong_loss
        
    return total_loss

def generator_hinge_loss(fake_score):
    """
    Hinge Loss for Generator.
    L_G = -E[D(fake, text)]
    """
    # Generator wants D(fake) to be large (positive)
    return -tf.reduce_mean(fake_score)

def logit_loss(output, negative=False):
    """
    BCE Loss (Logit Loss) for GANs.
    Alternative to Hinge Loss, used in official DF-GAN implementation.
    
    Args:
        output: Logits from the discriminator/compressor [B, 1]
        negative: Boolean. False for real samples (label 1), True for fake samples (label 0).
    """
    # PyTorch: output = nn.Sigmoid()(output); err = nn.BCELoss()(output, labels)
    # TF: Use from_logits=True for numerical stability which combines Sigmoid + BCE
    bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    
    if not negative:
        # Real labels: 1.0
        labels = tf.ones_like(output)
    else:
        # Fake labels: 0.0
        labels = tf.zeros_like(output)
        
    return bce(labels, output)

In [13]:
import tensorflow as tf


# Helper for individual hinge loss components (local to training logic)
def hinge_loss(output, negative=False):
    if not negative:
        return tf.reduce_mean(tf.nn.relu(1.0 - output))
    else:
        return tf.reduce_mean(tf.nn.relu(1.0 + output))

@tf.function
def train_step(
    real_images, 
    input_ids, 
    attention_mask, 
    generator, 
    discriminator, 
    net_c, 
    text_encoder, 
    g_optimizer, 
    d_optimizer, 
    batch_size, 
    z_dim,
    lambda_ma_gp=2.0,
    diff_augment_fn=None  # Added argument
):
    """
    Executes one training step for DF-GAN.
    """
    
    # 1. Encode Text
    # text_embeddings: [B, Cond_Dim]
    text_embeddings = text_encoder(input_ids, attention_mask, training=False)

    # 2. Train Discriminator (NetD + NetC)
    with tf.GradientTape() as d_tape:
        # --- Apply DiffAugment to Real Images ---
        if diff_augment_fn is not None:
            real_images = diff_augment_fn(real_images)
        
        # --- A. Real Image + Matching Text ---
        real_features = discriminator(real_images, training=True)
        real_score = net_c([real_features, text_embeddings], training=True)
        errD_real = hinge_loss(real_score, negative=False)
        
        # --- B. Real Image + Mismatched Text ---
        # Shift text to create mismatch
        mismatched_text = tf.roll(text_embeddings, shift=1, axis=0)
        # Note: PyTorch shifts features, we shift text. Result is equivalent (mismatched pairs).
        wrong_score = net_c([real_features, mismatched_text], training=True)
        errD_mis = hinge_loss(wrong_score, negative=True)
        
        # --- C. Fake Image + Matching Text ---
        noise = tf.random.normal([batch_size, z_dim])
        fake_images = generator([noise, text_embeddings], training=True)
        
        # --- Apply DiffAugment to Fake Images ---
        if diff_augment_fn is not None:
            fake_images = diff_augment_fn(fake_images)
        
        fake_features = discriminator(fake_images, training=True)
        fake_score = net_c([fake_features, text_embeddings], training=True)
        errD_fake = hinge_loss(fake_score, negative=True)
        
        # --- D. Matching Aware Gradient Penalty ---
        # PyTorch: errD_MAGP = MA_GP(imgs, sent_emb, pred_real)
        # Note: In TF we re-calculate pred_real inside this function to capture gradients correctly
        errD_MAGP = matching_aware_gradient_penalty(
            discriminator, net_c, real_images, text_embeddings, p=6.0
        )
        
        # --- E. Total D Loss ---
        # PyTorch: errD = errD_real + (errD_fake + errD_mis)/2.0 + errD_MAGP
        # CRITICAL FIX: Added the / 2.0 weighting to match your repo
        d_loss = errD_real + (errD_fake + errD_mis) / 2.0 + errD_MAGP

    # Calculate and Apply Gradients for D
    d_vars = discriminator.trainable_variables + net_c.trainable_variables
    d_grads = d_tape.gradient(d_loss, d_vars)
    d_optimizer.apply_gradients(zip(d_grads, d_vars))

    # 3. Train Generator
    with tf.GradientTape() as g_tape:
        # Re-generate noise/images for G update
        noise = tf.random.normal([batch_size, z_dim])
        fake_images = generator([noise, text_embeddings], training=True)
        
        # --- Apply DiffAugment to Fake Images (for G loss) ---
        if diff_augment_fn is not None:
            fake_images = diff_augment_fn(fake_images)
        
        fake_features = discriminator(fake_images, training=True)
        fake_score = net_c([fake_features, text_embeddings], training=True)
        
        # G Loss: -mean(fake_score)
        g_loss = generator_hinge_loss(fake_score)

    # Calculate and Apply Gradients for G
    g_vars = generator.trainable_variables
    g_grads = g_tape.gradient(g_loss, g_vars)
    g_optimizer.apply_gradients(zip(g_grads, g_vars))

    return {
        "d_loss": d_loss,
        "g_loss": g_loss,
        "ma_gp": errD_MAGP,
        "errD_real": errD_real,
        "errD_fake": errD_fake,
        "errD_mis": errD_mis
    }

In [14]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from transformers import TFCLIPModel
import matplotlib.pyplot as plt
import numpy as np
import os

def DiffAugment(x, policy='translation'):
    """
    TensorFlow implementation of DiffAugment.
    Supports 'color', 'translation', 'cutout'.
    """
    if policy:
        if 'color' in policy:
            x = rand_brightness(x)
            x = rand_saturation(x)
            x = rand_contrast(x)
        if 'translation' in policy:
            x = rand_translation(x)
        if 'cutout' in policy:
            x = rand_cutout(x)
    return x

# --- Augmentation Primitives ---
def rand_brightness(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=-0.5, maxval=0.5)
    x = x + magnitude
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_saturation(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=0.0, maxval=2.0)
    x_mean = tf.reduce_mean(x, axis=3, keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_contrast(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=0.5, maxval=1.5)
    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_translation(x, ratio=0.125):
    batch_size = tf.shape(x)[0]
    img_size = tf.shape(x)[1]
    shift = int(64 * ratio)
    
    # Pad the image with reflection
    x_padded = tf.pad(x, [[0, 0], [shift, shift], [shift, shift], [0, 0]], mode='REFLECT')
    
    # Vectorized Random Crop using crop_and_resize
    # We generate random top-left corners for the crop
    padded_size = tf.cast(img_size + 2*shift, tf.float32)
    max_offset = 2 * shift
    
    offsets_y = tf.random.uniform([batch_size], minval=0, maxval=max_offset + 1, dtype=tf.int32)
    offsets_x = tf.random.uniform([batch_size], minval=0, maxval=max_offset + 1, dtype=tf.int32)
    
    offsets_y = tf.cast(offsets_y, tf.float32)
    offsets_x = tf.cast(offsets_x, tf.float32)
    
    # Normalize coordinates to [0, 1] for crop_and_resize
    # Box: [y1, x1, y2, x2]
    y1 = offsets_y / padded_size
    x1 = offsets_x / padded_size
    y2 = (offsets_y + tf.cast(img_size, tf.float32)) / padded_size
    x2 = (offsets_x + tf.cast(img_size, tf.float32)) / padded_size
    
    boxes = tf.stack([y1, x1, y2, x2], axis=1) # [B, 4]
    box_indices = tf.range(batch_size)
    
    # Perform crop and resize (which acts as crop here since size matches)
    x_translated = tf.image.crop_and_resize(
        x_padded, 
        boxes, 
        box_indices, 
        crop_size=[img_size, img_size]
    )
    
    return x_translated

def rand_cutout(x, ratio=0.5):
    batch_size = tf.shape(x)[0]
    img_size = tf.shape(x)[1]
    cutout_size = int(64 * ratio // 2) * 2
    
    # Vectorized mask generation
    # Create grid [1, H, W]
    iy, ix = tf.meshgrid(tf.range(img_size), tf.range(img_size), indexing='ij')
    iy = tf.expand_dims(iy, 0) # [1, H, W]
    ix = tf.expand_dims(ix, 0)
    
    # Random top-left corners for the cutout box [B, 1, 1]
    offset_x = tf.random.uniform([batch_size, 1, 1], minval=0, maxval=img_size + 1 - cutout_size, dtype=tf.int32)
    offset_y = tf.random.uniform([batch_size, 1, 1], minval=0, maxval=img_size + 1 - cutout_size, dtype=tf.int32)
    
    # Create boolean masks [B, H, W]
    mask_x = tf.math.logical_and(ix >= offset_x, ix < offset_x + cutout_size)
    mask_y = tf.math.logical_and(iy >= offset_y, iy < offset_y + cutout_size)
    mask_box = tf.math.logical_and(mask_x, mask_y)
    
    # Invert mask (keep regions outside box) and cast to float
    mask_keep = tf.cast(tf.math.logical_not(mask_box), x.dtype)
    mask_keep = tf.expand_dims(mask_keep, -1) # [B, H, W, 1]
    
    return x * mask_keep

def save_sample_images(generator, text_encoder, fixed_ids, fixed_mask, fixed_noise, epoch, save_dir):
    """
    Generates and saves a grid of images using fixed noise/text for consistency.
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Encode text
    text_embeds = text_encoder(fixed_ids, fixed_mask, training=False)
    
    # Generate
    fake_imgs = generator([fixed_noise, text_embeds], training=False)
    
    # Convert to [0, 1] for plotting
    fake_imgs = (fake_imgs + 1.0) * 0.5
    fake_imgs = tf.clip_by_value(fake_imgs, 0.0, 1.0).numpy()
    
    # Plot Grid (assuming batch size 8 or similar)
    n = int(np.sqrt(len(fake_imgs)))
    if n * n != len(fake_imgs): n = 8 # Fallback default
    
    # Save first 8 images or so
    plt.figure(figsize=(10, 2))
    for i in range(min(8, len(fake_imgs))):
        plt.subplot(1, 8, i+1)
        plt.imshow(fake_imgs[i])
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch:03d}.png'))
    plt.close()

In [15]:
import tensorflow as tf
import os
import time
import datetime
import subprocess
import sys
from tqdm import tqdm

def train(dataset, args):
    """
    Main training loop for DF-GAN with TensorBoard logging and LR decay.
    """
    
    # ==========================================================================
    # 1. Initialization & Logging Setup
    # ==========================================================================
    print(f"--- Initializing Models (Image Size: {args['IMAGE_SIZE']}) ---")
    
    # Models
    generator = NetG(ngf=args['NGF'], nz=args['Z_DIM'], cond_dim=args['EMBED_DIM'])
    discriminator = NetD(ndf=args['NDF'])
    net_c = NetC(ndf=args['NDF'], cond_dim=args['EMBED_DIM'])
    
    print("--- Loading CLIP Text Encoder ---")
    text_encoder = ClipTextEncoder()

    # Optimizers
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=args['LR_G'], beta_1=0.0, beta_2=0.9)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=args['LR_D'], beta_1=0.0, beta_2=0.9)

    # Checkpoints
    checkpoint_dir = os.path.join(args['RUN_DIR'], 'checkpoints')
    checkpoint = tf.train.Checkpoint(
        generator=generator, discriminator=discriminator, net_c=net_c,
        g_optimizer=g_optimizer, d_optimizer=d_optimizer
    )
    manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    # TensorBoard Setup
    log_dir = os.path.join(args['RUN_DIR'], 'logs')
    summary_writer = tf.summary.create_file_writer(log_dir)
    
    try:
        tensorboard_process = subprocess.Popen(
            [sys.executable, "-m", "tensorboard.main", "--logdir", log_dir]
        )
        print(f"âœ“ TensorBoard launched (PID: {tensorboard_process.pid})")
    except Exception as e:
        print(f"âš  Could not launch TensorBoard: {e}")

    # ==========================================================================
    # 2. DiffAugment Setup
    # ==========================================================================
    diff_augment_fn = None
    if args.get('USE_DIFFAUG', False):
        print(f"--- DiffAugment Enabled: {args['DIFFAUG_POLICY']} ---")
        def da_fn(imgs): return DiffAugment(imgs, policy=args['DIFFAUG_POLICY'])
        diff_augment_fn = da_fn
    else:
        print("--- DiffAugment Disabled ---")

    # ==========================================================================
    # 3. Fixed Sample for Visualization
    # ==========================================================================
    for img_vis, ids_vis, mask_vis in dataset.take(1):
        fixed_ids = ids_vis[:16]
        fixed_mask = mask_vis[:16]
        fixed_noise = tf.random.normal([16, args['Z_DIM']])
        break
        
    # ==========================================================================
    # 4. Training Loop
    # ==========================================================================
    print(f"--- Starting Training for {args['N_EPOCH']} Epochs ---")
    steps_per_epoch = args['N_SAMPLE'] // args['BATCH_SIZE']
    global_step = 0
    
    for epoch in range(args['N_EPOCH']):
        start_time = time.time()
        
        # --- Learning Rate Decay ---
        if epoch >= args.get('LR_DECAY_START', 50) and epoch % args.get('LR_DECAY_EVERY', 10) == 0:
            decay = args.get('LR_DECAY_FACTOR', 0.95)
            min_lr = args.get('LR_MIN', 1e-6)
            
            new_lr_g = max(g_optimizer.learning_rate.numpy() * decay, min_lr)
            new_lr_d = max(d_optimizer.learning_rate.numpy() * decay, min_lr)
            
            g_optimizer.learning_rate.assign(new_lr_g)
            d_optimizer.learning_rate.assign(new_lr_d)
            print(f"ðŸ“‰ LR Decay: G={new_lr_g:.2e}, D={new_lr_d:.2e}")

        # --- Epoch Loop ---
        if epoch == 0: print("Note: First step takes longer due to XLA compilation...")
        
        pbar = tqdm(enumerate(dataset), total=steps_per_epoch, desc=f"Epoch {epoch+1}")
        epoch_metrics = {'g_loss': 0.0, 'd_loss': 0.0}
        
        for step, (real_images, input_ids, attention_mask) in pbar:
            losses = train_step(
                real_images, input_ids, attention_mask,
                generator, discriminator, net_c, text_encoder,
                g_optimizer, d_optimizer, args['BATCH_SIZE'], args['Z_DIM'],
                lambda_ma_gp=args.get('LAMBDA_GP', 2.0),
                diff_augment_fn=diff_augment_fn
            )
            
            # Accumulate for epoch stats
            epoch_metrics['g_loss'] += losses['g_loss']
            epoch_metrics['d_loss'] += losses['d_loss']
            
            # TensorBoard Logging (Step-wise)
            with summary_writer.as_default():
                tf.summary.scalar('Loss/D_Total', losses['d_loss'], step=global_step)
                tf.summary.scalar('Loss/G_Total', losses['g_loss'], step=global_step)
                tf.summary.scalar('Loss/MA_GP', losses['ma_gp'], step=global_step)
                tf.summary.scalar('Loss/D_Real', losses['errD_real'], step=global_step)
                tf.summary.scalar('Loss/D_Fake', losses['errD_fake'], step=global_step)
                tf.summary.scalar('Loss/D_Mis', losses['errD_mis'], step=global_step)
                
                if global_step % 100 == 0:
                    tf.summary.scalar('LR/Generator', g_optimizer.learning_rate, step=global_step)
                    tf.summary.scalar('LR/Discriminator', d_optimizer.learning_rate, step=global_step)

            if step % 50 == 0:
                pbar.set_postfix({
                    'D': f"{losses['d_loss']:.3f}", 
                    'G': f"{losses['g_loss']:.3f}",
                    'GP': f"{losses['ma_gp']:.3f}"
                })
            global_step += 1

        # --- End of Epoch ---
        avg_g = epoch_metrics['g_loss'] / (step + 1)
        avg_d = epoch_metrics['d_loss'] / (step + 1)
        print(f"Time: {time.time()-start_time:.1f}s | Avg G: {avg_g:.4f} | Avg D: {avg_d:.4f}")
        
        # Log Epoch Averages
        with summary_writer.as_default():
            tf.summary.scalar('Epoch/G_Loss', avg_g, step=epoch)
            tf.summary.scalar('Epoch/D_Loss', avg_d, step=epoch)

        # Save Checkpoint
        if (epoch + 1) % args['SAVE_FREQ'] == 0:
            manager.save()
            
        # Save & Log Samples
        if (epoch + 1) % args['SAMPLE_FREQ'] == 0:
            # Generate samples
            text_embeds = text_encoder(fixed_ids, fixed_mask, training=False)
            fake_imgs = generator([fixed_noise, text_embeds], training=False)
            
            # Save to disk
            save_sample_images(generator, text_encoder, fixed_ids, fixed_mask, fixed_noise, 
                             epoch+1, os.path.join(args['RUN_DIR'], 'samples'))
            
            # Log to TensorBoard
            with summary_writer.as_default():
                # Convert [-1, 1] -> [0, 1]
                imgs_vis = (fake_imgs + 1.0) * 0.5
                tf.summary.image('Generated Samples', imgs_vis, step=epoch, max_outputs=16)
                
    print(f"\nâœ“ Training Completed. Results in {args['RUN_DIR']}")

In [None]:
## Define configuration for training
import datetime
import json
import os

# Create a unique run directory
run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
run_dir = f"./runs/{run_id}"
if not os.path.exists(run_dir):
    os.makedirs(run_dir)

# User provided config
config = {
    'IMAGE_SIZE': [64, 64, 3],
    'NGF': 64,
    'NDF': 64,
    'Z_DIM': 100,
    'EMBED_DIM': 512,
    'LR_G': 0.0001,
    'LR_D': 0.0004,
    'BATCH_SIZE': BATCH_SIZE,
    'N_EPOCH': 600,          # Updated to 600 for good results
    'LAMBDA_GP': 2.0,
    'RUN_DIR': run_dir,
    'SAVE_FREQ': 25,         # Save less frequently to save space
    'SAMPLE_FREQ': 5,        # Sample every 5 epochs
    'USE_DIFFAUG': True,     # ENABLED: Critical for Oxford-102
    'DIFFAUG_POLICY': 'translation',
    'N_SAMPLE': num_training_sample if 'num_training_sample' in locals() else 7370
}

# Save config for reproducibility
with open(os.path.join(run_dir, 'config.json'), 'w') as f:
    # Filter for JSON serializable values
    json_config = {k: v for k, v in config.items() if isinstance(v, (int, float, str, list, bool))}
    json.dump(json_config, f, indent=4)

print(f"Training Run Directory: {run_dir}")
print(f"Config: {json.dumps(json_config, indent=2)}")

Training Run Directory: ./runs/20251125-002901
Config: {
  "IMAGE_SIZE": [
    64,
    64,
    3
  ],
  "NGF": 32,
  "NDF": 64,
  "Z_DIM": 100,
  "EMBED_DIM": 512,
  "LR_G": 0.0001,
  "LR_D": 0.0004,
  "BATCH_SIZE": 128,
  "N_EPOCH": 100,
  "LAMBDA_GP": 2.0,
  "RUN_DIR": "./runs/20251125-002901",
  "SAVE_FREQ": 5,
  "SAMPLE_FREQ": 1,
  "USE_DIFFAUG": false,
  "DIFFAUG_POLICY": "translation",
  "N_SAMPLE": 7370
}


In [None]:
# 'dataset' is the tf.data.Dataset object you created in the notebook
train(dataset, config)

--- Initializing Models (Image Size: [64, 64, 3]) ---
--- Loading CLIP Text Encoder ---


--- Initializing Models (Image Size: [64, 64, 3]) ---
--- Loading CLIP Text Encoder ---


All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.
All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


--- Initializing Models (Image Size: [64, 64, 3]) ---
--- Loading CLIP Text Encoder ---


All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.
All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


âœ“ TensorBoard launched (PID: 19448)
--- DiffAugment Disabled ---


--- Initializing Models (Image Size: [64, 64, 3]) ---
--- Loading CLIP Text Encoder ---


All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.
All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


âœ“ TensorBoard launched (PID: 19448)
--- DiffAugment Disabled ---


  import pkg_resources


--- Initializing Models (Image Size: [64, 64, 3]) ---
--- Loading CLIP Text Encoder ---


All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.
All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


âœ“ TensorBoard launched (PID: 19448)
--- DiffAugment Disabled ---


  import pkg_resources


--- Starting Training for 100 Epochs ---
Note: First step takes longer due to XLA compilation...


Epoch 1:   0%|          | 0/57 [00:00<?, ?it/s]
NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.2 at http://localhost:6007/ (Press CTRL+C to quit)
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.2 at http://localhost:6007/ (Press CTRL+C to quit)
Epoch 1:   4%|â–Ž         | 2/57 [00:53<23:43, 25.89s/it, D=2.919, G=0.502, GP=0.922]


<h2 id="Visualiztion">Visualiztion<a class="anchor-link" href="#Visualiztion">Â¶</a></h2>
<p>During training, we can visualize the generated image to evaluate the quality of generator. The followings are some functions helping visualization.</p>



In [None]:
def merge(images, size):
		h, w = images.shape[1], images.shape[2]
		img = np.zeros((h * size[0], w * size[1], 3))
		for idx, image in enumerate(images):
				i = idx % size[1]
				j = idx // size[1]
				img[j*h:j*h+h, i*w:i*w+w, :] = image
		return img

def imsave(images, size, path):
		# getting the pixel values between [0, 1] to save it
		return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
		return imsave(images, size, image_path)



<p>We always use same random seed and same senteces during training, which is more convenient for us to evaluate the quality of generated image.</p>



In [None]:
# Create sample data for visualization during training
# IMPORTANT: All three variables must have the same batch size!

sample_size = BATCH_SIZE  # Current: 32
ni = int(np.ceil(np.sqrt(sample_size)))  # Grid size for visualization

# Create random noise seed
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)

# Define 8 diverse sample sentences
base_sentences = [
		"the flower shown has yellow anther red pistil and bright red petals.",
		"this flower has petals that are yellow, white and purple and has dark lines",
		"the petals on this flower are white with a yellow center",
		"this flower has a lot of small round pink petals.",
		"this flower is orange in color, and has petals that are ruffled and rounded.",
		"the flower has yellow petals and the center of it is brown.",
		"this flower has petals that are blue and white.",
		"these white flowers have petals that start off white in color and end in a white towards the tips."
]

# Repeat sentences to match sample_size (batch size)
sample_sentences = []
for i in range(sample_size):
		sample_sentences.append(base_sentences[i % len(base_sentences)])

# Tokenize with CLIP
sample_encoded = preprocess_text_clip(sample_sentences, max_length=77)
sample_input_ids = sample_encoded['input_ids']
sample_attention_mask = sample_encoded['attention_mask']


<h2 id="Training">Training<a class="anchor-link" href="#Training">Â¶</a></h2>




<h2 id="Testing-Dataset">Testing Dataset<a class="anchor-link" href="#Testing-Dataset">Â¶</a></h2>
<p>If you change anything during preprocessing of training dataset, you must make sure same operations have be done in testing dataset.</p>



In [None]:
def testing_data_generator(caption_text, index):
		"""
		Updated testing data generator using CLIP tokenization
		
		Args:
				caption_text: Raw text string
				index: Test sample ID
		
		Returns:
				input_ids, attention_mask, index
		"""
		def tokenize_caption_clip(text):
				"""Python function to tokenize text using CLIP tokenizer"""
				# Convert EagerTensor to bytes, then decode to string
				text = text.numpy().decode('utf-8')
				
				# Tokenize using CLIP
				encoded = tokenizer(
						text,
						padding='max_length',
						truncation=True,
						max_length=77,
						return_tensors='np'
				)
				
				return encoded['input_ids'][0], encoded['attention_mask'][0]
		
		# Use tf.py_function to call Python tokenizer
		input_ids, attention_mask = tf.py_function(
				func=tokenize_caption_clip,
				inp=[caption_text],
				Tout=[tf.int32, tf.int32]
		)
		
		# Set shapes explicitly
		input_ids.set_shape([77])
		attention_mask.set_shape([77])
		
		return input_ids, attention_mask, index

def testing_dataset_generator(batch_size, data_generator):
		"""
		Updated testing dataset generator - decodes IDs to raw text
		"""
		data = pd.read_pickle('./dataset/testData.pkl')
		captions_ids = data['Captions'].values
		caption_texts = []
		
		# Decode pre-tokenized IDs back to text
		for i in range(len(captions_ids)):
				chosen_caption_ids = captions_ids[i]
				
				# Decode IDs back to text using id2word_dict
				words = []
				for word_id in chosen_caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':  # Skip padding tokens
								words.append(word)
				
				caption_text = ' '.join(words)
				caption_texts.append(caption_text)
		
		index = data['ID'].values
		index = np.asarray(index)
		
		# Create dataset from raw text
		dataset = tf.data.Dataset.from_tensor_slices((caption_texts, index))
		dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
		dataset = dataset.repeat().batch(batch_size)
		
		return dataset

In [None]:
testing_dataset = testing_dataset_generator(BATCH_SIZE, testing_data_generator)


In [None]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)



<h2 id="Inferece">Inferece<a class="anchor-link" href="#Inferece">Â¶</a></h2>



In [None]:
# Create inference directory inside the run directory
inference_dir = os.path.join(config['RUN_DIR'], 'inference')
if not os.path.exists(inference_dir):
    os.makedirs(inference_dir)
print(f"Inference Directory: {inference_dir}")

In [None]:
def inference(dataset, config):
    print("--- Starting Inference ---")
    
    # 1. Re-initialize Models
    # We need to re-create the models to load weights into them
    print("Loading models...")
    generator = NetG(ngf=config['NGF'], nz=config['Z_DIM'], cond_dim=config['EMBED_DIM'])
    text_encoder = ClipTextEncoder()
    
    # Dummy call to build the model (optional but good practice)
    # generator.build((None, config['Z_DIM'])) 
    
    # 2. Load Checkpoint
    checkpoint_dir = os.path.join(config['RUN_DIR'], 'checkpoints')
    
    # We need to restore the generator. 
    # Note: We must define the checkpoint object exactly as it was saved to restore correctly,
    # or use expect_partial() if we only care about specific parts (like generator).
    checkpoint = tf.train.Checkpoint(generator=generator)
    
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
        print(f"Loading weights from: {latest_ckpt}")
        status = checkpoint.restore(latest_ckpt).expect_partial()
        status.assert_existing_objects_matched()
        print("âœ“ Weights loaded successfully")
    else:
        print("âš  NO CHECKPOINT FOUND! Generating with random weights (Garbage output).")

    # 3. Inference Loop
    total_images = 0
    pbar = tqdm(total=NUM_TEST, desc='Generating images', unit='img')
    
    for input_ids, attention_mask, idx in dataset:
        current_batch_size = input_ids.shape[0]
        
        # Generate Noise
        noise = tf.random.normal([current_batch_size, config['Z_DIM']])
        
        # Encode Text
        text_embeds = text_encoder(input_ids, attention_mask, training=False)
        
        # Generate Images
        fake_imgs = generator([noise, text_embeds], training=False)
        
        # Save Images
        for i in range(current_batch_size):
            img_idx = idx[i].numpy()
            img_path = os.path.join(inference_dir, f'inference_{img_idx:04d}.jpg')
            
            # Rescale [-1, 1] -> [0, 1]
            img_to_save = (fake_imgs[i].numpy() + 1.0) * 0.5
            
            # Clip to ensure valid range
            img_to_save = np.clip(img_to_save, 0.0, 1.0)
            
            plt.imsave(img_path, img_to_save)
            
            total_images += 1
            pbar.update(1)
            
        if total_images >= NUM_TEST:
            break
            
    pbar.close()
    print(f"âœ“ Generated {total_images} images to {inference_dir}")

In [None]:
inference(testing_dataset, config)

In [None]:
# Run evaluation script to generate score.csv
# Note: This must be run from the testing directory because inception_score.py uses relative paths
# Arguments: [inference_dir] [output_csv] [batch_size]
# Batch size must be 1, 2, 3, 7, 9, 21, or 39 to avoid remainder (819 test images)

# Save score.csv inside the run directory
print("running in ", inference_dir, "with", run_dir)
!cd testing && python inception_score.py ../{inference_dir}/ ../{run_dir}/score.csv 39

## Visualize Generated Images

Below we randomly sample 20 images from our generated test results to visually inspect the quality and diversity of the model's outputs.


<h1><center class="subtitle">Demo</center></h1>

<p>We demonstrate the capability of our model (TA80) to generate plausible images of flowers from detailed text descriptions.</p>



In [None]:
# Visualize 20 random generated images with their captions
import glob

# Load test data
data = pd.read_pickle('./dataset/testData.pkl')
test_captions = data['Captions'].values
test_ids = data['ID'].values

# Get all generated images from the current inference directory
image_files = sorted(glob.glob(inference_dir + '/inference_*.jpg'))

if len(image_files) == 0:
		print(f'âš  No images found in {inference_dir}')
		print('Please run the inference cell first!')
else:
		# Randomly sample 20 images
		np.random.seed(42)  # For reproducibility
		num_samples = min(20, len(image_files))
		sample_indices = np.random.choice(len(image_files), size=num_samples, replace=False)
		sample_files = [image_files[i] for i in sorted(sample_indices)]

		# Create 4x5 grid
		fig, axes = plt.subplots(4, 5, figsize=(20, 16))
		axes = axes.flatten()

		for idx, img_path in enumerate(sample_files):
				# Extract image ID from filename
				img_id = int(Path(img_path).stem.split('_')[1])
				
				# Find caption
				caption_idx = np.where(test_ids == img_id)[0][0]
				caption_ids = test_captions[caption_idx]
				
				# Decode caption
				caption_text = ''
				for word_id in caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':
								caption_text += word + ' '
				
				# Load and display image
				img = plt.imread(img_path)
				axes[idx].imshow(img)
				axes[idx].set_title(f'ID: {img_id}\n{caption_text[:60]}...', fontsize=8)
				axes[idx].axis('off')

		# Hide unused subplots if less than 20 images
		for idx in range(num_samples, 20):
				axes[idx].axis('off')

		plt.tight_layout()
		plt.suptitle(f'Random Sample of {num_samples} Generated Images', fontsize=16, y=1.002)
		plt.show()

		print(f'\nTotal generated images: {len(image_files)}')
		print(f'Images directory: {inference_dir}')