**Import Libraries and Setup Environment**

In [51]:
import os
import time
import numpy as np
import tensorflow as tf
import json
import itertools
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from typing import Dict, List, Tuple, Optional, Union, Callable
from scipy import linalg
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
from dataclasses import dataclass
from tensorflow.keras import Model

In [52]:
print(f"TensorFlow version: {tf.__version__}")
gpus = tf.config.list_physical_devices('GPU')
print(f"GPU Available: {len(gpus) > 0}")
if len(gpus) > 0:
    print(f"GPU Details: {gpus}")

TensorFlow version: 2.18.0
GPU Available: True
GPU Details: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [53]:
np.random.seed(42)
tf.random.set_seed(42)

In [54]:
@dataclass
class GANConfig:
    z_dim: int = 128
    base_size: int = 4
    learning_rate_g: float = 1e-4
    learning_rate_d: float = 5e-5
    beta1: float = 0.5
    beta2: float = 0.999
    generator_steps: int = 2
    discriminator_steps: int = 1
    use_feature_matching: bool = True
    use_historical_averaging: bool = True
    use_minibatch_discrimination: bool = True
    feature_matching_weight: float = 10.0
    historical_averaging_weight: float = 0.5
    label_smoothing: float = 0.2
    generator_target_prob: float = 0.9
    sample_freq: int = 1
    save_freq: int = 10
    checkpoint_dir: str = "./checkpoints"
    sample_dir: str = "./samples"
    log_dir: str = "./logs"
    batch_size: int = 64
    image_size: int = 32
    channels: int = 3
    use_wgan_gp: bool = True
    gp_lambda: float = 10.0
    use_spectral_norm: bool = True

    @classmethod
    def from_json(cls, json_path):
        with open(json_path, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

    def to_json(self, json_path):
        config_dict = {k: v for k, v in self.__dict__.items()}
        with open(json_path, 'w') as f:
            json.dump(config_dict, f, indent=2)

    def __str__(self):
        return '\n'.join(f"{k}: {v}" for k, v in self.__dict__.items())

**Implementing the Core TransGAN Architectures**

This technique addresses the issue that traditional batch normalization makes the output of a layer dependent on all other instances in the same batch. Virtual batch normalization normalizes samples against a fixed reference batch for more stable training.

In [55]:
class VirtualBatchNormalization(layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super(VirtualBatchNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.reference_batch_set = False

    def build(self, input_shape):
        self.ndim = len(input_shape)
        shape = [1] * self.ndim
        shape[-1] = input_shape[-1]

        self.gamma = self.add_weight(
            shape=(input_shape[-1],),
            initializer=tf.random_normal_initializer(1.0, 0.02),
            name='gamma',
            trainable=True
        )
        self.beta = self.add_weight(
            shape=(input_shape[-1],),
            initializer=tf.zeros_initializer(),
            name='beta',
            trainable=False
        )

        self.ref_mean = self.add_weight(
            shape=shape,
            initializer=tf.zeros_initializer(),
            name='ref_mean',
            trainable=False
        )
        self.ref_var = self.add_weight(
            shape=shape,
            initializer=tf.ones_initializer(),
            name='ref_var',
            trainable=False
        )

        super(VirtualBatchNormalization, self).build(input_shape)

    def _get_axis(self):
        return list(range(self.ndim - 1))

    def set_reference_batch(self, x):
        axes = self._get_axis()
        mean = tf.reduce_mean(x, axis=axes, keepdims=True)
        var = tf.reduce_mean(tf.square(x - mean), axis=axes, keepdims=True)

        self.ref_mean.assign(mean)
        self.ref_var.assign(var)
        self.reference_batch_set = True

    def call(self, inputs, set_reference=False, **kwargs):
        if set_reference or not self.reference_batch_set:
            axes = self._get_axis()
            ref_mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
            ref_var = tf.reduce_mean(tf.square(inputs - ref_mean), axis=axes, keepdims=True)

            self.ref_mean.assign(ref_mean)
            self.ref_var.assign(ref_var)
            self.reference_batch_set = True

            batch_mean = ref_mean
            batch_var = ref_var
        else:
            axes = self._get_axis()
            batch_mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
            batch_var = tf.reduce_mean(tf.square(inputs - batch_mean), axis=axes, keepdims=True)

            batch_mean = 0.5 * (batch_mean + self.ref_mean)
            batch_var = 0.5 * (batch_var + self.ref_var)

        batch_var = tf.maximum(batch_var, self.epsilon)

        x_norm = (inputs - batch_mean) / tf.sqrt(batch_var)

        gamma_reshaped = self.gamma
        beta_reshaped = self.beta

        if self.ndim > 2:
            gamma_reshaped = tf.reshape(self.gamma, [1] * (self.ndim - 1) + [self.gamma.shape[0]])
            beta_reshaped = tf.reshape(self.beta, [1] * (self.ndim - 1) + [self.beta.shape[0]])

        return x_norm * gamma_reshaped + beta_reshaped

**EmbeddingProjection Layer**:

The EmbeddingProjection layer is designed to project random noise (latent vectors) into a structured feature space, transforming a 1D vector into a spatial feature map. This transformation is a crucial step in the generator of TransGAN, where convolutional layers are replaced with transformer-based architectures.

In [56]:
class EmbeddingProjection(tf.keras.Model):
    def __init__(self, in_features, out_features, base_size):
        super(EmbeddingProjection, self).__init__()
        self.in_features = in_features
        self.base_size = base_size
        self.out_features = out_features

        total_output_size = (self.base_size ** 2) * self.out_features

        self.fc = tf.keras.layers.Dense(
            total_output_size,
            input_shape=(in_features,),
            kernel_initializer=tf.keras.initializers.HeNormal(),
            bias_initializer='zeros'
        )

    def call(self, latents, training=False):
        flattened = self.fc(latents)
        batch_size = tf.shape(latents)[0]
        return tf.reshape(flattened, (batch_size, self.base_size ** 2, self.out_features))

def spectral_norm_wrapper(layer):
    if hasattr(layer, 'kernel'):
        layer.kernel = spectral_normalization(layer.kernel)
    return layer

def spectral_normalization(w, power_iterations=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])

    u = tf.random.normal([1, w_shape[-1]])

    for _ in range(power_iterations):
        v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
        u = tf.math.l2_normalize(tf.matmul(v, w))

    sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)
    w_norm = w / sigma
    w_norm = tf.reshape(w_norm, w_shape)

    return w_norm

**GridAttention Layer**

The GridAttention layer is a self-attention mechanism designed for TransGAN. It follows the multi-head self-attention (MHSA) mechanism, where the input to this layer is a set of patch embeddings, then transformed into query ($Q$), key ($K$), and value ($V$) tensors, and the output is an updated set of embeddings after applying attention which is computed using dot-product interactions. By applying attention within structured feature maps, this layer enables spatially-aware self-attention, improving information flow and enabling the generator to capture local and global dependencies efficiently. Compared with normal multi-head self-attention, the grid attention mechanism limits the calculations to local regions to reduce the computational cost.

In [57]:
class GridAttention(Model):
    def __init__(self, embed_dim, num_heads, window_size, qkv_bias=True, qk_scale=None, attn_drop=0.1, proj_drop=0.1, noise_enabled=True, use_spectral_norm=False):
        super(GridAttention, self).__init__()

        self.num_heads = num_heads
        head_dim = embed_dim // num_heads
        self.scale = qk_scale if qk_scale is not None else head_dim ** -0.5
        self.window_size = window_size
        self.noise_enabled = noise_enabled
        self.use_spectral_norm = use_spectral_norm
        self.qkv = layers.Dense(embed_dim * 3, use_bias=qkv_bias)
        self.attn_drop = layers.Dropout(attn_drop)
        self.proj = layers.Dense(embed_dim)
        self.proj_drop = layers.Dropout(proj_drop)

        if self.use_spectral_norm:
            self.qkv = spectral_norm_wrapper(self.qkv)
            self.proj = spectral_norm_wrapper(self.proj)

        if self.noise_enabled:
            self.noise_strength = self.add_weight(shape=(), initializer="zeros", trainable=True)

        if self.window_size:
            coords_h = tf.range(window_size)
            coords_w = tf.range(window_size)
            coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing='ij'))
            coords_flatten = tf.reshape(coords, (2, -1))
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
            relative_coords = tf.transpose(relative_coords, (1, 2, 0))
            relative_coords = tf.cast(relative_coords, dtype=tf.int32)
            relative_coords = relative_coords + (window_size - 1)
            relative_coords = relative_coords[:, :, 0] * (2 * window_size - 1) + relative_coords[:, :, 1]

            self.relative_position_index = tf.Variable(relative_coords, trainable=False, dtype=tf.int32)

            bias_table_shape = ((2 * window_size - 1) * (2 * window_size - 1), num_heads)
            self.relative_position_bias_table = self.add_weight(shape=bias_table_shape, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), trainable=True)

    def call(self, patch_embeddings, training=False):
        batch_size, n_tokens, embedding_dim = tf.shape(patch_embeddings)[0], tf.shape(patch_embeddings)[1], tf.shape(patch_embeddings)[2]

        if self.noise_enabled and training:
            noise = tf.random.normal((batch_size, n_tokens, 1)) * self.noise_strength
            patch_embeddings = patch_embeddings + noise

        qkv = self.qkv(patch_embeddings)
        qkv = tf.reshape(qkv, (batch_size, n_tokens, 3, self.num_heads, embedding_dim // self.num_heads))
        qkv = tf.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = tf.matmul(q, k, transpose_b=True) * self.scale

        if self.window_size:
            try:
                relative_position_bias_index = tf.reshape(self.relative_position_index, [-1])
                relative_position_bias = tf.gather(self.relative_position_bias_table, relative_position_bias_index)
                relative_position_bias = tf.reshape(relative_position_bias, (self.window_size * self.window_size, self.window_size * self.window_size, -1))
                relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))
                attn = attn + tf.expand_dims(relative_position_bias, axis=0)
            except Exception as e:
                print(f"Warning: Error in window attention: {e}")

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)

        aggregated = tf.matmul(attn, v)
        aggregated = tf.transpose(aggregated, (0, 2, 1, 3))
        aggregated = tf.reshape(aggregated, (batch_size, n_tokens, embedding_dim))

        aggregated = self.proj(aggregated)
        aggregated = self.proj_drop(aggregated, training=training)

        return aggregated

The GridAttention Layer above works by:

1. Projects a sequence of patch embeddings (input)  into Query ($Q$), Key ($K$), and Value ($V$) tensors using a fully connected layer (Dense) to compute $Q$, $K$, and $V$ simultaneously.

2. Computes attention scores using scaled dot-product attention
   
   2.1 Performs matrix multiplication between $Q$ and $K$ to compute raw attention scores.

   2.2 Incorporates relative positional bias (if window_size is specified).

3. Retrieves learnable positional biases and adds them to the attention scores

4. Applies softmax normalization to obtain attention weights and uses attention weights to aggregate values across all patches.

5. Applies final projection and dropout

**EncoderBlock Layer**

The EncoderBlock is a transformer-based processing unit designed to refine patch embeddings using self-attention and feedforward layers. It operates as a basic building block for deep vision models, like GridAttention. By applying attention-based refinement and feedforward transformations, this layer helps encode spatial and contextual relationships, making it a crucial component for deep vision transformers.

In [58]:
class EncoderBlock(Model):
    def __init__(self, d_model, n_heads, d_feedforward, window_size, dropout_rate=0.1, activation=None, noise_enabled=True, use_spectral_norm=False):
        super(EncoderBlock, self).__init__()
        self.attention = GridAttention(d_model, n_heads, window_size,
                                      noise_enabled=noise_enabled,
                                      attn_drop=dropout_rate,
                                      proj_drop=dropout_rate,
                                      use_spectral_norm=use_spectral_norm)
        self.feedforward_dense1 = layers.Dense(d_feedforward)
        self.activation = activation if callable(activation) else tf.nn.gelu
        self.dropout1 = layers.Dropout(dropout_rate)
        self.feedforward_dense2 = layers.Dense(d_model)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

        if use_spectral_norm:
            self.feedforward_dense1 = spectral_norm_wrapper(self.feedforward_dense1)
            self.feedforward_dense2 = spectral_norm_wrapper(self.feedforward_dense2)

    def call(self, patch_embeddings, training=False):
        attn_out = self.attention(self.norm1(patch_embeddings), training=training)
        patch_embeddings = patch_embeddings + attn_out
        ff_out = self.feedforward_dense1(self.norm2(patch_embeddings))
        ff_out = self.activation(ff_out)
        ff_out = self.dropout1(ff_out, training=training)
        ff_out = self.feedforward_dense2(ff_out)

        patch_embeddings = patch_embeddings + ff_out
        return patch_embeddings

This layer is structured similarly to a Transformer encoder:

1. First, applies self-attention to model relationships between patches.

2. Second, passes the embeddings through a feedforward network (MLP) for additional transformation.

3. Third, incorporates Layer Normalization (LayerNorm) and residual connections, ensuring stable gradient flow and improved convergence.

4. Forth, Optional dropout and activation functions provide regularization and non-linearity, enhancing generalization.

The input is a set of patch embeddings, and the output is an updated set of embeddings that have been refined through attention and feedforward transformations.

**StageBlock Layer**

The StageBlock serves as a higher-level processing unit in a transformer-based architecture. It is composed of multiple EncoderBlock layers, stacked together to progressively refine patch embeddings. Each EncoderBlock in the StageBlock applies self-attention and feedforward transformations, helping the model capture both local and global dependencies. By stacking multiple encoder blocks, StageBlock allows deeper feature extraction and hierarchical representation learning.

The input is a set of patch embeddings, and the output is a more refined set of embeddings after passing through multiple EncoderBlock layers.

In [59]:
class StageBlock(Model):
    def __init__(self, depth, num_heads, d_embeddings, d_ratio, window_size, activation=None, noise_enabled=True, use_spectral_norm=False):
        super(StageBlock, self).__init__()
        self.blocks = [
            EncoderBlock(
                d_model=d_embeddings,
                n_heads=num_heads,
                d_feedforward=d_embeddings * d_ratio,
                window_size=window_size,
                activation=activation,
                noise_enabled=noise_enabled,
                use_spectral_norm=use_spectral_norm
            )
            for _ in range(depth)
        ]

    def call(self, patch_embeddings, training=False):
        for block in self.blocks:
            patch_embeddings = block(patch_embeddings, training=training)
        return patch_embeddings

StageBlock Lay works by:

1. Receives a sequence of patch embeddings (patch_embeddings) as input

2. Initializes multiple EncoderBlock layers
   
   2.1 Each block consists of: (1) Multi-head self-attention (via GridAttention). (2) Feedforward transformations (MLP). (3) Residual connections and normalization for stable training.

3. Sequentially processes embeddings through multiple encoder blocks

4. Outputs the final refined embeddings which retain hierarchical and contextual information

In [60]:
def pixel_shuffle(input_tensor, scale_factor):
    batch_size, height, width, channels = input_tensor.shape

    if channels % (scale_factor ** 2) != 0:
        raise ValueError("Channels must be divisible by scale_factor^2.")

    new_channels = channels // (scale_factor ** 2)
    reshaped = tf.reshape(input_tensor, (batch_size, height, width, scale_factor, scale_factor, new_channels))
    transposed = tf.transpose(reshaped, [0, 1, 3, 2, 4, 5])
    output = tf.reshape(transposed, (batch_size, height * scale_factor, width * scale_factor, new_channels))

    return output

**Resampling Layer**

The Resampling layer is designed to adjust the spatial resolution of feature maps in TransGAN. It can perform both upsampling and downsampling, depending on the scale_factor. By dynamically adjusting feature map resolution, this layer allows the generator to process multi-scale representations effectively, ensuring smooth feature transitions during upsampling and downsampling.

If scale_factor > 1, it upsamples the feature map using Pixel Shuffle, which redistributes channel data into spatial dimensions.

If scale_factor < 1, it downsamples the feature map using Average Pooling, reducing spatial resolution while preserving information.

This layer ensures that the generator maintains spatial consistency as features are refined and upsampled to higher resolutions.

In [61]:
class Resampling(layers.Layer):
    def __init__(self, scale_factor):
        super(Resampling, self).__init__()
        self.scale_factor = scale_factor
        self.is_upsampling = scale_factor > 1

        if not self.is_upsampling:
            self.resampling = layers.AveragePooling2D(pool_size=int(1 / scale_factor))

    def call(self, embeddings, size=None):
        batch_size = tf.shape(embeddings)[0]
        embedding_dim = embeddings.shape[-1]

        if isinstance(size, tf.Tensor):
            size = tf.get_static_value(size)

            if size is None:
                size = tf.cast(tf.sqrt(tf.cast(tf.shape(embeddings)[1], tf.float32)), tf.int32)

        if self.is_upsampling:
            tf.debugging.assert_equal(
                embedding_dim % (self.scale_factor ** 2),
                0,
                message=f"Embedding dim {embedding_dim} must be divisible by scale_factor^2 ({self.scale_factor ** 2})!"
            )

            reduced_embedding_dim = embedding_dim // (self.scale_factor ** 2)
        else:
            reduced_embedding_dim = embedding_dim * (self.scale_factor ** 2)

        feature_maps = tf.reshape(embeddings, (batch_size, size, size, embedding_dim))

        if self.is_upsampling:
            resampled = pixel_shuffle(feature_maps, self.scale_factor)
            new_size = size * self.scale_factor
        else:
            resampled = self.resampling(feature_maps)
            new_size = size // int(1 / self.scale_factor)

        n_tokens_new = new_size * new_size

        output = tf.reshape(resampled, (batch_size, n_tokens_new, -1))

        return output, new_size

Resampling Lay works by:

1. Receives a set of patch embeddings and a scale_factor as input and Determines whether to upsample or downsample
   
   1.1 If scale_factor > 1, $Pixel Shuffle$ function is used to increase resolution which can converts the channel dimension into spatial resolution. Then the output resolution increases from ($H$, $W$) to ($H$ * $scale\_ factor$, $W$ * $scale\_ factor$)
   
   1.2 If scale_factor < 1, $Average Pooling$ is used to reduce resolution which can reduce resolution from ($H$, $W$) to ($H$ / $scale_factor$, $W$ / $scale_factor$). Besides, layers.AveragePooling2D will be used to preserve feature integrity while reducing size

2. Reshapes the output to maintain token structure.
   
   2.1 The spatially adjusted feature map is reshaped back into a structured token sequence.

   2.2 The number of tokens updates based on the new spatial size.

**Generator**

The Generator model is designed as a transformer-based generator for image synthesis, following the structure of TransGAN. Unlike traditional CNN-based GANs, this generator leverages self-attention mechanisms instead of convolutional layers to progressively refine feature maps. By leveraging transformers, progressive upsampling, and spatial self-attention, this generator enables high-quality image synthesis while preserving long-range dependencies and global coherence.

The key components of this generator include:

1. Embedding Projection: Maps a random latent vector to an initial spatial feature map.

2. Virtual Batch Normalization (VBN): Stabilizes feature distributions across batches.

3. Positional Encoding: Provides spatial awareness to the self-attention mechanism.

3. Transformer Encoder Blocks: Captures global dependencies and refines feature representations.

4. Resampling Layer: Upsamples feature maps progressively to increase resolution.

5. Output Projection: Uses a Conv2D layer to generate the final image, applying a tanh activation to normalize pixel values.

This generator takes a random latent vector as input and outputs a synthetic image with values in the range [-1, 1].

In [62]:
class Generator(Model):
    def __init__(self, config):
        super(Generator, self).__init__()

        self.base_size = config.base_size
        self.embed_dim = 1024
        self.output_size = config.image_size
        self.n_colors = config.channels
        self.z_dim = config.z_dim
        self.use_spectral_norm = config.use_spectral_norm

        self.n_upsamples = int(np.log2(self.output_size // self.base_size))
        depths = [4, 3, 2]
        if len(depths) < self.n_upsamples:
            depths = list(depths) + [2] * (self.n_upsamples - len(depths))

        self.embedding_projection = EmbeddingProjection(
            self.z_dim,
            self.embed_dim,
            self.base_size
        )

        self.vbn = VirtualBatchNormalization()
        self.positional_embeddings = []
        for index in range(self.n_upsamples + 1):
            size = self.base_size * (2 ** index)
            dim = self.embed_dim // (4 ** index) if index > 0 else self.embed_dim
            pos_embed = self.add_weight(
                shape=(1, size ** 2, dim),
                initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
                trainable=True,
                name=f'pos_embedding_{index}'
            )
            self.positional_embeddings.append(pos_embed)

        self.upsampling = Resampling(scale_factor=2)
        self.stages = []
        for index in range(self.n_upsamples + 1):
            depth = depths[index] if index < len(depths) else 2
            dim = self.embed_dim // (4 ** index) if index > 0 else self.embed_dim

            stage = StageBlock(
                depth=depth,
                num_heads=4,
                d_embeddings=dim,
                d_ratio=4,
                window_size=self.base_size * (2 ** index),
                noise_enabled=True,
                use_spectral_norm=self.use_spectral_norm
            )
            self.stages.append(stage)

        self.pre_output = layers.Conv2D(
            filters=self.n_colors * 4,
            kernel_size=3,
            padding="same",
            activation="relu",
            data_format="channels_last"
        )

        self.output_projection = layers.Conv2D(
            self.n_colors,
            kernel_size=1,
            activation="tanh",
            data_format="channels_last"
        )

        if self.use_spectral_norm:
            self.pre_output = spectral_norm_wrapper(self.pre_output)
            self.output_projection = spectral_norm_wrapper(self.output_projection)

        self.reference_batch = None
        self.leaky_relu = layers.LeakyReLU(0.2)

    def call(self, latents, training=False, set_reference=False):
        patch_embeddings = self.embedding_projection(latents)
        size = self.base_size

        if set_reference or self.reference_batch is None:
            self.reference_batch = patch_embeddings

        patch_embeddings = self.vbn(patch_embeddings, set_reference=set_reference)

        for index, stage in enumerate(self.stages):
            if index < len(self.positional_embeddings):
                patch_embeddings += self.positional_embeddings[index]

            patch_embeddings = stage(patch_embeddings, training=training)

            if index < self.n_upsamples:
                patch_embeddings, size = self.upsampling(patch_embeddings, size=size)

        batch_size = tf.shape(patch_embeddings)[0]
        feature_maps = tf.reshape(patch_embeddings, (batch_size, size, size, -1))
        feature_maps = self.pre_output(feature_maps)
        feature_maps = self.leaky_relu(feature_maps)

        output = self.output_projection(feature_maps)
        output = tf.clip_by_value(output, -1.0, 1.0)

        return output

In [63]:
class MinibatchDiscrimination(layers.Layer):
    def __init__(self, num_kernels=100, dim_per_kernel=5):
        super(MinibatchDiscrimination, self).__init__()
        self.num_kernels = num_kernels
        self.dim_per_kernel = dim_per_kernel

    def build(self, input_shape):
        self.W = self.add_weight(
            shape=(input_shape[-1], self.num_kernels * self.dim_per_kernel),
            initializer="random_normal",
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.W)
        x = tf.reshape(x, (-1, self.num_kernels, self.dim_per_kernel))

        diff = tf.expand_dims(x, axis=3) - tf.expand_dims(tf.transpose(x, [1, 2, 0]), axis=0)
        abs_diff = tf.reduce_sum(tf.abs(diff), axis=2)
        minibatch_features = tf.reduce_sum(tf.exp(-abs_diff), axis=2)

        return tf.concat([inputs, minibatch_features], axis=1)

**TransformerBlock**

The TransformerBlock is an encoder module inspired by the Transformer architecture, primarily used in Vision Transformers (ViTs), TransGAN, and NLP models. It processes input embeddings through self-attention and feedforward layers, refining feature representations while maintaining global contextual dependencies. By leveraging self-attention, feedforward transformations, and residual connections, the TransformerBlock allows the model to learn complex feature relationships, making it an essential building block for vision transformers, TransGAN, and other deep learning architectures. The TransformerBlock is crucial for enabling the generator or other transformer-based architectures to effectively model long-range dependencies.

In [64]:
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim=512, dropout_rate=0.1, use_spectral_norm=False):
        super(TransformerBlock, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)

        self.ff1 = layers.Dense(ff_dim, activation="gelu")
        self.ff2 = layers.Dense(embed_dim)

        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout2 = layers.Dropout(dropout_rate)

        if use_spectral_norm:
            self.ff1 = spectral_norm_wrapper(self.ff1)
            self.ff2 = spectral_norm_wrapper(self.ff2)

    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)

        ffn_output = self.ff1(out1)
        ffn_output = self.ff2(ffn_output)
        ffn_output = self.dropout2(ffn_output, training=training)

        return self.norm2(out1 + ffn_output)

This block contains:

1. Multi-Head Self-Attention (MHA): Captures relationships between input tokens.

2. Feedforward Network (FFN): Enhances feature representations after attention.

3. Layer Normalization (LayerNorm): Stabilizes training by normalizing activations.

4. Residual Connections: Helps retain original information and improve gradient flow.

5. Dropout Regularization: Prevents overfitting by adding noise during training.

**TransformerDiscriminator Layer**

The TransformerDiscriminator is a transformer-based discriminator designed to distinguish between real and generated images. Unlike traditional CNN-based discriminators, this model processes images using self-attention mechanisms instead of convolutional layers, allowing it to capture global dependencies and long-range interactions. By leveraging self-attention, transformer-based processing, and minibatch discrimination, the TransformerDiscriminator provides a powerful alternative to traditional CNN-based discriminators, making it highly effective in GAN architectures like TransGAN.

This discriminator follows a Vision Transformer (ViT)-like structure, consisting of:

1. Patch Embedding Layer: Converts input images into a sequence of patch tokens.

2. Transformer Encoder Blocks: Processes the patch embeddings using self-attention.

3. Minibatch Discrimination (optional): Helps prevent mode collapse by introducing diversity-sensitive features.

4. Normalization & MLP Classification Head: Outputs a real/fake probability score using a sigmoid activation function.

The model takes an image as input and outputs a single probability indicating whether the input is real or fake.

In [65]:
class TransformerDiscriminator(tf.keras.Model):
    def __init__(self, config):
        super(TransformerDiscriminator, self).__init__()

        self.img_size = config.image_size
        self.patch_size = 4
        self.use_minibatch = config.use_minibatch_discrimination
        self.embed_dim = 512
        self.use_spectral_norm = config.use_spectral_norm

        if self.use_spectral_norm:
            self.patch_embedding = spectral_norm_wrapper(layers.Conv2D(
                self.embed_dim,
                kernel_size=self.patch_size,
                strides=self.patch_size,
                padding="same"
            ))
        else:
            self.patch_embedding = layers.Conv2D(
                self.embed_dim,
                kernel_size=self.patch_size,
                strides=self.patch_size,
                padding="same"
            )

        self.flatten = layers.Reshape((-1, self.embed_dim))

        depth = 3
        self.transformer_blocks = [
            TransformerBlock(
                self.embed_dim,
                num_heads=4,
                dropout_rate=0.1,
                use_spectral_norm=self.use_spectral_norm
            )
            for _ in range(depth)
        ]

        if self.use_minibatch:
            self.minibatch_layer = MinibatchDiscrimination(num_kernels=100, dim_per_kernel=5)
            final_dim = self.embed_dim + 100
        else:
            final_dim = self.embed_dim

        self.norm = layers.LayerNormalization(epsilon=1e-6)

        if self.use_spectral_norm:
            self.mlp_head = spectral_norm_wrapper(layers.Dense(1))
        else:
            self.mlp_head = layers.Dense(1)

    def call(self, inputs, training=False, return_features=False):
        batch_size = tf.shape(inputs)[0]
        inputs = (inputs - 0.5) * 2
        x = self.patch_embedding(inputs)
        x = tf.reshape(x, [batch_size, -1, self.embed_dim])

        for block in self.transformer_blocks:
            x = block(x, training=training)

        features = tf.reduce_mean(x, axis=1)

        if self.use_minibatch:
            features = self.minibatch_layer(features)

        x = self.norm(features)
        output = self.mlp_head(x)

        if return_features:
            return output, features
        else:
            return output

Feature Matching & Historical Averaging

In [66]:
class FeatureMatching:
    def __call__(self, real_features, fake_features):
        real_mean = tf.reduce_mean(real_features, axis=0)
        fake_mean = tf.reduce_mean(fake_features, axis=0)
        return tf.reduce_mean(tf.square(real_mean - fake_mean))

class HistoricalAveraging:
    def __init__(self, beta=0.99):
        self.beta = beta
        self.parameter_history = {}

    def initialize_if_needed(self, model):
        for weight in model.trainable_weights:
            if weight.name not in self.parameter_history:
                self.parameter_history[weight.name] = weight.numpy()

    def __call__(self, model, weight=0.01):
        if not self.parameter_history:
            return tf.constant(0.0)

        total_loss = 0.0
        for curr_weight in model.trainable_weights:
            name = curr_weight.name
            if name not in self.parameter_history:
                continue

            hist_tensor = tf.convert_to_tensor(self.parameter_history[name], dtype=curr_weight.dtype)

            if hist_tensor.shape != curr_weight.shape:
                continue

            diff = curr_weight - hist_tensor
            loss = tf.reduce_sum(tf.square(diff))
            total_loss += loss

        return weight * total_loss

    def update_history(self, model):
        for weight in model.trainable_weights:
            name = weight.name
            if name in self.parameter_history:
                hist_tensor = tf.convert_to_tensor(self.parameter_history[name], dtype=weight.dtype)
                if hist_tensor.shape == weight.shape:
                    self.parameter_history[name] = (
                        self.beta * hist_tensor.numpy() +
                        (1 - self.beta) * weight.numpy()
                    )

WGAN-GP loss

In [67]:
class WGANGP:
    def __init__(self, discriminator, lambda_gp=10.0):
        self.discriminator = discriminator
        self.lambda_gp = lambda_gp

    def gradient_penalty(self, real_images, fake_images):
        batch_size = tf.shape(real_images)[0]

        if real_images.shape[1:] != fake_images.shape[1:]:
            real_images = tf.image.resize(real_images, (fake_images.shape[1], fake_images.shape[2]))

        epsilon = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        interpolated = real_images + epsilon * (fake_images - real_images)

        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            interpolated_output = self.discriminator(interpolated, training=True)

        gradients = tape.gradient(interpolated_output, interpolated)
        gradient_norm = tf.norm(tf.reshape(gradients, [batch_size, -1]), axis=1)
        gradient_penalty = tf.reduce_mean(tf.square(gradient_norm - 1.0))

        return gradient_penalty * self.lambda_gp * 0.1

    def discriminator_loss(self, real_images, fake_images):
        real_output = self.discriminator(real_images, training=True)
        fake_output = self.discriminator(fake_images, training=True)

        wasserstein_distance = tf.reduce_mean(fake_output - real_output) * 0.01
        gp = self.gradient_penalty(real_images, fake_images)

        return wasserstein_distance + gp

    def generator_loss(self, fake_images):
        fake_output = self.discriminator(fake_images, training=True)
        return -tf.reduce_mean(fake_output) * 0.01

**ImprovedTransGAN Model**

The ImprovedTransGAN class implements a Transformer-based Generative Adversarial Network (GAN) with several enhancements for stability and performance. Unlike traditional convolutional GANs, this model leverages self-attention mechanisms to capture long-range dependencies in image generation. By incorporating these techniques, ImprovedTransGAN achieves more stable training, higher-quality image synthesis, and improved generalization compared to traditional GANs.

TransGAN Architecture:

1. GridAttention-Based Generator: Uses transformers(GridAttention Block layers) instead of convolutional layers for better spatial understanding.

2. Transformer-Based Discriminator: Processes image patches as tokens for superior global feature extraction.


This implementation includes three major improvements over the TransGAN:

1. Feature Matching – Improves training stability by aligning real and generated feature distributions.

2. Historical Averaging – Stabilizes parameter updates by maintaining a history of model weights.

3. Virtual Batch Normalization (VBN) – Normalizes activations using a fixed reference batch for improved generalization.

In [68]:
class ImprovedTransGAN:
    def __init__(self, config):
        self.config = config

        self.generator = Generator(config)
        self.discriminator = TransformerDiscriminator(config)

        self.feature_matching = FeatureMatching()
        self.historical_averaging = HistoricalAveraging(beta=0.99)
        self.wgan_gp = WGANGP(self.discriminator, lambda_gp=config.gp_lambda)

        os.makedirs(config.sample_dir, exist_ok=True)
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        os.makedirs(config.log_dir, exist_ok=True)

        self.gen_optimizer = tf.keras.optimizers.Adam(
            learning_rate=config.learning_rate_g,
            beta_1=config.beta1,
            beta_2=config.beta2
        )

        self.disc_optimizer = tf.keras.optimizers.Adam(
            learning_rate=config.learning_rate_d,
            beta_1=config.beta1,
            beta_2=config.beta2
        )

        self.fixed_noise = tf.random.normal([32, config.z_dim])

        self.gen_losses = []
        self.disc_losses = []
        self.real_scores = []
        self.fake_scores = []
        self.real_accs = []
        self.fake_accs = []

        print("Improved TransGAN model initialized!")

    def generator_loss(self, fake_output, real_features=None, fake_features=None):
        if self.config.use_wgan_gp:
            gen_loss = self.wgan_gp.generator_loss(fake_output)
        else:
            target = tf.ones_like(fake_output) * self.config.generator_target_prob
            gen_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(target, fake_output, from_logits=True)
            )

        if self.config.use_feature_matching and real_features is not None and fake_features is not None:
            fm_loss = self.feature_matching(real_features, fake_features)
            gen_loss += self.config.feature_matching_weight * fm_loss

        if self.config.use_historical_averaging:
            ha_loss = self.historical_averaging(self.generator, self.config.historical_averaging_weight)
            gen_loss += ha_loss

        return gen_loss

    def discriminator_loss(self, real_output, fake_output, real_images=None, fake_images=None):
        if self.config.use_wgan_gp:
            return self.wgan_gp.discriminator_loss(real_images, fake_images)
        else:
            real_labels = tf.ones_like(real_output) * (1.0 - self.config.label_smoothing)
            fake_labels = tf.zeros_like(fake_output)

            real_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(real_labels, real_output, from_logits=True)
            )
            fake_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(fake_labels, fake_output, from_logits=True)
            )
            disc_loss = real_loss + fake_loss

            if self.config.use_historical_averaging:
                ha_loss = self.historical_averaging(self.discriminator, self.config.historical_averaging_weight)
                disc_loss += ha_loss

            return disc_loss

    def is_training_stable(self, metrics):
        """Check if training metrics are within reasonable bounds"""
        gen_loss = metrics["gen_loss"].numpy()
        disc_loss = metrics["disc_loss"].numpy()

        if (np.isnan(gen_loss) or np.isnan(disc_loss) or
            abs(gen_loss) > 50 or abs(disc_loss) > 50):
            print(f"Unstable training detected: G={gen_loss:.4f}, D={disc_loss:.4f}")

            if abs(disc_loss) > 50:
                for layer in self.discriminator.layers:
                    if hasattr(layer, 'kernel_initializer') and hasattr(layer, 'kernel'):
                        if 'embedding' not in layer.name and 'norm' not in layer.name:
                            w_init = layer.kernel_initializer
                            layer.kernel.assign(w_init(shape=layer.kernel.shape))
                print("Discriminator weights reset due to instability")

            return False

        return True


    @tf.function
    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        for _ in range(self.config.discriminator_steps):
            with tf.GradientTape() as disc_tape:
                noise = tf.random.normal([batch_size, self.config.z_dim])
                fake_images = self.generator(noise, training=True)

                if self.config.use_wgan_gp:
                    d_loss = self.wgan_gp.discriminator_loss(real_images, fake_images)
                    real_output = self.discriminator(real_images, training=True)
                    fake_output = self.discriminator(fake_images, training=True)
                else:
                    real_output, real_features = self.discriminator(real_images, training=True, return_features=True)
                    fake_output, fake_features = self.discriminator(fake_images, training=True, return_features=True)
                    d_loss = self.discriminator_loss(real_output, fake_output)

            disc_gradients = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
            disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 1.0)
            self.disc_optimizer.apply_gradients(zip(disc_gradients, self.discriminator.trainable_variables))

        for _ in range(self.config.generator_steps):
            with tf.GradientTape() as gen_tape:
                noise = tf.random.normal([batch_size, self.config.z_dim])
                fake_images = self.generator(noise, training=True)

                if self.config.use_wgan_gp:
                    g_loss = self.wgan_gp.generator_loss(fake_images)
                else:
                    fake_output, fake_features = self.discriminator(fake_images, training=True, return_features=True)
                    _, real_features = self.discriminator(real_images, training=False, return_features=True)
                    g_loss = self.generator_loss(fake_output, real_features, fake_features)

            gen_gradients = gen_tape.gradient(g_loss, self.generator.trainable_variables)
            gen_gradients, _ = tf.clip_by_global_norm(gen_gradients, 1.0)
            self.gen_optimizer.apply_gradients(zip(gen_gradients, self.generator.trainable_variables))

        if self.config.use_wgan_gp:
            real_output = self.discriminator(real_images, training=False)
            fake_output = self.discriminator(fake_images, training=False)

        if self.config.use_wgan_gp:
            real_sigmoid = tf.sigmoid(real_output * 0.1)
            fake_sigmoid = tf.sigmoid(fake_output * 0.1)
        else:
            real_sigmoid = tf.sigmoid(real_output)
            fake_sigmoid = tf.sigmoid(fake_output)

        real_acc = tf.reduce_mean(tf.cast(real_sigmoid > 0.5, tf.float32))
        fake_acc = tf.reduce_mean(tf.cast(fake_sigmoid < 0.5, tf.float32))

        real_score = tf.reduce_mean(real_sigmoid)
        fake_score = tf.reduce_mean(fake_sigmoid)

        return {
            "gen_loss": g_loss,
            "disc_loss": d_loss,
            "real_score": real_score,
            "fake_score": fake_score,
            "real_acc": real_acc,
            "fake_acc": fake_acc
        }

    def train(self, dataset, epochs):
        start_time = time.time()

        if self.config.use_historical_averaging:
            self.historical_averaging.initialize_if_needed(self.generator)
            self.historical_averaging.initialize_if_needed(self.discriminator)

        for epoch in range(epochs):
            epoch_start = time.time()

            progress_bar = tqdm(total=len(list(dataset)))
            progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")

            epoch_gen_losses = []
            epoch_disc_losses = []
            epoch_real_scores = []
            epoch_fake_scores = []
            epoch_real_accs = []
            epoch_fake_accs = []

        for batch in dataset:
            for attempt in range(3):
                try:
                    metrics = self.train_step(batch)

                    if not self.is_training_stable(metrics):
                        if attempt < 2:
                            print(f"Attempt {attempt+1} failed, reducing learning rates...")
                            self.disc_optimizer.learning_rate.assign(self.disc_optimizer.learning_rate * 0.1)
                            self.gen_optimizer.learning_rate.assign(self.gen_optimizer.learning_rate * 0.5)
                            print(f"New rates: G={self.gen_optimizer.learning_rate.numpy():.6f}, D={self.disc_optimizer.learning_rate.numpy():.6f}")
                            continue
                        else:
                            print("Skipping this batch after 3 failed attempts")
                            break

                    try:
                        epoch_gen_losses.append(float(metrics["gen_loss"]))
                        epoch_disc_losses.append(float(metrics["disc_loss"]))
                        epoch_real_scores.append(float(metrics["real_score"]))
                        epoch_fake_scores.append(float(metrics["fake_score"]))
                        epoch_real_accs.append(float(metrics["real_acc"]))
                        epoch_fake_accs.append(float(metrics["fake_acc"]))
                        progress_bar.update(1)
                        desc = f"Epoch {epoch+1}/{epochs} - "
                        desc += f"G: {float(metrics['gen_loss']):.4f}, D: {float(metrics['disc_loss']):.4f}, "
                        desc += f"D(x): {float(metrics['real_score']):.4f}, D(G(z)): {float(metrics['fake_score']):.4f}"
                        progress_bar.set_description(desc)

                        break

                    except (ValueError, TypeError) as e:
                        print(f"Error processing metrics: {e}")
                        continue

                except tf.errors.ResourceExhaustedError as e:
                    print(f"Out of memory error: {e}")
                    tf.keras.backend.clear_session()
                    continue
                except Exception as e:
                    print(f"Unexpected error during training: {e}")
                    continue

            avg_gen_loss = np.clip(np.mean(epoch_gen_losses), -1000, 1000)
            avg_disc_loss = np.clip(np.mean(epoch_disc_losses), -1000, 1000)
            avg_real_score = np.clip(np.mean(epoch_real_scores), 0, 1)
            avg_fake_score = np.clip(np.mean(epoch_fake_scores), 0, 1)
            avg_real_acc = np.clip(np.mean(epoch_real_accs), 0, 1)
            avg_fake_acc = np.clip(np.mean(epoch_fake_accs), 0, 1)

            self.gen_losses.append(avg_gen_loss)
            self.disc_losses.append(avg_disc_loss)
            self.real_scores.append(avg_real_score)
            self.fake_scores.append(avg_fake_score)
            self.real_accs.append(avg_real_acc)
            self.fake_accs.append(avg_fake_acc)

            if (epoch + 1) % self.config.sample_freq == 0:
                self.generate_and_save_images(epoch + 1)

            if (epoch + 1) % self.config.save_freq == 0:
                self.save_checkpoint(epoch + 1)

            epoch_time = time.time() - epoch_start
            print(f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
            print(f"Generator Loss: {avg_gen_loss:.4f}")
            print(f"Discriminator Loss: {avg_disc_loss:.4f}")
            print(f"D(x): {avg_real_score:.4f}, D(G(z)): {avg_fake_score:.4f}")
            print(f"Real Acc: {avg_real_acc:.4f}, Fake Acc: {avg_fake_acc:.4f}")

            if avg_fake_acc < 0.1 or avg_real_acc > 0.95:
                print("Warning: Potential discriminator overpowering detected.")
                print("Adjusting learning rates...")
                self.disc_optimizer.learning_rate.assign(self.disc_optimizer.learning_rate * 0.8)
                self.gen_optimizer.learning_rate.assign(self.gen_optimizer.learning_rate * 1.2)
                print(f"New learning rates - G: {self.gen_optimizer.learning_rate.numpy():.6f}, D: {self.disc_optimizer.learning_rate.numpy():.6f}")

            if avg_fake_acc > 0.9 or avg_real_acc < 0.1:
                print("Warning: Potential generator overpowering detected.")
                self.disc_optimizer.learning_rate.assign(self.disc_optimizer.learning_rate * 1.2)
                self.gen_optimizer.learning_rate.assign(self.gen_optimizer.learning_rate * 0.8)
                print(f"New learning rates - G: {self.gen_optimizer.learning_rate.numpy():.6f}, D: {self.disc_optimizer.learning_rate.numpy():.6f}")

            print("-" * 80)

        total_time = time.time() - start_time
        print(f"Training completed in {total_time/60:.2f} minutes")

        self.plot_training_history()

    def generate_and_save_images(self, epoch):
        predictions = self.generator(self.fixed_noise, training=False)

        fig, axes = plt.subplots(4, 4, figsize=(8, 8))

        for i in range(16):
            plt.subplot(4, 4, i+1)
            img = (predictions[i].numpy() + 1) / 2.0
            img = np.clip(img, 0, 1)
            plt.imshow(img)
            plt.axis("off")

        plt.suptitle(f"Generated Images - Epoch {epoch}")
        plt.tight_layout()
        plt.savefig(f"{self.config.sample_dir}/epoch_{epoch}.png")
        plt.close()

    def plot_training_history(self):
        epochs = range(1, len(self.gen_losses) + 1)

        plt.figure(figsize=(15, 10))
        plt.subplot(2, 2, 1)
        if max(max(self.gen_losses), max(self.disc_losses)) > 1000:
            plt.semilogy(epochs, self.gen_losses, label="Generator Loss")
            plt.semilogy(epochs, self.disc_losses, label="Discriminator Loss")
        else:
            plt.plot(epochs, self.gen_losses, label="Generator Loss")
            plt.plot(epochs, self.disc_losses, label="Discriminator Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Generator and Discriminator Loss")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 2)
        plt.plot(epochs, self.real_scores, label="D(x) - Real")
        plt.plot(epochs, self.fake_scores, label="D(G(z)) - Fake")
        plt.xlabel("Epochs")
        plt.ylabel("Score")
        plt.title("Discriminator Scores")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 3)
        plt.plot(epochs, self.real_accs, label='Real Accuracy')
        plt.plot(epochs, self.fake_accs, label='Fake Accuracy')
        plt.title('Discriminator Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 4)
        predictions = self.generator(self.fixed_noise, training=False)

        grid_size = 4
        image_size = predictions.shape[1]
        composite_image = np.zeros((grid_size * image_size, grid_size * image_size, 3))

        for i in range(grid_size):
            for j in range(grid_size):
                idx = i * grid_size + j
                if idx < predictions.shape[0]:
                    img = (predictions[idx].numpy() + 1) / 2.0
                    img = np.clip(img, 0, 1)
                    composite_image[i * image_size:(i + 1) * image_size, j * image_size:(j + 1) * image_size, :] = img

        plt.imshow(composite_image)
        plt.axis('off')
        plt.title('Generated Images')

        plt.suptitle('GAN Training Progress', fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(f"{self.config.log_dir}/training_history.png")
        plt.close()

    def save_checkpoint(self, epoch):
        checkpoint_dir = os.path.join(self.config.checkpoint_dir, f"checkpoint_epoch_{epoch}")
        os.makedirs(checkpoint_dir, exist_ok=True)

        self.generator.save_weights(os.path.join(checkpoint_dir, "generator.weights.h5"))
        self.discriminator.save_weights(os.path.join(checkpoint_dir, "discriminator.weights.h5"))

        np.save(os.path.join(checkpoint_dir, "gen_opt_weights.npy"),
                [var.numpy() for var in self.gen_optimizer.variables()])
        np.save(os.path.join(checkpoint_dir, "disc_opt_weights.npy"),
                [var.numpy() for var in self.disc_optimizer.variables()])

    def load_checkpoint(self, epoch):
        checkpoint_dir = os.path.join(self.config.checkpoint_dir, f"checkpoint_epoch_{epoch}")

        self.generator.load_weights(os.path.join(checkpoint_dir, "generator.weights.h5"))
        self.discriminator.load_weights(os.path.join(checkpoint_dir, "discriminator.weights.h5"))

        gen_opt_path = os.path.join(checkpoint_dir, "gen_opt_weights.npy")
        disc_opt_path = os.path.join(checkpoint_dir, "disc_opt_weights.npy")

        if os.path.exists(gen_opt_path) and os.path.exists(disc_opt_path):
            gen_opt_weights = np.load(gen_opt_path, allow_pickle=True)
            disc_opt_weights = np.load(disc_opt_path, allow_pickle=True)

            dummy_g_grads = [tf.zeros_like(var) for var in self.generator.trainable_variables]
            dummy_d_grads = [tf.zeros_like(var) for var in self.discriminator.trainable_variables]

            self.gen_optimizer.apply_gradients(zip(dummy_g_grads, self.generator.trainable_variables))
            self.disc_optimizer.apply_gradients(zip(dummy_d_grads, self.discriminator.trainable_variables))

            for i, var in enumerate(self.gen_optimizer.variables()):
                if i < len(gen_opt_weights):
                    var.assign(gen_opt_weights[i])

            for i, var in enumerate(self.disc_optimizer.variables()):
                if i < len(disc_opt_weights):
                    var.assign(disc_opt_weights[i])

        print(f"Checkpoint loaded from epoch {epoch}")

The ImprovedTransGAN Model Runs like these:

🔹 Step 1: Model Initialization

Sets up the generator, discriminator, optimizers, additional stabilization techniques. And loads advanced training features.

1.6 Creates directories for model checkpoints and sample images

1.7 Generates fixed noise samples for evaluation during training.

🔹 Step 2: Training Loop Execution

1. Processes real images in mini-batches.

2. Performs a discriminator update: (1) Uses multi-head self-attention to analyze image patches. (2)Includes Minibatch Discrimination (optional) to prevent mode collapse.

3. Performs multiple generator updates (generator_steps): (1)Uses self-attention instead of CNN layers. (2)Projects latent noise into a structured spatial feature map. (3)Incorporates Virtual Batch Normalization (VBN) for training stability.

4. Configures advanced training techniques: (1)Feature Matching: Ensures feature distributions of real and fake images are similar. (2)Historical Averaging: Uses past model weights to stabilize learning.

5. Sets up optimizers (Adam) for both networks: (1)Separate learning rates for generator and discriminator. (2)Uses momentum parameters (beta1, beta2) for stable updates.

6. Logs training performance.

7. Generates sample images every sample_freq epochs.

8. Saves checkpoints every save_freq epochs.

🔹 Step 3: Evaluating Model Performance

1. Plots training loss curves and discriminator confidence scores.

2. Generates synthetic images to monitor improvements.

**FIDEvaluator**

The FIDEvaluator class is designed to evaluate the performance of a Generative Adversarial Network (GAN) by calculating the Fréchet Inception Distance (FID). The FID score is a widely used metric to measure the similarity between real and generated images, helping assess the quality of synthetic data. By leveraging InceptionV3 to extract meaningful image features, it provides a robust and reliable measure of image realism. This metric is fundamental in improving and benchmarking GAN-based image generation models.

The key idea behind FID is to compare the distributions of real and generated images in a high-level feature space extracted using a pre-trained InceptionV3 model.

In [69]:
class FIDEvaluator:
    def __init__(self, config):
        self.config = config
        self.batch_size = min(config.batch_size, 16)

        try:
            self.inception_model = tf.keras.applications.InceptionV3(
                include_top=False,
                pooling='avg',
                weights='imagenet',
                input_shape=(299, 299, 3)
            )
        except (ImportError, ValueError) as e:
            print(f"Error loading InceptionV3: {e}")
            print("Using a simplified feature extractor instead.")
            self.inception_model = self.create_simplified_feature_extractor()

        self.real_features = None
        self.real_mean = None
        self.real_cov = None

        print("FID Evaluator initialized")

    def create_simplified_feature_extractor(self):
        """Create a simplified feature extractor if InceptionV3 can't be loaded"""
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(299, 299, 3)),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D()
        ])
        return model

    def preprocess_images(self, images):
        """Preprocess images for feature extraction"""
        images = (images + 1) / 2.0
        images = tf.clip_by_value(images, 0.0, 1.0)
        images = tf.image.resize(images, (299, 299))
        if isinstance(self.inception_model, tf.keras.applications.InceptionV3):
            images = tf.keras.applications.inception_v3.preprocess_input(images * 255.0)
        return images

    def extract_features(self, images):
        """Extract features using feature extractor model"""
        try:
            features = self.inception_model.predict(images, batch_size=self.batch_size, verbose=0)
            return features
        except Exception as e:
            print(f"Error in feature extraction: {e}")
            return np.zeros((images.shape[0], self.inception_model.output_shape[-1]))

    def compute_real_statistics(self, dataset, num_samples=1000):
        """Compute statistics for real images"""
        print(f"Computing real data statistics using {num_samples} samples...")

        real_images = []
        for batch in dataset:
            for img in batch:
                real_images.append(img.numpy())
                if len(real_images) >= num_samples:
                    break
            if len(real_images) >= num_samples:
                break

        real_images = np.array(real_images[:num_samples])
        processed_images = self.preprocess_images(real_images)

        features = []
        batch_size = self.batch_size
        for i in range(0, len(processed_images), batch_size):
            batch = processed_images[i:i+batch_size]
            try:
                batch_features = self.extract_features(batch)
                features.append(batch_features)
            except Exception as e:
                print(f"Error processing batch {i//batch_size}: {e}")
                features.append(np.zeros((len(batch), self.inception_model.output_shape[-1])))

        self.real_features = np.concatenate(features, axis=0)
        self.real_mean = np.mean(self.real_features, axis=0)
        self.real_cov = np.cov(self.real_features, rowvar=False)

        print(f"Real data statistics calculated from {len(self.real_features)} images")

        return self.real_mean, self.real_cov

    def calculate_fid(self, generator, num_samples=1000):
        """Calculate FID score for generated images with improved error tracing"""
        if self.real_mean is None or self.real_cov is None:
            print("ERROR: Real statistics not computed. Running compute_real_statistics first.")
            self.compute_real_statistics(dataset, num_samples=500)
            if self.real_mean is None:
                return float('nan')

        print(f"Calculating FID score using {num_samples} generated samples...")

        if num_samples > 500:
            print("Reducing sample size to 500 for memory efficiency")
            num_samples = 500

        z_dim = self.config.z_dim
        batch_size = 8

        fake_features = []
        for i in range(0, num_samples, batch_size):
            current_batch_size = min(batch_size, num_samples - i)
            try:
                z = tf.random.normal([current_batch_size, z_dim])
                with tf.device('/cpu:0'):
                    generated_batch = generator(z, training=False)
                    processed_images = self.preprocess_images(generated_batch)
                    batch_features = self.extract_features(processed_images)
                    fake_features.append(batch_features)
                    print(f"Processed batch {i//batch_size + 1}/{(num_samples + batch_size - 1)//batch_size}")
            except Exception as e:
                print(f"Error in batch {i//batch_size + 1}: {e}")

        if not fake_features:
            print("ERROR: No valid features extracted")
            return float('nan')

        fake_features = np.concatenate(fake_features, axis=0)
        fake_mean = np.mean(fake_features, axis=0)
        fake_cov = np.cov(fake_features, rowvar=False)

        if np.isnan(fake_mean).any() or np.isnan(fake_cov).any():
            print("ERROR: NaN values in fake statistics")
            return float('nan')

        try:
            mean_diff_squared = np.sum((self.real_mean - fake_mean) ** 2)
            print(f"Mean difference squared: {mean_diff_squared}")

            eps = 1e-6
            real_cov_reg = self.real_cov + np.eye(self.real_cov.shape[0]) * eps
            fake_cov_reg = fake_cov + np.eye(fake_cov.shape[0]) * eps
            covmean = linalg.sqrtm(real_cov_reg.dot(fake_cov_reg))
            if np.iscomplexobj(covmean):
                covmean = covmean.real

            trace_term = np.trace(real_cov_reg + fake_cov_reg - 2 * covmean)
            print(f"Trace term: {trace_term}")

            fid = mean_diff_squared + trace_term
            print(f"Raw FID score: {fid}")

            if not np.isfinite(fid):
                print("ERROR: FID is not finite")
                return float('nan')

            return fid
        except Exception as e:
            print(f"ERROR in FID calculation: {e}")
            traceback.print_exc()
            return float('nan')

In [70]:
def run_transgan_experiment(config, dataset, epochs=10, evaluate_every=5):
    """Run a TransGAN experiment with FID evaluation and error handling"""
    print("=== Starting TransGAN Experiment ===")
    print(f"Configuration: {config}")
    print(f"Training for {epochs} epochs, evaluating every {evaluate_every} epochs")

    fid_evaluator = FIDEvaluator(config)

    try:
        fid_evaluator.compute_real_statistics(dataset, num_samples=500)
    except Exception as e:
        print(f"Error computing real statistics: {e}")
        print("Continuing without FID evaluation.")
        fid_evaluator = None

    gan = ImprovedTransGAN(config)
    fid_scores = []

    os.makedirs(config.log_dir, exist_ok=True)

    for epoch in range(epochs):
        print(f"\n=== Epoch {epoch+1}/{epochs} ===")

        try:
            gan.train(dataset, 1)
        except Exception as e:
            print(f"Error during training epoch {epoch+1}: {e}")
            print("Trying to recover and continue training...")
            gan.save_checkpoint(epoch + 1)
            continue

        if fid_evaluator is not None and ((epoch + 1) % evaluate_every == 0 or epoch == epochs - 1):
            print(f"Evaluating FID after epoch {epoch+1}")
            try:
                fid = fid_evaluator.calculate_fid(gan.generator, num_samples=500)
                fid_scores.append((epoch + 1, fid))
            except Exception as e:
                print(f"Error calculating FID: {e}")

    print("\n=== Final Evaluation ===")
    if fid_evaluator is not None:
        try:
            final_fid = fid_evaluator.calculate_fid(gan.generator, num_samples=500)
        except Exception as e:
            print(f"Error calculating final FID: {e}")
            final_fid = 999.99
    else:
        final_fid = 999.99

    if len(fid_scores) > 0:
        plt.figure(figsize=(10, 6))
        epochs_, scores = zip(*fid_scores)
        plt.plot(epochs_, scores, 'o-', linewidth=2)
        plt.title('FID Score Over Training (Lower is Better)')
        plt.xlabel('Epochs')
        plt.ylabel('FID Score')
        plt.grid(True, alpha=0.3)
        plt.savefig(f"{config.log_dir}/fid_history.png")
        plt.close()

    final_samples = 36
    noise = tf.random.normal([final_samples, config.z_dim])
    generated_images = gan.generator(noise, training=False)

    plt.figure(figsize=(10, 10))
    for i in range(final_samples):
        plt.subplot(6, 6, i+1)
        img = (generated_images[i].numpy() + 1) / 2.0
        img = np.clip(img, 0, 1)
        plt.imshow(img)
        plt.axis('off')
    plt.suptitle(f'Final Generated Samples (FID: {final_fid:.4f})')
    plt.tight_layout()
    plt.savefig(f"{config.log_dir}/final_samples.png")
    plt.close()

    return gan, fid_scores, final_fid

In [71]:
def compare_transgan_improvements(dataset, base_epochs=5):
    """Compare different TransGAN improvement techniques with enhanced error handling"""
    print("=== TransGAN Improvement Comparison Study ===")

    configs = [
        {
            "name": "baseline",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": False,
                "use_spectral_norm": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "wgan_gp",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": True,
                "use_spectral_norm": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "spectral_norm",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": False,
                "use_spectral_norm": True,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "feature_matching",
            "updates": {
                "use_feature_matching": True,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": False,
                "use_spectral_norm": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "historical_avg",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": True,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": False,
                "use_spectral_norm": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "minibatch_disc",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": True,
                "use_wgan_gp": False,
                "use_spectral_norm": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "label_smoothing",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "use_wgan_gp": False,
                "use_spectral_norm": False,
                "label_smoothing": 0.2
            }
        },
        {
            "name": "all_improvements",
            "updates": {
                "use_feature_matching": True,
                "use_historical_averaging": True,
                "use_minibatch_discrimination": True,
                "use_wgan_gp": True,
                "use_spectral_norm": True,
                "label_smoothing": 0.2
            }
        }
    ]

    base_config = GANConfig()

    try:
        fid_evaluator = FIDEvaluator(base_config)
        fid_evaluator.compute_real_statistics(dataset, num_samples=500)
    except Exception as e:
        print(f"Error initializing FID evaluator: {e}")
        print("Continuing without FID evaluation.")
        fid_evaluator = None

    results = {}

    for config_info in configs:
        name = config_info["name"]
        updates = config_info["updates"]

        print(f"\n\n=== Training Configuration: {name} ===")
        print("Settings:")

        config = GANConfig()
        for k, v in updates.items():
            setattr(config, k, v)

        config.sample_dir = f"samples/{name}"
        config.checkpoint_dir = f"checkpoints/{name}"
        config.log_dir = f"logs/{name}"

        os.makedirs(config.sample_dir, exist_ok=True)
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        os.makedirs(config.log_dir, exist_ok=True)

        try:
            gan = ImprovedTransGAN(config)
            gan.train(dataset, base_epochs)

            if fid_evaluator is not None:
                try:
                    fid_score = fid_evaluator.calculate_fid(gan.generator, num_samples=500)
                except Exception as e:
                    print(f"Error calculating FID for {name}: {e}")
                    fid_score = 999.99
            else:
                fid_score = 999.99

            results[name] = {
                "fid": fid_score,
                "gan": gan,
                "config": config,
                "gen_loss": gan.gen_losses[-1] if gan.gen_losses else None,
                "disc_loss": gan.disc_losses[-1] if gan.disc_losses else None,
                "real_score": gan.real_scores[-1] if gan.real_scores else None,
                "fake_score": gan.fake_scores[-1] if gan.fake_scores else None
            }

            np.savez(
                f"{config.log_dir}/metrics.npz",
                gen_losses=np.array(gan.gen_losses),
                disc_losses=np.array(gan.disc_losses),
                real_scores=np.array(gan.real_scores),
                fake_scores=np.array(gan.fake_scores),
                fid=np.array([fid_score])
            )

        except Exception as e:
            print(f"Error during training of {name} configuration: {e}")
            print("Skipping to next configuration.")
            continue

    visualize_comparison_results(results)

    return results

In [72]:
def visualize_comparison_results(results):
    """Create comparative visualizations of different configurations with error handling"""
    if not results:
        print("No results to visualize.")
        return

    print("Creating comparative visualizations...")

    default_colors = ['gray', 'blue', 'green', 'orange', 'purple', 'red', 'cyan', 'magenta', 'brown', 'pink']
    colors = {name: default_colors[i % len(default_colors)] for i, name in enumerate(results.keys())}

    # --- FID Bar Chart ---
    try:
        plt.figure(figsize=(12, 6))

        names = list(results.keys())
        fid_scores = [results[name]["fid"] for name in names]

        valid_indices = [i for i, score in enumerate(fid_scores) if score < 999]
        if len(valid_indices) > 0:
            sorted_indices = np.argsort([fid_scores[i] for i in valid_indices])
            sorted_names = [names[valid_indices[i]] for i in sorted_indices]
            sorted_fids = [fid_scores[valid_indices[i]] for i in sorted_indices]
            sorted_colors = [colors[name] for name in sorted_names]

            plt.barh(sorted_names, sorted_fids, color=sorted_colors)
            plt.title('FID Score Comparison (Lower is Better)')
            plt.xlabel('FID Score')
            plt.grid(axis='x', alpha=0.3)

            for i, v in enumerate(sorted_fids):
                plt.text(v + 0.5, i, f"{v:.2f}", va='center')

            plt.tight_layout()
            plt.savefig("fid_comparison.png", dpi=150)
        else:
            print("No valid FID scores to visualize.")
        plt.close()
    except Exception as e:
        print(f"Error creating FID chart: {e}")

    # --- Loss Curve Comparison ---
    try:
        plt.figure(figsize=(15, 6))

        plt.subplot(1, 2, 1)
        for name in results:
            if "gan" in results[name] and hasattr(results[name]["gan"], "gen_losses"):
                losses = results[name]["gan"].gen_losses
                losses = np.array([x for x in losses if np.isfinite(x)])
                if len(losses) > 0:
                    plt.plot(range(1, len(losses) + 1), losses, label=name, color=colors[name], linewidth=2)

        plt.title('Generator Loss Comparison')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        for name in results:
            if "gan" in results[name] and hasattr(results[name]["gan"], "disc_losses"):
                losses = results[name]["gan"].disc_losses
                losses = np.array([x for x in losses if np.isfinite(x)])
                if len(losses) > 0:
                    plt.plot(range(1, len(losses) + 1), losses, label=name, color=colors[name], linewidth=2)

        plt.title('Discriminator Loss Comparison')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.suptitle('Training Loss Comparison Across Techniques', fontsize=16)
        plt.tight_layout()
        plt.savefig("loss_comparison.png", dpi=150)
        plt.close()
    except Exception as e:
        print(f"Error creating loss comparison chart: {e}")

    # --- Discriminator Score Comparison ---
    try:
        plt.figure(figsize=(15, 6))

        plt.subplot(1, 2, 1)
        for name in results:
            if "gan" in results[name] and hasattr(results[name]["gan"], "real_scores"):
                scores = results[name]["gan"].real_scores
                scores = np.array([x for x in scores if np.isfinite(x)])
                if len(scores) > 0:
                    plt.plot(range(1, len(scores) + 1), scores, label=name, color=colors[name], linewidth=2)

        plt.title('D(x) - Real Score Comparison')
        plt.xlabel('Epochs')
        plt.ylabel('Score')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        for name in results:
            if "gan" in results[name] and hasattr(results[name]["gan"], "fake_scores"):
                scores = results[name]["gan"].fake_scores
                scores = np.array([x for x in scores if np.isfinite(x)])
                if len(scores) > 0:
                    plt.plot(range(1, len(scores) + 1), scores, label=name, color=colors[name], linewidth=2)

        plt.title('D(G(z)) - Fake Score Comparison')
        plt.xlabel('Epochs')
        plt.ylabel('Score')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.suptitle('Discriminator Score Comparison', fontsize=16)
        plt.tight_layout()
        plt.savefig("score_comparison.png", dpi=150)
        plt.close()
    except Exception as e:
        print(f"Error creating score comparison chart: {e}")

    for name, result in results.items():
        if "gan" not in result:
            continue

        try:
            config = result["config"]
            z_dim = config.z_dim
            fixed_noise = tf.random.normal([16, z_dim])

            generated_images = result["gan"].generator(fixed_noise, training=False)

            plt.figure(figsize=(8, 8))
            plt.suptitle(f"{name.replace('_', ' ').title()} (FID: {result['fid']:.2f})", fontsize=14)

            for j in range(16):
                plt.subplot(4, 4, j+1)
                img = (generated_images[j].numpy() + 1) / 2.0
                img = np.clip(img, 0, 1)
                plt.imshow(img)
                plt.axis('off')

            save_path = os.path.join(config.log_dir, f"samples_{name}.png")
            plt.tight_layout()
            plt.savefig(save_path, dpi=150)
            plt.close()
        except Exception as e:
            print(f"Error generating sample images for {name}: {e}")

    print("\nSummary of Results:")
    print("-" * 100)
    print(f"{'Configuration':<20} {'FID Score':<12} {'Gen Loss':<12} {'Disc Loss':<12} {'D(x)':<12} {'D(G(z))':<12}")
    print("-" * 100)

    valid_results = {k: v for k, v in results.items() if v.get('fid', 999) < 999}
    if valid_results:
        sorted_results = sorted(valid_results.items(), key=lambda x: x[1]["fid"])
    else:
        sorted_results = list(results.items())

    for name, result in sorted_results:
        print(f"{name:<20} ", end="")

        if 'fid' in result and result['fid'] < 999:
            print(f"{result['fid']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('gen_loss') is not None and np.isfinite(result['gen_loss']):
            print(f"{result['gen_loss']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('disc_loss') is not None and np.isfinite(result['disc_loss']):
            print(f"{result['disc_loss']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('real_score') is not None and np.isfinite(result['real_score']):
            print(f"{result['real_score']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('fake_score') is not None and np.isfinite(result['fake_score']):
            print(f"{result['fake_score']:<12.4f}")
        else:
            print(f"{'N/A':<12}")

    print("-" * 100)


In [73]:
def analyze_mode_collapse(results, dataset, num_samples=200):
    """Analyze mode collapse across configurations using feature diversity with error handling"""
    if not results:
        print("No results to analyze for mode collapse.")
        return {}

    print("Analyzing mode collapse across configurations...")

    real_images = []
    try:
        for batch in dataset:
            for img in batch:
                real_images.append(img.numpy())
            if len(real_images) >= num_samples:
                break

        real_images = np.array(real_images[:num_samples])
    except Exception as e:
        print(f"Error collecting real images: {e}")
        return {}

    try:
        inception_model = tf.keras.applications.InceptionV3(
            include_top=False,
            pooling='avg',
            weights='imagenet'
        )
    except:
        print("Using a simplified feature extractor for mode collapse analysis.")
        inception_model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
            tf.keras.layers.GlobalAveragePooling2D()
        ])

    def preprocess_images(images):
        images = (images + 1) / 2.0
        images = tf.clip_by_value(images, 0.0, 1.0)
        images = tf.image.resize(images, (128, 128))
        return images

    try:
        processed_real = preprocess_images(real_images)
        real_features = inception_model.predict(processed_real, batch_size=16, verbose=0)
    except Exception as e:
        print(f"Error extracting real features: {e}")
        return {}

    feature_stats = {}
    for name, result in results.items():
        if "gan" not in result:
            continue

        try:
            config = result["config"]
            z = tf.random.normal([num_samples, config.z_dim])
            fake_images = result["gan"].generator(z, training=False).numpy()

            processed_fake = preprocess_images(fake_images)
            fake_features = inception_model.predict(processed_fake, batch_size=16, verbose=0)

            feature_stats[name] = {
                "features": fake_features,
                "mean": np.mean(fake_features, axis=0),
                "std": np.std(fake_features, axis=0),
                "min": np.min(fake_features, axis=0),
                "max": np.max(fake_features, axis=0)
            }
        except Exception as e:
            print(f"Error calculating feature stats for {name}: {e}")

    if not feature_stats:
        print("No valid feature statistics to analyze.")
        return {}

    feature_diversity = {}
    for name, stats in feature_stats.items():
        avg_std = np.mean(stats["std"])
        avg_range = np.mean(stats["max"] - stats["min"])

        feature_diversity[name] = {
            "avg_std": avg_std,
            "avg_range": avg_range
        }

    real_std = np.mean(np.std(real_features, axis=0))
    real_range = np.mean(np.max(real_features, axis=0) - np.min(real_features, axis=0))

    try:
        plt.figure(figsize=(12, 6))

        names = list(feature_diversity.keys())
        avg_stds = [feature_diversity[name]["avg_std"] for name in names]

        sorted_indices = np.argsort(avg_stds)[::-1]
        sorted_names = [names[i] for i in sorted_indices]
        sorted_stds = [avg_stds[i] for i in sorted_indices]

        default_colors = ['gray', 'blue', 'green', 'orange', 'purple', 'red', 'cyan', 'magenta']
        colors = {name: default_colors[i % len(default_colors)] for i, name in enumerate(sorted_names)}
        sorted_colors = [colors[name] for name in sorted_names]

        plt.bar(sorted_names, sorted_stds, color=sorted_colors)
        plt.axhline(y=real_std, color='black', linestyle='--', label=f'Real Data ({real_std:.4f})')

        plt.title('Feature Diversity Comparison (Higher is Better)')
        plt.ylabel('Average Feature Standard Deviation')
        plt.legend()
        plt.grid(axis='y', alpha=0.3)
        plt.xticks(rotation=45)

        for i, v in enumerate(sorted_stds):
            plt.text(i, v + 0.001, f"{v:.4f}", ha='center')

        plt.tight_layout()
        plt.savefig("feature_diversity.png", dpi=150)
        plt.close()
    except Exception as e:
        print(f"Error creating feature diversity chart: {e}")

    print("\nMode Collapse Analysis:")
    print("-" * 80)
    print(f"{'Configuration':<20} {'Feature Std':<15} {'% of Real':<15} {'Mode Collapse':<15}")
    print("-" * 80)

    sorted_results = sorted(feature_diversity.items(), key=lambda x: x[1]["avg_std"], reverse=True)

    for name, stats in sorted_results:
        avg_std = stats["avg_std"]
        pct_of_real = (avg_std / real_std) * 100

        if pct_of_real >= 90:
            collapse_status = "Minimal"
        elif pct_of_real >= 70:
            collapse_status = "Minor"
        elif pct_of_real >= 50:
            collapse_status = "Moderate"
        elif pct_of_real >= 30:
            collapse_status = "Significant"
        else:
            collapse_status = "Severe"

        print(f"{name:<20} {avg_std:<15.4f} {pct_of_real:<15.2f}% {collapse_status:<15}")

    print("-" * 80)
    print(f"Real Data Reference: {real_std:.4f}")
    print("-" * 80)

    return feature_diversity

def generate_summary_report(results, dataset_name, epochs):
    """Generate a comprehensive summary report of experiment results with error handling"""
    if not results:
        print("No results to generate summary report.")
        return

    print("\n=== Comprehensive Experiment Summary ===")
    print(f"Dataset: {dataset_name}")
    print(f"Training epochs per configuration: {epochs}")
    print("-" * 80)

    valid_results = {k: v for k, v in results.items() if v.get('fid', 999) < 999}

    if valid_results:
        best_config = min(valid_results.items(), key=lambda x: x[1]["fid"])
        worst_config = max(valid_results.items(), key=lambda x: x[1]["fid"])

        print(f"Best configuration: {best_config[0]} (FID: {best_config[1]['fid']:.4f})")
        print(f"Worst configuration: {worst_config[0]} (FID: {worst_config[1]['fid']:.4f})")

        improvement = (worst_config[1]['fid'] - best_config[1]['fid']) / worst_config[1]['fid'] * 100
        print(f"Improvement: {improvement:.2f}%")
    else:
        print("No valid FID scores available for comparison.")

    print("\n=== Impact of Individual Techniques ===")

    if "baseline" in results:
        baseline_fid = results["baseline"].get("fid", 999)

        techniques = [
            "feature_matching",
            "wgan_gp",
            "spectral_norm",
            "minibatch_disc",
            "historical_avg",
            "label_smoothing"
        ]

        technique_improvements = []

        for technique in techniques:
            if technique in results and results[technique].get("fid", 999) < 999:
                technique_fid = results[technique]["fid"]
                improvement = (baseline_fid - technique_fid) / baseline_fid * 100
                technique_improvements.append((technique, improvement))

                print(f"{technique.replace('_', ' ').title()}:")
                print(f"  FID: {technique_fid:.4f}")
                print(f"  Improvement over baseline: {improvement:.2f}%")

        if "all_improvements" in results and results["all_improvements"].get("fid", 999) < 999:
            all_fid = results["all_improvements"]["fid"]
            all_improvement = (baseline_fid - all_fid) / baseline_fid * 100

            print(f"\nAll Improvements Combined:")
            print(f"  FID: {all_fid:.4f}")
            print(f"  Improvement over baseline: {all_improvement:.2f}%")

        print("\n=== Conclusions ===")

        if technique_improvements:
            technique_improvements.sort(key=lambda x: x[1], reverse=True)

            print("Technique effectiveness ranking:")
            for i, (technique, improvement) in enumerate(technique_improvements):
                print(f"  {i+1}. {technique.replace('_', ' ').title()}: {improvement:.2f}% improvement")
    else:
        print("Baseline configuration not available for comparison.")

    print("-" * 80)
    print("Experiment complete!")


In [74]:
config = GANConfig(
    z_dim=64,
    base_size=4,
    image_size=32,
    learning_rate_g=5e-5,
    learning_rate_d=1e-5,
    beta1=0.5,
    beta2=0.999,
    generator_steps=1,
    discriminator_steps=1,
    use_feature_matching=True,
    use_historical_averaging=True,
    use_minibatch_discrimination=True,
    use_wgan_gp=False,
    use_spectral_norm=True,
    feature_matching_weight=0.1,
    historical_averaging_weight=0.01,
    label_smoothing=0.1,
    generator_target_prob=0.9,
    sample_freq=1,
    save_freq=5,
    checkpoint_dir="./checkpoints",
    sample_dir="./samples",
    log_dir="./logs",
    batch_size=32,
    channels=3
)

print("Loading dataset...")
(x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
x_train = x_train[:10000]
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)

def preprocess_image(img):
    img = tf.cast(img, tf.float32)
    img = (img - 127.5) / 127.5
    return img

dataset = train_dataset.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(10000)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
print(f"Dataset prepared with {len(x_train)} images")

os.makedirs(config.sample_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.log_dir, exist_ok=True)

Loading dataset...
Dataset prepared with 10000 images


In [77]:
print("\n=== Starting Comparison Experiment ===")
results = compare_transgan_improvements(dataset, base_epochs=3)
analyze_mode_collapse(results, dataset, num_samples=200)
generate_summary_report(results, "cifar10", 3)


=== Starting Comparison Experiment ===
=== TransGAN Improvement Comparison Study ===
FID Evaluator initialized
Computing real data statistics using 500 samples...
Error initializing FID evaluator: isinstance() arg 2 must be a type, a tuple of types, or a union
Continuing without FID evaluation.


=== Training Configuration: baseline ===
Settings:
Improved TransGAN model initialized!


  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 89.04s
Generator Loss: 0.4451
Discriminator Loss: 1.9545
D(x): 0.8309, D(G(z)): 0.9865
Real Acc: 1.0000, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
--------------------------------------------------------------------------------
Epoch 3/3 completed in 89.79s
Generator Loss: 0.3963
Discriminator Loss: 3.0682
D(x): 0.4709, D(G(z)): 0.9070
Real Acc: 0.5156, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000144, D: 0.000032
--------------------------------------------------------------------------------
Epoch 3/3 completed in 90.55s
Generator Loss: 0.3866
Discriminator Loss: 3.7770
D(x): 0.3914, D(G(z)): 0.9261
Real Acc: 0.3646, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000173, D: 0.000026
--------------------------------------------------------------------------------
Epoch 3/3 completed in 91.30s
Generator Loss: 0.3805
Discriminator Loss: 3.6812
D(x): 0.4093, D(G(z)

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 102.17s
Generator Loss: 0.0006
Discriminator Loss: 0.3658
D(x): 0.4660, D(G(z)): 0.4991
Real Acc: 0.0312, Fake Acc: 0.4688
New learning rates - G: 0.000080, D: 0.000060
--------------------------------------------------------------------------------
Epoch 3/3 completed in 102.95s
Generator Loss: -0.0211
Discriminator Loss: 0.3533
D(x): 0.4796, D(G(z)): 0.5528
Real Acc: 0.2188, Fake Acc: 0.2344
--------------------------------------------------------------------------------
Epoch 3/3 completed in 103.73s
Generator Loss: -0.0385
Discriminator Loss: 0.3436
D(x): 0.4804, D(G(z)): 0.5942
Real Acc: 0.2500, Fake Acc: 0.1562
--------------------------------------------------------------------------------
Epoch 3/3 completed in 104.51s
Generator Loss: -0.0456
Discriminator Loss: 0.3344
D(x): 0.4778, D(G(z)): 0.6112
Real Acc: 0.2656, Fake Acc: 0.1172
--------------------------------------------------------------------------------
Epoch 3/3 completed in 105.30s
Generator Lo

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 84.78s
Generator Loss: 0.3691
Discriminator Loss: 2.5534
D(x): 0.4158, D(G(z)): 0.7914
Real Acc: 0.0938, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
New learning rates - G: 0.000096, D: 0.000048
--------------------------------------------------------------------------------
Epoch 3/3 completed in 85.52s
Generator Loss: 0.3538
Discriminator Loss: 3.5535
D(x): 0.3991, D(G(z)): 0.8183
Real Acc: 0.1406, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000115, D: 0.000038
--------------------------------------------------------------------------------
Epoch 3/3 completed in 86.24s
Generator Loss: 0.3912
Discriminator Loss: 3.5173
D(x): 0.4153, D(G(z)): 0.7706
Real Acc: 0.2188, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000138, D: 0.000031
--------------------------------------------------------------------------------
Epoch 3/3 completed in 86.97s
Generator Loss: 0.3749
D

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 86.86s
Generator Loss: 1.1786
Discriminator Loss: 1.0488
D(x): 0.6514, D(G(z)): 0.8226
Real Acc: 0.9688, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
--------------------------------------------------------------------------------
Epoch 3/3 completed in 87.60s
Generator Loss: 1.1783
Discriminator Loss: 1.2845
D(x): 0.5960, D(G(z)): 0.7649
Real Acc: 0.7656, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000144, D: 0.000032
--------------------------------------------------------------------------------
Epoch 3/3 completed in 88.33s
Generator Loss: 1.0792
Discriminator Loss: 1.5519
D(x): 0.4869, D(G(z)): 0.6953
Real Acc: 0.5417, Fake Acc: 0.0208
Adjusting learning rates...
New learning rates - G: 0.000173, D: 0.000026
--------------------------------------------------------------------------------
Epoch 3/3 completed in 89.06s
Generator Loss: 1.1228
Discriminator Loss: 1.7803
D(x): 0.5111, D(G(z)

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 86.20s
Generator Loss: 0.4914
Discriminator Loss: 3.0877
D(x): 0.0709, D(G(z)): 0.9920
Real Acc: 0.0000, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
New learning rates - G: 0.000096, D: 0.000048
--------------------------------------------------------------------------------
Epoch 3/3 completed in 86.92s
Generator Loss: 0.4350
Discriminator Loss: 3.1419
D(x): 0.4947, D(G(z)): 0.8839
Real Acc: 0.5000, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000115, D: 0.000038
--------------------------------------------------------------------------------
Epoch 3/3 completed in 87.66s
Generator Loss: 0.4028
Discriminator Loss: 3.0338
D(x): 0.6311, D(G(z)): 0.8714
Real Acc: 0.6667, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000138, D: 0.000031
--------------------------------------------------------------------------------
Epoch 3/3 completed in 88.37s
Generator Loss: 0.3898
D

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 82.43s
Generator Loss: 1.0495
Discriminator Loss: 1.5508
D(x): 0.2743, D(G(z)): 0.3275
Real Acc: 0.0312, Fake Acc: 1.0000
New learning rates - G: 0.000080, D: 0.000060
--------------------------------------------------------------------------------
Epoch 3/3 completed in 83.16s
Generator Loss: 0.9279
Discriminator Loss: 1.1924
D(x): 0.5123, D(G(z)): 0.3838
Real Acc: 0.5156, Fake Acc: 0.9219
New learning rates - G: 0.000064, D: 0.000072
--------------------------------------------------------------------------------
Epoch 3/3 completed in 83.87s
Generator Loss: 0.9211
Discriminator Loss: 1.0896
D(x): 0.6463, D(G(z)): 0.3856
Real Acc: 0.6771, Fake Acc: 0.9375
New learning rates - G: 0.000051, D: 0.000086
--------------------------------------------------------------------------------
Epoch 3/3 completed in 84.59s
Generator Loss: 1.0670
Discriminator Loss: 1.0033
D(x): 0.7180, D(G(z)): 0.3377
Real Acc: 0.7578, Fake Acc: 0.9531
New learning rates - G: 0.000041, D: 0.

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 81.69s
Generator Loss: 0.4005
Discriminator Loss: 1.4770
D(x): 0.3364, D(G(z)): 0.9773
Real Acc: 0.0625, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
New learning rates - G: 0.000096, D: 0.000048
--------------------------------------------------------------------------------
Epoch 3/3 completed in 82.41s
Generator Loss: 0.4383
Discriminator Loss: 2.2566
D(x): 0.6018, D(G(z)): 0.8221
Real Acc: 0.5312, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000115, D: 0.000038
--------------------------------------------------------------------------------
Epoch 3/3 completed in 83.14s
Generator Loss: 0.4561
Discriminator Loss: 2.4202
D(x): 0.5883, D(G(z)): 0.7656
Real Acc: 0.5625, Fake Acc: 0.0000
Adjusting learning rates...
New learning rates - G: 0.000138, D: 0.000031
--------------------------------------------------------------------------------
Epoch 3/3 completed in 83.87s
Generator Loss: 0.4541
D

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]

  0%|          | 0/312 [00:00<?, ?it/s]



Epoch 3/3 completed in 91.18s
Generator Loss: -0.0014
Discriminator Loss: 0.1949
D(x): 0.5035, D(G(z)): 0.5037
Real Acc: 0.7812, Fake Acc: 0.2812
--------------------------------------------------------------------------------
Epoch 3/3 completed in 91.93s
Generator Loss: -0.0057
Discriminator Loss: 0.2306
D(x): 0.5069, D(G(z)): 0.5154
Real Acc: 0.8750, Fake Acc: 0.1406
--------------------------------------------------------------------------------
Epoch 3/3 completed in 92.69s
Generator Loss: -0.0113
Discriminator Loss: 0.2097
D(x): 0.5107, D(G(z)): 0.5303
Real Acc: 0.9167, Fake Acc: 0.0938
Adjusting learning rates...
New learning rates - G: 0.000120, D: 0.000040
--------------------------------------------------------------------------------
Epoch 3/3 completed in 93.44s
Generator Loss: -0.0167
Discriminator Loss: 0.1878
D(x): 0.5143, D(G(z)): 0.5443
Real Acc: 0.9375, Fake Acc: 0.0703
Adjusting learning rates...
New learning rates - G: 0.000144, D: 0.000032
-------------------------