<a href="https://colab.research.google.com/github/AiJared/stylegan/blob/main/StyleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import keras

In [3]:
from google.colab import drive
drive.mount('/content/MyDrive')

Mounted at /content/MyDrive


In [None]:
class StyleGAN:
  def __init__(self, img_shape=(256, 256, 3), latent_dim=512, n_styles=18):
    """
    Initialize stylegan with specific parameters

    Args:
      img_shape(tuple): Dimensions of output images
      latent_dim(int): Dimensionality of the input latent space
      n_styles(int): Number of style layers
    """

    self.img_shape = img_shape
    self.latent_dim = latent_dim
    self.n_styles = n_styles

    # Key StyleGAN Components
    self.mapping_network = self.build_mapping_network()
    self.synthesis_network = self.build_synthesis_network()
    self.discriminator = self.build_discriminator()

    # Compile the full generator and adversarial model
    self.generator = self.build_generator()
    self.adversarial_model = self.build_adversarial_model()

  def build_mapping_network(self):
    """
    Mapping network: Transforms input noise to intermediate latent space
    Key styleGAN Innovation: Non-linear transformation of input noise
    """

    model  = keras.Sequential([
        keras.layers.Dense(self.latent_dim, activation='relu'),
        keras.layers.Dense(self.latent_dim, activation='relu'),
        keras.layers.Dense(self.latent_dim, activation='relu')
    ])

    return model

  def build_style_block(self, out_channels, upsample=True):
    """
    Create a style block with Adaptive Instance Normalization (Adain) concept.

    Args:
      out_channels (int): Number of output channels
      upsample (bool): Whether to upsample the feature map

    Returns:
      keras.Model: Style block model
    """

    block = keras.Sequential()

    if upsample:
      # upsampling layer
      block.add(keras.layers.UpSampling2D(size=(2,2)))

    # convolutional layer
    block.add(keras.layers.Conv2D(out_channels, 3, padding='same'))

    # Noise input (Simulated AdaIN)
    block.add(keras.layers.Lambda(lambda x: x + tf.random.normal(tf.shape(x)) * 0.1))

    # Activate
    block.add(keras.layers.LeakyReLU(0.2))

    return block
  def build_synthesis_network(self):
    """
    synthesis network: Generates images from style vectors
    Progressive growing of feature maps
    """

    model = keras.Sequential([
        # starting block
        keras.layers.Dense(4 * 4 * 512, input_shape=(self.latent_dim,)),
        keras.layers.Reshape((4, 4, 512)),

        # Progressive style blocks
        self.build_style_block(256), # 8x8
        self.build_style_block(128), # 16x16
        self.build_style_block(64), # 32x32
        self.build_style_block(16), # 64x64

        # Final convolution layer to match image channels
        keras.layers.Conv2D(self.img_shape[2], 1, activation='tanh')
    ])

    return model

  def build_discriminator(self):
    """

    Discriminator Network: Distinguishes real from generated images
    Uses multi-scale feature matching
    """

    model = keras.Sequential([
        # multi-scale feature extraction
        keras.layers.Conv2D(64, 4, strides=2, padding='same', input_shape=self.img_shape),
        keras.layers.LeakyReLU(0.2),

        keras.layers.Conv2D(128, 4, strides=2, padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(0.2),

        keras.layers.Conv2D(256, 4, strides=2, padding='same'),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(0.2),

        keras.layers.Flatten(),
        keras.layers.Dense(1, activation='sigmoid')
    ])

    return model

  def build_generator(self):
    """
    Full Generator: Combines mapping and synthesis networks.
    """

    def generator(noise):
      # transform noise through mapping network
      w = self.mapping_network(noise)

      # generate image through synthesis network
      image = self.synthesis_network(w)
      return image