<a href="https://colab.research.google.com/github/Gregtom3/vossen_ecal_ai/blob/main/notebooks/nb03_shape_condensation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Author: Gregory Matousek

Contact: gregory.matousek@duke.edu

# Tutorial Overview

In this tutorial, we construct a neural network using CNNs to perform a clustering task. This neural network predicts latent space coordinates and a confidence value $\beta$ for each pixel in the image. Using these 3 output values, we minimize the object condensation loss, back-propagating gradients through the network. In the end, by minimizing the object condensation loss, the 3 output values for each pixel will be such that we can observe clustering in the latent space.

For more documentation on object condensation, see here: https://arxiv.org/abs/2002.03605




# Imports

In [1]:
# Import source code from the GitHub to generate images
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
from shape_gen import generate_dataset

# Import source code from the GitHub for the object condensation loss function
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
from nb03_loss_functions import CustomLoss, AttractiveLossMetric, RepulsiveLossMetric, CowardLossMetric, NoiseLossMetric, condensation_loss

--2025-03-14 13:51:21--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4604 (4.5K) [text/plain]
Saving to: ‘shape_gen.py’


2025-03-14 13:51:21 (5.83 MB/s) - ‘shape_gen.py’ saved [4604/4604]

--2025-03-14 13:51:39--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7218 (7.0K) [text/plain]
Saving to: ‘nb03_loss_functions.py’


2025-03-14 13:51:39 (53.7 MB/s) - ‘n

In [11]:
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display,  clear_output
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
from matplotlib import patches

# Data Generation

Below we provide some of the parameters for creating an array of images. By default, we produce 1,000 images in a 32-by-32 grid containing $5-8$ shapes. A few other parameters like `same_color`, `same_shape` and `shape_overlap_max` can be tweaked.

In [3]:
num_images           = 1000
image_width          = 32
image_height         = image_width
min_shapes           = 5
max_shapes           = 8
shape_size_range     = (5,12)

dataset = generate_dataset(num_images=num_images,
                           image_size=(image_width,image_height),
                           min_shapes=min_shapes,
                           max_shapes=max_shapes,
                           shape_size_range=shape_size_range,
                           same_color=False, # True
                           same_shape=None, # ['circle','triangle','square']
                           shape_overlap_max=0.5)

dataset = np.array(dataset)

print(dataset.shape) # (1000, 32, 32, 7)

# --> 0,1,2 = RGB
# --> 3 = x
# --> 4 = y
# --> 5 = unique_shape_id (background == 0)
# --> 6 = shape type
#     --> 0 = noise/empty
#     --> 1 = circle
#     --> 2 = square
#     --> 3 = triangle

# Set RGB of white pixels (1,1,1) to black (0,0,0)
dataset[...,0:3][dataset[...,0:3] == 1] = 0

(1000, 32, 32, 7)


From `dataset.shape`, we see we are dealing with a tensor of dimension [1000,32,32,7]. As indicated by the comment, the first 3 features for each pixel are its RGB. Then, the (x,y) of the pixel is stored as the 4th and 5th feature. **The most crucial feature** to understand is the 6th, the "unique_shape_id".

Consider the first image, first shape. All pixels that correspond to that shape will have a `unique_shape_id` of 1. Then, for the second generated shape, they will have a `unique_shape_id` of 2, and so on. An important distinction is that no two shapes, even across different "events" will have the same `unique_shape_id`. All background pixels have a `unique_shape_id` of 0.

Lastly, the final feature indicates what type of shape the pixel belongs to.


Lets plot some sample event.



In [4]:
def plot_toy(dataset, evtnum, PLOT_TYPE):
    # Check inputs
    assert PLOT_TYPE in ['RGB', 'X', 'Y', 'uid', 'type'], "PLOT_TYPE must be one of ['RGB', 'X', 'Y', 'uid', 'type']"
    assert evtnum < len(dataset), "evtnum must be less than the number of events in the dataset"

    # Copy and process the event data
    data_reshape = deepcopy(dataset[evtnum])
    if PLOT_TYPE == 'RGB':
        image_data = data_reshape[:, :, 0:3]
    elif PLOT_TYPE == 'X':
        image_data = data_reshape[:, :, 3]
    elif PLOT_TYPE == 'Y':
        image_data = data_reshape[:, :, 4]
    elif PLOT_TYPE == 'uid':
        image_data = data_reshape[:, :, 5]
    elif PLOT_TYPE == 'type':
        image_data = data_reshape[:, :, 6]

    # Create the plot
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    im = ax.imshow(image_data)  # Capture the image object
    # Add a colorbar if the plot type is not RGB
    if PLOT_TYPE != 'RGB':
        fig.colorbar(im, ax=ax)
    # Set the title based on the widget inputs
    ax.set_title(f'Event: {evtnum} | Plot Type: {PLOT_TYPE}')
    plt.tight_layout()
    plt.show()

# Update function for the widget
def update_plot(event_num, plot_type):
    plot_toy(dataset, event_num, plot_type)

# Create the interactive widgets
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset)-1,
    step=1,
    description='Event Number:'
)
plot_type_dropdown = widgets.Dropdown(
    options=['RGB', 'X', 'Y', 'uid', 'type'],
    value='RGB',
    description='Plot Type:'
)

# Link the widgets to the update function
interactive_plot = interactive(update_plot, event_num=event_slider, plot_type=plot_type_dropdown)
display(interactive_plot)

interactive(children=(IntSlider(value=0, description='Event Number:', max=999), Dropdown(description='Plot Typ…

**Note** the background pixels are **black** so that their RGB features are (0,0,0) as opposed to (1,1,1). This gives the model an easier time fitting and determining what is and isn't background.

# Creating the CNN

We define two different types of CNN's, both built with the same structure in mind. The first, called SmallCNN, has quite a bit fewer parameters than LargeCNN, which will be more accurate but take longer to fit.


- The model is built with three main convolutional blocks (Block 1, Block 2, and Block 3).
  - Each block processes the input image (or a concatenation of the image with previous block outputs) with several convolutional layers, activations, batch normalization, and max pooling.
  - After each block, the output is upsampled back to the original image size and concatenated with the original input, providing skip connections that preserve spatial details.

- Following the blocks, the network applies two 1x1 convolution layers (acting as fully connected layers) to combine the learned features.

- The model then splits into two branches:
  - One branch produces a 'beta' output using additional convolutional layers with a sigmoid activation.
  - The other branch predicts coordinate information with a simple convolution layer.

- Finally, the outputs from these branches are concatenated and reshaped to yield per-pixel predictions.


### LargeCNN

In [5]:
class LargeCNN(tf.keras.Model):
    def __init__(self):
        super(LargeCNN, self).__init__()

        # Block 1: processes the input (shape: HxWx3)
        self.block1 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 2: takes a concatenation of the input and upsampled block1 output (channels: 3 + 64)
        self.block2 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 3: takes a concatenation of the input and upsampled block2 output (channels: 3 + 64)
        self.block3 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Fully connected layers using 1x1 convolutions
        self.fc1 = tf.keras.Sequential([
            layers.Conv2D(64, kernel_size=1, padding='same', activation=None),
            layers.ELU()
        ])
        self.fc2 = tf.keras.Sequential([
            layers.Conv2D(64, kernel_size=1, padding='same', activation=None)
        ])

        # p_beta branch: produces 1 channel with sigmoid activation
        self.p_beta = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid')
        ])

        # p_ccoords branch: produces 2 channels (for example, representing coordinates)
        self.p_ccoords = layers.Conv2D(2, kernel_size=1, padding='same')


    def call(self, x):
        # x is assumed to have shape (batch, height, width, 3)

        # Block 1
        block1_out = self.block1(x)
        block1_out_up = tf.image.resize(block1_out, size=tf.shape(x)[1:3], method='bilinear')
        concat1 = tf.concat([x, block1_out_up], axis=-1)  # along channels

        # Block 2
        block2_out = self.block2(concat1)
        block2_out_up = tf.image.resize(block2_out, size=tf.shape(x)[1:3], method='bilinear')
        concat2 = tf.concat([x, block2_out_up], axis=-1)

        # Block 3
        block3_out = self.block3(concat2)
        block3_out_up = tf.image.resize(block3_out, size=tf.shape(x)[1:3], method='bilinear')
        concat3 = tf.concat([x, block3_out_up], axis=-1)

        # Fully connected layers (implemented as 1x1 convolutions)
        out = self.fc1(concat3)
        out = self.fc2(out)

        # Compute branches
        beta = self.p_beta(out) * 0.999 + 1e-9
        ccoords = self.p_ccoords(out)

        # Concatenate predictions along the channel dimension
        predictions = tf.concat([beta, ccoords], axis=-1)  # resulting shape: (batch, H, W, 3)
        predictions = tf.reshape(predictions, [-1, image_width*image_height, 3])         # reshape to (batch, H * W, 3)

        return predictions

    def get_config(self):
        config = super(LargeCNN, self).get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

### SmallCNN

In [6]:
class SmallCNN(tf.keras.Model):
    def __init__(self, **kwargs):
        super(SmallCNN, self).__init__(**kwargs)

        # Block 1: processes the input (shape: HxWx3)
        self.block1 = tf.keras.Sequential([
            layers.Conv2D(8, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 2: takes a concatenation of the input and upsampled block1 output (channels: 3 + 64)
        self.block2 = tf.keras.Sequential([
            layers.Conv2D(8, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])


        # Fully connected layers using 1x1 convolutions
        self.fc1 = tf.keras.Sequential([
            layers.Conv2D(32, kernel_size=1, padding='same', activation=None),
            layers.ELU()
        ])
        self.fc2 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=1, padding='same', activation=None)
        ])

        # p_beta branch: produces 1 channel with sigmoid activation
        self.p_beta = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid')
        ])

        # p_ccoords branch: produces 2 channels (for example, representing coordinates)
        self.p_ccoords = layers.Conv2D(2, kernel_size=1, padding='same')


    def call(self, x):
        # x is assumed to have shape (batch, height, width, 3)

        # Block 1
        block1_out = self.block1(x)
        block1_out_up = tf.image.resize(block1_out, size=tf.shape(x)[1:3], method='bilinear')
        concat1 = tf.concat([x, block1_out_up], axis=-1)  # along channels

        # Block 2
        block2_out = self.block2(concat1)
        block2_out_up = tf.image.resize(block2_out, size=tf.shape(x)[1:3], method='bilinear')
        concat2 = tf.concat([x, block2_out_up], axis=-1)

        # Fully connected layers (implemented as 1x1 convolutions)
        out = self.fc1(concat2)
        out = self.fc2(out)

        # Compute branches
        beta = self.p_beta(out) * 0.999 + 1e-9
        ccoords = self.p_ccoords(out)

        # Concatenate predictions along the channel dimension
        predictions = tf.concat([beta, ccoords], axis=-1)  # resulting shape: (batch, H, W, 3)
        predictions = tf.reshape(predictions, [-1, image_width*image_height, 3])         # reshape to (batch, H * W, 3)

        return predictions

    def get_config(self):
        config = super(SmallCNN, self).get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

# Initializing the Model

Next, we create the model by calling either SmallCNN or LargeCNN. There are several hyperparameters we can tweak, such as the number of epochs to train, the batch_size, and the learning_rate.

The hyperparameter `q_min` defines the minimum charge of each point in the latent space. The attraction/repulsion strength of two points is determined by their `q_i * q_j`. If all points in the latent space could have `q_i=0` then the attraction/repulsion loss would become 0, i.e. fully minimized.  To counteract this, a `q_min` is defined.

We then proceed to load in the data by slicing the dataset. Because the output of the CNN is $[N,H\times W,3]$ we match this shape for the y-tensor $[N,H\times W,1]$. This is important, otherwise the `condensation_loss` will throw an error.

Some basic train/test split is performed, and then the model is compiled. By default, we compile using the `adam` optimizer. We use the `CustomLoss` class defined in the GitHub (check imports). This class handles the `y_true` and `y_pred` (from the CNN) and inputs them into functions that call the object condensation loss. We use custom metrics to print out loss components during training.

Lastly, we pass a single event through the model. This builds the model, which allows it to interpret the number of parameters.



In [7]:
#!!! Define which CNN to use !!!
CNN = SmallCNN # or LargeCNN

# Define hyperparameters
# - q_min: Hyperparameter defined in object condensation
#      q_i = "Charge" of point i = arctan2(beta_i) + q_min

epochs = 10
batch_size = 32
learning_rate = 0.001
q_min = 0.1

# Load in the data
X = dataset[...,0:3] # RGB of each pixel
y = dataset[...,5] # unique_shape_id of each pixel

# Reshape 'y' to be [N,H*W,1]
y = y.reshape(y.shape[0], y.shape[1]*y.shape[2], 1)

# Perform train-test splitting
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Load in the CNN model
model = CNN()

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
    loss=CustomLoss(q_min=q_min), # from GitHub
    metrics=[
        AttractiveLossMetric(name="attractive_loss"),
        RepulsiveLossMetric(name="repulsive_loss"),
        CowardLossMetric(name="coward_loss"),
        NoiseLossMetric(name="noise_loss")
    ]
)


# Pass one event through the model initially
# This is done to print out the model summary with the proper shapes
model(X_train[0:1])
model.summary()

# Fitting

Next we fit the model to the input dataset. We define a checkpoint to save the model state at the end of each epoch. We do this to make some pretty plots later on, allowing us to see how the object condensation latent space develops over time.

In [8]:

# Define a checkpoint callback to save the model after each epoch.
checkpoint_callback = ModelCheckpoint(
    filepath='model_epoch_{epoch:02d}.keras',  # Model file name
    save_weights_only=False,
    verbose=1,                              # Verbosity mode.
    save_freq='epoch'                       # Save at the end of every epoch.
)

# Train the model
history = model.fit(
    X_train,
    y_train,
    #validation_data=(X_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
    verbose=1,
    callbacks=[checkpoint_callback]
)

Epoch 1/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - attractive_loss: 0.0304 - coward_loss: 0.4505 - loss: 0.8770 - noise_loss: 0.3204 - repulsive_loss: 0.0758
Epoch 1: saving model to model_epoch_01.keras
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - attractive_loss: 0.0307 - coward_loss: 0.4494 - loss: 0.8710 - noise_loss: 0.3155 - repulsive_loss: 0.0755
Epoch 2/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 981ms/step - attractive_loss: 0.0436 - coward_loss: 0.3023 - loss: 0.4566 - noise_loss: 0.0435 - repulsive_loss: 0.0672
Epoch 2: saving model to model_epoch_02.keras
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 986ms/step - attractive_loss: 0.0435 - coward_loss: 0.3015 - loss: 0.4552 - noise_loss: 0.0432 - repulsive_loss: 0.0670
Epoch 3/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 975ms/step - attractive_loss: 0.0363 - coward_loss: 0.2241 - loss: 0.3407 - noi

# Evaluation

We can evaluate the model performance visually by comparing the initial shape image with its latent 2D image. In other words, we can show where each pixel in the initial image gets mapped to (by the CNN) in a 2D latent space. If the object condensation loss is minimized, what we should see are clusters forming in the latent space.

First, lets evaluate the image data for each of our training epochs.

In [9]:
# --------------------------
# Precompute predictions for all epochs.
# --------------------------
predictions = {}
print("Beginning precomputation (this may take a few minutes).")
for epoch in range(1, epochs + 1):
    model_path = f'model_epoch_{epoch:02d}.keras'
    print(f"Loading and predicting with {model_path} ...")
    loaded_model = load_model(model_path, custom_objects={'SmallCNN': SmallCNN,
                                                          'CustomLoss': CustomLoss,
                                                          'AttractiveLossMetric': AttractiveLossMetric,
                                                          'RepulsiveLossMetric': RepulsiveLossMetric,
                                                          'CowardLossMetric': CowardLossMetric,
                                                          'NoiseLossMetric': NoiseLossMetric})
    predictions[epoch] = loaded_model.predict(X)
print("Precomputation complete.")


Beginning precomputation (this may take a few minutes).
Loading and predicting with model_epoch_01.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 80ms/step
Loading and predicting with model_epoch_02.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 82ms/step
Loading and predicting with model_epoch_03.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 82ms/step
Loading and predicting with model_epoch_04.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 107ms/step
Loading and predicting with model_epoch_05.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 81ms/step
Loading and predicting with model_epoch_06.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 81ms/step
Loading and predicting with model_epoch_07.keras ...
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 112ms/step
Loading and predicting with model_epoch_08.keras ...


In the next code, we show three plots. The left plot shows the image data (shapes), the middle plot shows the latent space output from the CNN, where the marker brightness is correlated with $\beta$. The right plot shows what pixels in the original image are clustered together, depending on the choice of `tB` and `tD`.

There are also a couple customizable sliders we review here.

* Event: Allows us to switch between different input images
* Epoch: Allows us to view the latent space output per epoch
* Show Highest Beta Stars: This checkbox, when clicked, puts a star atop the pixel with the highest output $\beta$ for each shape (including background).
* Cluster: Allows us to select which cluster in the latent space to view in the right plot.
* tB: Modify tB
* tD: Modify tD

To check out the full performance, consider picking a random event and scanning through the epochs. You will see that, overtime, the model learns to take in the input data (which is just RGB) and map it to clustered points in the latent space. At the latest epoch, consider cycling through the different clusters to see the shapes emerge.




In [12]:
# --------------------------
# Define the interactive update function.
# --------------------------
def update_plots(event_num, training_epoch, show_stars, cluster_idx, tD, tB):
    # Retrieve precomputed predictions for the selected epoch.
    y_pred = predictions[training_epoch]

    # Create a figure with 3 subplots side by side.
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    ax_left, ax_middle, ax_right = axs

    # ----- Left Subplot: Toy Image with Scatter and Optional Stars -----
    data_reshape = deepcopy(dataset[event_num])
    image_data = data_reshape[:, :, 0:3]

    # Replace pure black pixels with white.
    image_data[(image_data == [0, 0, 0]).all(axis=2)] = [1, 1, 1]

    im_left = ax_left.imshow(image_data)
    ax_left.set_title(f'Event: {event_num}')

    # Retrieve scatter data from predictions.
    colors = X[event_num][..., 0:3].reshape(-1, 3)
    colors[(colors == [0, 0, 0]).all(axis=1)] = [1, 1, 1]
    beta = y_pred[event_num][..., 0]
    xc = y_pred[event_num][..., 1]
    yc = y_pred[event_num][..., 2]
    id_arr = y[event_num][..., 0]

    unique_ids = np.unique(id_arr)


    # ----- Middle Subplot: Clustering Inference Visualization -----
    # Cluster the scatter points using thresholds tD and tB.
    num_points = beta.shape[0]
    clustered = np.zeros(num_points, dtype=bool)
    clusters_info = []  # Each element will be a dict with cluster center, members, etc.
    sorted_indices = np.argsort(-beta)  # descending order

    for uid in unique_ids:
      indices = (id_arr == uid)
      marker_alpha = [max(b, 0.0005) for b in beta[indices]]
      marker_size = 40
      current_color = colors[indices][0]
      marker_edge = "black" if np.all(current_color == [1, 1, 1]) else "none"
      ax_middle.scatter(xc[indices], yc[indices],
                      c=colors[indices],
                      alpha=marker_alpha,
                      s=marker_size,
                      edgecolor=marker_edge)
      # If checkbox is checked, add a star on the left image for the highest β point in the group.
      if show_stars and np.any(indices):
          idx_in_group = np.argmax(beta[indices])
          overall_idx = np.where(indices)[0][idx_in_group]
          # Convert overall index to pixel coordinates.
          star_x = overall_idx % image_width
          star_y = overall_idx // image_width
          ax_left.scatter(star_x, star_y, marker='*', color='red', s=150,
                          edgecolor='black', linewidth=1, zorder=10)

    for idx in sorted_indices:
        if beta[idx] < tB:
            break
        if clustered[idx]:
            continue
        center_x = xc[idx]
        center_y = yc[idx]
        distances = np.sqrt((xc - center_x)**2 + (yc - center_y)**2)
        members = np.where((distances <= tD) & (~clustered))[0]
        clustered[members] = True
        clusters_info.append({
            'center_idx': idx,
            'members': members,
            'center_x': center_x,
            'center_y': center_y,
            'color': colors[idx]
        })

    # For each cluster, draw a hatched circle around the highest β point.
    for cluster in clusters_info:
        circle = patches.Circle((cluster['center_x'], cluster['center_y']),
                                tD, linewidth=2,
                                edgecolor=cluster['color'],
                                facecolor='none', hatch='//', alpha=0.5)
        ax_middle.add_patch(circle)
        # Optionally, mark the cluster center.
        ax_middle.scatter(cluster['center_x'], cluster['center_y'],
                          color=cluster['color'], s=100, marker='o')
    ax_middle.set_title("Clustering (hatched circles)")
    ax_middle.axis('equal')

    # ----- Right Subplot: Input Image with Cluster Highlight -----
    # Create a copy of the image and set all non-background pixels to black.
    new_image = image_data.copy()
    # Assuming background is white ([1,1,1]); convert non-white pixels to black.
    mask = ~np.all(new_image == [1, 1, 1], axis=-1)
    new_image[mask] = [0, 0, 0]
    # If clusters were computed and the cluster index is valid, highlight that cluster.
    if clusters_info and (cluster_idx < len(clusters_info)):
        selected_cluster = clusters_info[cluster_idx]
        for member in selected_cluster['members']:
            # Convert member index to pixel coordinates.
            px = member % image_width
            py = member // image_width
            new_image[py, px] = [1, 0, 0]  # Red
    ax_right.imshow(new_image)
    ax_right.set_title("Input Image with Cluster Highlight")
    ax_right.axis('off')

    # Set a suptitle reflecting the training epoch used.
    fig.suptitle(f"Predictions from model at training epoch: {training_epoch}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

# --------------------------
# Create interactive widgets.
# --------------------------
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset) - 1,
    step=1,
    description='Event:'
)
epoch_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=epochs,  # Should match the total number of training epochs
    step=1,
    description='Epoch:'
)
show_stars_checkbox = widgets.Checkbox(
    value=False,
    description="Show Highest Beta Stars"
)
cluster_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=10,  # Dummy maximum; actual number of clusters may be fewer.
    step=1,
    description='Cluster:'
)
tD_slider = widgets.FloatSlider(
    value=0.25,
    min=0.0,
    max=1.0,
    step=0.01,
    description='tD:'
)
tB_slider = widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.01,
    description='tB:'
)

# Link widgets to the update function.
interactive_plot = interactive(update_plots,
                               event_num=event_slider,
                               training_epoch=epoch_slider,
                               show_stars=show_stars_checkbox,
                               cluster_idx=cluster_slider,
                               tD=tD_slider,
                               tB=tB_slider)
display(interactive_plot)



interactive(children=(IntSlider(value=0, description='Event:', max=999), IntSlider(value=1, description='Epoch…