In [289]:
import tensorflow as tf

In [290]:
class TimeEmbedding(tf.keras.layers.Layer):
    def __init__(self, dim):
        super(TimeEmbedding, self).__init__()
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = tf.math.log(10000.0) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, timesteps):
        timesteps = tf.cast(timesteps, tf.float32)
        emb = timesteps[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb

In [291]:
class ResNetBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size):
        super(ResNetBlock, self).__init__()
        self.conv2d = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
        self.group_norm = tf.keras.layers.GroupNormalization(axis=-1)
        self.gelu = tf.keras.layers.Activation('gelu')
        self.conv2d_2 = tf.keras.layers.Conv2D(filters, kernel_size, padding='same')
        self.group_norm_2 = tf.keras.layers.GroupNormalization(axis=-1)
        self.adjust_dim = tf.keras.layers.Conv2D(filters, 1, padding='same')

    def call(self, inputs):
        if tf.rank(inputs) == 3:
          inputs = tf.expand_dims(inputs, axis=1)
        x = self.conv2d(inputs)
        x = self.group_norm(x)
        x = self.gelu(x)
        x = self.conv2d_2(x)
        x = self.group_norm_2(x)

        inputs = self.adjust_dim(inputs)

        return tf.keras.layers.Add()([inputs, x])

In [292]:
class DownSampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size):
        super(DownSampleBlock, self).__init__()
        self.max_pool = tf.keras.layers.MaxPool2D((2, 2))
        self.resnet_block = ResNetBlock(filters, kernel_size)
        self.resnet_block_2 = ResNetBlock(filters, kernel_size)
        self.linear = tf.keras.layers.Dense(filters)

    def call(self, inputs):
        x, y = inputs
        x = self.max_pool(x)
        x = self.resnet_block(x)
        x = self.resnet_block_2(x)

        y = tf.keras.activations.silu(y)
        y = self.linear(y)
        y = tf.reshape(y, [y.shape[0], 1, 1, y.shape[-1]])

        return x + y

In [293]:
class SelfAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, filters, num_heads=8, embed_dim=64):
        super(SelfAttentionBlock, self).__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.linear1 = tf.keras.layers.Dense(filters)
        self.gelu = tf.keras.layers.Activation('gelu')
        self.linear2 = tf.keras.layers.Dense(filters)

    def call(self, inputs):
        x = tf.reshape(inputs, shape=(inputs.shape[0], -1, inputs.shape[-1]))
        x_norm = self.layernorm1(x)
        x_mha = self.mha(query=x_norm, key=x_norm, value=x_norm)
        x_mha_out = x_mha + x
        x_mha_out = self.layernorm2(x_mha_out)
        x_ffn = self.linear1(x_mha_out)
        x_ffn = self.gelu(x_ffn)
        x_ffn = self.linear2(x_ffn)
        return x_ffn + x_mha_out

In [294]:
class UpsampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size):
        super(UpsampleBlock, self).__init__()
        self.upsample = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')
        self.resnet_block = ResNetBlock(filters, kernel_size)
        self.resnet_block_2 = ResNetBlock(filters, kernel_size)
        self.linear = tf.keras.layers.Dense(filters)
        self.reduce_skip = tf.keras.layers.Conv2D(filters, 1, padding='same')

    def call(self, inputs):
        x, skip, t = inputs

        # Upsample x
        x = self.upsample(x)

        # Ensure skip has the right shape
        skip = self.reduce_skip(skip)

        # Check the shapes for debugging
        print(f"x shape after upsampling: {x.shape}")
        print(f"skip shape after reduction: {skip.shape}")

        # Concatenate x and skip
        if skip.shape[1] != x.shape[1] or skip.shape[2] != x.shape[2]:
            skip = tf.image.resize(skip, [x.shape[1], x.shape[2]])
        print(f"After resizing skip shape: {skip.shape}")

        x = tf.image.resize(x, [skip.shape[1], skip.shape[2]])

        x_concat = tf.concat([x, skip], axis=-1)

        x_resnet_1 = self.resnet_block(x_concat)
        x_resnet_2 = self.resnet_block_2(x_resnet_1)

        t = tf.keras.activations.silu(t)
        t = self.linear(t)

        t = tf.expand_dims(t, axis=1)
        t = tf.expand_dims(t, axis=1)

        t = tf.tile(t, [1, 8, 8, 1])
        print(f"x_resnet_2: {x_resnet_2.shape}")
        print(f"t: {t.shape}")
        # Add t to x_resnet_2
        x_resnet_2 = tf.image.resize(x_resnet_2, [t.shape[1], t.shape[2]])
        return x_resnet_2 + t

In [295]:
class DiffusionModel(tf.keras.Model):
    def __init__(self, img_size=32, base_filters=64, time_embedding_dim=256):
        super(DiffusionModel, self).__init__()
        self.time_embedding = TimeEmbedding(time_embedding_dim)
        self.first_resnet = ResNetBlock(base_filters * 8, 3)
        self.init_conv = tf.keras.layers.Conv2D(base_filters, 3, padding='same')

        self.down1 = DownSampleBlock(base_filters * 2, 3)
        self.down2 = DownSampleBlock(base_filters * 4, 3)
        self.down3 = DownSampleBlock(base_filters * 8, 3)

        self.mid_resnet1 = ResNetBlock(base_filters * 8, 3)
        self.mid_attention = SelfAttentionBlock(base_filters * 8)
        self.mid_resnet2 = ResNetBlock(base_filters * 8, 3)

        self.up1 = UpsampleBlock(base_filters * 4, 3)
        self.up2 = UpsampleBlock(base_filters * 2, 3)
        self.up3 = UpsampleBlock(base_filters, 3)

        self.final_norm = tf.keras.layers.GroupNormalization(axis=-1)
        self.final_conv = tf.keras.layers.Conv2D(3, 3, padding='same')

    def call(self, x, timesteps):
        t = self.time_embedding(timesteps)

        x = self.init_conv(x)
        x = self.first_resnet(x)
        skip1 = x

        x = self.down1([x, t])
        skip2 = x

        x = self.down2([x, t])
        skip3 = x

        x = self.down3([x, t])
        x = self.mid_attention(x)
        x = self.mid_resnet1(x)

        x = self.up1([x, skip3, t])
        x = self.up2([x, skip2, t])
        x = self.up3([x, skip1, t])

        x = self.final_norm(x)
        x = self.final_conv(x)

        return x

In [296]:
model = DiffusionModel()
batch_size = 4
img = tf.random.normal((batch_size, 32, 32, 3))
t = tf.random.uniform((batch_size,), maxval=1000, dtype=tf.int32)
prediction = model(img, t)
print(f"Output shape: {prediction.shape}")


1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling ResNetBlock.call().

[1mUsing a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/ten

x shape after upsampling: (4, 2, 32, 512)
skip shape after reduction: (4, 8, 8, 256)
After resizing skip shape: (4, 2, 32, 256)
x shape after upsampling: (4, 2, 32, 512)
skip shape after reduction: (4, 8, 8, 256)
After resizing skip shape: (4, 2, 32, 256)
x_resnet_2: (4, 2, 32, 256)
t: (4, 8, 8, 256)


1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling ResNetBlock.call().

[1mUsing a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/ten

x shape after upsampling: (4, 16, 16, 256)
skip shape after reduction: (4, 16, 16, 128)
After resizing skip shape: (4, 16, 16, 128)
x shape after upsampling: (4, 16, 16, 256)
skip shape after reduction: (4, 16, 16, 128)
After resizing skip shape: (4, 16, 16, 128)
x_resnet_2: (4, 16, 16, 128)
t: (4, 8, 8, 128)


1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling ResNetBlock.call().

[1mUsing a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/ten

x shape after upsampling: (4, 16, 16, 128)
skip shape after reduction: (4, 32, 32, 64)
After resizing skip shape: (4, 16, 16, 64)
x shape after upsampling: (4, 16, 16, 128)
skip shape after reduction: (4, 32, 32, 64)
After resizing skip shape: (4, 16, 16, 64)
x_resnet_2: (4, 16, 16, 64)
t: (4, 8, 8, 64)
Output shape: (4, 8, 8, 3)


1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Exception encountered when calling ResNetBlock.call().

[1mUsing a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/ten