<a href="https://colab.research.google.com/github/Gregtom3/vossen_ecal_ai/blob/main/notebooks/nb03_shapeCondensation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial Overview

In this tutorial, we construct a neural network based using CNNs. This neural network predicts latent space coordinates and a confidence value $\beta$ for each pixel in the image. Using these 3 output values, we minimize the object condensation loss, back-propagating gradients through the network. In the end, by minimizing the object condensation loss, the 3 output values for each pixel will be such that we can observe clustering in the latent space.


# Imports

In [None]:
# Import source code from the GitHub to generate images
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
from shape_gen import generate_dataset

# Import source code from the GitHub for the object condensation loss function
!wget https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
from nb03_loss_functions import CustomLoss, AttractiveLossMetric, RepulsiveLossMetric, CowardLossMetric, NoiseLossMetric, condensation_loss

--2025-03-11 20:58:40--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/shape_gen.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4600 (4.5K) [text/plain]
Saving to: ‘shape_gen.py’


2025-03-11 20:58:41 (30.5 MB/s) - ‘shape_gen.py’ saved [4600/4600]

--2025-03-11 20:58:58--  https://raw.githubusercontent.com/Gregtom3/vossen_ecal_ai/main/src/nb03_loss_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7218 (7.0K) [text/plain]
Saving to: ‘nb03_loss_functions.py’


2025-03-11 20:58:58 (56.9 MB/s) - ‘n

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display,  clear_output
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split

# Data Generation

Below we provide some of the parameters for creating an array of images. By default, we produce 1,000 images in a 32-by-32 grid containing $5-8$ shapes. A few other parameters like `same_color`, `same_shape` and `shape_overlap_max` can be tweaked.

In [None]:
num_images           = 1000
image_width          = 32
image_height         = image_width
min_shapes           = 5
max_shapes           = 8
shape_size_range     = (5,12)

dataset = generate_dataset(num_images=num_images,
                           image_size=(image_width,image_height),
                           min_shapes=min_shapes,
                           max_shapes=max_shapes,
                           shape_size_range=shape_size_range,
                           same_color=False, # True
                           same_shape=None, # ['circle','triangle','square']
                           shape_overlap_max=0.5)

dataset = np.array(dataset)

print(dataset.shape) # (1000, 32, 32, 7)

# --> 0,1,2 = RGB
# --> 3 = x
# --> 4 = y
# --> 5 = unique_shape_id (background == 0)
# --> 6 = shape type
#     --> 0 = noise/empty
#     --> 1 = circle
#     --> 2 = square
#     --> 3 = triangle

# Set RGB of white pixels (1,1,1) to black (0,0,0)
dataset[...,0:3][dataset[...,0:3] == 1] = 0

(1000, 32, 32, 7)


From `dataset.shape`, we see we are dealing with a tensor of dimension [1000,32,32,7]. As indicated by the comment, the first 3 features for each pixel are its RGB. Then, the (x,y) of the pixel is stored as the 4th and 5th feature. **The most crucial feature** to understand is the 6th, the "unique_shape_id".

Consider the first image, first shape. All pixels that correspond to that shape will have a `unique_shape_id` of 1. Then, for the second generated shape, they will have a `unique_shape_id` of 2, and so on. An important distinction is that no two shapes, even across different "events" will have the same `unique_shape_id`. All background pixels have a `unique_shape_id` of 0.

Lastly, the final feature indicates what type of shape the pixel belongs to.


Lets plot some sample event.



In [None]:
def plot_toy(dataset, evtnum, PLOT_TYPE):
    # Check inputs
    assert PLOT_TYPE in ['RGB', 'X', 'Y', 'uid', 'type'], "PLOT_TYPE must be one of ['RGB', 'X', 'Y', 'uid', 'type']"
    assert evtnum < len(dataset), "evtnum must be less than the number of events in the dataset"

    # Copy and process the event data
    data_reshape = deepcopy(dataset[evtnum])
    if PLOT_TYPE == 'RGB':
        image_data = data_reshape[:, :, 0:3]
    elif PLOT_TYPE == 'X':
        image_data = data_reshape[:, :, 3]
    elif PLOT_TYPE == 'Y':
        image_data = data_reshape[:, :, 4]
    elif PLOT_TYPE == 'uid':
        image_data = data_reshape[:, :, 5]
    elif PLOT_TYPE == 'type':
        image_data = data_reshape[:, :, 6]

    # Create the plot
    fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    im = ax.imshow(image_data)  # Capture the image object
    # Add a colorbar if the plot type is not RGB
    if PLOT_TYPE != 'RGB':
        fig.colorbar(im, ax=ax)
    # Set the title based on the widget inputs
    ax.set_title(f'Event: {evtnum} | Plot Type: {PLOT_TYPE}')
    plt.tight_layout()
    plt.show()

# Update function for the widget
def update_plot(event_num, plot_type):
    plot_toy(dataset, event_num, plot_type)

# Create the interactive widgets
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset)-1,
    step=1,
    description='Event Number:'
)
plot_type_dropdown = widgets.Dropdown(
    options=['RGB', 'X', 'Y', 'uid', 'type'],
    value='RGB',
    description='Plot Type:'
)

# Link the widgets to the update function
interactive_plot = interactive(update_plot, event_num=event_slider, plot_type=plot_type_dropdown)
display(interactive_plot)

interactive(children=(IntSlider(value=0, description='Event Number:', max=999), Dropdown(description='Plot Typ…

**Note** the background pixels are **black** so that their RGB features are (0,0,0) as opposed to (1,1,1). This gives the model an easier time fitting and determining what is and isn't background.

# Creating the CNN

We define two different types of CNN's, both built with the same structure in mind. The first, called SmallCNN, has quite a bit fewer parameters than LargeCNN, which will be more accurate but take longer to fit.


- The model is built with three main convolutional blocks (Block 1, Block 2, and Block 3).
  - Each block processes the input image (or a concatenation of the image with previous block outputs) with several convolutional layers, activations, batch normalization, and max pooling.
  - After each block, the output is upsampled back to the original image size and concatenated with the original input, providing skip connections that preserve spatial details.

- Following the blocks, the network applies two 1x1 convolution layers (acting as fully connected layers) to combine the learned features.

- The model then splits into two branches:
  - One branch produces a 'beta' output using additional convolutional layers with a sigmoid activation.
  - The other branch predicts coordinate information with a simple convolution layer.

- Finally, the outputs from these branches are concatenated and reshaped to yield per-pixel predictions.


### LargeCNN

In [None]:
class LargeCNN(tf.keras.Model):
    def __init__(self):
        super(CNN, self).__init__()

        # Block 1: processes the input (shape: HxWx3)
        self.block1 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 2: takes a concatenation of the input and upsampled block1 output (channels: 3 + 64)
        self.block2 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 3: takes a concatenation of the input and upsampled block2 output (channels: 3 + 64)
        self.block3 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(64, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Fully connected layers using 1x1 convolutions
        self.fc1 = tf.keras.Sequential([
            layers.Conv2D(64, kernel_size=1, padding='same', activation=None),
            layers.ELU()
        ])
        self.fc2 = tf.keras.Sequential([
            layers.Conv2D(64, kernel_size=1, padding='same', activation=None)
        ])

        # p_beta branch: produces 1 channel with sigmoid activation
        self.p_beta = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid')
        ])

        # p_ccoords branch: produces 2 channels (for example, representing coordinates)
        self.p_ccoords = layers.Conv2D(2, kernel_size=1, padding='same')


    def call(self, x):
        # x is assumed to have shape (batch, height, width, 3)

        # Block 1
        block1_out = self.block1(x)
        block1_out_up = tf.image.resize(block1_out, size=tf.shape(x)[1:3], method='bilinear')
        concat1 = tf.concat([x, block1_out_up], axis=-1)  # along channels

        # Block 2
        block2_out = self.block2(concat1)
        block2_out_up = tf.image.resize(block2_out, size=tf.shape(x)[1:3], method='bilinear')
        concat2 = tf.concat([x, block2_out_up], axis=-1)

        # Block 3
        block3_out = self.block3(concat2)
        block3_out_up = tf.image.resize(block3_out, size=tf.shape(x)[1:3], method='bilinear')
        concat3 = tf.concat([x, block3_out_up], axis=-1)

        # Fully connected layers (implemented as 1x1 convolutions)
        out = self.fc1(concat3)
        out = self.fc2(out)

        # Compute branches
        beta = self.p_beta(out) * 0.999 + 1e-9
        ccoords = self.p_ccoords(out)

        # Concatenate predictions along the channel dimension
        predictions = tf.concat([beta, ccoords], axis=-1)  # resulting shape: (batch, H, W, 3)
        predictions = tf.reshape(predictions, [-1, image_width*image_height, 3])         # reshape to (batch, H * W, 3)

        return predictions

### SmallCNN

In [None]:
class SmallCNN(tf.keras.Model):
    def __init__(self):
        super(CNN, self).__init__()

        # Block 1: processes the input (shape: HxWx3)
        self.block1 = tf.keras.Sequential([
            layers.Conv2D(8, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])

        # Block 2: takes a concatenation of the input and upsampled block1 output (channels: 3 + 64)
        self.block2 = tf.keras.Sequential([
            layers.Conv2D(8, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(32, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.BatchNormalization(momentum=0.6, epsilon=1e-5),
            layers.MaxPooling2D(pool_size=2, strides=2, padding='valid')
        ])


        # Fully connected layers using 1x1 convolutions
        self.fc1 = tf.keras.Sequential([
            layers.Conv2D(32, kernel_size=1, padding='same', activation=None),
            layers.ELU()
        ])
        self.fc2 = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=1, padding='same', activation=None)
        ])

        # p_beta branch: produces 1 channel with sigmoid activation
        self.p_beta = tf.keras.Sequential([
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(16, kernel_size=3, padding='same', activation=None),
            layers.ELU(),
            layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid')
        ])

        # p_ccoords branch: produces 2 channels (for example, representing coordinates)
        self.p_ccoords = layers.Conv2D(2, kernel_size=1, padding='same')


    def call(self, x):
        # x is assumed to have shape (batch, height, width, 3)

        # Block 1
        block1_out = self.block1(x)
        block1_out_up = tf.image.resize(block1_out, size=tf.shape(x)[1:3], method='bilinear')
        concat1 = tf.concat([x, block1_out_up], axis=-1)  # along channels

        # Block 2
        block2_out = self.block2(concat1)
        block2_out_up = tf.image.resize(block2_out, size=tf.shape(x)[1:3], method='bilinear')
        concat2 = tf.concat([x, block2_out_up], axis=-1)

        # Fully connected layers (implemented as 1x1 convolutions)
        out = self.fc1(concat2)
        out = self.fc2(out)

        # Compute branches
        beta = self.p_beta(out) * 0.999 + 1e-9
        ccoords = self.p_ccoords(out)

        # Concatenate predictions along the channel dimension
        predictions = tf.concat([beta, ccoords], axis=-1)  # resulting shape: (batch, H, W, 3)
        predictions = tf.reshape(predictions, [-1, image_width*image_height, 3])         # reshape to (batch, H * W, 3)

        return predictions

# Initializing the Model

In [None]:
#!!! Define which CNN to use !!!
CNN = SmallCNN # or LargeCNN

# Define hyperparameters
# - q_min: Hyperparameter defined in object condensation
#      q_i = "Charge" of point i = arctan2(beta_i) + q_min

epochs = 10
batch_size = 32
learning_rate = 0.001
q_min = 0.1

# Load in the data
X = dataset[...,0:3] # RGB of each pixel
y = dataset[...,5] # unique_shape_id of each pixel

# Reshape 'y' to be [N,H*W,1]
y = y.reshape(y.shape[0], y.shape[1]*y.shape[2], 1)

# Perform train-test splitting
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Load in the CNN model
model = CNN()

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
    loss=CustomLoss(q_min=q_min), # from GitHub
    metrics=[
        AttractiveLossMetric(name="attractive_loss"),
        RepulsiveLossMetric(name="repulsive_loss"),
        CowardLossMetric(name="coward_loss"),
        NoiseLossMetric(name="noise_loss")
    ]
)


# Pass one event through the model initially
# This is done to print out the model summary with the proper shapes
model(X_train[0:1])
model.summary()

# Fitting

In [None]:
model.fit(
    X_train,
    y_train,
    batch_size=batch_size,
    #validation_data=(X_test, y_test),
    epochs=epochs,
    verbose=1
)

Epoch 1/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 991ms/step - attractive_loss: 0.0302 - coward_loss: 0.1594 - loss: 0.2555 - noise_loss: 0.0077 - repulsive_loss: 0.0583
Epoch 2/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 964ms/step - attractive_loss: 0.0250 - coward_loss: 0.1463 - loss: 0.2315 - noise_loss: 0.0073 - repulsive_loss: 0.0529
Epoch 3/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 1s/step - attractive_loss: 0.0213 - coward_loss: 0.1341 - loss: 0.2115 - noise_loss: 0.0065 - repulsive_loss: 0.0497
Epoch 4/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 1s/step - attractive_loss: 0.0188 - coward_loss: 0.1172 - loss: 0.1926 - noise_loss: 0.0067 - repulsive_loss: 0.0500
Epoch 5/10


KeyboardInterrupt: 

# Evaluation

In [None]:
y_pred = model.predict(X)

[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 79ms/step


In [None]:
def update_plots(event_num, plot_type):
    # Create a figure with 2 subplots side by side
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # ----- Left Plot: Toy Image -----
    # Copy and process the event data
    data_reshape = deepcopy(dataset[event_num])
    if plot_type == 'RGB':
        image_data = data_reshape[:, :, 0:3]
    elif plot_type == 'X':
        image_data = data_reshape[:, :, 3]
    elif plot_type == 'Y':
        image_data = data_reshape[:, :, 4]
    elif plot_type == 'uid':
        image_data = data_reshape[:, :, 5]
    elif plot_type == 'type':
        image_data = data_reshape[:, :, 6]

    # Make black pixels white
    image_data[(image_data == [0, 0, 0]).all(axis=2)] = [1, 1, 1]

    im = axs[0].imshow(image_data)
    if plot_type != 'RGB':
        fig.colorbar(im, ax=axs[0])
    axs[0].set_title(f'Event: {event_num} | Plot Type: {plot_type}')

    # ----- Right Plot: Scatter Plot -----
    # Get important arrays for this event
    colors = X[event_num][..., 0:3].reshape(-1, 3)
    # Replace background ([0, 0, 0]) with white ([1, 1, 1])
    colors[(colors == [0, 0, 0]).all(axis=1)] = [1, 1, 1]

    beta = y_pred[event_num][..., 0]
    xc = y_pred[event_num][..., 1]
    yc = y_pred[event_num][..., 2]
    id_arr = y[event_num][..., 0]

    unique_id = np.unique(id_arr)

    for uid in unique_id:
        # Find indices for the current unique id
        indices = (id_arr == uid)
        # Define marker alpha for these points, ensuring a minimum value
        marker_alpha = [max(b, 0.0005) for b in beta[indices]]
        marker_size = 40

        # Assume all points for the same id share the same color: use the first one
        current_color = colors[indices][0]
        marker_edge = "black" if np.all(current_color == [1, 1, 1]) else "none"

        axs[1].scatter(xc[indices], yc[indices],
                       c=colors[indices],
                       alpha=marker_alpha,
                       s=marker_size,
                       edgecolor=marker_edge)

    axs[1].set_title(f'Event: {event_num} Scatter Plot')
    axs[1].set_xlabel("X")
    axs[1].set_ylabel("Y")
    axs[1].axis('equal')

    plt.tight_layout()
    plt.show()

# Create interactive widgets
event_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(dataset) - 1,
    step=1,
    description='Event Number:'
)
plot_type_dropdown = widgets.Dropdown(
    options=['RGB', 'X', 'Y', 'uid', 'type'],
    value='RGB',
    description='Plot Type:'
)

# Link widgets to the update function
interactive_plot = interactive(update_plots, event_num=event_slider, plot_type=plot_type_dropdown)
display(interactive_plot)


interactive(children=(IntSlider(value=0, description='Event Number:', max=999), Dropdown(description='Plot Typ…