<a href="https://colab.research.google.com/github/Gregtom3/vossen_ecal_ai/blob/main/notebooks/nb04_shape_condensation_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial Overview

In this notebook we will be building off the `shapeCondensation` example that used a CNN-only architecture. Here, we will be using a Transformer encoder to learn object condensation variables to perform clustering. One **key takeaway** we will get is the important of a meaningful positional encoding.

If you are new to transformer, but perhaps not so new to machine learning in general, I would highly recommend 3Blue1Brown's YouTube series, starting with https://www.youtube.com/watch?v=eMlx5fFNoYc&t=1396s.


**I would highly encourage the allocation of a GPU. To do so...**
1. Click `Runtime`
2. Click `Change runtime type`
3. Select one of the available GPUs

You will only be able to do this temporarily, since there are usage limits.

**Also, because transformer encoders scale by $O(n^2)$ we reduce the sequence length to 400=20x20 as opposed to 1024=32x32**

# Imports

In [2]:
# Import source code from the GitHub to generate images
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
from shape_gen import generate_dataset

# Import source code from the GitHub for the object condensation loss function
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
from nb03_loss_functions import CustomLoss, AttractiveLossMetric, RepulsiveLossMetric, CowardLossMetric, NoiseLossMetric, condensation_loss

--2025-03-13 19:47:13--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4604 (4.5K) [text/plain]
Saving to: ‘shape_gen.py.1’


2025-03-13 19:47:13 (46.1 MB/s) - ‘shape_gen.py.1’ saved [4604/4604]

--2025-03-13 19:47:13--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7218 (7.0K) [text/plain]
Saving to: ‘nb03_loss_functions.py.1’


2025-03-13 19:47:13 (59.4 MB/s

In [15]:
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display,  clear_output
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
from matplotlib import patches
from ipywidgets import interact

physical_devices = tf.config.list_physical_devices('GPU')
print("GPU:", physical_devices)
print("Num GPUs:", len(physical_devices))

GPU: []
Num GPUs: 0


# Data Generation

Below we provide some of the parameters for creating an array of images. By default, we produce 1,000 images in a 20-by-20 grid containing $3-4$ shapes.

In [4]:
num_images           = 1000
image_width          = 20
image_height         = image_width
min_shapes           = 3
max_shapes           = 4
shape_size_range     = (5,5)

dataset = generate_dataset(num_images=num_images,
                           image_size=(image_width,image_height),
                           min_shapes=min_shapes,
                           max_shapes=max_shapes,
                           shape_size_range=shape_size_range,
                           same_color=False, # True
                           same_shape='square',
                           shape_overlap_max=0.5)

dataset = np.array(dataset)

print(dataset.shape) # (1000, 20, 20, 7)

# --> 0,1,2 = RGB
# --> 3 = x
# --> 4 = y
# --> 5 = unique_shape_id (background == 0)
# --> 6 = shape type
#     --> 0 = noise/empty
#     --> 1 = circle
#     --> 2 = square
#     --> 3 = triangle

# Set RGB of white pixels (1,1,1) to black (0,0,0)
dataset[...,0:3][dataset[...,0:3] == 1] = 0

(1000, 20, 20, 7)


From `dataset.shape`, we see we are dealing with a tensor of dimension [1000,20,20,7]. As indicated by the comment, the first 3 features for each pixel are its RGB. Then, the (x,y) of the pixel is stored as the 4th and 5th feature. **The most crucial feature** to understand is the 6th, the "unique_shape_id".

Consider the first image, first shape. All pixels that correspond to that shape will have a `unique_shape_id` of 1. Then, for the second generated shape, they will have a `unique_shape_id` of 2, and so on. An important distinction is that no two shapes, even across different "events" will have the same `unique_shape_id`. All background pixels have a `unique_shape_id` of 0.

Lastly, the final feature indicates what type of shape the pixel belongs to.


Lets plot some sample event.

In [5]:
def plot_toy(dataset, evtnum, PLOT_TYPE):
    # Check inputs
    assert PLOT_TYPE in ['RGB', 'X', 'Y', 'uid', 'type'], "PLOT_TYPE must be one of ['RGB', 'X', 'Y', 'uid', 'type']"
    assert evtnum < len(dataset), "evtnum must be less than the number of events in the dataset"

    # Copy and process the event data
    data_reshape = deepcopy(dataset[evtnum])
    if PLOT_TYPE == 'RGB':
        image_data = data_reshape[:, :, 0:3]
    elif PLOT_TYPE == 'X':
        image_data = data_reshape[:, :, 3]
    elif PLOT_TYPE == 'Y':
        image_data = data_reshape[:, :, 4]
    elif PLOT_TYPE == 'uid':
        image_data = data_reshape[:, :, 5]
    elif PLOT_TYPE == 'type':
        image_data = data_reshape[:, :, 6]

    # Create the plot
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    im = ax.imshow(image_data)  # Capture the image object
    # Add a colorbar if the plot type is not RGB
    if PLOT_TYPE != 'RGB':
        fig.colorbar(im, ax=ax)
    # Set the title based on the widget inputs
    ax.set_title(f'Event: {evtnum} | Plot Type: {PLOT_TYPE}')
    plt.tight_layout()
    plt.show()

# Update function for the widget
def update_plot(event_num, plot_type):
    plot_toy(dataset, event_num, plot_type)

# Create the interactive widgets
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset)-1,
    step=1,
    description='Event Number:'
)
plot_type_dropdown = widgets.Dropdown(
    options=['RGB', 'X', 'Y', 'uid', 'type'],
    value='RGB',
    description='Plot Type:'
)

# Link the widgets to the update function
interactive_plot = interactive(update_plot, event_num=event_slider, plot_type=plot_type_dropdown)
display(interactive_plot)

interactive(children=(IntSlider(value=0, description='Event Number:', max=999), Dropdown(description='Plot Typ…

**Note** the background pixels are **black** so that their RGB features are (0,0,0) as opposed to (1,1,1). This gives the model an easier time fitting and determining what is and isn't background.

# Creating a Transformer-Encoder

## Summary on Encoders
The input to our full neural network is $[B, H, W, 3]$ for the RGB of each pixel, where $B$ is the `batch_size`. The first step taken by our transformer model is to reshape the input to $[B, H\times W, 3]$ and pass it through a MLP to increase the hidden representation to $[B,H\times W,d_{model}]$ where $d_{model}$ is a hyperparameter.

Transformers are often used in language processing. They usually consist of encoder/decoder pairs. The purpose of an encoder is to take a sequence (for instance, words) and output another sequence of the same length. When we say *a sequence of words* what we mean is some tensor of size $[L,d_{model}]$ where $L$ is the length of the sequence and $d_{model}$ is the feature-length of a word (token). Usually, in language processing, a preprocessing step analyzes an input word and maps it to a vector. In ChatGPT 3, each token is mapped to a vector of length $12,288$. That means even the simplest of "," is encoded as a massively long vector.

The key behind an encoder is that its structure allows for all tokens (words) in a sequence to influence one another. This is identical to the concept of, say, a fully connected graph neural network. Every vertex influences one another at each layer of the graph. We say that each token *attends* to one another in a transformer. This is referred to as *self-attention*.

But there is one critical difference between graph neural networks and encoders: position. In a graph neural network $g(X)$, say you have a fully connected input graph $X=[V,F]$. If you pass this input through $g$, you get some answer $y$. Because the graph is fully connected, you should be able to shuffle along the $V$ dimension (swap the order of vertices) and get the same answer $y$. A fully connected graph will be, in general, ambiguous to ordering.

In the context of language processing, the sentence "X = I ate the burger" and the sentence "Y = The burger ate I" imply different things based on ordering. Without going into the mathematical details, if you passed these tokenized inputs to an encoder, you would get the same *result*. That is to say that an encoder *is also* ambivalent to sequence ordering, unlike CNNs. So how do we rectify this?

## Positional Encoding

The solution is *Positional Encoding*. Say we have the sentence "X = I ate the burger" where $X\in[L,d_{model}]$. Before we pass this input to the `TransformerEncoderLayer` we ADD a positional encoding tensor of the same size $PE\in[L,d_{model}]$. You can literally see us do this in the `build_transformer_model()` code below. To explain why this works, I cite this comment from a Reddit User (https://www.reddit.com/r/MachineLearning/comments/cttefo/comment/exs7d08/). Below, `x` and `y` refer to the vector-embedding of two arbitary words in our sequence.

> In attention, we basically take two word embeddings (x and y), pass one through a Query transformation matrix (Q) and the second through a Key transformation matrix (K), and compare how similar the resulting query and key vectors are by their dot product. So, basically, we want the dot product between Qx and Ky, which we write as:

>(Qx)'(Ky) = x' (Q'Ky). So equivalently we just need to learn one joint Query-Key transformation (Q'K) that transform the secondary inputs y into a new space in which we can compare x.

>By adding positional encodings e and f to x and y, respectively, we essentially change the dot product to

> (Q(x+e))' (K(y+f)) = (Qx+Qe)' (Ky+Kf) = (Qx)' Ky + (Qx)' Kf + (Qe)' Ky + (Qe)' Kf = x' (Q'Ky) + x' (Q'Kf) + e' (Q'Ky) + e' (Q'K f),

> where in addition to the original x' (Q'Ky) term, which asks the question "how much attention should we pay to word x given word y", we also have x' (Q'Kf) + e' (Q'Ky) + e' (Q'K f), which ask the additional questions, "how much attention should we pay to word x given the position f of word y", "how much attention should we pay to y given the position e of word x", and "how much attention should we pay to the position e of word x given the position f of word y".

> Essentially, the learned transformation matrix Q'K with positional encodings has to do all four of these tasks simultaneously. This is the part that may appear inefficient, since intuitively, there should be a trade-off in the ability of Q'K to do four tasks simultaneously and well.

In this work, we showcase two approach to defining a positional encoding. One popular method, which you can read about here https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ , uses a frequency dependent sinusoidal encoding. This feeds the network information about the relative and absolute position of tokens in the sequence.

Another approach is *learned embeddings*. In a 20x20 grid of pixels, we know that pixels 0,1 and 0,20 are both adjacent. A CNN has very little trouble picking this behavior up, so we use a `CNNPositionalEncoder` that takes in a fixed `[H,W,2]` tensor, where the `2` represents the $(x,y)$ of the pixel. With repeated CNN layers, the model can learn the relative positional of tokens in our sequence based on their spatial orientation. The key here is that **we do not use the RGB to motivate a positional encoding**. The output of the `CNNPositionalEncoder`, a tensor of size $[L,d_{model}]\equiv [H\times W,d_{model}]$ represents a learned encoding for each pixel's (L dimension) positional in 2D space, and its relative position to other pixels.

## The Transformer Encoder

The Transformer Encoder (image below) is a compound neural network architecture that processes sequential data with the following steps:

1. **Input Representation:**
   - The input to the encoder is a batch of sequences with shape `[B, H * W, d_model]`, where:
     - `B` is the batch size.
     - `H * W` is the sequence length.
     - `d_model` is the embedding dimension for each token in the sequence.
   - Positional encoding has already been added to the input sequence to capture positional information.

2. **Multi-Head Self-Attention:**
   - The first core operation is multi-head self-attention. It allows the model to focus on different parts of the sequence when processing each token.
   - The input is transformed into three vectors: Query (Q), Key (K), and Value (V), and attention is computed to capture dependencies between tokens.
   - The result of self-attention is a weighted sum of the values, reflecting the importance of other tokens in the sequence.

3. **Add & Norm (Residual Connection):**
   - A residual connection is added to the output of the attention mechanism, followed by layer normalization.
   - This helps stabilize training and prevents vanishing/exploding gradients.

4. **Feed-Forward Neural Network:**
   - A position-wise feed-forward neural network (FFN) is applied to each token independently.
   - The FFN typically consists of two linear layers with a ReLU activation in between, transforming the input to a higher-dimensional space and then back to the original dimension.

5. **Add & Norm (Residual Connection):**
   - Another residual connection is added between the output of the FFN and the input, followed by layer normalization.

6. **Output:**
   - The final output of the encoder is a sequence of hidden states, each with shape `[B, H * W, d_model]`.



![link text](https://kikaben.com/transformers-encoder-decoder/images/encoder-layer-norm.png)


Essentially, the Transformer Encoder replaces the CNN blocks in https://github.com/Gregtom3/vossen_ecal_ai/blob/main/notebooks/nb03_shapeCondensation.ipynb. The purpose of both is to learn some hidden representation for each pixel dependent on its neighbors. In the case of the Transformer model, a neighborhood is only established with the help of the positional encoder.


## Defining Positional Encoders

Both encoders output a tensor of size $[B,H\times W,d_{model}]$.

 The first encoder uses sinusoidal functions and is entirely fixed.

 The second encoder starts with a fixed $[H,W,2]$ tensor that represents the grid. It passes this grid through few CNN layers to form $[H,W,d_{model}]$ at which point it flattens to $[H\times W,d_{model}]$. The model effectively learns $d_{model}$ features per pixel that encode its relative and absolute position.

In [7]:
# --------------------------
# Standard Positional Encoding Layer
# --------------------------
class PositionalEncoding(layers.Layer):
    def __init__(self, max_seq_len, d_model, **kwargs):
        super(PositionalEncoding, self).__init__(**kwargs)
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.pos_encoding = self.positional_encoding(max_seq_len, d_model)

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
        return pos * angle_rates

    def positional_encoding(self, max_seq_len, d_model):
        angle_rads = self.get_angles(np.arange(max_seq_len)[:, np.newaxis],
                                     np.arange(d_model)[np.newaxis, :],
                                     d_model)
        # Apply sin to even indices; cos to odd indices.
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
        pos_encoding = angle_rads[np.newaxis, ...]  # shape (1, max_seq_len, d_model)
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x):
        seq_len = tf.shape(x)[1]
        return x + self.pos_encoding[:, :seq_len, :]

    def get_config(self):
        config = super(PositionalEncoding, self).get_config()
        config.update({
            'max_seq_len': self.max_seq_len,
            'd_model': self.d_model,
        })
        return config

# --------------------------
# CNN Positional Encoding Layer
# --------------------------
class CNNPositionalEncoding(layers.Layer):
    def __init__(self, img_height, img_width, d_model, **kwargs):
        super(CNNPositionalEncoding, self).__init__(**kwargs)
        self.img_height = img_height
        self.img_width = img_width
        self.d_model = d_model
        self.conv_layers = tf.keras.Sequential([
            layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'),
            layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'),
            layers.Conv2D(d_model, kernel_size=3, padding='same')  # No activation here.
        ])

    def call(self, inputs):
        # We pass 'inputs' to have the CNNPositionalEncoding appear in the model summary
        # We do not use 'inputs' to influence the CNN weights
        grid_x = tf.range(self.img_height, dtype=tf.float32)
        grid_y = tf.range(self.img_width, dtype=tf.float32)
        gx, gy = tf.meshgrid(grid_x, grid_y, indexing='ij')
        grid = tf.stack([gx, gy], axis=-1)  # (H, W, 2)
        grid = tf.expand_dims(grid, axis=0)  # (1, H, W, 2)
        pos_enc = self.conv_layers(grid)     # (1, H, W, d_model)
        # Reshape to (1, H*W, d_model)
        pos_enc = tf.reshape(pos_enc, (1, self.img_height * self.img_width, self.d_model))
        return pos_enc

    def get_config(self):
        config = super(CNNPositionalEncoding, self).get_config()
        config.update({
            'img_height': self.img_height,
            'img_width': self.img_width,
            'd_model': self.d_model,
        })
        return config

## Defining Transformer Encoder

We build the encoder from scratch. We explicitly see the Multiheaded Attention layer, skip connections, and the feed forward DNN.

In [6]:
# --------------------------
# Transformer Encoder Block
# --------------------------
class TransformerEncoder(layers.Layer):
    def __init__(self, d_model, num_heads, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.d_model = d_model
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate

        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)

        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation='relu'),
            layers.Dense(d_model)
        ])
        self.dropout2 = layers.Dropout(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, x, training=False):
        attn_output = self.mha(x, x, x)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(x + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.norm2(out1 + ffn_output)

    def get_config(self):
        config = super(TransformerEncoder, self).get_config()
        config.update({
            'd_model': self.d_model,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'dropout_rate': self.dropout_rate,
        })
        return config

## Defining Full Network

In [9]:


# --------------------------
# Build Transformer Model
# --------------------------
def build_transformer_model(input_shape, d_model, num_layers, ff_dim, num_heads, dropout_rate=0.1,
                            positional_encoding_type="standard"):
    """
    Builds a transformer model.

    Args:
        input_shape: tuple, e.g. (H, W, 3)
        d_model: int, model dimension.
        num_layers: int, number of transformer encoder layers.
        ff_dim: int, feed-forward network dimension.
        num_heads: int, number of attention heads.
        dropout_rate: float, dropout rate.
        positional_encoding_type: str, either "standard" for sine–cosine or "cnn" for CNN-based.

    Returns:
        A tf.keras.Model.
    """
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    assert d_model % 4 == 0, "d_model must be divisible by 4"
    assert positional_encoding_type in ["standard", "cnn"], "positional_encoding_type must be either 'standard' or 'cnn'"
    assert num_layers > 0, "num_layers must be greater than 0"
    assert ff_dim > 0, "ff_dim must be greater than 0"
    assert num_heads > 0, "num_heads must be greater than 0"
    assert dropout_rate >= 0 and dropout_rate <= 1, "dropout_rate must be between 0 and 1"
    assert input_shape is not None, "input_shape must be provided"
    assert len(input_shape) == 3, "input_shape must be a tuple of length 3"


    inputs = tf.keras.Input(shape=input_shape)  # (H, W, 3)
    H, W, _ = input_shape
    num_patches = H * W

    # Reshape x to [B, H*W, 3]
    x = layers.Reshape((num_patches, 3))(inputs)

    # Pass X through MLP to produce [B, H*W, d_model]
    x = layers.Dense(d_model)(x)

    # Apply dropout
    x = layers.Dropout(dropout_rate)(x)

    if positional_encoding_type == "cnn":
        # Use CNN-based positional encoding.
        cnn_pos_enc_layer = CNNPositionalEncoding(img_height=H, img_width=W, d_model=d_model, name="cnn_pos_encoding")
        pos_enc = cnn_pos_enc_layer(x)
        x = x + pos_enc
    else:
        # Use standard sine-cosine positional encoding.
        x = PositionalEncoding(max_seq_len=num_patches, d_model=d_model)(x)

    # Apply N transformer encoder layers.
    for _ in range(num_layers):
        x = TransformerEncoder(d_model, num_heads, ff_dim, dropout_rate)(x)

    # Map transformer output with a Dense layer.
    x = layers.Dense(d_model)(x)

    # Two output heads:
    # Head 1: produces shape [None, H*W, 2]
    head1 = layers.Dense(int(d_model/2))(x)
    head1 = layers.Dense(int(d_model/4))(head1)
    head1 = layers.Dense(2)(head1)
    # Head 2: produces shape [None, H*W, 1] with a final sigmoid activation.
    head2 = layers.Dense(int(d_model/2))(x)
    head2 = layers.Dense(int(d_model/4))(head2)
    head2 = layers.Dense(1, activation='sigmoid')(head2)

    # Concatenate the two heads to get shape [None, H*W, 3]
    outputs = layers.Concatenate(axis=-1)([head2, head1])

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# Initializing the Model

Feel free to try both `positional_encoding_type`'s to see the effect of using a `CNNPositionalEncoder` over a standard `SinusoidalEncoder`. I find that both work, but the CNN method learns quicker. Again, using a transformer encoder as a feature extractor for this shape clustering problem is a bit overkill.

In [28]:

# Define hyperparameters for training
epochs = 25
batch_size = 32
learning_rate = 0.001

# Define hyperparameters for transformer encoder model
input_shape = (image_width, image_height, 3)
d_model = 32
num_layers = 4
ff_dim = 64
num_heads = 2
dropout_rate = 0.01
positional_encoding_type = "cnn" # "cnn" or "standard"

# Define hyperparameter for object condensation
q_min = 0.1

# Load in the data
X = dataset[...,0:3] # RGB of each pixel
y = dataset[...,5] # unique_shape_id of each pixel

# Reshape 'y' to be [N,H*W,1]
y = y.reshape(y.shape[0], y.shape[1]*y.shape[2], 1)

# Perform train-test splitting
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Load in the transformer model
model = build_transformer_model(input_shape, d_model, num_layers, ff_dim, num_heads, dropout_rate, positional_encoding_type)

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
    loss=CustomLoss(q_min=q_min), # from GitHub
    metrics=[
        AttractiveLossMetric(name="attractive_loss"),
        RepulsiveLossMetric(name="repulsive_loss"),
        CowardLossMetric(name="coward_loss"),
        NoiseLossMetric(name="noise_loss")
    ]
)


# Pass one event through the model initially
# This is done to print out the model summary with the proper shapes
model(X_train[0:1])
model.summary()

# Fitting

In [29]:

# Define a checkpoint callback to save the model after each epoch.
checkpoint_callback = ModelCheckpoint(
    filepath='model_epoch_{epoch:02d}.keras',  # Model file name
    save_weights_only=False,
    verbose=1,                              # Verbosity mode.
    save_freq='epoch'                       # Save at the end of every epoch.
)

# Train the model
history = model.fit(
    X_train,
    y_train,
    # validation_data=(X_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
    verbose=1,
    callbacks=[checkpoint_callback]
)

Epoch 1/25
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - attractive_loss: 0.0364 - coward_loss: 0.8276 - loss: 1.0183 - noise_loss: 0.1420 - repulsive_loss: 0.0122
Epoch 1: saving model to model_epoch_01.keras
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 1s/step - attractive_loss: 0.0355 - coward_loss: 0.8264 - loss: 1.0166 - noise_loss: 0.1425 - repulsive_loss: 0.0122
Epoch 2/25
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - attractive_loss: 0.0359 - coward_loss: 0.5411 - loss: 0.7800 - noise_loss: 0.1818 - repulsive_loss: 0.0212
Epoch 2: saving model to model_epoch_02.keras
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - attractive_loss: 0.0359 - coward_loss: 0.5399 - loss: 0.7776 - noise_loss: 0.1802 - repulsive_loss: 0.0216
Epoch 3/25
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - attractive_loss: 0.0567 - coward_loss: 0.3094 - loss: 0.5164 - noise_loss: 

# Evaluation

We can evaluate the model performance visually by comparing the initial shape image with its latent 2D image. In other words, we can show where each pixel in the initial image gets mapped to (by the network) in a 2D latent space. If the object condensation loss is minimized, what we should see are clusters forming in the latent space.

First, lets evaluate the image data for each of our training epochs.

In [30]:
# --------------------------
# Precompute predictions for all epochs.
# --------------------------
predictions = {}
print("Beginning precomputation (this may take a few minutes).")
for epoch in range(1, epochs+1):
    model_path = f'model_epoch_{epoch:02d}.keras'
    print(f"Loading and predicting with {model_path} ...")
    loaded_model = load_model(model_path, custom_objects={'CNNPositionalEncoding': CNNPositionalEncoding,
                                                          'PositionalEncoding': PositionalEncoding,
                                                          'TransformerEncoder': TransformerEncoder,
                                                          'build_transformer_model': build_transformer_model,
                                                          'CustomLoss': CustomLoss,
                                                          'AttractiveLossMetric': AttractiveLossMetric,
                                                          'RepulsiveLossMetric': RepulsiveLossMetric,
                                                          'CowardLossMetric': CowardLossMetric,
                                                          'NoiseLossMetric': NoiseLossMetric})
    predictions[epoch] = loaded_model.predict(X)
print("Precomputation complete.")


Beginning precomputation (this may take a few minutes).
Loading and predicting with model_epoch_01.keras ...




[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 519ms/step
Loading and predicting with model_epoch_02.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 480ms/step
Loading and predicting with model_epoch_03.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 475ms/step
Loading and predicting with model_epoch_04.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 469ms/step
Loading and predicting with model_epoch_05.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 529ms/step
Loading and predicting with model_epoch_06.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 485ms/step
Loading and predicting with model_epoch_07.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 444ms/step
Loading and predicting with model_epoch_08.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 476ms/step
Loading and predictin

In the next code, we show three plots. The left plot shows the image data (shapes), the middle plot shows the latent space output from the network, where the marker brightness is correlated with $\beta$. The right plot shows what pixels in the original image are clustered together, depending on the choice of `tB` and `tD`.

There are also a couple customizable sliders we review here.

* Event: Allows us to switch between different input images
* Epoch: Allows us to view the latent space output per epoch
* Show Highest Beta Stars: This checkbox, when clicked, puts a star atop the pixel with the highest output $\beta$ for each shape (including background).
* Cluster: Allows us to select which cluster in the latent space to view in the right plot.
* tB: Modify tB
* tD: Modify tD

To check out the full performance, consider picking a random event and scanning through the epochs. You will see that, overtime, the model learns to take in the input data (which is just RGB) and map it to clustered points in the latent space. At the latest epoch, consider cycling through the different clusters to see the shapes emerge.




In [31]:
# --------------------------
# Define the interactive update function.
# --------------------------
def update_plots(event_num, training_epoch, show_stars, cluster_idx, tD, tB):
    # Retrieve precomputed predictions for the selected epoch.
    y_pred = predictions[training_epoch]

    # Create a figure with 3 subplots side by side.
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    ax_left, ax_middle, ax_right = axs

    # ----- Left Subplot: Toy Image with Scatter and Optional Stars -----
    data_reshape = deepcopy(dataset[event_num])
    image_data = data_reshape[:, :, 0:3]

    # Replace pure black pixels with white.
    image_data[(image_data == [0, 0, 0]).all(axis=2)] = [1, 1, 1]

    im_left = ax_left.imshow(image_data)
    ax_left.set_title(f'Event: {event_num}')

    # Retrieve scatter data from predictions.
    colors = X[event_num][..., 0:3].reshape(-1, 3)
    colors[(colors == [0, 0, 0]).all(axis=1)] = [1, 1, 1]
    beta = y_pred[event_num][..., 0]
    xc = y_pred[event_num][..., 1]
    yc = y_pred[event_num][..., 2]
    id_arr = y[event_num][..., 0]

    unique_ids = np.unique(id_arr)


    # ----- Middle Subplot: Clustering Inference Visualization -----
    # Cluster the scatter points using thresholds tD and tB.
    num_points = beta.shape[0]
    clustered = np.zeros(num_points, dtype=bool)
    clusters_info = []  # Each element will be a dict with cluster center, members, etc.
    sorted_indices = np.argsort(-beta)  # descending order

    for uid in unique_ids:
      indices = (id_arr == uid)
      marker_alpha = [max(b, 0.0005) for b in beta[indices]]
      marker_size = 40
      current_color = colors[indices][0]
      marker_edge = "black" if np.all(current_color == [1, 1, 1]) else "none"
      ax_middle.scatter(xc[indices], yc[indices],
                      c=colors[indices],
                      alpha=marker_alpha,
                      s=marker_size,
                      edgecolor=marker_edge)
      # If checkbox is checked, add a star on the left image for the highest β point in the group.
      if show_stars and np.any(indices):
          idx_in_group = np.argmax(beta[indices])
          overall_idx = np.where(indices)[0][idx_in_group]
          # Convert overall index to pixel coordinates.
          star_x = overall_idx % image_width
          star_y = overall_idx // image_width
          ax_left.scatter(star_x, star_y, marker='*', color='red', s=150,
                          edgecolor='black', linewidth=1, zorder=10)

    for idx in sorted_indices:
        if beta[idx] < tB:
            break
        if clustered[idx]:
            continue
        center_x = xc[idx]
        center_y = yc[idx]
        distances = np.sqrt((xc - center_x)**2 + (yc - center_y)**2)
        members = np.where((distances <= tD) & (~clustered))[0]
        clustered[members] = True
        clusters_info.append({
            'center_idx': idx,
            'members': members,
            'center_x': center_x,
            'center_y': center_y,
            'color': colors[idx]
        })

    # For each cluster, draw a hatched circle around the highest β point.
    for cluster in clusters_info:
        circle = patches.Circle((cluster['center_x'], cluster['center_y']),
                                tD, linewidth=2,
                                edgecolor=cluster['color'],
                                facecolor='none', hatch='//', alpha=0.5)
        ax_middle.add_patch(circle)
    ax_middle.set_title("Clustering (hatched circles)")
    ax_middle.axis('equal')

    # ----- Right Subplot: Input Image with Cluster Highlight -----
    # Create a copy of the image and set all non-background pixels to black.
    new_image = image_data.copy()
    # Assuming background is white ([1,1,1]); convert non-white pixels to black.
    mask = ~np.all(new_image == [1, 1, 1], axis=-1)
    new_image[mask] = [0, 0, 0]
    # If clusters were computed and the cluster index is valid, highlight that cluster.
    if clusters_info and (cluster_idx < len(clusters_info)):
        selected_cluster = clusters_info[cluster_idx]
        for member in selected_cluster['members']:
            # Convert member index to pixel coordinates.
            px = member % image_width
            py = member // image_width
            new_image[py, px] = [1, 0, 0]  # Red
    ax_right.imshow(new_image)
    ax_right.set_title("Input Image with Cluster Highlight")
    ax_right.axis('off')

    # Set a suptitle reflecting the training epoch used.
    fig.suptitle(f"Predictions from model at training epoch: {training_epoch}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# --------------------------
# Create interactive widgets.
# --------------------------
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset) - 1,
    step=1,
    description='Event:'
)
epoch_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=epochs,  # Should match the total number of training epochs
    step=1,
    description='Epoch:'
)
show_stars_checkbox = widgets.Checkbox(
    value=False,
    description="Show Highest Beta Stars"
)
cluster_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=10,  # Dummy maximum; actual number of clusters may be fewer.
    step=1,
    description='Cluster:'
)
tD_slider = widgets.FloatSlider(
    value=0.25,
    min=0.0,
    max=1.0,
    step=0.01,
    description='tD:'
)
tB_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.01,
    description='tB:'
)

# Link widgets to the update function.
interactive_plot = interactive(update_plots,
                               event_num=event_slider,
                               training_epoch=epoch_slider,
                               show_stars=show_stars_checkbox,
                               cluster_idx=cluster_slider,
                               tD=tD_slider,
                               tB=tB_slider)
display(interactive_plot)



interactive(children=(IntSlider(value=0, description='Event:', max=999), IntSlider(value=1, description='Epoch…

You may notice that the transformer based model does not perform as well as the CNN only model in https://github.com/Gregtom3/vossen_ecal_ai/blob/main/notebooks/nb03_shape_condensation.ipynb. That is partly because there is still some more tuning to be done, partly because of the limited training sample size. Using a transformer encoder for this task is a bit overkill, afterall.

# Visualizing the Black Box

Based on the `model.summary()` before fitting, we see the several different named components that make up the full network. There are `reshape`'s, `EncoderLayer`'s, and of course the `PositionalEncoder`. Since the input tensor to many of these is of the form `[400,k]` we plot the output of each of the layers below. The rows correspond to different pixels in the image, and the columns represent different features for the given pixel.

If you trained the model using sinusoidal positional encodings, the `PostionalEncoding` output should look very wave-like. The `CNNPositionalEncoding` on the other hand looks a bit more abstract.

In [16]:
# List the names of the layers (from "reshape_5" onward) you want to visualize.
layer_names = [layer.name for layer in model.layers]

# Build an intermediate model that outputs all the intermediate activations.
intermediate_model = tf.keras.Model(
    inputs=model.input,
    outputs=[model.get_layer(name).output for name in layer_names]
)

def show_intermediate(event_idx):
    # Select one sample from your data (X is assumed to be defined; shape: (num_events, 20, 20, 3)).
    sample = X[event_idx:event_idx+1]  # shape: (1, 20, 20, 3)

    # Get the outputs from each of the intermediate layers.
    outputs = intermediate_model.predict(sample)

    num_layers = len(outputs)
    # Arrange the plots in a grid. Here we use 3 rows x 5 columns since we have 15 outputs.
    fig, axes = plt.subplots(3, 5, figsize=(20, 12))
    axes = axes.flatten()

    for i, (ax, output) in enumerate(zip(axes, outputs)):
        # Remove the batch dimension; output shape becomes (400, channels).
        activation = output[0]
        im = ax.imshow(activation, aspect='auto', interpolation='nearest')
        ax.set_title(layer_names[i])
        ax.set_xlabel("Feature Dimension")
        ax.set_ylabel("Sequence Length")
        fig.colorbar(im, ax=ax)

    # Remove any unused subplots if the number of layers is less than grid size.
    for j in range(num_layers, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

# Create an interactive slider for event/sample index.
# Adjust the max value as X.shape[0]-1 (number of events in your dataset).
interact(show_intermediate, event_idx=widgets.IntSlider(min=0, max=X.shape[0]-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='event_idx', max=999), Output()), _dom_classes=('widget-i…