
<h1><center id="title">DataLab Cup 3: Reverse Image Caption</center></h1>

<center id="author">Shan-Hung Wu &amp; DataLab<br/>Fall 2025</center>



In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
from tqdm import tqdm

import re
from IPython import display

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
	try:
		# Restrict TensorFlow to only use the first GPU
		tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

		# Currently, memory growth needs to be the same across GPUs
		for gpu in gpus:
			tf.config.experimental.set_memory_growth(gpu, True)
		logical_gpus = tf.config.experimental.list_logical_devices('GPU')
		print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
	except RuntimeError as e:
		# Memory growth must be set before GPUs have been initialized
		print(e)

RANDOM_SEED = 42

# Python random
import random
random.seed(RANDOM_SEED)

# NumPy random
np.random.seed(RANDOM_SEED)

# TensorFlow random
tf.random.set_seed(RANDOM_SEED)

BATCH_SIZE = 128


<h2 id="Preprocess-Text">Preprocess Text<a class="anchor-link" href="#Preprocess-Text">¶</a></h2>
<p>Since dealing with raw string is inefficient, we have done some data preprocessing for you:</p>

<ul>
<li>Delete text over <code>MAX_SEQ_LENGTH (20)</code>.</li>
<li>Delete all puntuation in the texts.</li>
<li>Encode each vocabulary in <code>dictionary/vocab.npy</code>.</li>
<li>Represent texts by a sequence of integer IDs.</li>
<li>Replace rare words by <code>&lt;RARE&gt;</code> token to reduce vocabulary size for more efficient training.</li>
<li>Add padding as <code>&lt;PAD&gt;</code> to each text to make sure all of them have equal length to <code>MAX_SEQ_LENGTH (20)</code>.</li>
</ul>

<p>It is worth knowing that there is no necessary to append <code>&lt;ST&gt;</code> and <code>&lt;ED&gt;</code> to each text because we don't need to generate any sequence in this task.</p>

<p>To make sure correctness of encoding of the original text, we can decode sequence vocabulary IDs by looking up the vocabulary dictionary:</p>

<ul>
<li><code>dictionary/word2Id.npy</code> is a numpy array mapping word to id.</li>
<li><code>dictionary/id2Word.npy</code> is a numpy array mapping id back to word.</li>
</ul>



In [None]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))


In [None]:
print("✓ Using CLIP tokenizer (sent2IdList removed)")


<h2 id="Dataset">Dataset<a class="anchor-link" href="#Dataset">¶</a></h2>
<p>For training, the following files are in dataset folder:</p>

<ul>
<li><code>./dataset/text2ImgData.pkl</code> is a pandas dataframe with attribute 'Captions' and 'ImagePath'.<ul>
<li>'Captions' : A list of text id list contain 1 to 10 captions.</li>
<li>'ImagePath': Image path that store paired image.</li>
</ul>
</li>
<li><code>./102flowers/</code> is the directory containing all training images.</li>
<li><code>./dataset/testData.pkl</code> is a pandas a dataframe with attribute 'ID' and 'Captions', which contains testing data.</li>
</ul>



In [None]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))


In [None]:
df.head(5)



<h2 id="Create-Dataset-by-Dataset-API">Create Dataset by Dataset API<a class="anchor-link" href="#Create-Dataset-by-Dataset-API">¶</a></h2>



In [None]:
# ==============================================================================
# 1. DATASET GENERATOR (Adapted for CLIP)
# ==============================================================================

IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_CHANNEL = 3
MAX_SEQ_LENGTH = 77 # CLIP default

# Initialize CLIP Tokenizer
try:
    from transformers import CLIPTokenizer
    # Use the same model name as the vision/text models we will load later
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    print("✓ CLIP Tokenizer loaded")
except Exception as e:
    print(f"⚠ Error loading CLIP Tokenizer: {e}")

def training_data_generator(caption_text, image_path):
    """
    Data generator using CLIP Tokenizer
    
    Args:
        caption_text: Raw text string
        image_path: Path to image file
    
    Returns:
        img, input_ids, attention_mask
    """
    # ============= IMAGE PROCESSING =============
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0, 1]
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    
    # Normalize to [-1, 1] to match generator's tanh output
    img = (img * 2.0) - 1.0
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    
    # ============= TEXT PROCESSING =============
    # Tokenize using CLIP
    # We use py_function because tokenizer is Python code
    def tokenize(text):
        text = text.numpy().decode('utf-8')
        # CLIP Tokenizer handles padding and truncation
        enc = tokenizer(
            text, 
            padding='max_length', 
            truncation=True, 
            max_length=MAX_SEQ_LENGTH, 
            return_tensors='tf'
        )
        return enc['input_ids'][0], enc['attention_mask'][0]
        
    input_ids, attention_mask = tf.py_function(
        func=tokenize, 
        inp=[caption_text], 
        Tout=[tf.int32, tf.int32]
    )
    
    input_ids.set_shape([MAX_SEQ_LENGTH])
    attention_mask.set_shape([MAX_SEQ_LENGTH])
    
    return img, input_ids, attention_mask

def dataset_generator(filenames, batch_size, data_generator, word2Id_dict, id2word_dict, expand_captions=True):
    """
    Dataset generator that decodes IDs to text for CLIP
    """
    df = pd.read_pickle(filenames)
    captions_ids = df['Captions'].values
    image_paths = df['ImagePath'].values
    
    print(f"Loading dataset from {filenames}...")
    
    # Helper to decode IDs to text
    def decode_ids(id_list):
        words = []
        for i in id_list:
            word = id2word_dict.get(str(i), '')
            if word and word != '<PAD>':
                words.append(word)
        return ' '.join(words)

    all_captions_text = []
    all_paths = []

    if expand_captions:
        # Expand: Create a sample for every caption
        print("Expanding captions (one sample per caption)...")
        for caps, path in zip(captions_ids, image_paths):
            for cap_ids in caps:
                text = decode_ids(cap_ids)
                all_captions_text.append(text)
                all_paths.append(path)
    else:
        # Random Select: Pick one random caption per image (static for this generator call)
        # Note: Ideally we'd do random selection at runtime, but decoding text in graph is hard.
        # For simplicity/performance, we pick one now. 
        # To get true randomness per epoch, we'd need to re-create the dataset or use py_function logic.
        print("Selecting one random caption per image...")
        for caps, path in zip(captions_ids, image_paths):
            cap_ids = random.choice(caps)
            text = decode_ids(cap_ids)
            all_captions_text.append(text)
            all_paths.append(path)
            
    all_captions_text = np.array(all_captions_text)
    all_paths = np.array(all_paths)
    
    print(f"Dataset size: {len(all_captions_text)} samples")
    
    dataset = tf.data.Dataset.from_tensor_slices((all_captions_text, all_paths))
    dataset = dataset.shuffle(len(all_captions_text))
    dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset


In [None]:
# Create the dataset
# We use expand_captions=False to keep epoch size manageable (same as number of images)
# or True for more training data. Let's use False for faster epochs initially, or True for better quality.
# Given the small dataset (7k images), expanding is probably better (70k samples).
dataset = dataset_generator(
    data_path + '/text2ImgData.pkl', 
    BATCH_SIZE, 
    training_data_generator,
    word2Id_dict,
    id2word_dict,
    expand_captions=True 
)



<h2 id="Conditional-GAN-Model">Conditional GAN Model<a class="anchor-link" href="#Conditional-GAN-Model">¶</a></h2>
<p>As mentioned above, there are three models in this task, text encoder, generator and discriminator.</p>

<h2 id="Text-Encoder">Text Encoder<a class="anchor-link" href="#Text-Encoder">¶</a></h2>
<p>A RNN encoder that captures the meaning of input text.</p>

<ul>
<li>Input: text, which is a list of ids.</li>
<li>Output: embedding, or hidden representation of input text.</li>
</ul>



In [None]:
# ==============================================================================
# 1. IMPORTS & SETUP
# ==============================================================================
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
from transformers import TFCLIPVisionModel, TFCLIPTextModel, CLIPProcessor, CLIPConfig

print("TensorFlow Version:", tf.__version__)
try:
    import transformers
    print("Transformers Version:", transformers.__version__)
except ImportError:
    print("Transformers not installed. Please install it.")

# ==============================================================================
# PYTORCH-TENSORFLOW COMPATIBILITY CONSTANTS
# ==============================================================================
# These constants ensure numerical equivalence between PyTorch and TensorFlow
# implementations of GALIP.

# 1. Optimizer epsilon: PyTorch Adam default is 1e-8, TensorFlow default is 1e-7
#    Using 1e-7 can cause subtle numerical divergence over training.
ADAM_EPSILON = 1e-8  # Match PyTorch default

# 2. LayerNorm epsilon: PyTorch default is 1e-5, TensorFlow default is 1e-3
#    This affects CLIP and any custom LayerNorm layers.
LAYER_NORM_EPSILON = 1e-5  # Match PyTorch default

# 3. Weight initialization: PyTorch Linear/Conv2d use Kaiming Uniform (He)
#    TensorFlow defaults to Glorot Uniform (Xavier).
#    All our layers now use kernel_initializer='he_uniform' explicitly.

print(f"PyTorch-compatible settings:")
print(f"  ADAM_EPSILON = {ADAM_EPSILON}")
print(f"  LAYER_NORM_EPSILON = {LAYER_NORM_EPSILON}")
print(f"  Weight init: he_uniform (Kaiming Uniform)")

# ==============================================================================
# HELPER FUNCTION: Create PyTorch-compatible Adam optimizer
# ==============================================================================
def create_pytorch_compatible_adam(learning_rate, beta_1=0.0, beta_2=0.9):
    """
    Creates an Adam optimizer with PyTorch-equivalent settings.
    
    PyTorch defaults:
        - lr: required
        - betas: (0.9, 0.999) but GALIP uses (0.0, 0.9)
        - eps: 1e-8
        - weight_decay: 0
        - amsgrad: False
    
    TensorFlow defaults that differ:
        - epsilon: 1e-7 (10x larger than PyTorch!)
    
    Args:
        learning_rate: Learning rate
        beta_1: First moment decay (default 0.0 for GAN training)
        beta_2: Second moment decay (default 0.9 for GAN training)
    
    Returns:
        tf.keras.optimizers.Adam with PyTorch-equivalent settings
    """
    return tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        beta_1=beta_1,
        beta_2=beta_2,
        epsilon=ADAM_EPSILON  # CRITICAL: Match PyTorch 1e-8
    )


In [None]:
# ==============================================================================
# VERIFY CLIP CONFIG (LayerNorm epsilon, visual_projection bias)
# ==============================================================================
# Load CLIP and verify critical configuration values

def verify_clip_config(clip_model_name="openai/clip-vit-base-patch32"):
    """
    Verify that CLIP model has correct LayerNorm epsilon and check visual_projection bias.
    
    PyTorch CLIP uses LayerNorm eps=1e-5.
    TensorFlow default is eps=1e-3, which can cause numerical divergence.
    
    Returns:
        dict with verification results
    """
    from transformers import TFCLIPModel, CLIPConfig
    
    config = CLIPConfig.from_pretrained(clip_model_name)
    
    print("=" * 60)
    print("CLIP Configuration Verification")
    print("=" * 60)
    
    # Check vision config
    vision_config = config.vision_config
    print(f"\nVision Config:")
    print(f"  layer_norm_eps: {vision_config.layer_norm_eps}")
    print(f"  hidden_size: {vision_config.hidden_size}")
    print(f"  projection_dim: {config.projection_dim}")
    
    # Check text config
    text_config = config.text_config
    print(f"\nText Config:")
    print(f"  layer_norm_eps: {text_config.layer_norm_eps}")
    print(f"  hidden_size: {text_config.hidden_size}")
    
    # Verify LayerNorm epsilon matches PyTorch
    expected_eps = 1e-5
    vision_ok = abs(vision_config.layer_norm_eps - expected_eps) < 1e-10
    text_ok = abs(text_config.layer_norm_eps - expected_eps) < 1e-10
    
    print(f"\n✓ Vision LayerNorm eps matches PyTorch: {vision_ok}")
    print(f"✓ Text LayerNorm eps matches PyTorch: {text_ok}")
    
    # Load model to check visual_projection bias
    print("\nLoading model to verify visual_projection...")
    model = TFCLIPModel.from_pretrained(clip_model_name)
    
    # Check if visual_projection has bias
    has_visual_bias = model.visual_projection.use_bias if hasattr(model.visual_projection, 'use_bias') else model.visual_projection.bias is not None
    has_text_bias = model.text_projection.use_bias if hasattr(model.text_projection, 'use_bias') else model.text_projection.bias is not None
    
    print(f"\nProjection Layer Bias:")
    print(f"  visual_projection.use_bias: {has_visual_bias}")
    print(f"  text_projection.use_bias: {has_text_bias}")
    
    # Note: OpenAI CLIP uses bias=False for projections, HuggingFace may differ
    # This is usually fine since pre-trained weights handle it
    
    print("\n" + "=" * 60)
    
    return {
        'vision_layer_norm_eps': vision_config.layer_norm_eps,
        'text_layer_norm_eps': text_config.layer_norm_eps,
        'projection_dim': config.projection_dim,
        'hidden_size': vision_config.hidden_size,
        'vision_eps_ok': vision_ok,
        'text_eps_ok': text_ok
    }

# Run verification
clip_verification = verify_clip_config()

In [None]:
# ==============================================================================
# 2. BASIC BLOCKS (DF-GAN & GALIP Components)
# ==============================================================================

class Affine(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's Affine layer.
    
    PyTorch signature: Affine(cond_dim, num_features)
    
    PyTorch structure:
        fc_gamma: Linear(cond_dim, num_features) -> ReLU -> Linear(num_features, num_features)
        fc_beta:  Linear(cond_dim, num_features) -> ReLU -> Linear(num_features, num_features)
    
    Initialization:
        fc_gamma.linear2: weight=0, bias=1 (so initial gamma=1, identity scaling)
        fc_beta.linear2:  weight=0, bias=0 (so initial beta=0, no shift)
    """
    def __init__(self, cond_dim, num_features):
        super(Affine, self).__init__()
        self.cond_dim = cond_dim
        self.num_features = num_features
        
        # fc_gamma: 2-layer MLP
        # PyTorch: Linear(cond_dim, num_features) -> ReLU -> Linear(num_features, num_features)
        # First layer: cond_dim -> num_features, he_uniform init (matches PyTorch Linear default)
        # Second layer: num_features -> num_features, zeros weight, ones bias
        self.gamma_linear1 = layers.Dense(num_features, kernel_initializer='he_uniform')
        self.gamma_linear2 = layers.Dense(
            num_features, 
            kernel_initializer='zeros',
            bias_initializer='ones'
        )
        
        # fc_beta: 2-layer MLP
        # PyTorch: Linear(cond_dim, num_features) -> ReLU -> Linear(num_features, num_features)
        # First layer: cond_dim -> num_features, he_uniform init (matches PyTorch Linear default)
        # Second layer: num_features -> num_features, zeros weight, zeros bias
        self.beta_linear1 = layers.Dense(num_features, kernel_initializer='he_uniform')
        self.beta_linear2 = layers.Dense(
            num_features,
            kernel_initializer='zeros',
            bias_initializer='zeros'
        )

    def call(self, x, y):
        """
        Args:
            x: [B, H, W, C] feature map
            y: [B, cond_dim] conditioning vector
        """
        # Compute gamma (scale)
        gamma = self.gamma_linear1(y)
        gamma = tf.nn.relu(gamma)
        gamma = self.gamma_linear2(gamma)  # [B, num_features]
        
        # Compute beta (shift)
        beta = self.beta_linear1(y)
        beta = tf.nn.relu(beta)
        beta = self.beta_linear2(beta)  # [B, num_features]
        
        # Reshape for broadcasting: [B, 1, 1, C]
        gamma = tf.reshape(gamma, [-1, 1, 1, self.num_features])
        beta = tf.reshape(beta, [-1, 1, 1, self.num_features])
        
        return gamma * x + beta


class DFBLK(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's DFBLK.
    
    PyTorch signature: DFBLK(cond_dim, in_ch)
    
    Structure:
        affine0 -> LeakyReLU(0.2) -> affine1 -> LeakyReLU(0.2)
    
    NO convolutions - just two affine transforms with activations.
    """
    def __init__(self, cond_dim, in_ch):
        super(DFBLK, self).__init__()
        # PyTorch: self.affine0 = Affine(cond_dim, in_ch)
        # Pass cond_dim to match PyTorch signature exactly
        self.affine0 = Affine(cond_dim, in_ch)
        self.affine1 = Affine(cond_dim, in_ch)

    def call(self, x, y):
        """
        Args:
            x: [B, H, W, C] feature map
            y: [B, cond_dim] conditioning vector
        Returns:
            [B, H, W, C] transformed feature map
        """
        h = self.affine0(x, y)
        h = tf.nn.leaky_relu(h, alpha=0.2)
        h = self.affine1(h, y)
        h = tf.nn.leaky_relu(h, alpha=0.2)
        return h



class G_Block(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's G_Block.
    
    PyTorch signature: G_Block(cond_dim, in_ch, out_ch, imsize)
    
    Structure:
        1. Interpolate to target size
        2. Residual path: fuse1(DFBLK) -> c1(conv) -> fuse2(DFBLK) -> c2(conv)
        3. Shortcut path: c_sc(1x1 conv) if in_ch != out_ch
        4. Output: shortcut + residual
    
    Note: imsize is handled dynamically via target_size parameter in call().
    """
    def __init__(self, cond_dim, in_ch, out_ch):
        super(G_Block, self).__init__()
        self.learnable_sc = in_ch != out_ch
        
        # PyTorch: nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.c1 = layers.Conv2D(out_ch, 3, strides=1, padding='same', kernel_initializer='he_uniform')
        self.c2 = layers.Conv2D(out_ch, 3, strides=1, padding='same', kernel_initializer='he_uniform')
        
        # PyTorch: DFBLK(cond_dim, in_ch) and DFBLK(cond_dim, out_ch)
        self.fuse1 = DFBLK(cond_dim, in_ch)
        self.fuse2 = DFBLK(cond_dim, out_ch)
        
        # Shortcut: 1x1 conv only if channel dimensions change
        # PyTorch: nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        if self.learnable_sc:
            self.c_sc = layers.Conv2D(out_ch, 1, strides=1, padding='valid', kernel_initializer='he_uniform')

    def call(self, h, y, target_size):
        """
        Args:
            h: [B, H, W, in_ch] input feature map
            y: [B, cond_dim] conditioning vector
            target_size: int, target spatial size for interpolation
        Returns:
            [B, target_size, target_size, out_ch] output feature map
        """
        # PyTorch: h = F.interpolate(h, size=(self.imsize, self.imsize))
        h = tf.image.resize(h, [target_size, target_size], method='nearest')
        
        # Residual path: fuse1 -> c1 -> fuse2 -> c2
        # PyTorch: h = self.fuse1(h, y); h = self.c1(h); h = self.fuse2(h, y); h = self.c2(h)
        res = self.fuse1(h, y)
        res = self.c1(res)
        res = self.fuse2(res, y)
        res = self.c2(res)
        
        # Shortcut path
        if self.learnable_sc:
            sc = self.c_sc(h)
        else:
            sc = h
            
        return sc + res


class D_Block(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's D_Block.
    
    PyTorch signature: D_Block(fin, fout, k, s, p, res, CLIP_feat)
    
    PyTorch structure:
        conv_r: Conv2D(fin, fout, k, s, p, bias=False) -> LeakyReLU(0.2) -> Conv2D(fout, fout, k, s, p, bias=False) -> LeakyReLU(0.2)
        conv_s: Conv2D(fin, fout, 1, stride=1, padding=0) for shortcut
        gamma: learnable scalar for residual (init=0)
        beta: learnable scalar for CLIP features (init=0)
    
    Note: All PyTorch D_Block instantiations use k=3, s=1, p=1, so we hardcode these.
    """
    def __init__(self, fin, fout, is_down=False, is_res=True, clip_feat=False):
        super(D_Block, self).__init__()
        self.is_res = is_res
        self.clip_feat = clip_feat
        self.learned_shortcut = (fin != fout)
        
        # Main conv path (PyTorch: k=3, s=1, p=1)
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.conv_r1 = layers.Conv2D(fout, 3, padding='same', use_bias=False, kernel_initializer='he_uniform')
        self.conv_r2 = layers.Conv2D(fout, 3, padding='same', use_bias=False, kernel_initializer='he_uniform')
        
        # Shortcut conv (PyTorch: 1x1, stride=1, padding=0)
        # CRITICAL: padding='valid' to match PyTorch padding=0
        self.conv_s = layers.Conv2D(fout, 1, padding='valid', kernel_initializer='he_uniform')
        
        # Learnable scalars (initialized to 0, matching PyTorch torch.zeros(1))
        if is_res:
            self.gamma = tf.Variable(0.0, trainable=True, name='gamma')
        if clip_feat:
            self.beta = tf.Variable(0.0, trainable=True, name='beta')

    def call(self, x, clip_f=None):
        # Residual path
        res = self.conv_r1(x)
        res = tf.nn.leaky_relu(res, alpha=0.2)
        res = self.conv_r2(res)
        res = tf.nn.leaky_relu(res, alpha=0.2)
        
        # Shortcut
        if self.learned_shortcut:
            x = self.conv_s(x)
        
        # Combine based on flags
        out = x
        if self.is_res:

            out = out + self.gamma * res     

        if self.clip_feat and clip_f is not None:            
            out = out + self.beta * clip_f
            
        return out

In [None]:
# ==============================================================================
# 3. CLIP ADAPTER (100% Faithful Replication)
# ==============================================================================
# This cell contains ONLY the CLIP_Adapter and its dependencies:
# - DFBLK (also used by G_Block, defined in Basic Blocks cell)
# - M_Block (for CLIP_Adapter)
# - CLIP_Mapper (for CLIP_Adapter)
# - CLIP_Adapter
# Matching PyTorch GALIP exactly.

class M_Block(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's M_Block.
    
    PyTorch signature: M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p)
    
    Structure:
        Residual: conv1(k,s,p) -> fuse1(DFBLK) -> conv2(k,s,p) -> fuse2(DFBLK)
        Shortcut: 1x1 conv if in_ch != out_ch
        Output: shortcut + residual
    
    Weight Initialization: PyTorch Conv2d uses Kaiming Uniform (He) by default.
    """
    def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p):
        super(M_Block, self).__init__()
        
        # PyTorch: nn.Conv2d(in_ch, mid_ch, k, s, p)
        # TensorFlow: padding='same' when p = k//2 and s=1
        # For k=3, s=1, p=1: this is standard 'same' padding
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.conv1 = layers.Conv2D(mid_ch, k, strides=s, padding='same', kernel_initializer='he_uniform')
        self.fuse1 = DFBLK(cond_dim, mid_ch)
        
        # PyTorch: nn.Conv2d(mid_ch, out_ch, k, s, p)
        self.conv2 = layers.Conv2D(out_ch, k, strides=s, padding='same', kernel_initializer='he_uniform')
        self.fuse2 = DFBLK(cond_dim, out_ch)
        
        # Shortcut: 1x1 conv only if channel dimensions change
        # PyTorch: nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        self.learnable_sc = in_ch != out_ch
        if self.learnable_sc:
            self.c_sc = layers.Conv2D(out_ch, 1, strides=1, padding='valid', kernel_initializer='he_uniform')

    def call(self, h, c):
        """
        Args:
            h: [B, H, W, in_ch] input feature map
            c: [B, cond_dim] conditioning vector
        Returns:
            [B, H, W, out_ch] output feature map
        """
        # Residual path: conv1 -> fuse1 -> conv2 -> fuse2
        res = self.conv1(h)
        res = self.fuse1(res, c)
        res = self.conv2(res)
        res = self.fuse2(res, c)
        
        # Shortcut path
        if self.learnable_sc:
            sc = self.c_sc(h)
        else:
            sc = h
            
        return sc + res


class CLIP_Mapper(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's CLIP_Mapper.
    
    PyTorch signature: CLIP_Mapper(CLIP)
    
    Key behaviors:
    1. Takes already-processed features [B, H, W, 768] (not raw images)
    2. Injects learnable prompts at layers [1,2,3,4,5,6,7,8]
    3. Does NOT apply post_layernorm (returns raw transformer output)
    4. Returns spatial features [B, H, W, 768]
    
    PyTorch forward flow:
        1. Reshape img [B, C, H, W] -> [B, H*W, C]
        2. Add CLS token -> [B, H*W+1, C]
        3. Add positional embedding
        4. Apply ln_pre (pre-LayerNorm)
        5. For each transformer layer (0-11):
           - If in selected [1,2,3,4,5,6,7,8]: inject prompt, run layer, remove prompt
           - Else: just run layer
        6. Remove CLS, reshape back to [B, 768, H, W]
    
    Note: HuggingFace TF CLIP uses:
        - embeddings.class_embedding: [768] raw tensor
        - embeddings.position_embedding: Embedding layer (need .embeddings to get weights)
        - pre_layrnorm: LayerNorm
        - encoder.layers: list of transformer blocks
    """
    def __init__(self, clip_vision_model):
        super(CLIP_Mapper, self).__init__()
        self.vision_model = clip_vision_model.vision_model
        # Freeze all CLIP parameters
        self.vision_model.trainable = False
        
    def call(self, img_feats, prompts):
        """
        Args:
            img_feats: [B, H, W, 768] - already 768-channel features from conv_fuse
                       (TF channels-last format)
            prompts: [B, 8, 768] - learnable prompts for injection
            
        Returns:
            [B, H, W, 768] - CLIP-mapped features (TF channels-last format)
        """
        B = tf.shape(img_feats)[0]
        H = tf.shape(img_feats)[1]
        W = tf.shape(img_feats)[2]
        
        # Cast prompts to match dtype (PyTorch: prompts.type(self.dtype))
        prompts = tf.cast(prompts, img_feats.dtype)
        
        # PyTorch: x = img.reshape(B, C, -1).permute(0, 2, 1) -> [B, H*W, C]
        # TF: img_feats is [B, H, W, C], reshape to [B, H*W, C]
        x = tf.reshape(img_feats, [B, H * W, 768])
        
        # Add CLS token
        # PyTorch: torch.cat([class_embedding + zeros(...), x], dim=1)
        cls_token = self.vision_model.embeddings.class_embedding  # [768]
        cls_token = tf.cast(cls_token, x.dtype)
        cls_token = tf.reshape(cls_token, [1, 1, 768])
        cls_token = tf.tile(cls_token, [B, 1, 1])  # [B, 1, 768]
        x = tf.concat([cls_token, x], axis=1)  # [B, H*W+1, 768]
        
        # Add positional embedding
        # PyTorch: x = x + self.positional_embedding.to(x.dtype)
        # HuggingFace TF: position_embedding is a tf.keras.layers.Embedding
        # Access the weight matrix via .weights[0] (NOT .embeddings which doesn't exist in Keras!)
        pos_embed = self.vision_model.embeddings.position_embedding.weights[0]  # [num_positions, 768]
        pos_embed = tf.cast(pos_embed, x.dtype)
        seq_len = tf.shape(x)[1]
        x = x + pos_embed[:seq_len, :]
        
        # Pre-LayerNorm
        # PyTorch: x = self.ln_pre(x)
        x = self.vision_model.pre_layrnorm(x)
        
        # Process through transformer layers with prompt injection
        # PyTorch: selected = [1,2,3,4,5,6,7,8], begin=0, end=12
        selected = [1, 2, 3, 4, 5, 6, 7, 8]
        prompt_idx = 0
        
        for i, layer in enumerate(self.vision_model.encoder.layers):
            if i in selected:
                # PyTorch: prompt = prompts[:,prompt_idx,:].unsqueeze(0) -> [1, B, D]
                # Then: x = torch.cat((x, prompt), dim=0) in LND format
                # In TF (BLD format): concat on axis=1
                p = prompts[:, prompt_idx, :]  # [B, 768]
                p = tf.expand_dims(p, 1)  # [B, 1, 768]
                x = tf.concat([x, p], axis=1)  # [B, L+1, 768]
                
                # Run transformer layer with training=False to ensure deterministic behavior
                # (disables Dropout even though model is frozen)
                layer_out = layer(x, output_attentions=False, training=False)
                x = layer_out[0]
                
                # Remove prompt (last token)
                # PyTorch: x = x[:-1,:,:]
                x = x[:, :-1, :]
                
                prompt_idx += 1
            else:
                # Run transformer layer with training=False
                layer_out = layer(x, output_attentions=False, training=False)
                x = layer_out[0]
        
        # IMPORTANT: PyTorch CLIP_Mapper does NOT apply post_layernorm!
        # It returns raw transformer output directly.
        
        # Remove CLS token and reshape back to spatial format
        # PyTorch: x.permute(1,0,2)[:,1:,:].permute(0,2,1).reshape(-1, 768, grid, grid)
        # TF: x is [B, L, 768], remove CLS -> [B, H*W, 768], reshape to [B, H, W, 768]
        x = x[:, 1:, :]  # Remove CLS: [B, H*W, 768]
        x = tf.reshape(x, [B, H, W, 768])
        
        return x


class CLIP_Adapter(layers.Layer):
    """
    100% Faithful replication of PyTorch GALIP's CLIP_Adapter.
    
    PyTorch signature:
        CLIP_Adapter(in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP)
    
    NetG instantiates with:
        CLIP_Adapter(
            in_ch=64,      # code_ch
            mid_ch=32,     # mid_ch  
            out_ch=64,     # code_ch
            G_ch=512,      # ngf*8
            CLIP_ch=768,   # CLIP hidden dim
            cond_dim=612,  # cond_dim+nz (512+100)
            k=3, s=1, p=1, # conv kernel params for M_Block
            map_num=4,     # number of M_Blocks
            CLIP=clip_model
        )
    
    Structure:
        1. fc_prompt: cond_dim -> CLIP_ch*8 (generate 8 prompts)
        2. FBlocks: map_num M_Blocks processing features
        3. conv_fuse: out_ch -> CLIP_ch (5x5 conv, pad=2)
        4. CLIP_ViT: CLIP_Mapper with prompt injection
        5. conv_out: 768 -> G_ch (5x5 conv, pad=2)
        6. Output: conv_out(fuse_feat + 0.1*map_feat)
    """
    def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, clip_model):
        super(CLIP_Adapter, self).__init__()
        self.CLIP_ch = CLIP_ch
        
        # FBlocks: ModuleList of M_Blocks
        # PyTorch: first is M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p)
        #          rest are M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p)
        self.f_blocks = []
        self.f_blocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p))
        for _ in range(map_num - 1):
            self.f_blocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p))
        
        # conv_fuse: project features to CLIP dimension
        # PyTorch: nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2) -> 5x5 kernel, stride=1, pad=2
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.conv_fuse = layers.Conv2D(CLIP_ch, 5, strides=1, padding='same', kernel_initializer='he_uniform')
        
        # CLIP Mapper (ViT with prompt injection)
        self.CLIP_ViT = CLIP_Mapper(clip_model)
        
        # conv_out: project back to generator channels
        # PyTorch: nn.Conv2d(768, G_ch, 5, 1, 2) -> 5x5 kernel, stride=1, pad=2
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.conv_out = layers.Conv2D(G_ch, 5, strides=1, padding='same', kernel_initializer='he_uniform')
        
        # Prompt generator
        # PyTorch: nn.Linear(cond_dim, CLIP_ch*8)
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Linear default (Kaiming Uniform)
        self.fc_prompt = layers.Dense(CLIP_ch * 8, kernel_initializer='he_uniform')

    def call(self, out, c):
        """
        Args:
            out: [B, H, W, in_ch] - input features (from fc_code reshape in NetG)
            c: [B, cond_dim] - conditioning vector (noise + text_embed concatenated)
            
        Returns:
            [B, H, W, G_ch] - output features for generator
        """
        # 1. Generate prompts: [B, CLIP_ch*8] -> [B, 8, CLIP_ch]
        # PyTorch: prompts = self.fc_prompt(c).view(c.size(0), -1, self.CLIP_ch)
        prompts = self.fc_prompt(c)
        prompts = tf.reshape(prompts, [-1, 8, self.CLIP_ch])
        
        # 2. Process through FBlocks (map_num M_Blocks)
        for FBlock in self.f_blocks:
            out = FBlock(out, c)
        
        # 3. Project to CLIP dimension
        fuse_feat = self.conv_fuse(out)  # [B, H, W, CLIP_ch]
        
        # 4. Run through CLIP Mapper with prompt injection
        map_feat = self.CLIP_ViT(fuse_feat, prompts)  # [B, H, W, 768]
        
        # 5. Combine with 0.1 scaling factor and project to output channels
        # PyTorch: return self.conv(fuse_feat + 0.1*map_feat)
        # The 0.1 factor is crucial - it gates the CLIP-mapped features
        return self.conv_out(fuse_feat + 0.1 * map_feat)

In [None]:
# ==============================================================================
# 4. MODELS (Generator, Discriminator, Encoders)
# ==============================================================================
class CLIP_Text_Encoder(layers.Layer):
    def __init__(self, clip_model):
        """
        Args:
            clip_model: A TFCLIPModel instance (not just text_model)
        """
        super().__init__()
        self.text_model = clip_model.text_model
        self.text_projection = clip_model.text_projection  # This is a Dense layer
        
        # Freeze CLIP text model and projection
        self.text_model.trainable = False
        self.text_projection.trainable = False  # Explicit freeze
        
    def call(self, input_ids, attention_mask=None):
        """
        Args:
            input_ids: [B, 77] tokenized text
            attention_mask: [B, 77] optional attention mask
            
        Returns:
            sent_emb: [B, 512] projected sentence embedding (for conditioning)
            word_emb: [B, 77, 768] sequence features (for word-level attention if needed)
        """
        # Get text model outputs
        # HuggingFace's TFCLIPTextModel already applies final_layer_norm internally
        outputs = self.text_model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            output_hidden_states=False,
            return_dict=True
        )
        
        # last_hidden_state: [B, 77, 768] (for clip-vit-base-patch32, hidden_size=768)
        word_emb = outputs.last_hidden_state
        
        # Find EOT token position (highest token ID in each sequence)
        # EOT token ID is 49407 for CLIP tokenizer
        eot_indices = tf.argmax(tf.cast(input_ids, tf.int32), axis=-1)  # [B]
        
        # Gather EOT embeddings: [B, 768]
        batch_size = tf.shape(input_ids)[0]
        batch_indices = tf.range(batch_size, dtype=tf.int64)
        eot_indices = tf.cast(eot_indices, tf.int64)
        gather_indices = tf.stack([batch_indices, eot_indices], axis=1)  # [B, 2]
        
        pooled_output = tf.gather_nd(word_emb, gather_indices)  # [B, 768]
        
        # Project to shared embedding space
        # text_projection is a Dense layer [768 -> 512] for ViT-B/32
        # IMPORTANT: Call the layer, don't matrix multiply!
        sent_emb = self.text_projection(pooled_output)  # [B, 512]
        
        return sent_emb, word_emb
    
    @property
    def trainable_weights(self):
        # Override to ensure projection is also not trained
        # (text_model.trainable=False already handles its weights)
        return []
    
    @property  
    def non_trainable_weights(self):
        return self.text_model.weights + self.text_projection.weights

class CLIP_Image_Encoder(layers.Layer):
    """
    Faithful TensorFlow replication of GALIP's CLIP_IMG_ENCODER.
    
    For clip-vit-base-patch32:
    - Hidden size: 768
    - Projection output: 512
    - Grid size: 7x7 (224/32 = 7)
    - 12 transformer layers
    - Selected layers for local features: [1, 4, 8]
    
    Returns:
        local_features: [B, 3, 7, 7, 768] - stacked local features from selected layers
        global_emb: [B, 512] - projected CLS token embedding
    """
    def __init__(self, clip_model):
        super().__init__()
        self.vision_model = clip_model.vision_model
        # Get projection layer (visual_projection in HuggingFace is a Dense layer!)
        self.visual_projection = clip_model.visual_projection  # Dense [768 -> 512]
        
        # Freeze CLIP vision model AND projection layer explicitly
        self.vision_model.trainable = False
        self.visual_projection.trainable = False  # Critical: separate from vision_model
        
    def transf_to_CLIP_input(self, inputs):
        """
        Transform generator output to CLIP input format.
        Matches PyTorch GALIP preprocessing exactly.
        
        Args:
            inputs: [B, H, W, 3] in range [-1, 1] (TF channels-last)
            
        Returns:
            x: [B, 224, 224, 3] normalized for CLIP
        """
        # PyTorch original: inputs*0.5+0.5 then ((inputs+1)*0.5-mean)/var
        # This is equivalent to: (inputs + 1) / 2, then normalize
        
        # 1. Convert from [-1, 1] to [0, 1]
        x = (inputs + 1.0) * 0.5
        
        # 2. Resize to (224, 224) using BICUBIC interpolation
        # IMPORTANT: CLIP uses bicubic, TF defaults to bilinear
        x = tf.image.resize(x, [224, 224], method='bicubic')
        
        # 3. Normalize with CLIP stats
        mean = tf.constant([0.48145466, 0.4578275, 0.40821073], dtype=x.dtype)
        std = tf.constant([0.26862954, 0.26130258, 0.27577711], dtype=x.dtype)
        x = (x - mean) / std
        
        return x

    def call(self, img):
        """
        Forward pass matching PyTorch CLIP_IMG_ENCODER.forward()
        
        Args:
            img: [B, H, W, 3] image tensor in [-1, 1] range
            
        Returns:
            local_features: [B, 3, 7, 7, 768] - features from layers [1, 4, 8]
            global_emb: [B, 512] - projected global embedding
        """
        x = self.transf_to_CLIP_input(img)
        
        # Get patch embeddings + CLS token + position embeddings
        # HuggingFace TFCLIPVisionModel.embeddings returns [B, 50, 768]
        # (1 CLS + 49 patches for 7x7 grid)
        x = self.vision_model.embeddings(x)
        
        # Pre-LayerNorm (ln_pre in PyTorch)
        x = self.vision_model.pre_layrnorm(x)
        
        # Extract local features at selected layers
        # PyTorch GALIP uses layers [1, 4, 8] (0-indexed)
        local_features = []
        selected = [1, 4, 8]
        
        for i, layer in enumerate(self.vision_model.encoder.layers):
            # Run transformer layer
            layer_output = layer(x, output_attentions=False)
            x = layer_output[0]  # [B, 50, 768]
            
            if i in selected:
                # Extract spatial features (remove CLS token)
                # x[:, 1:, :] -> [B, 49, 768]
                grid = x[:, 1:, :]  # [B, 49, 768]
                B = tf.shape(grid)[0]
                # Reshape to spatial: [B, 7, 7, 768]
                grid = tf.reshape(grid, [B, 7, 7, 768])
                local_features.append(grid)
        
        # Post-LayerNorm on CLS token only (ln_post in PyTorch)
        # x[:, 0, :] selects CLS token -> [B, 768]
        cls_token = self.vision_model.post_layernorm(x[:, 0, :])  # [B, 768]
        
        # Project to shared embedding space
        # IMPORTANT: visual_projection is a Dense LAYER, not a matrix!
        # Call the layer instead of matrix multiply
        global_emb = self.visual_projection(cls_token)  # [B, 512]
        
        # Stack local features: [B, 3, 7, 7, 768]
        # PyTorch returns: torch.stack(local_features, dim=1)
        local_features = tf.stack(local_features, axis=1)  # [B, 3, 7, 7, 768]
        
        return local_features, global_emb


class NetG(Model):
    """
    100% Faithful TensorFlow replication of PyTorch GALIP's NetG.
    Target output: 64x64x3
    
    PyTorch get_G_in_out_chs(ngf=64, imsize=64):
        in_out_pairs = [(512,512), (512,256), (256,128), (128,64)]
        target_sizes = [8, 16, 32, 64]
    """
    def __init__(self, ngf, nz, cond_dim, clip_model):
        super(NetG, self).__init__()
        self.ngf = ngf
        self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32
        self.CLIP_ch = 768
        
        self.fc_code = layers.Dense(self.code_sz * self.code_sz * self.code_ch, kernel_initializer='he_uniform')
        
        self.mapping = CLIP_Adapter(
            in_ch=self.code_ch,
            mid_ch=self.mid_ch,
            out_ch=self.code_ch,
            G_ch=ngf * 8,
            CLIP_ch=self.CLIP_ch,
            cond_dim=cond_dim + nz,
            k=3, s=1, p=1,
            map_num=4,
            clip_model=clip_model
        )
        
        # G_Blocks: get_G_in_out_chs(64, 64) -> [(512,512), (512,256), (256,128), (128,64)]
        # Target sizes: [8, 16, 32, 64]
        self.g_blocks = []
        self.g_blocks.append(G_Block(cond_dim + nz, ngf * 8, ngf * 8))   # 512->512, 7->8
        self.g_blocks.append(G_Block(cond_dim + nz, ngf * 8, ngf * 4))   # 512->256, 8->16
        self.g_blocks.append(G_Block(cond_dim + nz, ngf * 4, ngf * 2))   # 256->128, 16->32
        self.g_blocks.append(G_Block(cond_dim + nz, ngf * 2, ngf * 1))   # 128->64, 32->64
        self.target_sizes = [8, 16, 32, 64]
        
        self.to_rgb = tf.keras.Sequential([
            layers.LeakyReLU(0.2),
            layers.Conv2D(3, 3, padding='same', kernel_initializer='he_uniform'),
        ])

    def call(self, inputs, training=False):
        noise, c = inputs
        cond = tf.concat([noise, c], axis=1)
        
        out = self.fc_code(noise)
        out = tf.reshape(out, [-1, self.code_sz, self.code_sz, self.code_ch])
        
        out = self.mapping(out, cond)
        
        for block, target_size in zip(self.g_blocks, self.target_sizes):
            out = block(out, cond, target_size)
        
        out = self.to_rgb(out)
        out = tf.nn.tanh(out)
        
        return out


class NetD(Model):
    """
    100% Faithful TensorFlow replication of PyTorch GALIP's NetD.
    
    Operates on CLIP local features [B, 3, 7, 7, 768].
    Returns feature map for NetC to process.
    """
    def __init__(self, ndf):
        super(NetD, self).__init__()
        self.d_blocks = []
        self.d_blocks.append(D_Block(768, 768, is_res=True, clip_feat=True))
        self.d_blocks.append(D_Block(768, 768, is_res=True, clip_feat=True))
        
        self.main = D_Block(768, 512, is_res=True, clip_feat=False)

    def call(self, h):
        # h: [B, 3, 7, 7, 768] stacked local features
        out = h[:, 0]
        for idx in range(len(self.d_blocks)):
            out = self.d_blocks[idx](out, h[:, idx+1])
        out = self.main(out)
        return out  # [B, 7, 7, 512]


class NetC(Model):

    def __init__(self, ndf, cond_dim):
        super(NetC, self).__init__()
        self.cond_dim = cond_dim
        # CRITICAL: kernel_initializer='he_uniform' to match PyTorch Conv2d default
        self.joint_conv = tf.keras.Sequential([
            layers.Conv2D(128, 4, strides=1, padding='valid', use_bias=False, kernel_initializer='he_uniform'), # 7x7 -> 4x4
            layers.LeakyReLU(0.2),
            layers.Conv2D(1, 4, strides=1, padding='valid', use_bias=False, kernel_initializer='he_uniform') # 4x4 -> 1x1
        ])

    def call(self, out, cond):
        # out: [B, 7, 7, 512]
        # cond: [B, cond_dim]
        
        B = tf.shape(out)[0]
        cond = tf.reshape(cond, [B, 1, 1, self.cond_dim])
        cond = tf.tile(cond, [1, 7, 7, 1])
        h_c = tf.concat([out, cond], axis=3)
        return self.joint_conv(h_c)

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import matplotlib.pyplot as plt
import numpy as np
import os

def DiffAugment(x, policy='translation'):
    """
    TensorFlow implementation of DiffAugment.
    Supports 'color', 'translation', 'cutout'.
    """
    if policy:
        if 'color' in policy:
            x = rand_brightness(x)
            x = rand_saturation(x)
            x = rand_contrast(x)
        if 'translation' in policy:
            x = rand_translation(x)
        if 'cutout' in policy:
            x = rand_cutout(x)
    return x

# --- Augmentation Primitives ---
def rand_brightness(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=-0.5, maxval=0.5)
    x = x + magnitude
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_saturation(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=0.0, maxval=2.0)
    x_mean = tf.reduce_mean(x, axis=3, keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_contrast(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], minval=0.5, maxval=1.5)
    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
    x = (x - x_mean) * magnitude + x_mean
    return tf.clip_by_value(x, -1.0, 1.0)

def rand_translation(x, ratio=0.125):
    batch_size = tf.shape(x)[0]
    img_size = tf.shape(x)[1]
    shift = int(64 * ratio)
    
    # Pad the image with reflection
    x_padded = tf.pad(x, [[0, 0], [shift, shift], [shift, shift], [0, 0]], mode='REFLECT')
    
    # Vectorized Random Crop using crop_and_resize
    padded_size = tf.cast(img_size + 2*shift, tf.float32)
    max_offset = 2 * shift
    
    offsets_y = tf.random.uniform([batch_size], minval=0, maxval=max_offset + 1, dtype=tf.int32)
    offsets_x = tf.random.uniform([batch_size], minval=0, maxval=max_offset + 1, dtype=tf.int32)
    
    offsets_y = tf.cast(offsets_y, tf.float32)
    offsets_x = tf.cast(offsets_x, tf.float32)
    
    # Normalize coordinates to [0, 1] for crop_and_resize
    y1 = offsets_y / padded_size
    x1 = offsets_x / padded_size
    y2 = (offsets_y + tf.cast(img_size, tf.float32)) / padded_size
    x2 = (offsets_x + tf.cast(img_size, tf.float32)) / padded_size
    
    boxes = tf.stack([y1, x1, y2, x2], axis=1) # [B, 4]
    box_indices = tf.range(batch_size)
    
    x_translated = tf.image.crop_and_resize(
        x_padded, 
        boxes, 
        box_indices, 
        crop_size=[img_size, img_size]
    )
    
    return x_translated

def rand_cutout(x, ratio=0.5):
    batch_size = tf.shape(x)[0]
    img_size = tf.shape(x)[1]
    cutout_size = int(64 * ratio // 2) * 2
    
    iy, ix = tf.meshgrid(tf.range(img_size), tf.range(img_size), indexing='ij')
    iy = tf.expand_dims(iy, 0) 
    ix = tf.expand_dims(ix, 0)
    
    offset_x = tf.random.uniform([batch_size, 1, 1], minval=0, maxval=img_size + 1 - cutout_size, dtype=tf.int32)
    offset_y = tf.random.uniform([batch_size, 1, 1], minval=0, maxval=img_size + 1 - cutout_size, dtype=tf.int32)
    
    mask_x = tf.math.logical_and(ix >= offset_x, ix < offset_x + cutout_size)
    mask_y = tf.math.logical_and(iy >= offset_y, iy < offset_y + cutout_size)
    mask_box = tf.math.logical_and(mask_x, mask_y)
    
    mask_keep = tf.cast(tf.math.logical_not(mask_box), x.dtype)
    mask_keep = tf.expand_dims(mask_keep, -1) 
    
    return x * mask_keep

def save_sample_images(generator, text_encoder, fixed_input_ids, fixed_noise, epoch, save_dir):
    """
    Generates and saves a grid of images using fixed noise/text for consistency.
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Encode text
    text_embeds, _ = text_encoder(fixed_input_ids, training=False)
    
    # Generate
    fake_imgs = generator([fixed_noise, text_embeds], training=False)
    
    # Convert to [0, 1] for plotting
    fake_imgs = (fake_imgs + 1.0) * 0.5
    fake_imgs = tf.clip_by_value(fake_imgs, 0.0, 1.0).numpy()
    
    # Plot Grid
    n = int(np.sqrt(len(fake_imgs)))
    if n * n != len(fake_imgs): n = 8 
    
    plt.figure(figsize=(10, 2))
    for i in range(min(8, len(fake_imgs))):
        plt.subplot(1, 8, i+1)
        plt.imshow(fake_imgs[i])
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'epoch_{epoch:03d}.png'))
    plt.close()

@tf.function
def train_step(real_images, input_ids, attention_mask, 
               generator, discriminator, net_c, 
               text_encoder, image_encoder,
               g_optimizer, d_optimizer, 
               batch_size, z_dim, lambda_ma_gp=2.0, diff_augment_fn=None):
    
    # 1. Encode Text
    text_embeds, _ = text_encoder(input_ids, attention_mask=attention_mask) # [B, 512]
    
    # 2. Generate Fake Images
    noise = tf.random.normal([batch_size, z_dim])
    fake_images = generator([noise, text_embeds], training=True)
    
    # 3. Augment (Optional)
    if diff_augment_fn:
        real_images_aug = diff_augment_fn(real_images)
        fake_images_aug = diff_augment_fn(fake_images)
    else:
        real_images_aug = real_images
        fake_images_aug = fake_images
        
    # 4. Encode Images (CLIP)
    real_img_feats = image_encoder(real_images_aug)
    fake_img_feats = image_encoder(fake_images_aug)
    
    # 5. Train Discriminator
    with tf.GradientTape() as d_tape:
        # D(Real)
        d_real_score, d_real_feat = discriminator(real_img_feats, training=True)
        # NetC(Real)
        d_real_c = net_c(d_real_feat, text_embeds, training=True)
        
        # D(Fake)
        d_fake_score, d_fake_feat = discriminator(fake_img_feats, training=True)
        # NetC(Fake)
        d_fake_c = net_c(d_fake_feat, text_embeds, training=True)
        
        # D(Mismatch) - Wrong Text
        # Shuffle text embeddings
        text_embeds_mis = tf.roll(text_embeds, shift=1, axis=0)
        d_mis_c = net_c(d_real_feat, text_embeds_mis, training=True)
        
        # Losses
        # Hinge Loss
        errD_real = tf.reduce_mean(tf.nn.relu(1.0 - d_real_score))
        errD_real_c = tf.reduce_mean(tf.nn.relu(1.0 - d_real_c))
        
        errD_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_fake_score))
        errD_fake_c = tf.reduce_mean(tf.nn.relu(1.0 + d_fake_c))
        
        errD_mis = tf.reduce_mean(tf.nn.relu(1.0 + d_mis_c))
        
        # MA-GP (Matching Aware Gradient Penalty)
        alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        interpolated_feat = alpha * d_real_feat + (1 - alpha) * d_fake_feat
        
        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated_feat)
            out = net_c(interpolated_feat, text_embeds, training=True)
            
        grads = gp_tape.gradient(out, [interpolated_feat])[0]
        grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]) + 1e-8)
        ma_gp = tf.reduce_mean(tf.square(grad_norm - 1.0)) * lambda_ma_gp
        
        d_loss = errD_real + (errD_real_c + errD_fake_c + errD_mis) * 0.5 + errD_fake + ma_gp
        
    d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables + net_c.trainable_variables)
    d_optimizer.apply_gradients(zip(d_grads, discriminator.trainable_variables + net_c.trainable_variables))
    
    # 6. Train Generator
    with tf.GradientTape() as g_tape:
        fake_images = generator([noise, text_embeds], training=True)
        fake_images_aug = diff_augment_fn(fake_images) if diff_augment_fn else fake_images
        fake_img_feats = image_encoder(fake_images_aug)
        
        d_fake_score, d_fake_feat = discriminator(fake_img_feats, training=True)
        d_fake_c = net_c(d_fake_feat, text_embeds, training=True)
        
        g_loss = -tf.reduce_mean(d_fake_score) - tf.reduce_mean(d_fake_c)
        
    g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(g_grads, generator.trainable_variables))
    
    return {
        'd_loss': d_loss,
        'g_loss': g_loss,
        'ma_gp': ma_gp,
        'errD_real': errD_real,
        'errD_fake': errD_fake,
        'errD_mis': errD_mis
    }


In [None]:
import subprocess
import sys
from tqdm import tqdm

def train(dataset, args):
    """
    Main training loop for GALIP with TensorBoard logging.
    """
    
    # ==========================================================================
    # 1. Initialization & Logging Setup
    # ==========================================================================
    print(f"--- Initializing Models (Image Size: {args['IMAGE_SIZE']}) ---")
    
    # Load CLIP Models
    print("--- Loading CLIP Models ---")
    try:
        clip_vision_model = TFCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        clip_text_model = TFCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        print("✓ CLIP Models loaded")
    except Exception as e:
        print(f"⚠ Error loading CLIP models: {e}")
        return

    # Initialize Encoders
    text_encoder = CLIP_Text_Encoder(clip_text_model)
    image_encoder = CLIP_Image_Encoder(clip_vision_model)
    
    # Initialize GAN Models
    generator = NetG(ngf=args['NGF'], nz=args['Z_DIM'], cond_dim=args['EMBED_DIM'], clip_model=clip_vision_model)
    discriminator = NetD(ndf=args['NDF'])
    net_c = NetC(ndf=args['NDF'], cond_dim=args['EMBED_DIM'])
    
    # Optimizers
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=args['LR_G'], beta_1=0.0, beta_2=0.9)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=args['LR_D'], beta_1=0.0, beta_2=0.9)

    # Checkpoints
    checkpoint_dir = os.path.join(args['RUN_DIR'], 'checkpoints')
    checkpoint = tf.train.Checkpoint(
        generator=generator, discriminator=discriminator, net_c=net_c,
        g_optimizer=g_optimizer, d_optimizer=d_optimizer
    )
    manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    # TensorBoard Setup
    log_dir = os.path.join(args['RUN_DIR'], 'logs')
    summary_writer = tf.summary.create_file_writer(log_dir)
    
    try:
        tensorboard_process = subprocess.Popen(
            [sys.executable, "-m", "tensorboard.main", "--logdir", log_dir]
        )
        print(f"✓ TensorBoard launched (PID: {tensorboard_process.pid})")
    except Exception as e:
        print(f"⚠ Could not launch TensorBoard: {e}")

    # ==========================================================================
    # 2. DiffAugment Setup
    # ==========================================================================
    diff_augment_fn = None
    if args.get('USE_DIFFAUG', False):
        print(f"--- DiffAugment Enabled: {args['DIFFAUG_POLICY']} ---")
        def da_fn(imgs): return DiffAugment(imgs, policy=args['DIFFAUG_POLICY'])
        diff_augment_fn = da_fn
    else:
        print("--- DiffAugment Disabled ---")

    # ==========================================================================
    # 3. Training Loop
    # ==========================================================================
    
    # Fixed noise/text for visualization
    fixed_noise = tf.random.normal([8, args['Z_DIM']])
    # Take first batch for fixed text
    for fixed_text, _ in dataset.take(1):
        # Tokenize fixed text
        # We need to tokenize manually here or use the generator's output if it yields tokens
        # Our generator yields (img, input_ids, attention_mask)
        pass
        
    # Wait, dataset yields (img, input_ids, attention_mask)
    # So we can just take it.
    for _, fixed_input_ids, fixed_mask in dataset.take(1):
        fixed_input_ids = fixed_input_ids[:8]
        fixed_mask = fixed_mask[:8]
        break

    start_epoch = 0
    if manager.latest_checkpoint:
        checkpoint.restore(manager.latest_checkpoint)
        print(f"Restored from {manager.latest_checkpoint}")
        
    print(f"Starting training for {args['MAX_EPOCH']} epochs...")
    
    for epoch in range(start_epoch, args['MAX_EPOCH']):
        start_time = time.time()
        
        # Progress bar
        pbar = tqdm(dataset, desc=f"Epoch {epoch+1}/{args['MAX_EPOCH']}")
        
        d_losses = []
        g_losses = []
        
        for step, (real_images, input_ids, attention_mask) in enumerate(pbar):
            
            losses = train_step(
                real_images, 
                input_ids, 
                attention_mask, 
                generator, 
                discriminator, 
                net_c, 
                text_encoder, 
                image_encoder,
                g_optimizer, 
                d_optimizer, 
                args['BATCH_SIZE'], 
                args['Z_DIM'],
                lambda_ma_gp=args['LAMBDA_MA_GP'],
                diff_augment_fn=diff_augment_fn
            )
            
            d_losses.append(losses['d_loss'])
            g_losses.append(losses['g_loss'])
            
            # Update pbar
            pbar.set_postfix({
                'D': f"{losses['d_loss']:.4f}", 
                'G': f"{losses['g_loss']:.4f}",
                'MA': f"{losses['ma_gp']:.4f}"
            })
            
            # Log to TensorBoard
            with summary_writer.as_default():
                step_global = epoch * len(dataset) + step
                tf.summary.scalar('Loss/D', losses['d_loss'], step=step_global)
                tf.summary.scalar('Loss/G', losses['g_loss'], step=step_global)
                tf.summary.scalar('Loss/MA_GP', losses['ma_gp'], step=step_global)
                tf.summary.scalar('Loss/D_Real', losses['errD_real'], step=step_global)
                tf.summary.scalar('Loss/D_Fake', losses['errD_fake'], step=step_global)
                tf.summary.scalar('Loss/D_Mis', losses['errD_mis'], step=step_global)

        # End of Epoch
        avg_d_loss = np.mean(d_losses)
        avg_g_loss = np.mean(g_losses)
        print(f"Epoch {epoch+1} done. D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}, Time: {time.time()-start_time:.1f}s")
        
        # Save Checkpoint
        if (epoch + 1) % args['SAVE_FREQ'] == 0:
            save_path = manager.save()
            print(f"Saved checkpoint for epoch {epoch+1}: {save_path}")
            
        # Save Sample Images
        if (epoch + 1) % args['SAMPLE_FREQ'] == 0:
            save_sample_images(generator, text_encoder, fixed_input_ids, fixed_noise, epoch+1, os.path.join(args['RUN_DIR'], 'samples'))

    print("Training Complete.")


In [None]:
## Define configuration for training
import datetime
import json
import os

# Create a unique run directory
run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
run_dir = f"./runs/{run_id}"
if not os.path.exists(run_dir):
    os.makedirs(run_dir)

# User provided config
config = {
    'IMAGE_SIZE': [64, 64, 3],
    'NGF': 64,
    'NDF': 64,
    'Z_DIM': 100,
    'EMBED_DIM': 512,        # CLIP embedding dimension
    'LR_G': 0.0001,
    'LR_D': 0.0004,          # MATCHED to LR_G to prevent D from overpowering G
    'BATCH_SIZE': BATCH_SIZE,
    'MAX_EPOCH': 600,          # Updated to 600 for good results
    'LAMBDA_MA_GP': 2.0,
    'RUN_DIR': run_dir,
    'SAVE_FREQ': 25,         # Save less frequently to save space
    'SAMPLE_FREQ': 1,        # Sample every epoch
    'USE_DIFFAUG': False,    # DISABLED: To strictly match official DF-GAN and avoid MA-GP conflicts
    'DIFFAUG_POLICY': 'translation',
    'N_SAMPLE': num_training_sample if 'num_training_sample' in locals() else 7370
}

# Save config for reproducibility
with open(os.path.join(run_dir, 'config.json'), 'w') as f:
    # Filter for JSON serializable values
    json_config = {k: v for k, v in config.items() if isinstance(v, (int, float, str, list, bool))}
    json.dump(json_config, f, indent=4)

print(f"Training Run Directory: {run_dir}")
print(f"Config: {json.dumps(json_config, indent=2)}")


In [None]:
# 'dataset' is the tf.data.Dataset object you created in the notebook
train(dataset, config)


<h2 id="Visualiztion">Visualiztion<a class="anchor-link" href="#Visualiztion">¶</a></h2>
<p>During training, we can visualize the generated image to evaluate the quality of generator. The followings are some functions helping visualization.</p>




<p>We always use same random seed and same senteces during training, which is more convenient for us to evaluate the quality of generated image.</p>




<h2 id="Training">Training<a class="anchor-link" href="#Training">¶</a></h2>




<h2 id="Testing-Dataset">Testing Dataset<a class="anchor-link" href="#Testing-Dataset">¶</a></h2>
<p>If you change anything during preprocessing of training dataset, you must make sure same operations have be done in testing dataset.</p>



In [None]:
def testing_data_generator(caption_text, index):
		"""
		Updated testing data generator using CLIP tokenization
		
		Args:
				caption_text: Raw text string
				index: Test sample ID
		
		Returns:
				input_ids, attention_mask, index
		"""
		def tokenize_caption_clip(text):
				"""Python function to tokenize text using CLIP tokenizer"""
				# Convert EagerTensor to bytes, then decode to string
				text = text.numpy().decode('utf-8')
				
				# Tokenize using CLIP
				encoded = tokenizer(
						text,
						padding='max_length',
						truncation=True,
						max_length=77,
						return_tensors='np'
				)
				
				return encoded['input_ids'][0], encoded['attention_mask'][0]
		
		# Use tf.py_function to call Python tokenizer
		input_ids, attention_mask = tf.py_function(
				func=tokenize_caption_clip,
				inp=[caption_text],
				Tout=[tf.int32, tf.int32]
		)
		
		# Set shapes explicitly
		input_ids.set_shape([77])
		attention_mask.set_shape([77])
		
		return input_ids, attention_mask, index

def testing_dataset_generator(batch_size, data_generator):
		"""
		Updated testing dataset generator - decodes IDs to raw text
		"""
		data = pd.read_pickle('./dataset/testData.pkl')
		captions_ids = data['Captions'].values
		caption_texts = []
		
		# Decode pre-tokenized IDs back to text
		for i in range(len(captions_ids)):
				chosen_caption_ids = captions_ids[i]
				
				# Decode IDs back to text using id2word_dict
				words = []
				for word_id in chosen_caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':  # Skip padding tokens
								words.append(word)
				
				caption_text = ' '.join(words)
				caption_texts.append(caption_text)
		
		index = data['ID'].values
		index = np.asarray(index)
		
		# Create dataset from raw text
		dataset = tf.data.Dataset.from_tensor_slices((caption_texts, index))
		dataset = dataset.map(data_generator, num_parallel_calls=tf.data.AUTOTUNE)
		dataset = dataset.repeat().batch(batch_size)
		
		return dataset

In [None]:
testing_dataset = testing_dataset_generator(BATCH_SIZE, testing_data_generator)


In [None]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / BATCH_SIZE)



<h2 id="Inferece">Inferece<a class="anchor-link" href="#Inferece">¶</a></h2>



In [None]:
# Create inference directory inside the run directory
inference_dir = os.path.join(config['RUN_DIR'], 'inference')
if not os.path.exists(inference_dir):
    os.makedirs(inference_dir)
print(f"Inference Directory: {inference_dir}")

In [None]:
def inference(dataset, config):
    print("--- Starting Inference ---")
    
    # 1. Re-initialize Models
    # We need to re-create the models to load weights into them
    print("Loading models...")
    generator = NetG(ngf=config['NGF'], nz=config['Z_DIM'], cond_dim=config['EMBED_DIM'])
    
    # Use RNN_Encoder instead of ClipTextEncoder
    # nhidden=256 to match training
    # Force vocab_size=5429
    vocab_size = 5429
    text_encoder = RNN_Encoder(ntoken=vocab_size, ninput=300, nhidden=256, nlayers=1)
    
    # Load Pretrained Weights for Text Encoder
    # Keras 3 requires .weights.h5 extension
    damsm_weights_path = './damsm_checkpoints/text_encoder.weights.h5'
    
    if os.path.exists(damsm_weights_path):
        print(f"✓ Loading pretrained DAMSM weights from {damsm_weights_path}")
        dummy_input = tf.zeros((1, 20), dtype=tf.int32)
        text_encoder(dummy_input)
        try:
            text_encoder.load_weights(damsm_weights_path)
            print("✓ Weights loaded successfully.")
        except Exception as e:
            print(f"⚠ Error loading weights: {e}")
    else:
        print("⚠ WARNING: No pretrained DAMSM weights found! Encoder is random.")

    # Dummy call to build the model (optional but good practice)
    # generator.build((None, config['Z_DIM'])) 
    
    # 2. Load Checkpoint
    checkpoint_dir = os.path.join(config['RUN_DIR'], 'checkpoints')
    
    # We need to restore the generator. 
    # Note: We must define the checkpoint object exactly as it was saved to restore correctly,
    # or use expect_partial() if we only care about specific parts (like generator).
    checkpoint = tf.train.Checkpoint(generator=generator)
    
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
        print(f"Loading weights from: {latest_ckpt}")
        status = checkpoint.restore(latest_ckpt).expect_partial()
        status.assert_existing_objects_matched()
        print("✓ Weights loaded successfully")
    else:
        print("⚠ NO CHECKPOINT FOUND! Generating with random weights (Garbage output).")

    # 3. Inference Loop
    total_images = 0
    pbar = tqdm(total=NUM_TEST, desc='Generating images', unit='img')
    
    for step, (caption_texts, image_ids) in enumerate(dataset):
        # caption_texts: [B, 10] (list of strings? No, dataset generator returns strings?)
        # Wait, testing_dataset_generator returns (caption_texts, index)
        # caption_texts is a list of strings.
        # We need to tokenize them.
        
        # Actually, let's check testing_dataset_generator.
        # It returns caption_texts which are strings.
        # We need to convert to IDs.
        
        # Tokenize
        # We need to map words to IDs using word2Id_dict
        # This is slow in loop, but fine for inference.
        
        batch_size_curr = len(caption_texts)
        input_ids_list = []
        
        for cap in caption_texts:
            # cap is a tensor string, need to decode
            cap_str = cap.numpy().decode('utf-8')
            
            # Preprocess (simple split and map)
            # Remove punctuation
            cap_str = cap_str.translate(str.maketrans('', '', string.punctuation))
            words = cap_str.lower().split()
            
            ids = []
            for w in words:
                if w in word2Id_dict:
                    ids.append(word2Id_dict[w])
                else:
                    ids.append(word2Id_dict['<RARE>'])
            
            # Pad/Truncate
            if len(ids) > MAX_SEQ_LENGTH:
                ids = ids[:MAX_SEQ_LENGTH]
            else:
                ids = ids + [word2Id_dict['<PAD>']] * (MAX_SEQ_LENGTH - len(ids))
            
            input_ids_list.append(ids)
            
        input_ids = tf.convert_to_tensor(input_ids_list, dtype=tf.int32)
        
        # Encode Text
        # Compute lengths
        cap_lens = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, 0), tf.int32), axis=1)
        _, text_embeddings = text_encoder(input_ids, cap_lens=cap_lens, training=False)
        
        # Generate Noise
        noise = tf.random.normal([batch_size_curr, config['Z_DIM']])
        
        # Generate Images
        fake_imgs = generator([noise, text_embeddings], training=False)
        
        # Post-process
        fake_imgs = (fake_imgs + 1.0) * 0.5
        fake_imgs = tf.clip_by_value(fake_imgs, 0.0, 1.0).numpy()
        
        # Save Images
        for i in range(batch_size_curr):
            img_id = image_ids[i].numpy().decode('utf-8')
            save_path = os.path.join(inference_dir, f'inference_{img_id}.jpg')
            plt.imsave(save_path, fake_imgs[i])
            total_images += 1
            pbar.update(1)
            
    pbar.close()
    print(f"Inference Complete. Saved {total_images} images to {inference_dir}")

In [None]:
inference(testing_dataset, config)

In [None]:
# Run evaluation script to generate score.csv
# Note: This must be run from the testing directory because inception_score.py uses relative paths
# Arguments: [inference_dir] [output_csv] [batch_size]
# Batch size must be 1, 2, 3, 7, 9, 21, or 39 to avoid remainder (819 test images)

# Save score.csv inside the run directory
print("running in ", inference_dir, "with", run_dir)
!cd testing && python inception_score.py ../{inference_dir}/ ../{run_dir}/score.csv 39

## Visualize Generated Images

Below we randomly sample 20 images from our generated test results to visually inspect the quality and diversity of the model's outputs.


<h1><center class="subtitle">Demo</center></h1>

<p>We demonstrate the capability of our model (TA80) to generate plausible images of flowers from detailed text descriptions.</p>



In [None]:
# Visualize 20 random generated images with their captions
import glob

# Load test data
data = pd.read_pickle('./dataset/testData.pkl')
test_captions = data['Captions'].values
test_ids = data['ID'].values

# Get all generated images from the current inference directory
image_files = sorted(glob.glob(inference_dir + '/inference_*.jpg'))

if len(image_files) == 0:
		print(f'⚠ No images found in {inference_dir}')
		print('Please run the inference cell first!')
else:
		# Randomly sample 20 images
		np.random.seed(42)  # For reproducibility
		num_samples = min(20, len(image_files))
		sample_indices = np.random.choice(len(image_files), size=num_samples, replace=False)
		sample_files = [image_files[i] for i in sorted(sample_indices)]

		# Create 4x5 grid
		fig, axes = plt.subplots(4, 5, figsize=(20, 16))
		axes = axes.flatten()

		for idx, img_path in enumerate(sample_files):
				# Extract image ID from filename
				img_id = int(Path(img_path).stem.split('_')[1])
				
				# Find caption
				caption_idx = np.where(test_ids == img_id)[0][0]
				caption_ids = test_captions[caption_idx]
				
				# Decode caption
				caption_text = ''
				for word_id in caption_ids:
						word = id2word_dict[str(word_id)]
						if word != '<PAD>':
								caption_text += word + ' '
				
				# Load and display image
				img = plt.imread(img_path)
				axes[idx].imshow(img)
				axes[idx].set_title(f'ID: {img_id}\n{caption_text[:60]}...', fontsize=8)
				axes[idx].axis('off')

		# Hide unused subplots if less than 20 images
		for idx in range(num_samples, 20):
				axes[idx].axis('off')

		plt.tight_layout()
		plt.suptitle(f'Random Sample of {num_samples} Generated Images', fontsize=16, y=1.002)
		plt.show()

		print(f'\nTotal generated images: {len(image_files)}')
		print(f'Images directory: {inference_dir}')