<a href="https://colab.research.google.com/github/ArthurKuhn-prog/MA-SenaDiff/blob/main/MA_senaDiff_DDIMGEN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
##importing all the libraries needed

import math
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

#importing google Drive to use my own Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Setting up the hyperparameters and variables

batch_size = 32
num_epochs = 50
total_timesteps = 1000
norm_groups = 8
learning_rate = 2e-4

img_height = 64
img_width = 64
img_channels = 3
clip_min = -1.0
clip_max = 1.0

first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2

dataset_name = "oxford_flowers102"
splits = ["train"]

In [None]:
#Importing the dataset from my GDrive

ds = keras.utils.image_dataset_from_directory(
  '/content/drive/MyDrive/1.MelodiaAtomizacji/senaSet/',
  label_mode=None,
  batch_size = None,
  image_size=(64,64),
  shuffle=True,
  validation_split=0.2,
  subset='validation',
  seed=8111994, #Seed is needed for validation_split
)

#(ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)

Found 14109 files belonging to 1 classes.
Using 2821 files for validation.


In [None]:
#Let's start working on the DS

#Augmenting the DS with a few flips
def augment(img):
    return tf.image.random_flip_left_right(img)

#Resizing the image and rescale its values from [-1.0, 1.0]
#img is the image tensor
#size is the desired size of the image
def resize_and_rescale(img, size):
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    crop_size = tf.minimum(height, width)

    imf = tf.image.crop_to_bounding_box(
        img,
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )

    #Resize the image
    img = tf.cast(img, dtype=tf.float32)
    img = tf.image.resize(img, size=size, antialias=True)

    #Rescale the pixel values
    img = img / 127.5 - 1.0
    img = tf.clip_by_value(img, clip_min, clip_max)
    return img

def train_preprocessing(x):
    img = x
    img = resize_and_rescale(img, size=(img_height, img_width))
    img = augment(img)
    return img

train_ds = (
    ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size, drop_remainder=True)
    .shuffle(batch_size * 2)
    .prefetch(tf.data.AUTOTUNE)
)



In [None]:
class GaussianDiffusion:
  """Gaussian diffusion utility.

      beta_start is the start value of the scheduled variance
      beta_end is the end value of the scheduled variance
      timesteps is the number of time steps in the forward process
  """

  def __init__(
      self,
      beta_start=1e-4,
      beta_end=0.02,
      timesteps=1000,
      clip_min=-1.0,
      clip_max=1.0,
  ):
      self.beta_start = beta_start
      self.beta_end = beta_end
      self.timesteps = timesteps
      self.clip_min = clip_min
      self.clip_max = clip_max

      # Define the linear variance schedule
      self.betas = betas = np.linspace(
          beta_start,
          beta_end,
          timesteps,
          dtype=np.float64,  # Using float64 for better precision
      )
      self.num_timesteps = int(timesteps)

      alphas = 1.0 - betas
      alphas_cumprod = np.cumprod(alphas, axis=0)
      alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

      self.betas = tf.constant(betas, dtype=tf.float32)
      self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
      self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)

      # Calculations for diffusion q(x_t | x_{t-1}) and others
      self.sqrt_alphas_cumprod = tf.constant(
          np.sqrt(alphas_cumprod), dtype=tf.float32
      )

      self.sqrt_one_minus_alphas_cumprod = tf.constant(
          np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
      )

      self.log_one_minus_alphas_cumprod = tf.constant(
          np.log(1.0 - alphas_cumprod), dtype=tf.float32
      )

      self.sqrt_recip_alphas_cumprod = tf.constant(
          np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
      )
      self.sqrt_recipm1_alphas_cumprod = tf.constant(
          np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
      )

      # Calculations for posterior q(x_{t-1} | x_t, x_0)
      posterior_variance = (
          betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
      )
      self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)

      # Log calculation clipped because the posterior variance is 0 at the beginning
      # of the diffusion chain
      self.posterior_log_variance_clipped = tf.constant(
          np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
      )

      self.posterior_mean_coef1 = tf.constant(
          betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
          dtype=tf.float32,
      )

      self.posterior_mean_coef2 = tf.constant(
          (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
          dtype=tf.float32,
      )

  def _extract(self, a, t, x_shape):
      """Extract some coefficients at specified timesteps,
      then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.

      Args:
          a: Tensor to extract from
          t: Timestep for which the coefficients are to be extracted
          x_shape: Shape of the current batched samples
      """
      batch_size = x_shape[0]
      out = tf.gather(a, t)
      return tf.reshape(out, [batch_size, 1, 1, 1])

  def q_mean_variance(self, x_start, t):
      """Extracts the mean, and the variance at current timestep.

      Args:
          x_start: Initial sample (before the first diffusion step)
          t: Current timestep
      """
      x_start_shape = tf.shape(x_start)
      mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
      variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
      log_variance = self._extract(
          self.log_one_minus_alphas_cumprod, t, x_start_shape
      )
      return mean, variance, log_variance

  def q_sample(self, x_start, t, noise):
      """Diffuse the data.

      Args:
          x_start: Initial sample (before the first diffusion step)
          t: Current timestep
          noise: Gaussian noise to be added at the current timestep
      Returns:
          Diffused samples at timestep `t`
      """
      x_start_shape = tf.shape(x_start)
      return (
          self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
          + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
          * noise
      )

  def predict_start_from_noise(self, x_t, t, noise):
      x_t_shape = tf.shape(x_t)
      return (
          self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
          - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
      )

  def q_posterior(self, x_start, x_t, t):
      """Compute the mean and variance of the diffusion
      posterior q(x_{t-1} | x_t, x_0).

      Args:
          x_start: Stating point(sample) for the posterior computation
          x_t: Sample at timestep `t`
          t: Current timestep
      Returns:
          Posterior mean and variance at current timestep
      """

      x_t_shape = tf.shape(x_t)
      posterior_mean = (
          self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
          + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
      )
      posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
      posterior_log_variance_clipped = self._extract(
          self.posterior_log_variance_clipped, t, x_t_shape
      )
      return posterior_mean, posterior_variance, posterior_log_variance_clipped

  def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
      x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
      if clip_denoised:
          x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)

      model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
          x_start=x_recon, x_t=x, t=t
      )
      return model_mean, posterior_variance, posterior_log_variance

  def p_sample(self, pred_noise, x, t, clip_denoised=True):
      """Sample from the diffusion model.

      Args:
          pred_noise: Noise predicted by the diffusion model
          x: Samples at a given timestep for which the noise was predicted
          t: Current timestep
          clip_denoised (bool): Whether to clip the predicted noise
              within the specified range or not.
      """
      model_mean, _, model_log_variance = self.p_mean_variance(
          pred_noise, x=x, t=t, clip_denoised=clip_denoised
      )
      noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
      # No noise when t == 0
      nonzero_mask = tf.reshape(
          1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
      )
      return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise

In [None]:
#Kernel initialiser for the U-Net architecture
def kernel_init(scale):
  scale = max(scale, 1e-10)
  return keras.initializers.VarianceScaling(
      scale, mode="fan_avg", distribution="uniform"
  )

"""Applying self-attention.

units is the number of units un the dense layer
groups is the number of groups to be used for GroupNormalization layer
"""
class AttentionBlock(layers.Layer):
    """Applies self-attention.

    Args:
        units: Number of units in the dense layers
        groups: Number of groups to be used for GroupNormalization layer
    """

    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = layers.GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])

        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
        proj = self.proj(proj)
        return inputs + proj

class TimeEmbedding(layers.Layer):
  def __init__(self, dim, **kwargs):
    super().__init__(**kwargs)
    self.dim = dim
    self.half_dim = dim // 2
    self.emb = math.log(10000) / (self.half_dim - 1)
    self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

  def call(self, inputs):
    inputs = tf.cast(inputs, dtype=tf.float32)
    emb = inputs[:, None] * self.emb[None, :]
    emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
    return emb

def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
  def apply(inputs):
    x, t = inputs
    input_width = x.shape[3]

    if input_width == width:
      residual = x
    else:
      residual = layers.Conv2D(
          width, kernel_size=1, kernel_initializer=kernel_init(1.0)
      )(x)

    temb = activation_fn(t)
    temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
        :, None, None, :
    ]

    x = layers.GroupNormalization(groups=groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(
        width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
    )(x)

    x = layers.Add()([x, temb])
    x = layers.GroupNormalization(groups=groups)(x)
    x = activation_fn(x)

    x = layers.Conv2D(
        width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
    )(x)
    x = layers.Add()([x, residual])
    return x

  return apply

def DownSample(width):
  def apply(x):
    x = layers.Conv2D(
        width,
        kernel_size=3,
        strides=2,
        padding="same",
        kernel_initializer=kernel_init(1.0),
    )(x)
    return x

  return apply

def UpSample(width, interpolation="nearest"):
  def apply(x):
    x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
    x = layers.Conv2D(
        width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
    )(x)
    return x

  return apply

def TimeMLP(units, activation_fn=keras.activations.swish):
  def apply(inputs):
    temb = layers.Dense(
        units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
    )(inputs)
    temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
    return temb

  return apply

def build_model(
    img_height,
    img_width,
    img_channels,
    widths,
    has_attention,
    num_res_blocks=2,
    norm_groups=32,
    interpolation="nearest",
    activation_fn=keras.activations.swish,
):
  image_input = layers.Input(
      shape=(img_height, img_width, img_channels), name="image_input"
  )
  time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")

  x = layers.Conv2D(
      first_conv_channels,
      kernel_size=(3, 3),
      padding="same",
      kernel_initializer=kernel_init(1.0),
  )(image_input)

  temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
  temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

  skips = [x]

  #Downblock
  for i in range(len(widths)):
    for _ in range(num_res_blocks):
      x = ResidualBlock(
          widths[i], groups=norm_groups, activation_fn=activation_fn
      )([x, temb])
      if has_attention[i]:
        x = AttentionBlock(widths[i], groups=norm_groups)(x)
      skips.append(x)

    if widths[i] != widths[-1]:
      x = DownSample(widths[i])(x)
      skips.append(x)

  #Middleblock
  x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
      [x, temb]
  )
  x = AttentionBlock(widths[-1], groups=norm_groups)(x)
  x = ResidualBlock(widths[-1], groups = norm_groups, activation_fn=activation_fn)(
      [x, temb]
  )

  #Upblock
  for i in reversed(range(len(widths))):
    for _ in range(num_res_blocks + 1):
      x = layers.Concatenate(axis=-1)([x, skips.pop()])
      x = ResidualBlock(
          widths[i], groups=norm_groups, activation_fn=activation_fn
      )([x, temb])
      if has_attention[i]:
        x = AttentionBlock(widths[i], groups=norm_groups)(x)

    if i != 0:
      x = UpSample(widths[i], interpolation=interpolation)(x)

  #Endblock
  x = layers.GroupNormalization(groups=norm_groups)(x)
  x = activation_fn(x)
  x = layers.Conv2D(3, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
  return keras.Model([image_input, time_input], x, name="unet")

In [None]:
class DiffusionModel(keras.Model):
  def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
    super().__init__()
    self.network = network
    self.ema_network = ema_network
    self.timesteps = timesteps
    self.gdf_util = gdf_util
    self.ema = ema

  def train_step(self, images):
    #First we get the btach size
    batch_size = tf.shape(images)[0]

    #Then we sample timesteps uniformly
    t = tf.random.uniform(
        minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64
    )

    with tf.GradientTape() as tape:
      #We sample the random noise to be added to the images in the batch
      noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)

      #We diffuse the images with noise
      images_t = self.gdf_util.q_sample(images, t, noise)

      #Pass the diffused images and time steps to the network
      pred_noise = self.network([images_t, t], training=True)

      #Calculate the loss
      loss = self.loss(noise, pred_noise)

    #Get the gradients
    gradients = tape.gradient(loss, self.network.trainable_weights)

    #Update the weights of the network
    self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

    #Update the weight values for the network with EMA weight
    for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
      ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)

    #Return loss values
    return {"loss": loss}

  def generate_images(self, num_images=16):
    #First we randomly sample noise
    samples = tf.random.normal(
        shape=(num_images, img_height, img_width, img_channels), dtype=tf.float32
    )

    #Sample from the model iteratively
    for t in reversed(range(0, self.timesteps)):
      tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
      pred_noise = self.ema_network.predict(
          [samples, tt], verbose=0, batch_size=num_images
      )
      samples = self.gdf_util.p_sample(
          pred_noise, samples, tt, clip_denoised=True
      )

      #Return generated samples
    return samples

  def plot_images(self, epoch=None, logs=None, num_rows=2, num_cols=4, figsize=(12,5)):
    generated_samples = self.generate_images(num_images=num_rows * num_cols)
    generated_samples = (
        tf.clip_by_value(generated_samples * 127.5 + 127.5, 0.0, 255.0)
        .numpy()
        .astype(np.uint8)
    )

    _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
    for i, image in enumerate(generated_samples):
      if num_rows == 1:
        ax[i].imshow(image)
        ax[i].axis("off")
      else:
        ax[i // num_cols, i % num_cols].imshow(image)
        ax[i // num_cols, i % num_cols].axis("off")

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/1.MelodiaAtomizacji/SenaTrainings/23.08.24 - Training 2/training_%03d.png' % (epoch+1))
    plt.close()

checkWeights = tf.keras.callbacks.ModelCheckpoint(
  filepath='/content/drive/MyDrive/1.MelodiaAtomizacji/SenaTrainings/23.08.24 - Training 2/model_cp/senaDiff.ckpt',
  save_weights_only=True,
  verbose=1
)

#Build the UNet model
network = build_model(
  img_height=img_height,
  img_width=img_width,
  img_channels=img_channels,
  widths=widths,
  has_attention=has_attention,
  num_res_blocks=num_res_blocks,
  norm_groups=norm_groups,
  activation_fn=keras.activations.swish,
)
ema_network = build_model(
  img_height=img_height,
  img_width=img_width,
  img_channels=img_channels,
  widths=widths,
  has_attention=has_attention,
  num_res_blocks=num_res_blocks,
  norm_groups=norm_groups,
  activation_fn=keras.activations.swish,
)
ema_network.set_weights(network.get_weights())

#Get an instance of the gaussian diffusion
gdf_util = GaussianDiffusion(timesteps=total_timesteps)

#Get the model
model = DiffusionModel(
  network=network,
  ema_network=ema_network,
  gdf_util=gdf_util,
  timesteps=total_timesteps,
)

#Compile the model
model.compile(
  loss=keras.losses.MeanSquaredError(),
  optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
)


In [None]:
#Loading previously trained-weights
model.load_weights('/content/drive/MyDrive/1.MelodiaAtomizacji/SenaTrainings/23.08.24 - Training 2/model_cp/senaDiff.ckpt')

In [None]:
#Train the model
model.fit(
  train_ds,
  epochs=num_epochs,
  batch_size=batch_size,
  callbacks=[
      keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images), #to plot and save training epoch results
      save_weights, #to checkpoint the weights
  ]
)

In [None]:
#Saving the model weights
tf.keras.callbacks.ModelCheckpoint(
  filepath='/content/drive/MyDrive/1.MelodiaAtomizacji/SenaTrainings/23.08.24 - Training 2/model_cp/senaDiff.ckpt',
  save_weights_only=True,
  verbose=1
)

In [None]:
"""#Then we plot images
model.plot_images(num_rows=2, num_cols=2)"""

#Plotting the images
