In [None]:
import tensorflow as tf

In [None]:
class ResNetBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size):
    super(ResNetBlock, self).__init__()
    self.conv2d = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.group_norm = tf.keras.layers.GroupNormalization(axis=-1)
    self.gelu = tf.keras.layers.Activation('gelu')
    self.conv2d_2 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
    self.group_norm_2 = tf.keras.layers.GroupNormalization(axis=-1)

  def call(self, inputs):
    x = self.conv2d(inputs)
    x = self.group_norm(x)
    x = self.gelu(x)
    x = self.conv2d_2(x)
    x = self.group_norm_2(x)

    return tf.keras.layers.Add()([inputs, x])


In [None]:
class DownSampleBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size):
    super(DownSampleBlock, self).__init__()
    self.max_pool = tf.keras.layers.MaxPool2D((2, 2))
    self.resnet_block = ResNetBlock(filters, kernel_size)
    self.resnet_block_2 = ResNetBlock(filters, kernel_size)
    self.linear = tf.keras.layers.Dense(filters)

  def call(self, inputs):
    x, y = inputs
    x = self.max_pool(x)
    x = self.resnet_block(x)
    x = self.resnet_block_2(x)
    y = tf.keras.activations.silu(y)
    y = self.linear(y)

    return x + y

In [None]:
class SelfAttentionBlock(tf.keras.layers.Layer):
  def __init__(self, filters):
    super(SelfAttentionBlock, self).__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.linear1 = tf.keras.layers.Dense(filters)
    self.gelu = tf.keras.layers.Activation('gelu')
    self.linear2 = tf.keras.layers.Dense(filters)

  def call(self, inputs):
    x = tf.reshape(inputs, shape=(inputs.shape[0], -1, inputs.shape[-1]))  # Shape: (128, 1024)
    x = tf.transpose(x, perm=[1, 0, 2])

    x_norm = self.layernorm1(x)
    x_mha = self.mha(query=x_norm, key=x_norm, value=x_norm)

    x_mha_out = x_mha + x
    x_mha_out = self.layernorm2(x_mha_out)

    x_ffn = self.linear1(x_mha_out)
    x_ffn = self.gelu(x_ffn)
    x_ffn = self.linear2(x_ffn)

    x_out = x_ffn + x_mha_out
    x_out = tf.transpose(x_out, perm=[1, 0, 2])  # Shape back to (128, 1024)
    x_out = tf.reshape(x_out, shape=inputs.shape)

    return x_out

In [None]:
class UpsampleBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size):
    super(UpsampleBlock, self).__init__()
    self.resnet_block = ResNetBlock(filters, kernel_size)
    self.resnet_block_2 = ResNetBlock(filters, kernel_size)
    self.linear = tf.keras.layers.Dense(filters)

  def call(self, inputs):
    x, y, z = inputs
    x_concat = tf.concat([x, y], axis=-1)

    x_resnet_1 = self.resnet_block(x_concat)
    x_resnet_2 = self.resnet_block_2(x_resnet_1)

    z_silu = tf.keras.activations.silu(z)
    z_linear = self.linear(z_silu)

    return x_resnet_2 + z_linear