**Import Libraries and Setup Environment**

In [1]:
import os
import time
import numpy as np
import tensorflow as tf
import json
import itertools
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from typing import Dict, List, Tuple, Optional, Union, Callable
from scipy import linalg
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.decomposition import PCA
from dataclasses import dataclass
from tensorflow.keras import Model

In [2]:
print(f"TensorFlow version: {tf.__version__}")
gpus = tf.config.list_physical_devices('GPU')
print(f"GPU Available: {len(gpus) > 0}")
if len(gpus) > 0:
    print(f"GPU Details: {gpus}")

TensorFlow version: 2.18.0
GPU Available: True
GPU Details: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
np.random.seed(42)
tf.random.set_seed(42)

In [4]:
# Configuration class for TransGAN
@dataclass
class GANConfig:
    """Configuration for Improved TransGAN training and architecture"""
    z_dim: int =  128                             # Dimension of the random noise vector input to the generator
    base_size: int = 4                           # Initial spatial resolution (height and width) of the feature map
    learning_rate_g: float = 1e-4                # Learning rate for the generator optimizer
    learning_rate_d: float = 1e-4                # Learning rate for the discriminator optimizer
    beta1: float = 0.5                           # First momentum parameter for Adam optimizer
    beta2: float = 0.999                         # Second momentum parameter for Adam optimizer
    generator_steps: int = 1                     # Number of generator updates per iteration
    discriminator_steps: int = 1                 # Number of discriminator updates per iteration
    use_feature_matching: bool = False           # Whether to use feature matching loss
    use_historical_averaging: bool = False       # Whether to use historical averaging of parameters
    use_minibatch_discrimination: bool = False   # Whether to use minibatch discrimination
    feature_matching_weight: float = 1.0         # Weight for the feature matching loss
    historical_averaging_weight: float = 0.1     # Weight for historical averaging penalty
    label_smoothing: float = 0.0                 # Amount of label smoothing for discriminator
    generator_target_prob: float = 0.9           # Target probability for generator labels
    sample_freq: int = 1                         # Frequency of generating sample images during training
    save_freq: int = 10                          # Frequency of saving model checkpoints
    checkpoint_dir: str = "./checkpoints"        # Directory to save model checkpoints
    sample_dir: str = "./samples"                # Directory to save generated samples
    log_dir: str = "./logs"                      # Directory to save training logs
    batch_size: int = 64                         # Number of samples processed in each training step
    image_size: int = 32                         # Output image size
    channels: int = 3                           # Number of channels in output image

    @classmethod
    def from_json(cls, json_path):
        """Load configuration from a JSON file"""
        with open(json_path, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

    def to_json(self, json_path):
        """Save configuration to a JSON file"""
        config_dict = {k: v for k, v in self.__dict__.items()}
        with open(json_path, 'w') as f:
            json.dump(config_dict, f, indent=2)

    def __str__(self):
        """String representation of the configuration"""
        return '\n'.join(f"{k}: {v}" for k, v in self.__dict__.items())

**Implementing the Core TransGAN Architectures**

This technique addresses the issue that traditional batch normalization makes the output of a layer dependent on all other instances in the same batch. Virtual batch normalization normalizes samples against a fixed reference batch for more stable training.

In [6]:
# Virtual Batch Normalization Layer
class VirtualBatchNormalization(layers.Layer):
    """Virtual Batch Normalization Layer"""

    def __init__(self, epsilon=1e-5, **kwargs):
        super(VirtualBatchNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.reference_batch_set = False

    def build(self, input_shape):
        self.ndim = len(input_shape)
        shape = [1] * self.ndim
        shape[-1] = input_shape[-1]

        self.gamma = self.add_weight(
            shape=(input_shape[-1],),
            initializer=tf.random_normal_initializer(1.0, 0.02),
            name='gamma',
            trainable=True
        )
        self.beta = self.add_weight(
            shape=(input_shape[-1],),
            initializer=tf.zeros_initializer(),
            name='beta',
            trainable=False
        )

        self.ref_mean = self.add_weight(
            shape=shape,
            initializer=tf.zeros_initializer(),
            name='ref_mean',
            trainable=False
        )
        self.ref_var = self.add_weight(
            shape=shape,
            initializer=tf.ones_initializer(),
            name='ref_var',
            trainable=False
        )

        super(VirtualBatchNormalization, self).build(input_shape)

    def _get_axis(self):
        return list(range(self.ndim - 1))

    def set_reference_batch(self, x):
        axes = self._get_axis()
        mean = tf.reduce_mean(x, axis=axes, keepdims=True)
        var = tf.reduce_mean(tf.square(x - mean), axis=axes, keepdims=True)

        self.ref_mean.assign(mean)
        self.ref_var.assign(var)
        self.reference_batch_set = True

    def call(self, inputs, set_reference=False, **kwargs):
        if set_reference or not self.reference_batch_set:
            axes = self._get_axis()
            ref_mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
            ref_var = tf.reduce_mean(tf.square(inputs - ref_mean), axis=axes, keepdims=True)

            self.ref_mean.assign(ref_mean)
            self.ref_var.assign(ref_var)
            self.reference_batch_set = True

            batch_mean = ref_mean
            batch_var = ref_var
        else:
            axes = self._get_axis()
            batch_mean = tf.reduce_mean(inputs, axis=axes, keepdims=True)
            batch_var = tf.reduce_mean(tf.square(inputs - batch_mean), axis=axes, keepdims=True)

            batch_mean = 0.5 * (batch_mean + self.ref_mean)
            batch_var = 0.5 * (batch_var + self.ref_var)

        batch_var = tf.maximum(batch_var, self.epsilon)

        x_norm = (inputs - batch_mean) / tf.sqrt(batch_var)

        gamma_reshaped = self.gamma
        beta_reshaped = self.beta

        if self.ndim > 2:
            gamma_reshaped = tf.reshape(self.gamma, [1] * (self.ndim - 1) + [self.gamma.shape[0]])
            beta_reshaped = tf.reshape(self.beta, [1] * (self.ndim - 1) + [self.beta.shape[0]])

        return x_norm * gamma_reshaped + beta_reshaped

**EmbeddingProjection Layer**:

The EmbeddingProjection layer is designed to project random noise (latent vectors) into a structured feature space, transforming a 1D vector into a spatial feature map. This transformation is a crucial step in the generator of TransGAN, where convolutional layers are replaced with transformer-based architectures.

In [7]:
class EmbeddingProjection(tf.keras.Model):
    """Projects a latent vector to a spatial grid of embeddings"""
    def __init__(self, in_features, out_features, base_size):
        super(EmbeddingProjection, self).__init__()
        self.in_features = in_features
        self.base_size = base_size
        self.out_features = out_features

        total_output_size = (self.base_size ** 2) * self.out_features
        self.fc = tf.keras.layers.Dense(
            total_output_size,
            input_shape=(in_features,),
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros'
        )

    def call(self, latents, training=False):
        flattened = self.fc(latents)
        batch_size = tf.shape(latents)[0]
        return tf.reshape(flattened, (batch_size, self.base_size ** 2, self.out_features))

**GridAttention Layer**

The GridAttention layer is a self-attention mechanism designed for TransGAN. It follows the multi-head self-attention (MHSA) mechanism, where the input to this layer is a set of patch embeddings, then transformed into query ($Q$), key ($K$), and value ($V$) tensors, and the output is an updated set of embeddings after applying attention which is computed using dot-product interactions. By applying attention within structured feature maps, this layer enables spatially-aware self-attention, improving information flow and enabling the generator to capture local and global dependencies efficiently. Compared with normal multi-head self-attention, the grid attention mechanism limits the calculations to local regions to reduce the computational cost.

In [8]:
class GridAttention(Model):
    """Grid Attention Layer"""
    def __init__(self, embed_dim, num_heads, window_size, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., noise_enabled=False):
        super(GridAttention, self).__init__()

        self.num_heads = num_heads
        head_dim = embed_dim // num_heads
        self.scale = qk_scale if qk_scale is not None else head_dim ** -0.5
        self.window_size = window_size
        self.noise_enabled = noise_enabled

        self.qkv = layers.Dense(embed_dim * 3, use_bias=qkv_bias)
        self.attn_drop = layers.Dropout(attn_drop)
        self.proj = layers.Dense(embed_dim)
        self.proj_drop = layers.Dropout(proj_drop)

        if self.noise_enabled:
            self.noise_strength = self.add_weight(shape=(), initializer="zeros", trainable=True)

        if self.window_size:
            coords_h = tf.range(window_size)
            coords_w = tf.range(window_size)
            coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing='ij'))  # shape: (2, Wh, Ww)
            coords_flatten = tf.reshape(coords, (2, -1))  # shape: (2, Wh*Ww)
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # shape: (2, Wh*Ww, Wh*Ww)
            relative_coords = tf.transpose(relative_coords, (1, 2, 0))  # shape: (Wh*Ww, Wh*Ww, 2)
            relative_coords = tf.cast(relative_coords, dtype=tf.int32)
            relative_coords = relative_coords + (window_size - 1)
            relative_coords = relative_coords[:, :, 0] * (2 * window_size - 1) + relative_coords[:, :, 1]

            self.relative_position_index = tf.Variable(relative_coords, trainable=False, dtype=tf.int32)

            bias_table_shape = ((2 * window_size - 1) * (2 * window_size - 1), num_heads)
            self.relative_position_bias_table = self.add_weight(shape=bias_table_shape, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), trainable=True)

    def call(self, patch_embeddings, training=False):
        batch_size, n_tokens, embedding_dim = tf.shape(patch_embeddings)[0], tf.shape(patch_embeddings)[1], tf.shape(patch_embeddings)[2]

        if self.noise_enabled and training:
            noise = tf.random.normal((batch_size, n_tokens, 1)) * self.noise_strength
            patch_embeddings = patch_embeddings + noise

        qkv = self.qkv(patch_embeddings)  # (batch_size, n_tokens, 3 * embedding_dim)
        qkv = tf.reshape(qkv, (batch_size, n_tokens, 3, self.num_heads, embedding_dim // self.num_heads))
        qkv = tf.transpose(qkv, (2, 0, 3, 1, 4))  # (3, batch_size, num_heads, n_tokens, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = tf.matmul(q, k, transpose_b=True) * self.scale

        if self.window_size:
            try:
                relative_position_bias_index = tf.reshape(self.relative_position_index, [-1])
                relative_position_bias = tf.gather(self.relative_position_bias_table, relative_position_bias_index)
                relative_position_bias = tf.reshape(relative_position_bias, (self.window_size * self.window_size, self.window_size * self.window_size, -1))
                relative_position_bias = tf.transpose(relative_position_bias, (2, 0, 1))  # (num_heads, Wh*Ww, Wh*Ww)
                attn = attn + tf.expand_dims(relative_position_bias, axis=0)  # Broadcast across batch_size
            except Exception as e:
                print(f"Warning: Error in window attention: {e}")
                # Continue without relative position bias if there's an error

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)

        aggregated = tf.matmul(attn, v)
        aggregated = tf.transpose(aggregated, (0, 2, 1, 3))  # (batch_size, n_tokens, num_heads, head_dim)
        aggregated = tf.reshape(aggregated, (batch_size, n_tokens, embedding_dim))

        aggregated = self.proj(aggregated)
        aggregated = self.proj_drop(aggregated, training=training)

        return aggregated

The GridAttention Layer above works by:

1. Projects a sequence of patch embeddings (input)  into Query ($Q$), Key ($K$), and Value ($V$) tensors using a fully connected layer (Dense) to compute $Q$, $K$, and $V$ simultaneously.

2. Computes attention scores using scaled dot-product attention
   
   2.1 Performs matrix multiplication between $Q$ and $K$ to compute raw attention scores.

   2.2 Incorporates relative positional bias (if window_size is specified).

3. Retrieves learnable positional biases and adds them to the attention scores

4. Applies softmax normalization to obtain attention weights and uses attention weights to aggregate values across all patches.

5. Applies final projection and dropout

**EncoderBlock Layer**

The EncoderBlock is a transformer-based processing unit designed to refine patch embeddings using self-attention and feedforward layers. It operates as a basic building block for deep vision models, like GridAttention. By applying attention-based refinement and feedforward transformations, this layer helps encode spatial and contextual relationships, making it a crucial component for deep vision transformers.

In [9]:
class EncoderBlock(Model):
    """Transformer Encoder Block"""
    def __init__(self, d_model, n_heads, d_feedforward, window_size, dropout_rate=0, activation=None, noise_enabled=False):
        super(EncoderBlock, self).__init__()
        self.attention = GridAttention(d_model, n_heads, window_size, noise_enabled=noise_enabled)
        self.feedforward_dense1 = layers.Dense(d_feedforward)
        self.activation = activation if callable(activation) else tf.nn.gelu
        self.dropout1 = layers.Dropout(dropout_rate)
        self.feedforward_dense2 = layers.Dense(d_model)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, patch_embeddings, training=False):
        attn_out = self.attention(self.norm1(patch_embeddings), training=training)
        patch_embeddings = patch_embeddings + attn_out
        ff_out = self.feedforward_dense1(self.norm2(patch_embeddings))
        ff_out = self.activation(ff_out)
        ff_out = self.dropout1(ff_out, training=training)
        ff_out = self.feedforward_dense2(ff_out)

        patch_embeddings = patch_embeddings + ff_out
        return patch_embeddings

This layer is structured similarly to a Transformer encoder:

1. First, applies self-attention to model relationships between patches.

2. Second, passes the embeddings through a feedforward network (MLP) for additional transformation.

3. Third, incorporates Layer Normalization (LayerNorm) and residual connections, ensuring stable gradient flow and improved convergence.

4. Forth, Optional dropout and activation functions provide regularization and non-linearity, enhancing generalization.

The input is a set of patch embeddings, and the output is an updated set of embeddings that have been refined through attention and feedforward transformations.

**StageBlock Layer**

The StageBlock serves as a higher-level processing unit in a transformer-based architecture. It is composed of multiple EncoderBlock layers, stacked together to progressively refine patch embeddings. Each EncoderBlock in the StageBlock applies self-attention and feedforward transformations, helping the model capture both local and global dependencies. By stacking multiple encoder blocks, StageBlock allows deeper feature extraction and hierarchical representation learning.

The input is a set of patch embeddings, and the output is a more refined set of embeddings after passing through multiple EncoderBlock layers.

In [10]:
class StageBlock(Model):
    """Transformer Stage Block"""
    def __init__(self, depth, num_heads, d_embeddings, d_ratio, window_size, activation=None, noise_enabled=False):
        super(StageBlock, self).__init__()
        self.blocks = [
            EncoderBlock(
                d_model=d_embeddings,
                n_heads=num_heads,
                d_feedforward=d_embeddings * d_ratio,
                window_size=window_size,
                activation=activation,
                noise_enabled=noise_enabled
            )
            for _ in range(depth)
        ]

    def call(self, patch_embeddings, training=False):
        for block in self.blocks:
            patch_embeddings = block(patch_embeddings, training=training)
        return patch_embeddings

StageBlock Lay works by:

1. Receives a sequence of patch embeddings (patch_embeddings) as input

2. Initializes multiple EncoderBlock layers
   
   2.1 Each block consists of: (1) Multi-head self-attention (via GridAttention). (2) Feedforward transformations (MLP). (3) Residual connections and normalization for stable training.

3. Sequentially processes embeddings through multiple encoder blocks

4. Outputs the final refined embeddings which retain hierarchical and contextual information

In [11]:
def pixel_shuffle(input_tensor, scale_factor):
    """Rearranges elements in a tensor from depth to spatial dimensions"""
    batch_size, height, width, channels = input_tensor.shape

    if channels % (scale_factor ** 2) != 0:
        raise ValueError("Channels must be divisible by scale_factor^2.")

    new_channels = channels // (scale_factor ** 2)
    reshaped = tf.reshape(input_tensor, (batch_size, height, width, scale_factor, scale_factor, new_channels))
    transposed = tf.transpose(reshaped, [0, 1, 3, 2, 4, 5])
    output = tf.reshape(transposed, (batch_size, height * scale_factor, width * scale_factor, new_channels))

    return output

**Resampling Layer**

The Resampling layer is designed to adjust the spatial resolution of feature maps in TransGAN. It can perform both upsampling and downsampling, depending on the scale_factor. By dynamically adjusting feature map resolution, this layer allows the generator to process multi-scale representations effectively, ensuring smooth feature transitions during upsampling and downsampling.

If scale_factor > 1, it upsamples the feature map using Pixel Shuffle, which redistributes channel data into spatial dimensions.

If scale_factor < 1, it downsamples the feature map using Average Pooling, reducing spatial resolution while preserving information.

This layer ensures that the generator maintains spatial consistency as features are refined and upsampled to higher resolutions.

In [12]:
class Resampling(layers.Layer):
    """Upsamples or downsamples feature maps"""
    def __init__(self, scale_factor):
        super(Resampling, self).__init__()
        self.scale_factor = scale_factor
        self.is_upsampling = scale_factor > 1

        if not self.is_upsampling:
            self.resampling = layers.AveragePooling2D(pool_size=int(1 / scale_factor))

    def call(self, embeddings, size=None):
        """Upsamples or downsamples based on scale factor"""
        batch_size = tf.shape(embeddings)[0]
        embedding_dim = embeddings.shape[-1]

        if isinstance(size, tf.Tensor):
            size = tf.get_static_value(size)

            if size is None:
                size = tf.cast(tf.sqrt(tf.cast(tf.shape(embeddings)[1], tf.float32)), tf.int32)

        if self.is_upsampling:
            tf.debugging.assert_equal(
                embedding_dim % (self.scale_factor ** 2),
                0,
                message=f"Embedding dim {embedding_dim} must be divisible by scale_factor^2 ({self.scale_factor ** 2})!"
            )

            reduced_embedding_dim = embedding_dim // (self.scale_factor ** 2)
        else:
            reduced_embedding_dim = embedding_dim * (self.scale_factor ** 2)

        feature_maps = tf.reshape(embeddings, (batch_size, size, size, embedding_dim))

        if self.is_upsampling:

            resampled = pixel_shuffle(feature_maps, self.scale_factor)
            new_size = size * self.scale_factor
        else:

            resampled = self.resampling(feature_maps)
            new_size = size // int(1 / self.scale_factor)

        n_tokens_new = new_size * new_size

        output = tf.reshape(resampled, (batch_size, n_tokens_new, -1))

        return output, new_size

Resampling Lay works by:

1. Receives a set of patch embeddings and a scale_factor as input and Determines whether to upsample or downsample
   
   1.1 If scale_factor > 1, $Pixel Shuffle$ function is used to increase resolution which can converts the channel dimension into spatial resolution. Then the output resolution increases from ($H$, $W$) to ($H$ * $scale\_ factor$, $W$ * $scale\_ factor$)
   
   1.2 If scale_factor < 1, $Average Pooling$ is used to reduce resolution which can reduce resolution from ($H$, $W$) to ($H$ / $scale_factor$, $W$ / $scale_factor$). Besides, layers.AveragePooling2D will be used to preserve feature integrity while reducing size

2. Reshapes the output to maintain token structure.
   
   2.1 The spatially adjusted feature map is reshaped back into a structured token sequence.

   2.2 The number of tokens updates based on the new spatial size.

**Generator**

The Generator model is designed as a transformer-based generator for image synthesis, following the structure of TransGAN. Unlike traditional CNN-based GANs, this generator leverages self-attention mechanisms instead of convolutional layers to progressively refine feature maps. By leveraging transformers, progressive upsampling, and spatial self-attention, this generator enables high-quality image synthesis while preserving long-range dependencies and global coherence.

The key components of this generator include:

1. Embedding Projection: Maps a random latent vector to an initial spatial feature map.

2. Virtual Batch Normalization (VBN): Stabilizes feature distributions across batches.

3. Positional Encoding: Provides spatial awareness to the self-attention mechanism.

3. Transformer Encoder Blocks: Captures global dependencies and refines feature representations.

4. Resampling Layer: Upsamples feature maps progressively to increase resolution.

5. Output Projection: Uses a Conv2D layer to generate the final image, applying a tanh activation to normalize pixel values.

This generator takes a random latent vector as input and outputs a synthetic image with values in the range [-1, 1].

In [13]:
class Generator(Model):
    """Transformer-based generator"""
    def __init__(self, config):
        super(Generator, self).__init__()

        self.base_size = config.base_size
        self.embed_dim = 1024
        self.output_size = config.image_size
        self.n_colors = config.channels
        self.z_dim = config.z_dim

        self.n_upsamples = int(np.log2(self.output_size // self.base_size))

        depths = [5, 4, 2]
        if len(depths) < self.n_upsamples:
            depths = list(depths) + [2] * (self.n_upsamples - len(depths))

        self.embedding_projection = EmbeddingProjection(
            self.z_dim,
            self.embed_dim,
            self.base_size
        )

        self.vbn = VirtualBatchNormalization()
        self.positional_embeddings = []
        for index in range(self.n_upsamples + 1):
            size = self.base_size * (2 ** index)
            dim = self.embed_dim // (4 ** index) if index > 0 else self.embed_dim
            pos_embed = self.add_weight(
                shape=(1, size ** 2, dim),
                initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
                trainable=True,
                name=f'pos_embedding_{index}'
            )
            self.positional_embeddings.append(pos_embed)

        self.upsampling = Resampling(scale_factor=2)
        self.stages = []
        for index in range(self.n_upsamples + 1):
            depth = depths[index] if index < len(depths) else 2
            dim = self.embed_dim // (4 ** index) if index > 0 else self.embed_dim

            stage = StageBlock(
                depth=depth,
                num_heads=4,
                d_embeddings=dim,
                d_ratio=4,
                window_size=self.base_size * (2 ** index)
            )
            self.stages.append(stage)

        self.output_projection = layers.Conv2D(
            self.n_colors,
            kernel_size=1,
            activation="tanh",
            data_format="channels_last"
        )

        self.reference_batch = None

    def call(self, latents, training=False, set_reference=False):
        patch_embeddings = self.embedding_projection(latents)
        size = self.base_size

        if set_reference or self.reference_batch is None:
            self.reference_batch = patch_embeddings

        patch_embeddings = self.vbn(patch_embeddings, set_reference=set_reference)

        for index, stage in enumerate(self.stages):
            if index < len(self.positional_embeddings):
                patch_embeddings += self.positional_embeddings[index]

            patch_embeddings = stage(patch_embeddings, training=training)

            if index < self.n_upsamples:
                patch_embeddings, size = self.upsampling(patch_embeddings, size=size)

        batch_size = tf.shape(patch_embeddings)[0]
        feature_maps = tf.reshape(patch_embeddings, (batch_size, size, size, -1))

        output = self.output_projection(feature_maps)
        output = tf.clip_by_value(output, -1.0, 1.0)

        return output

In [14]:
class MinibatchDiscrimination(layers.Layer):
    """Minibatch Discrimination layer to prevent mode collapse"""
    def __init__(self, num_kernels=100, dim_per_kernel=5):
        super(MinibatchDiscrimination, self).__init__()
        self.num_kernels = num_kernels
        self.dim_per_kernel = dim_per_kernel

    def build(self, input_shape):
        self.W = self.add_weight(
            shape=(input_shape[-1], self.num_kernels * self.dim_per_kernel),
            initializer="random_normal",
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.W)
        x = tf.reshape(x, (-1, self.num_kernels, self.dim_per_kernel))

        diff = tf.expand_dims(x, axis=3) - tf.expand_dims(tf.transpose(x, [1, 2, 0]), axis=0)
        abs_diff = tf.reduce_sum(tf.abs(diff), axis=2)
        minibatch_features = tf.reduce_sum(tf.exp(-abs_diff), axis=2)

        return tf.concat([inputs, minibatch_features], axis=1)

**TransformerBlock**

The TransformerBlock is an encoder module inspired by the Transformer architecture, primarily used in Vision Transformers (ViTs), TransGAN, and NLP models. It processes input embeddings through self-attention and feedforward layers, refining feature representations while maintaining global contextual dependencies. By leveraging self-attention, feedforward transformations, and residual connections, the TransformerBlock allows the model to learn complex feature relationships, making it an essential building block for vision transformers, TransGAN, and other deep learning architectures. The TransformerBlock is crucial for enabling the generator or other transformer-based architectures to effectively model long-range dependencies.

In [15]:
class TransformerBlock(layers.Layer):
    """Transformer Encoder Block for Discriminator"""
    def __init__(self, embed_dim, num_heads, ff_dim=512, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.norm1 = layers.LayerNormalization()
        self.dropout1 = layers.Dropout(dropout_rate)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="gelu"),
            layers.Dense(embed_dim)
        ])
        self.norm2 = layers.LayerNormalization()
        self.dropout2 = layers.Dropout(dropout_rate)

    def call(self, inputs, training=False):
        attn_output = self.attention(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

This block contains:

1. Multi-Head Self-Attention (MHA): Captures relationships between input tokens.

2. Feedforward Network (FFN): Enhances feature representations after attention.

3. Layer Normalization (LayerNorm): Stabilizes training by normalizing activations.

4. Residual Connections: Helps retain original information and improve gradient flow.

5. Dropout Regularization: Prevents overfitting by adding noise during training.

**TransformerDiscriminator Layer**

The TransformerDiscriminator is a transformer-based discriminator designed to distinguish between real and generated images. Unlike traditional CNN-based discriminators, this model processes images using self-attention mechanisms instead of convolutional layers, allowing it to capture global dependencies and long-range interactions. By leveraging self-attention, transformer-based processing, and minibatch discrimination, the TransformerDiscriminator provides a powerful alternative to traditional CNN-based discriminators, making it highly effective in GAN architectures like TransGAN.

This discriminator follows a Vision Transformer (ViT)-like structure, consisting of:

1. Patch Embedding Layer: Converts input images into a sequence of patch tokens.

2. Transformer Encoder Blocks: Processes the patch embeddings using self-attention.

3. Minibatch Discrimination (optional): Helps prevent mode collapse by introducing diversity-sensitive features.

4. Normalization & MLP Classification Head: Outputs a real/fake probability score using a sigmoid activation function.

The model takes an image as input and outputs a single probability indicating whether the input is real or fake.

In [16]:
class TransformerDiscriminator(tf.keras.Model):
    """Transformer-based Discriminator"""
    def __init__(self, config):
        super(TransformerDiscriminator, self).__init__()

        self.img_size = config.image_size
        self.patch_size = 4
        self.use_minibatch = config.use_minibatch_discrimination
        self.embed_dim = 512
        self.patch_embedding = layers.Conv2D(
            self.embed_dim,
            kernel_size=self.patch_size,
            strides=self.patch_size,
            padding="same"
        )
        self.flatten = layers.Reshape((-1, self.embed_dim))

        depth = 4
        self.transformer_blocks = [
            TransformerBlock(self.embed_dim, num_heads=4)
            for _ in range(depth)
        ]

        if self.use_minibatch:
            self.minibatch_layer = MinibatchDiscrimination(num_kernels=100, dim_per_kernel=5)
            final_dim = self.embed_dim + 100
        else:
            final_dim = self.embed_dim

        self.norm = layers.LayerNormalization()
        self.mlp_head = layers.Dense(1, activation="sigmoid")

    def call(self, inputs, training=False, return_features=False):
        batch_size = tf.shape(inputs)[0]
        inputs = (inputs - 0.5) * 2
        x = self.patch_embedding(inputs)
        x = tf.reshape(x, [batch_size, -1, self.embed_dim])

        for block in self.transformer_blocks:
            x = block(x, training=training)

        features = tf.reduce_mean(x, axis=1)

        if self.use_minibatch:
            features = self.minibatch_layer(features)

        x = self.norm(features)
        output = self.mlp_head(x)

        if return_features:
            return output, features
        else:
            return output

Feature Matching & Historical Averaging

In [17]:
# Feature Matching
class FeatureMatching:
    def __call__(self, real_features, fake_features):
        real_mean = tf.reduce_mean(real_features, axis=0)
        fake_mean = tf.reduce_mean(fake_features, axis=0)
        return tf.reduce_mean(tf.square(real_mean - fake_mean))

# Historical Averaging
class HistoricalAveraging:
    def __init__(self, beta=0.99):
        self.beta = beta
        self.parameter_history = {}

    def initialize_if_needed(self, model):
        """Initialize parameter history for a model"""
        for weight in model.trainable_weights:
            if weight.name not in self.parameter_history:
                self.parameter_history[weight.name] = weight.numpy()

    def __call__(self, model, weight=0.01):
        """Compute historical averaging loss for regularization"""
        if not self.parameter_history:
            return tf.constant(0.0)

        total_loss = 0.0
        for curr_weight in model.trainable_weights:
            name = curr_weight.name
            if name not in self.parameter_history:
                continue

            hist_tensor = tf.convert_to_tensor(self.parameter_history[name], dtype=curr_weight.dtype)

            if hist_tensor.shape != curr_weight.shape:
                continue

            diff = curr_weight - hist_tensor
            loss = tf.reduce_sum(tf.square(diff))
            total_loss += loss

        return weight * total_loss

    def update_history(self, model):
        """Update parameter history after training steps"""
        for weight in model.trainable_weights:
            name = weight.name
            if name in self.parameter_history:
                hist_tensor = tf.convert_to_tensor(self.parameter_history[name], dtype=weight.dtype)
                if hist_tensor.shape == weight.shape:
                    self.parameter_history[name] = (
                        self.beta * hist_tensor.numpy() +
                        (1 - self.beta) * weight.numpy()
                    )

WGAN-GP loss

In [18]:
class WGANGP:
    def __init__(self, discriminator, lambda_gp=10.0):
        self.discriminator = discriminator
        self.lambda_gp = lambda_gp

    def gradient_penalty(self, real_images, fake_images):
        """Calculate gradient penalty for WGAN-GP"""

        real_images = tf.image.resize(real_images, (fake_images.shape[1], fake_images.shape[2]))

        batch_size = tf.shape(real_images)[0]
        alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        interpolated = alpha * real_images + (1 - alpha) * fake_images

        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            interpolated_output = self.discriminator(interpolated, training=True)

        gradients = tape.gradient(interpolated_output, interpolated)
        gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]) + 1e-8)
        gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)

        return gradient_penalty

    def discriminator_loss(self, real_images, fake_images):
        """WGAN-GP discriminator loss"""
        real_images = tf.image.resize(real_images, (fake_images.shape[1], fake_images.shape[2]))

        real_output = self.discriminator(real_images, training=True)
        fake_output = self.discriminator(fake_images, training=True)
        real_loss = tf.reduce_mean(real_output)
        fake_loss = tf.reduce_mean(fake_output)
        gp = self.gradient_penalty(real_images, fake_images)

        return fake_loss - real_loss + self.lambda_gp * gp

    def generator_loss(self, fake_images):
        """WGAN-GP generator loss"""
        fake_output = self.discriminator(fake_images, training=True)
        return -tf.reduce_mean(fake_output)

**ImprovedTransGAN Model**

The ImprovedTransGAN class implements a Transformer-based Generative Adversarial Network (GAN) with several enhancements for stability and performance. Unlike traditional convolutional GANs, this model leverages self-attention mechanisms to capture long-range dependencies in image generation. By incorporating these techniques, ImprovedTransGAN achieves more stable training, higher-quality image synthesis, and improved generalization compared to traditional GANs.

TransGAN Architecture:

1. GridAttention-Based Generator: Uses transformers(GridAttention Block layers) instead of convolutional layers for better spatial understanding.

2. Transformer-Based Discriminator: Processes image patches as tokens for superior global feature extraction.


This implementation includes three major improvements over the TransGAN:

1. Feature Matching – Improves training stability by aligning real and generated feature distributions.

2. Historical Averaging – Stabilizes parameter updates by maintaining a history of model weights.

3. Virtual Batch Normalization (VBN) – Normalizes activations using a fixed reference batch for improved generalization.

In [19]:
class ImprovedTransGAN:
    """Improved TransGAN Model

    This implementation fixes shape issues and provides improved stability through:
    1. Fixed embedding projection with proper shape management
    2. Error handling in attention mechanisms
    3. Robust historical averaging
    4. Proper parameter initialization
    5. WGAN-GP loss option for improved stability
    """

    def __init__(self, config):
        self.config = config

        self.generator = Generator(config)
        self.discriminator = TransformerDiscriminator(config)

        self.feature_matching = FeatureMatching()
        self.historical_averaging = HistoricalAveraging(beta=0.99)

        os.makedirs(config.sample_dir, exist_ok=True)
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        os.makedirs(config.log_dir, exist_ok=True)

        self.gen_optimizer = tf.keras.optimizers.Adam(
            learning_rate=config.learning_rate_g,
            beta_1=config.beta1,
            beta_2=config.beta2
        )

        self.disc_optimizer = tf.keras.optimizers.Adam(
            learning_rate=config.learning_rate_d,
            beta_1=config.beta1,
            beta_2=config.beta2
        )

        self.fixed_noise = tf.random.normal([32, config.z_dim])

        self.gen_losses = []
        self.disc_losses = []
        self.real_scores = []
        self.fake_scores = []
        self.real_accs = []
        self.fake_accs = []

        print("TransGAN model initialized!")

    def generator_loss(self, fake_output, real_features=None, fake_features=None):
        """Standard generator loss with optional feature matching"""
        target = tf.ones_like(fake_output) * self.config.generator_target_prob
        gen_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(target, fake_output, from_logits=False)
        )

        if self.config.use_feature_matching and real_features is not None and fake_features is not None:
            fm_loss = self.feature_matching(real_features, fake_features)
            gen_loss += self.config.feature_matching_weight * fm_loss

        if self.config.use_historical_averaging:
            ha_loss = self.historical_averaging(self.generator, self.config.historical_averaging_weight)
            gen_loss += ha_loss

        return gen_loss

    def discriminator_loss(self, real_output, fake_output):
        """Standard discriminator loss with label smoothing"""
        real_labels = tf.ones_like(real_output) * (1.0 - self.config.label_smoothing)
        fake_labels = tf.zeros_like(fake_output)

        real_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(real_labels, real_output, from_logits=False)
        )
        fake_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(fake_labels, fake_output, from_logits=False)
        )
        disc_loss = real_loss + fake_loss

        if self.config.use_historical_averaging:
            ha_loss = self.historical_averaging(self.discriminator, self.config.historical_averaging_weight)
            disc_loss += ha_loss

        return disc_loss

    @tf.function
    def train_step(self, real_images):
        """Single training step"""
        batch_size = tf.shape(real_images)[0]

        for _ in range(self.config.discriminator_steps):
            with tf.GradientTape() as disc_tape:

                noise = tf.random.normal([batch_size, self.config.z_dim])
                fake_images = self.generator(noise, training=True)
                real_output, real_features = self.discriminator(real_images, training=True, return_features=True)
                fake_output, fake_features = self.discriminator(fake_images, training=True, return_features=True)
                d_loss = self.discriminator_loss(real_output, fake_output)

            disc_gradients = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
            self.disc_optimizer.apply_gradients(zip(disc_gradients, self.discriminator.trainable_variables))

            if self.config.use_historical_averaging:
                tf.py_function(
                    func=lambda: self.historical_averaging.update_history(self.discriminator),
                    inp=[],
                    Tout=[]
                )

        for _ in range(self.config.generator_steps):
            with tf.GradientTape() as gen_tape:
                noise = tf.random.normal([batch_size, self.config.z_dim])
                fake_images = self.generator(noise, training=True)
                fake_output, fake_features = self.discriminator(fake_images, training=True, return_features=True)
                _, real_features = self.discriminator(real_images, training=False, return_features=True)
                g_loss = self.generator_loss(fake_output, real_features, fake_features)

            gen_gradients = gen_tape.gradient(g_loss, self.generator.trainable_variables)
            self.gen_optimizer.apply_gradients(zip(gen_gradients, self.generator.trainable_variables))

            if self.config.use_historical_averaging:
                tf.py_function(
                    func=lambda: self.historical_averaging.update_history(self.generator),
                    inp=[],
                    Tout=[]
                )

        real_sigmoid = real_output
        fake_sigmoid = fake_output

        real_acc = tf.reduce_mean(tf.cast(real_sigmoid > 0.5, tf.float32))
        fake_acc = tf.reduce_mean(tf.cast(fake_sigmoid < 0.5, tf.float32))

        real_score = tf.reduce_mean(real_sigmoid)
        fake_score = tf.reduce_mean(fake_sigmoid)

        return {
            "gen_loss": g_loss,
            "disc_loss": d_loss,
            "real_score": real_score,
            "fake_score": fake_score,
            "real_acc": real_acc,
            "fake_acc": fake_acc
        }

    def train(self, dataset, epochs):
        """Train the GAN model for multiple epochs"""
        start_time = time.time()

        if self.config.use_historical_averaging:
            self.historical_averaging.initialize_if_needed(self.generator)
            self.historical_averaging.initialize_if_needed(self.discriminator)

        for epoch in range(epochs):
            epoch_start = time.time()

            progress_bar = tqdm(total=len(list(dataset)))
            progress_bar.set_description(f"Epoch {epoch+1}/{epochs}")

            epoch_gen_losses = []
            epoch_disc_losses = []
            epoch_real_scores = []
            epoch_fake_scores = []
            epoch_real_accs = []
            epoch_fake_accs = []

            for batch in dataset:
                metrics = self.train_step(batch)

                if self.config.use_historical_averaging:
                    self.historical_averaging.update_history(self.generator)
                    self.historical_averaging.update_history(self.discriminator)

                epoch_gen_losses.append(metrics["gen_loss"].numpy())
                epoch_disc_losses.append(metrics["disc_loss"].numpy())
                epoch_real_scores.append(metrics["real_score"].numpy())
                epoch_fake_scores.append(metrics["fake_score"].numpy())
                epoch_real_accs.append(metrics["real_acc"].numpy())
                epoch_fake_accs.append(metrics["fake_acc"].numpy())

                progress_bar.update(1)
                desc = f"Epoch {epoch+1}/{epochs} - "
                desc += f"G: {metrics['gen_loss']:.4f}, D: {metrics['disc_loss']:.4f}, "
                desc += f"D(x): {metrics['real_score']:.4f}, D(G(z)): {metrics['fake_score']:.4f}"
                progress_bar.set_description(desc)

            progress_bar.close()

            avg_gen_loss = np.mean(epoch_gen_losses)
            avg_disc_loss = np.mean(epoch_disc_losses)
            avg_real_score = np.mean(epoch_real_scores)
            avg_fake_score = np.mean(epoch_fake_scores)
            avg_real_acc = np.mean(epoch_real_accs)
            avg_fake_acc = np.mean(epoch_fake_accs)

            self.gen_losses.append(avg_gen_loss)
            self.disc_losses.append(avg_disc_loss)
            self.real_scores.append(avg_real_score)
            self.fake_scores.append(avg_fake_score)
            self.real_accs.append(avg_real_acc)
            self.fake_accs.append(avg_fake_acc)

            if (epoch + 1) % self.config.sample_freq == 0:
                self.generate_and_save_images(epoch + 1)

            if (epoch + 1) % self.config.save_freq == 0:
                self.save_checkpoint(epoch + 1)

            epoch_time = time.time() - epoch_start
            print(f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
            print(f"Generator Loss: {avg_gen_loss:.4f}")
            print(f"Discriminator Loss: {avg_disc_loss:.4f}")
            print(f"D(x): {avg_real_score:.4f}, D(G(z)): {avg_fake_score:.4f}")
            print(f"Real Acc: {avg_real_acc:.4f}, Fake Acc: {avg_fake_acc:.4f}")
            print("-" * 80)

        total_time = time.time() - start_time
        print(f"Training completed in {total_time/60:.2f} minutes")

        self.plot_training_history()

    def generate_and_save_images(self, epoch):
        """Generate and save sample images"""
        predictions = self.generator(self.fixed_noise, training=False)

        fig, axes = plt.subplots(4, 4, figsize=(8, 8))

        for i in range(16):
            plt.subplot(4, 4, i+1)
            img = (predictions[i].numpy() + 1) / 2.0
            plt.imshow(img)
            plt.axis("off")

        plt.suptitle(f"Generated Images - Epoch {epoch}")
        plt.tight_layout()
        plt.savefig(f"{self.config.sample_dir}/epoch_{epoch}.png")
        plt.close()

    def plot_training_history(self):
        """Plot training metrics history"""
        epochs = range(1, len(self.gen_losses) + 1)

        plt.figure(figsize=(15, 10))
        plt.subplot(2, 2, 1)
        plt.plot(epochs, self.gen_losses, label="Generator Loss")
        plt.plot(epochs, self.disc_losses, label="Discriminator Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Generator and Discriminator Loss")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 2)
        plt.plot(epochs, self.real_scores, label="D(x) - Real")
        plt.plot(epochs, self.fake_scores, label="D(G(z)) - Fake")
        plt.xlabel("Epochs")
        plt.ylabel("Score")
        plt.title("Discriminator Scores")
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 3)
        plt.plot(epochs, self.real_accs, label='Real Accuracy')
        plt.plot(epochs, self.fake_accs, label='Fake Accuracy')
        plt.title('Discriminator Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 2, 4)
        predictions = self.generator(self.fixed_noise, training=False)

        grid_size = 4
        image_size = predictions.shape[1]
        composite_image = np.zeros((grid_size * image_size, grid_size * image_size, 3))

        for i in range(grid_size):
            for j in range(grid_size):
                idx = i * grid_size + j
                if idx < predictions.shape[0]:
                    img = (predictions[idx].numpy() + 1) / 2.0
                    composite_image[i * image_size:(i + 1) * image_size, j * image_size:(j + 1) * image_size, :] = img

        plt.imshow(composite_image)
        plt.axis('off')
        plt.title('Generated Images')

        plt.suptitle('GAN Training Progress', fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(f"{self.config.log_dir}/training_history.png")
        plt.close()

    def save_checkpoint(self, epoch):
        """Save model checkpoint"""
        checkpoint_dir = os.path.join(self.config.checkpoint_dir, f"checkpoint_epoch_{epoch}")
        os.makedirs(checkpoint_dir, exist_ok=True)

        self.generator.save_weights(os.path.join(checkpoint_dir, "generator.weights.h5"))
        self.discriminator.save_weights(os.path.join(checkpoint_dir, "discriminator.weights.h5"))

    def load_checkpoint(self, epoch):
        """Load model checkpoint"""
        checkpoint_dir = os.path.join(self.config.checkpoint_dir, f"checkpoint_epoch_{epoch}")

        self.generator.load_weights(os.path.join(checkpoint_dir, "generator.weights.h5"))
        self.discriminator.load_weights(os.path.join(checkpoint_dir, "discriminator.weights.h5"))

        print(f"Checkpoint loaded from epoch {epoch}")

The ImprovedTransGAN Model Runs like these:

🔹 Step 1: Model Initialization

Sets up the generator, discriminator, optimizers, additional stabilization techniques. And loads advanced training features.

1.6 Creates directories for model checkpoints and sample images

1.7 Generates fixed noise samples for evaluation during training.

🔹 Step 2: Training Loop Execution

1. Processes real images in mini-batches.

2. Performs a discriminator update: (1) Uses multi-head self-attention to analyze image patches. (2)Includes Minibatch Discrimination (optional) to prevent mode collapse.

3. Performs multiple generator updates (generator_steps): (1)Uses self-attention instead of CNN layers. (2)Projects latent noise into a structured spatial feature map. (3)Incorporates Virtual Batch Normalization (VBN) for training stability.

4. Configures advanced training techniques: (1)Feature Matching: Ensures feature distributions of real and fake images are similar. (2)Historical Averaging: Uses past model weights to stabilize learning.

5. Sets up optimizers (Adam) for both networks: (1)Separate learning rates for generator and discriminator. (2)Uses momentum parameters (beta1, beta2) for stable updates.

6. Logs training performance.

7. Generates sample images every sample_freq epochs.

8. Saves checkpoints every save_freq epochs.

🔹 Step 3: Evaluating Model Performance

1. Plots training loss curves and discriminator confidence scores.

2. Generates synthetic images to monitor improvements.

**FIDEvaluator**

The FIDEvaluator class is designed to evaluate the performance of a Generative Adversarial Network (GAN) by calculating the Fréchet Inception Distance (FID). The FID score is a widely used metric to measure the similarity between real and generated images, helping assess the quality of synthetic data. By leveraging InceptionV3 to extract meaningful image features, it provides a robust and reliable measure of image realism. This metric is fundamental in improving and benchmarking GAN-based image generation models.

The key idea behind FID is to compare the distributions of real and generated images in a high-level feature space extracted using a pre-trained InceptionV3 model.

In [20]:
class FIDEvaluator:
    """Evaluates GAN performance using Fréchet Inception Distance (FID)"""
    def __init__(self, config):
        self.config = config
        self.batch_size = config.batch_size

        self.inception_model = tf.keras.applications.InceptionV3(
            include_top=False,
            pooling='avg',
            weights='imagenet',
            input_shape=(299, 299, 3)
        )

        self.real_features = None
        self.real_mean = None
        self.real_cov = None

        print("FID Evaluator initialized")

    def preprocess_images(self, images):
        """Preprocess images for InceptionV3 model"""
        images = (images + 1) / 2.0
        images = tf.image.resize(images, (299, 299))
        images = tf.keras.applications.inception_v3.preprocess_input(images * 255.0)
        return images

    def extract_features(self, images):
        """Extract features using InceptionV3 model"""
        features = self.inception_model.predict(images, batch_size=self.batch_size)
        return features

    def compute_real_statistics(self, dataset, num_samples=10000):
        """Compute statistics for real images"""
        print(f"Computing real data statistics using {num_samples} samples...")

        real_images = []
        for batch in dataset:
            real_images.extend([img.numpy() for img in batch])
            if len(real_images) >= num_samples:
                break

        real_images = np.array(real_images[:num_samples])
        processed_images = self.preprocess_images(real_images)

        features = []
        for i in range(0, len(processed_images), self.batch_size):
            batch = processed_images[i:i+self.batch_size]
            batch_features = self.extract_features(batch)
            features.append(batch_features)

        self.real_features = np.concatenate(features, axis=0)
        self.real_mean = np.mean(self.real_features, axis=0)
        self.real_cov = np.cov(self.real_features, rowvar=False)

        print(f"Real data statistics calculated from {len(self.real_features)} images")

        return self.real_mean, self.real_cov

    def calculate_fid(self, generator, num_samples=10000):
        """Calculate FID score for generated images"""
        if self.real_mean is None or self.real_cov is None:
            raise ValueError("Real statistics not computed. Run compute_real_statistics first.")

        print(f"Calculating FID score using {num_samples} generated samples...")

        z_dim = self.config.z_dim
        batch_size = self.batch_size
        num_batches = int(np.ceil(num_samples / batch_size))

        fake_images = []
        for i in range(num_batches):
            current_batch_size = min(batch_size, num_samples - i * batch_size)
            z = tf.random.normal([current_batch_size, z_dim])
            generated_batch = generator(z, training=False)
            fake_images.append(generated_batch)

        fake_images = np.concatenate(fake_images, axis=0)[:num_samples]
        processed_images = self.preprocess_images(fake_images)

        fake_features = []
        for i in range(0, len(processed_images), batch_size):
            batch = processed_images[i:i+batch_size]
            batch_features = self.extract_features(batch)
            fake_features.append(batch_features)

        fake_features = np.concatenate(fake_features, axis=0)
        fake_mean = np.mean(fake_features, axis=0)
        fake_cov = np.cov(fake_features, rowvar=False)
        mean_diff_squared = np.sum((self.real_mean - fake_mean) ** 2)
        covmean = linalg.sqrtm(self.real_cov.dot(fake_cov))

        if np.iscomplexobj(covmean):
            covmean = covmean.real

        fid = mean_diff_squared + np.trace(self.real_cov + fake_cov - 2 * covmean)

        print(f"FID Score: {fid:.4f} (lower is better)")

        return fid

In [21]:
def run_transgan_experiment(config, dataset, epochs=10, evaluate_every=5):
    """Run a TransGAN experiment with FID evaluation"""
    print("=== Starting TransGAN Experiment ===")
    print(f"Configuration: {config}")
    print(f"Training for {epochs} epochs, evaluating every {evaluate_every} epochs")

    fid_evaluator = FIDEvaluator(config)
    fid_evaluator.compute_real_statistics(dataset, num_samples=1000)

    gan = ImprovedTransGAN(config)
    fid_scores = []

    os.makedirs(config.log_dir, exist_ok=True)

    for epoch in range(epochs):
        print(f"\n=== Epoch {epoch+1}/{epochs} ===")
        gan.train(dataset, 1)

        if (epoch + 1) % evaluate_every == 0 or epoch == epochs - 1:
            print(f"Evaluating FID after epoch {epoch+1}")
            fid = fid_evaluator.calculate_fid(gan.generator, num_samples=1000)
            fid_scores.append((epoch + 1, fid))

    print("\n=== Final Evaluation ===")
    final_fid = fid_evaluator.calculate_fid(gan.generator, num_samples=2000)

    if len(fid_scores) > 0:
        plt.figure(figsize=(10, 6))
        epochs_, scores = zip(*fid_scores)
        plt.plot(epochs_, scores, 'o-', linewidth=2)
        plt.title('FID Score Over Training (Lower is Better)')
        plt.xlabel('Epochs')
        plt.ylabel('FID Score')
        plt.grid(True, alpha=0.3)
        plt.savefig(f"{config.log_dir}/fid_history.png")
        plt.close()

    final_samples = 36
    noise = tf.random.normal([final_samples, config.z_dim])
    generated_images = gan.generator(noise, training=False)

    plt.figure(figsize=(10, 10))
    for i in range(final_samples):
        plt.subplot(6, 6, i+1)
        img = (generated_images[i].numpy() + 1) / 2.0
        plt.imshow(img)
        plt.axis('off')
    plt.suptitle(f'Final Generated Samples (FID: {final_fid:.4f})')
    plt.tight_layout()
    plt.savefig(f"{config.log_dir}/final_samples.png")
    plt.close()

    return gan, fid_scores, final_fid

In [22]:
# Function to run comparison experiments
def compare_transgan_improvements(dataset, base_epochs=5):
    """Compare different TransGAN improvement techniques"""
    print("=== TransGAN Improvement Comparison Study ===")

    configs = [
        {
            "name": "baseline",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "feature_matching",
            "updates": {
                "use_feature_matching": True,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "historical_avg",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": True,
                "use_minibatch_discrimination": False,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "minibatch_disc",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": True,
                "label_smoothing": 0.0
            }
        },
        {
            "name": "label_smoothing",
            "updates": {
                "use_feature_matching": False,
                "use_historical_averaging": False,
                "use_minibatch_discrimination": False,
                "label_smoothing": 0.1
            }
        },
        {
            "name": "all_improvements",
            "updates": {
                "use_feature_matching": True,
                "use_historical_averaging": True,
                "use_minibatch_discrimination": True,
                "label_smoothing": 0.1
            }
        }
    ]

    base_config = GANConfig()
    fid_evaluator = FIDEvaluator(base_config)
    fid_evaluator.compute_real_statistics(dataset, num_samples=1000)
    results = {}

    for config_info in configs:
        name = config_info["name"]
        updates = config_info["updates"]

        print(f"\n\n=== Training Configuration: {name} ===")
        print("Settings:")

        config = GANConfig()
        for k, v in updates.items():
            setattr(config, k, v)

        config.sample_dir = f"samples/{name}"
        config.checkpoint_dir = f"checkpoints/{name}"
        config.log_dir = f"logs/{name}"

        os.makedirs(config.sample_dir, exist_ok=True)
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        os.makedirs(config.log_dir, exist_ok=True)

        gan = ImprovedTransGAN(config)
        gan.train(dataset, base_epochs)

        fid_score = fid_evaluator.calculate_fid(gan.generator, num_samples=1000)

        results[name] = {
            "fid": fid_score,
            "gan": gan,
            "config": config,
            "gen_loss": gan.gen_losses[-1] if gan.gen_losses else None,
            "disc_loss": gan.disc_losses[-1] if gan.disc_losses else None,
            "real_score": gan.real_scores[-1] if gan.real_scores else None,
            "fake_score": gan.fake_scores[-1] if gan.fake_scores else None
        }

        np.savez(
            f"{config.log_dir}/metrics.npz",
            gen_losses=np.array(gan.gen_losses),
            disc_losses=np.array(gan.disc_losses),
            real_scores=np.array(gan.real_scores),
            fake_scores=np.array(gan.fake_scores),
            fid=np.array([fid_score])
        )

    visualize_comparison_results(results)

    return results

In [23]:
# Visualization functions
def visualize_comparison_results(results):
    """Create comparative visualizations of different configurations"""
    print("Creating comparative visualizations...")

    default_colors = itertools.cycle(['gray', 'blue', 'green', 'orange', 'purple', 'red', 'cyan', 'magenta'])
    colors = {name: next(default_colors) for name in results}

    # --- FID Bar Chart ---
    plt.figure(figsize=(12, 6))

    names = list(results.keys())
    fid_scores = [results[name]["fid"] for name in names]

    sorted_indices = np.argsort(fid_scores)
    sorted_names = [names[i] for i in sorted_indices]
    sorted_fids = [fid_scores[i] for i in sorted_indices]
    sorted_colors = [colors[name] for name in sorted_names]

    plt.barh(sorted_names, sorted_fids, color=sorted_colors)
    plt.title('FID Score Comparison (Lower is Better)')
    plt.xlabel('FID Score')
    plt.grid(axis='x', alpha=0.3)

    for i, v in enumerate(sorted_fids):
        plt.text(v + 0.5, i, f"{v:.2f}", va='center')

    plt.tight_layout()
    plt.savefig("fid_comparison.png", dpi=150)
    plt.close()

    # --- Loss Curve Comparison ---
    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    for name in results:
        if hasattr(results[name]["gan"], "gen_losses"):
            plt.plot(results[name]["gan"].gen_losses, label=name, color=colors[name], linewidth=2)

    plt.title('Generator Loss Comparison')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    for name in results:
        if hasattr(results[name]["gan"], "disc_losses"):
            plt.plot(results[name]["gan"].disc_losses, label=name, color=colors[name], linewidth=2)

    plt.title('Discriminator Loss Comparison')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.suptitle('Training Loss Comparison Across Techniques', fontsize=16)
    plt.tight_layout()
    plt.savefig("loss_comparison.png", dpi=150)
    plt.close()

    # --- Discriminator Score Comparison ---
    plt.figure(figsize=(15, 6))

    plt.subplot(1, 2, 1)
    for name in results:
        if hasattr(results[name]["gan"], "real_scores"):
            plt.plot(results[name]["gan"].real_scores, label=name, color=colors[name], linewidth=2)

    plt.title('D(x) - Real Score Comparison')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    for name in results:
        if hasattr(results[name]["gan"], "fake_scores"):
            plt.plot(results[name]["gan"].fake_scores, label=name, color=colors[name], linewidth=2)

    plt.title('D(G(z)) - Fake Score Comparison')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.suptitle('Discriminator Score Comparison', fontsize=16)
    plt.tight_layout()
    plt.savefig("score_comparison.png", dpi=150)
    plt.close()

    for name, result in results.items():
        config = result["config"]
        z_dim = config.z_dim
        fixed_noise = tf.random.normal([16, z_dim])

        generated_images = result["gan"].generator(fixed_noise, training=False)

        plt.figure(figsize=(8, 8))
        plt.suptitle(f"{name.replace('_', ' ').title()} (FID: {result['fid']:.2f})", fontsize=14)

        for j in range(16):
            plt.subplot(4, 4, j+1)
            img = (generated_images[j].numpy() + 1) / 2.0
            plt.imshow(img)
            plt.axis('off')

        save_path = os.path.join(config.log_dir, f"samples_{name}.png")
        plt.tight_layout()
        plt.savefig(save_path, dpi=150)
        plt.close()

    print("\nSummary of Results:")
    print("-" * 100)
    print(f"{'Configuration':<20} {'FID Score':<12} {'Gen Loss':<12} {'Disc Loss':<12} {'D(x)':<12} {'D(G(z))':<12}")
    print("-" * 100)

    sorted_results = sorted(results.items(), key=lambda x: x[1]["fid"])

    for name, result in sorted_results:
        print(f"{name:<20} {result['fid']:<12.4f} ", end="")

        if result.get('gen_loss') is not None:
            print(f"{result['gen_loss']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('disc_loss') is not None:
            print(f"{result['disc_loss']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('real_score') is not None:
            print(f"{result['real_score']:<12.4f} ", end="")
        else:
            print(f"{'N/A':<12} ", end="")

        if result.get('fake_score') is not None:
            print(f"{result['fake_score']:<12.4f}")
        else:
            print(f"{'N/A':<12}")

    print("-" * 100)

In [24]:
# Mode collapse analysis
def analyze_mode_collapse(results, dataset, num_samples=1000):
    """Analyze mode collapse across configurations using feature diversity"""
    print("Analyzing mode collapse across configurations...")

    real_images = []
    for batch in dataset:
        for img in batch:
            real_images.append(img.numpy())
        if len(real_images) >= num_samples:
            break

    real_images = np.array(real_images[:num_samples])

    inception_model = tf.keras.applications.InceptionV3(
        include_top=False,
        pooling='avg',
        weights='imagenet'
    )

    def preprocess_images(images):
        images = (images + 1) / 2.0
        images = tf.image.resize(images, (299, 299))
        images = tf.keras.applications.inception_v3.preprocess_input(images * 255.0)
        return images

    processed_real = preprocess_images(real_images)
    real_features = inception_model.predict(processed_real, batch_size=32, verbose=0)

    feature_stats = {}
    for name, result in results.items():
        config = result["config"]
        z = tf.random.normal([num_samples, config.z_dim])
        fake_images = result["gan"].generator(z, training=False).numpy()

        processed_fake = preprocess_images(fake_images)
        fake_features = inception_model.predict(processed_fake, batch_size=32, verbose=0)

        feature_stats[name] = {
            "features": fake_features,
            "mean": np.mean(fake_features, axis=0),
            "std": np.std(fake_features, axis=0),
            "min": np.min(fake_features, axis=0),
            "max": np.max(fake_features, axis=0)
        }

    feature_diversity = {}
    for name, stats in feature_stats.items():
        avg_std = np.mean(stats["std"])
        avg_range = np.mean(stats["max"] - stats["min"])

        feature_diversity[name] = {
            "avg_std": avg_std,
            "avg_range": avg_range
        }

    real_std = np.mean(np.std(real_features, axis=0))
    real_range = np.mean(np.max(real_features, axis=0) - np.min(real_features, axis=0))
    plt.figure(figsize=(12, 6))

    names = list(feature_diversity.keys())
    avg_stds = [feature_diversity[name]["avg_std"] for name in names]

    sorted_indices = np.argsort(avg_stds)[::-1]
    sorted_names = [names[i] for i in sorted_indices]
    sorted_stds = [avg_stds[i] for i in sorted_indices]

    colors = {
        'baseline': 'gray',
        'feature_matching': 'blue',
        'minibatch_disc': 'green',
        'historical_avg': 'purple',
        'label_smoothing': 'orange',
        'all_improvements': 'red'
    }
    sorted_colors = [colors.get(name, 'skyblue') for name in sorted_names]

    plt.bar(sorted_names, sorted_stds, color=sorted_colors)
    plt.axhline(y=real_std, color='black', linestyle='--', label=f'Real Data ({real_std:.4f})')

    plt.title('Feature Diversity Comparison (Higher is Better)')
    plt.ylabel('Average Feature Standard Deviation')
    plt.legend()
    plt.grid(axis='y', alpha=0.3)

    for i, v in enumerate(sorted_stds):
        plt.text(i, v + 0.001, f"{v:.4f}", ha='center')

    plt.tight_layout()
    plt.savefig("feature_diversity.png", dpi=150)
    plt.close()

    print("\nMode Collapse Analysis:")
    print("-" * 80)
    print(f"{'Configuration':<20} {'Feature Std':<15} {'% of Real':<15} {'Mode Collapse':<15}")
    print("-" * 80)

    sorted_results = sorted(feature_diversity.items(), key=lambda x: x[1]["avg_std"], reverse=True)

    for name, stats in sorted_results:
        avg_std = stats["avg_std"]
        pct_of_real = (avg_std / real_std) * 100

        if pct_of_real >= 90:
            collapse_status = "Minimal"
        elif pct_of_real >= 70:
            collapse_status = "Minor"
        elif pct_of_real >= 50:
            collapse_status = "Moderate"
        elif pct_of_real >= 30:
            collapse_status = "Significant"
        else:
            collapse_status = "Severe"

        print(f"{name:<20} {avg_std:<15.4f} {pct_of_real:<15.2f}% {collapse_status:<15}")

    print("-" * 80)
    print(f"Real Data Reference: {real_std:.4f}")
    print("-" * 80)

    return feature_diversity

In [25]:
# Main experiment runner
def run_complete_experiment(epochs=10, dataset_name='cifar10'):
    """Run complete TransGAN experiment with various configurations"""
    print(f"=== Starting Complete TransGAN Experiment ===")
    print(f"Dataset: {dataset_name}")
    print(f"Training epochs: {epochs}")

    config = GANConfig(
        z_dim=128,
        base_size=4,
        image_size=32,
        learning_rate_g=1e-4,
        learning_rate_d=1e-4,
        beta1=0.5,
        beta2=0.999,
        generator_steps=1,
        discriminator_steps=1,
        use_feature_matching=False,
        use_historical_averaging=False,
        use_minibatch_discrimination=False,
        feature_matching_weight=1.0,
        historical_averaging_weight=0.1,
        label_smoothing=0.0,
        generator_target_prob=0.9,
        sample_freq=1,
        save_freq=10,
        checkpoint_dir="./checkpoints",
        sample_dir="./samples",
        log_dir="./logs",
        batch_size=64,
        channels=3
    )

    if dataset_name == 'cifar10':

        (x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
        x_train = x_train[:int(0.2 * len(x_train))]
        dataset = tf.data.Dataset.from_tensor_slices(x_train)

        def preprocess_image(img):
            img = tf.cast(img, tf.float32)
            img = (img - 127.5) / 127.5
            return img

        dataset = dataset.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.shuffle(10000)
        dataset = dataset.batch(config.batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    start_time = time.time()
    results = compare_transgan_improvements(dataset, base_epochs=epochs)
    analyze_mode_collapse(results, dataset)
    total_time = time.time() - start_time
    print(f"\nTotal experiment time: {total_time/60:.2f} minutes")
    generate_summary_report(results, dataset_name, epochs)

    return results

In [26]:
# Generate summary report
def generate_summary_report(results, dataset_name, epochs):
    """Generate a comprehensive summary report of experiment results"""
    print("\n=== Comprehensive Experiment Summary ===")
    print(f"Dataset: {dataset_name}")
    print(f"Training epochs per configuration: {epochs}")
    print("-" * 80)

    best_config = min(results.items(), key=lambda x: x[1]["fid"])
    worst_config = max(results.items(), key=lambda x: x[1]["fid"])

    print(f"Best configuration: {best_config[0]} (FID: {best_config[1]['fid']:.4f})")
    print(f"Worst configuration: {worst_config[0]} (FID: {worst_config[1]['fid']:.4f})")
    print(f"Improvement: {(worst_config[1]['fid'] - best_config[1]['fid']) / worst_config[1]['fid'] * 100:.2f}%")

    print("\n=== Impact of Individual Techniques ===")

    baseline_fid = results["baseline"]["fid"]

    techniques = [
        "feature_matching",
        "minibatch_disc",
        "historical_avg",
        "label_smoothing"
    ]

    technique_improvements = []

    for technique in techniques:
        if technique in results:
            technique_fid = results[technique]["fid"]
            improvement = (baseline_fid - technique_fid) / baseline_fid * 100
            technique_improvements.append((technique, improvement))

            print(f"{technique.replace('_', ' ').title()}:")
            print(f"  FID: {technique_fid:.4f}")
            print(f"  Improvement over baseline: {improvement:.2f}%")

    if "all_improvements" in results:
        all_fid = results["all_improvements"]["fid"]
        all_improvement = (baseline_fid - all_fid) / baseline_fid * 100

        print(f"\nAll Improvements Combined:")
        print(f"  FID: {all_fid:.4f}")
        print(f"  Improvement over baseline: {all_improvement:.2f}%")

    print("\n=== Conclusions ===")

    technique_improvements.sort(key=lambda x: x[1], reverse=True)

    print("Technique effectiveness ranking:")
    for i, (technique, improvement) in enumerate(technique_improvements):
        print(f"  {i+1}. {technique.replace('_', ' ').title()}: {improvement:.2f}% improvement")

    print("-" * 80)
    print("Experiment complete!")

In [28]:
# Create configuration for the model
config = GANConfig(
    z_dim=128,
    base_size=4,
    image_size=32,
    learning_rate_g=1e-4,
    learning_rate_d=1e-4,
    beta1=0.5,
    beta2=0.999,
    generator_steps=1,
    discriminator_steps=1,
    use_feature_matching=False,
    use_historical_averaging=False,
    use_minibatch_discrimination=False,
    feature_matching_weight=1.0,
    historical_averaging_weight=0.1,
    label_smoothing=0.0,
    generator_target_prob=0.9,
    sample_freq=1,
    save_freq=10,
    checkpoint_dir="./checkpoints",
    sample_dir="./samples",
    log_dir="./logs",
    batch_size=64,
    channels=3
)

model = ImprovedTransGAN(config)
(x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
x_train = x_train[:int(0.2 * len(x_train))]
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)

def preprocess_image(img):
    img = tf.cast(img, tf.float32)
    img = (img - 127.5) / 127.5
    return img

dataset = train_dataset.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(10000)
dataset = dataset.batch(config.batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

model.train(dataset, epochs=10)

noise = tf.random.normal([16, config.z_dim])
images = model.generator(noise, training=False)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


TransGAN model initialized!
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 0us/step


  0%|          | 0/156 [00:00<?, ?it/s]



Epoch 1/10 completed in 105.32s
Generator Loss: 4.5171
Discriminator Loss: 0.5116
D(x): 0.8967, D(G(z)): 0.0397
Real Acc: 0.9105, Fake Acc: 0.9820
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 2/10 completed in 44.89s
Generator Loss: 6.3445
Discriminator Loss: 0.0365
D(x): 0.9902, D(G(z)): 0.0032
Real Acc: 0.9928, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 3/10 completed in 44.91s
Generator Loss: 8.2195
Discriminator Loss: 0.0004
D(x): 0.9998, D(G(z)): 0.0001
Real Acc: 0.9999, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 4/10 completed in 44.87s
Generator Loss: 8.1648
Discriminator Loss: 0.0004
D(x): 0.9998, D(G(z)): 0.0001
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 5/10 completed in 45.56s
Generator Loss: 9.4377
Discriminator Loss: 0.0001
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 6/10 completed in 44.80s
Generator Loss: 9.9682
Discriminator Loss: 0.0000
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 7/10 completed in 44.87s
Generator Loss: 10.3802
Discriminator Loss: 0.0000
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 8/10 completed in 44.76s
Generator Loss: 10.5922
Discriminator Loss: 0.0000
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 9/10 completed in 44.81s
Generator Loss: 10.9098
Discriminator Loss: 0.0000
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------


  0%|          | 0/156 [00:00<?, ?it/s]

Epoch 10/10 completed in 45.71s
Generator Loss: 11.0591
Discriminator Loss: 0.0000
D(x): 1.0000, D(G(z)): 0.0000
Real Acc: 1.0000, Fake Acc: 1.0000
--------------------------------------------------------------------------------
Training completed in 8.51 minutes


In [30]:
results = {
    "baseline": {
        "gan": model,
        "config": config,
        "gen_loss": model.gen_losses[-1] if model.gen_losses else None,
        "disc_loss": model.disc_losses[-1] if model.disc_losses else None,
        "real_score": model.real_scores[-1] if model.real_scores else None,
        "fake_score": model.fake_scores[-1] if model.fake_scores else None
    }
}

# Create and initialize FID evaluator with smaller batch size
fid_evaluator = FIDEvaluator(config)
fid_evaluator.batch_size = 16  # Smaller batch size to avoid memory issues

# Compute real statistics with fewer samples
print("Computing real data statistics...")
fid_evaluator.compute_real_statistics(dataset, num_samples=200)

# Calculate FID with smaller sample size
print("Calculating FID score...")
fid = fid_evaluator.calculate_fid(model.generator, num_samples=200)
results["baseline"]["fid"] = fid

# Use existing visualization function
print("Creating visualizations...")
visualize_comparison_results(results)

# Use modified mode collapse analysis with smaller sample size
def analyze_mode_collapse_small(results, dataset, num_samples=50):
    """Modified version with fewer samples to avoid memory issues"""
    return analyze_mode_collapse(results, dataset, num_samples)

# Call the modified function
print("Analyzing mode collapse...")
feature_diversity = analyze_mode_collapse_small(results, dataset)

# Generate summary report
print("Generating summary report...")
generate_summary_report(results, "cifar10", 10)

FID Evaluator initialized
Computing real data statistics...
Computing real data statistics using 200 samples...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 7s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 54ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 52ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step
[1m1/1[0m [32m━━━━━━━━━