# Multimodality In Robotics
## Predicting a continuous action when multiple correct actions exist

This notebook explores different approaches to handling multimodality in robotics.
It is an addition to the video below. Please watch the video to understand the context.

In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo("6oZe_tKE3YA", width=640, height=360)

Github repo: https://github.com/IliaLarchenko/robotics_multimodality_explained

Kaggle notebook: https://www.kaggle.com/code/ilialar/multimodality-explained

## Imports and Helper functions

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from scipy.stats import lognorm

import ipywidgets as widgets
from IPython.display import display, clear_output

sns.set_theme(style="whitegrid")

# --- Data Generation ---
NUM_EXPERT_SAMPLES = 10000
N_TEST_SAMPLES = 1000

DEFAULT_DATA_PARAMS = {
    's_min': -1.2,
    's_max': 1.2,
    'safe_boundary': 1.0,
    'target_mean_delta': 0.2,
    'lognorm_sigma': 1.0
}
SAFE_BOUNDARY = DEFAULT_DATA_PARAMS['safe_boundary']

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model training utils

HIDDEN_DIMS = (32, 32)
OUTPUT_DIM = 1
LEARNING_RATE = 1e-3
EPOCHS = 100
BATCH_SIZE = 125

In [61]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
class MLP(nn.Module):
    """ A simple MLP with configurable layers """
    def __init__(self, input_dim, output_dim=OUTPUT_DIM, hidden_dims=HIDDEN_DIMS):
        super().__init__()
        layers = []
        last_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            last_dim = hidden_dim
        layers.append(nn.Linear(last_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

def train_one_epoch(model, train_loader, optimizer, loss_fn):
    model.train()
    epoch_train_loss = 0.0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = loss_fn(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item() * batch_X.size(0)
    epoch_train_loss /= len(train_loader.dataset)
    return model

In [13]:
def plot_predictions(s_pred_input=None, x_pred_output=None, test_data=None, title="Data/Model Predictions", model_label="Predicted"):
    """ 
    Plots:
    1. Scatter plot of predictions (s vs x) overlaid with test data (if provided).
       If predictions are None, only plots test data.
    2. Overall density distribution comparison (predicted x vs test x).
       If predictions are None, only plots test data distribution.
    3. Conditional density distribution for s < 0 (if test_data provided).
    4. Conditional density distribution for s > 0 (if test_data provided).
    
    Args:
        s_pred_input: The input s values used for prediction (1D array) or None.
        x_pred_output: The corresponding predicted x values (1D array) or None.
        test_data: DataFrame with 's' and 'x' columns for ground truth/data points.
        title: Base title for the plots.
        model_label: Label for the predicted points/distribution.
    """

    # --- Fixed plotting params---
    PLOT_FIG_SIZE_WIDE = (12, 5)
    PLOT_ALPHA = 0.3
    PLOT_POINT_SIZE = 10

    if test_data is None and s_pred_input is None:
        print("Warning: No data provided to plot_predictions.")
        return
        
    # Check if we are plotting predictions or just data
    plot_preds = s_pred_input is not None and x_pred_output is not None
    if plot_preds:
        s_pred_input = np.asarray(s_pred_input).flatten()
        x_pred_output = np.asarray(x_pred_output).flatten()
        if s_pred_input.shape != x_pred_output.shape:
            raise ValueError("s_pred_input and x_pred_output must have the same shape.")
    else: # Adjust title if only plotting data
        title = title.replace("Model Predictions", "Data")
        model_label = "" # No model label needed

    fig, axes = plt.subplots(2, 2, figsize=(PLOT_FIG_SIZE_WIDE[0] * 1.5, PLOT_FIG_SIZE_WIDE[1] * 1.8))
    axes = axes.flatten()
    ax1, ax2, ax3, ax4 = axes[0], axes[1], axes[2], axes[3]
    
    bw_adjustment = 0.1

    # --- Determine overall y-limits from all available valid data ---
    all_y_data = []
    if test_data is not None:
        all_y_data.append(test_data['x'])
    if plot_preds:
        all_y_data.append(x_pred_output)
    
    ymin, ymax = (-SAFE_BOUNDARY * 1.2, SAFE_BOUNDARY * 1.2) # Default limits
    if all_y_data:
        valid_y = np.concatenate(all_y_data)
        valid_y = valid_y[~np.isnan(valid_y)]
        if len(valid_y) > 0:
            ymin = min(np.min(valid_y), -SAFE_BOUNDARY) - 0.2
            ymax = max(np.max(valid_y), SAFE_BOUNDARY) + 0.2

    # --- Subplot 1: Scatter Plot (s vs x) ---
    if test_data is not None:
        ax1.scatter(test_data['s'], test_data['x'], alpha=PLOT_ALPHA, 
                    s=PLOT_POINT_SIZE, label='Data Points', color='blue') # Changed label
    if plot_preds:
        ax1.scatter(s_pred_input, x_pred_output, alpha=PLOT_ALPHA + 0.1, s=PLOT_POINT_SIZE,
                    label=model_label, color='orange') 
    ax1.axhline(-SAFE_BOUNDARY, color='red', linestyle='--', label=f'Boundary (+/- {SAFE_BOUNDARY:.1f})')
    ax1.axhline(SAFE_BOUNDARY, color='red', linestyle='--')
    ax1.set_title(f'{title} - Scatter Plot')
    ax1.set_xlabel('State(s)')
    ax1.set_ylabel('Action (x)')
    ax1.grid(True)
    ax1.legend()
    ax1.set_ylim(ymin, ymax)

    # --- Subplot 2: Overall Density Plot (Distribution of x) ---
    if test_data is not None:
        sns.kdeplot(test_data['x'], ax=ax2, color='blue', label='Data Distribution', 
                    fill=True, alpha=PLOT_ALPHA, bw_adjust=bw_adjustment) # Changed label
    if plot_preds:
        valid_x_pred_output = x_pred_output[~np.isnan(x_pred_output)]
        if len(valid_x_pred_output) > 0:
            sns.kdeplot(valid_x_pred_output, ax=ax2, color='orange', label=model_label, 
                        fill=True, alpha=PLOT_ALPHA + 0.1, bw_adjust=bw_adjustment)
        else:
             ax2.text(0.5, 0.5, 'No valid preds', ha='center', va='center', transform=ax2.transAxes)
    ax2.set_title(f'{title} - Overall Output Dist.')
    ax2.set_xlabel('Action (x)')
    ax2.set_ylabel('Density')
    ax2.grid(True, linestyle=':')
    ax2.legend()
    ax2.set_xlim(ymin, ymax) # Use consistent limits

    # Conditional plots only make sense if we have test_data with s
    if test_data is not None:
        # Use test_data['s'] for conditioning, even when plotting predictions
        s_for_conditioning = test_data['s'].values
        
        # --- Subplot 3: Conditional Density Plot (s < 0) ---
        neg_s_indices = np.where(s_for_conditioning < 0)[0]
        test_data_neg = test_data.iloc[neg_s_indices]
        pred_output_neg = x_pred_output[neg_s_indices] if plot_preds else None
        
        if not test_data_neg.empty:
            sns.kdeplot(test_data_neg['x'], ax=ax3, color='blue', label='Data (s<0)', 
                        fill=True, alpha=PLOT_ALPHA, bw_adjust=bw_adjustment)
        if plot_preds:
            valid_pred_neg = pred_output_neg[~np.isnan(pred_output_neg)]
            if len(valid_pred_neg) > 0:
                sns.kdeplot(valid_pred_neg, ax=ax3, color='orange', label=f'{model_label} (s<0)', 
                            fill=True, alpha=PLOT_ALPHA + 0.1, bw_adjust=bw_adjustment)
            else:
                 # Only add text if there were supposed to be predictions
                 ax3.text(0.5, 0.5, 'No valid preds for s<0', ha='center', va='center', transform=ax3.transAxes)
        ax3.set_title(f'{title} - Output Dist. (s < 0)')
        ax3.set_xlabel('Action (x)')
        ax3.set_ylabel('Density')
        ax3.grid(True, linestyle=':')
        ax3.legend()
        ax3.set_xlim(ymin, ymax) 

        # --- Subplot 4: Conditional Density Plot (s > 0) ---
        pos_s_indices = np.where(s_for_conditioning > 0)[0]
        test_data_pos = test_data.iloc[pos_s_indices]
        pred_output_pos = x_pred_output[pos_s_indices] if plot_preds else None

        if not test_data_pos.empty:
            sns.kdeplot(test_data_pos['x'], ax=ax4, color='blue', label='Data (s>0)', 
                        fill=True, alpha=PLOT_ALPHA, bw_adjust=bw_adjustment)
        if plot_preds:
            valid_pred_pos = pred_output_pos[~np.isnan(pred_output_pos)]
            if len(valid_pred_pos) > 0:
                sns.kdeplot(valid_pred_pos, ax=ax4, color='orange', label=f'{model_label} (s>0)', 
                            fill=True, alpha=PLOT_ALPHA + 0.1, bw_adjust=bw_adjustment)
            else:
                 # Only add text if there were supposed to be predictions
                 ax4.text(0.5, 0.5, 'No valid preds for s>0', ha='center', va='center', transform=ax4.transAxes)
        ax4.set_title(f'{title} - Output Dist. (s > 0)')
        ax4.set_xlabel('Action (x)')
        ax4.set_ylabel('Density')
        ax4.grid(True, linestyle=':')
        ax4.legend()
        ax4.set_xlim(ymin, ymax)
    else: # Hide axes 3 and 4 if no test_data to condition on
        ax3.axis('off')
        ax4.axis('off')

    # --- Final Touches ---
    fig.tight_layout()
    plt.show()

In [4]:
# Illustration of the diffusion process

from PIL import Image

def add_diffusion_noise(x0, b_t):
    noise = np.random.randn(*x0.shape)
    x_t = np.sqrt(1- b_t) * x0 + np.sqrt(b_t) * noise
    return np.clip(x_t, 0, 1)

def get_noisy_steps(img_path, T):
# Load the original image
    original_img = Image.open(img_path).convert("RGB")
    original_np = np.array(original_img) / 255.0  # Normalize

    b_t = np.linspace(0.1, 0.8, T)  # noise schedule
    noisy_steps = [original_np]
    for b in b_t:
        noisy_steps.append(add_diffusion_noise(noisy_steps[-1], b))

    return noisy_steps  

def plot_noisy_steps(noisy_steps, reverse=False, title="Denoising diffusion model training"):
    T = len(noisy_steps) - 1
    # Plotting
    fig, axes = plt.subplots(1, T + 1, figsize=(18, 5))
    titles = ["Original image"] + [""] * (len(noisy_steps) - 2) +  ["Pure noise"]
    if reverse:
        titles = titles[::-1]
        noisy_steps = list(noisy_steps[::-1])

    for i, (ax, img) in enumerate(zip(axes, noisy_steps)):
        ax.imshow(img)
        ax.set_title(titles[i], fontsize=13)
        ax.axis("off")

    arrow = "→" if reverse else "←"
    for i in range(1, len(axes)):
        fig.text(0.20 * i, 0.07, arrow, fontsize=20, ha='center')

    # Add text under the arrows
    fig.text(0.5, 0.01, title, fontsize=13, ha='center')

    plt.tight_layout()
    plt.show()

In [5]:
def plot_evolution_distributions(intermediate_steps):
    """ Plots distributions at different steps of a generative process. Uses config. """

    num_steps_to_plot = len(intermediate_steps)
    cols = min(5, num_steps_to_plot) # Adjust columns for potentially more steps
    rows = (num_steps_to_plot + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3.5, rows * 3), sharex=True, sharey=False)
    axes = axes.flatten()

    all_xt_values = np.concatenate(intermediate_steps)

    x_min, x_max = np.percentile(all_xt_values, [1, 99])
    x_padding = max((x_max - x_min) * 0.1, 0.1)
    x_lims = (x_min - x_padding, x_max + x_padding)

    # Plot steps in logical order (t=0 to T or T down to 0 depending on process)
    for i, step_data in enumerate(intermediate_steps):
        ax = axes[i]
        sns.histplot(step_data, ax=ax, bins=50, stat='density')
        step_label = i

        ax.set_title(f'Step {step_label}/{num_steps_to_plot - 1}')
        ax.set_xlabel('x value')
        ax.set_xlim(x_lims)
        ax.grid(True, linestyle=':')

    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    fig.suptitle(f'Step by step change in the data distribution')
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

## Multimodality

A simple example of multimodal distributions.

In [None]:
from scipy.stats import norm, t, lognorm, laplace, triang, beta, gamma, logistic, weibull_min

# Common x-axis
x = np.linspace(-5, 5, 1000)

def plot_dist(ax, y, title):
    ax.plot(x, y, lw=5, color='orange')
    ax.set_title(title, fontweight='bold')
    ax.grid(True)
    ax.tick_params(labelsize=10)

def plot_all(dists, title):
    fig, axs = plt.subplots(3, 3, figsize=(15, 10))
    for ax, (name, y) in zip(axs.flat, dists):
        plot_dist(ax, y, name)
    fig.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

# === Unimodal distributions ===
unimodal = [
    ("Normal (mean=0, std=1)", norm.pdf(x, 0, 1)),
    ("Student's t (df=3)", t.pdf(x, df=3)),
    ("Log-Normal", lognorm.pdf(x, s=0.5, scale=np.exp(0))),
    ("Laplace", laplace.pdf(x, 0, 1)),
    ("Triangular", triang.pdf(x, c=0.5, loc=-3, scale=6)),
    ("Beta (a=2, b=5) scaled", beta.pdf((x + 3)/6, a=2, b=5) * (1/6)),
    ("Gamma (a=2)", gamma.pdf(x + 5, a=2)),
    ("Logistic", logistic.pdf(x)),
    ("Weibull (c=2)", weibull_min.pdf(x + 5, c=2))
]

plot_all(unimodal, "Unimodal Distributions")


In [None]:
# === Multimodal distributions ===
multimodal = [
    ("Mixture of 2 Gaussians (equal)", 0.5 * norm.pdf(x, -2, 0.5) + 0.5 * norm.pdf(x, 2, 0.5)),
    ("Mixture of 2 Gaussians (unequal)", 0.3 * norm.pdf(x, -2, 0.5) + 0.7 * norm.pdf(x, 2, 1)),
    ("Mixture of 3 Gaussians", 0.2 * norm.pdf(x, -3, 0.7) + 0.5 * norm.pdf(x, 1, 0.5) + 0.3 * norm.pdf(x, 2, 0.3)),
    ("Mixture of 4 Gaussians", sum([0.25 * norm.pdf(x, mu, 0.3) for mu in [-3, -1, 1, 3]])),
    ("W-shaped Beta (2,5)+(5,2)", 0.5 * beta.pdf((x + 3)/6, 2, 5) * (1/6) + 0.5 * beta.pdf((x + 3)/6, 5, 2) * (1/6)),
    ("Mixture with far peaks", 0.5 * norm.pdf(x, -4, 0.5) + 0.5 * norm.pdf(x, 4, 0.5)),
    ("Trimodal mixed shapes", 0.4 * norm.pdf(x, -3, 0.6) + 0.4 * laplace.pdf(x, 0, 0.5) + 0.2 * norm.pdf(x, 3, 1)),
    ("Beta(0.5, 0.5) scaled", beta.pdf((x + 3)/6, 0.5, 0.5) * (1/6)),
    ("Mixture of 5 peaks", sum([0.2 * norm.pdf(x, mu, 0.25) for mu in [-4, -2, 0, 2, 4]])),
]

plot_all(multimodal, "Multimodal Distributions")

## Problem visualization

### Interactive Visualization

Use the sliders and checkboxes to visualize the robot (`s`) and the chosen action (`x`).  
- Adjust the initial position (`s`) of the robot.
- Toggle visibility of the obstacle, robot, and action squares.
- Observe collision detection status


In [None]:
# Creates an interactive visualization of the problem

# --- Configuration ---
SQUARE_SIZE = 1.0
OBSTACLE_POS = 0.0

# Y positions for elements
Y_OBSTACLE = 1
Y_ACTION = 1
Y_ROBOT = -1
PLOT_XLIM = (-2, 2)
PLOT_YLIM = (-2, 2)

# --- Plotting Functions ---

def draw_square(ax, center_x, center_y, color=None, facecolor=None, fill=True, edgecolor='black', label=None):
    """Draws a square centered at (center_x, center_y)."""
    half_size = SQUARE_SIZE / 2
    rect = plt.Rectangle(
        (center_x - half_size, center_y - half_size), # Use center_y
        SQUARE_SIZE,
        SQUARE_SIZE,
        facecolor=facecolor if facecolor else color, # Use facecolor if provided
        fill=fill,
        edgecolor=edgecolor,
        linewidth=2 if not fill else 1,
        label=label
    )
    ax.add_patch(rect)

def check_collision(action_x):
    """Checks if the action collides with the obstacle."""
    obstacle_left = OBSTACLE_POS - SQUARE_SIZE / 2
    obstacle_right = OBSTACLE_POS + SQUARE_SIZE / 2
    action_left = action_x - SQUARE_SIZE / 2
    action_right = action_x + SQUARE_SIZE / 2

    return max(obstacle_left, action_left) < min(obstacle_right, action_right)

# --- Interactive Setup ---

# Widgets
slider_robot_s = widgets.FloatSlider(value=0.0, min=-1.5, max=1.5, step=0.1, description='Robot s:')
slider_action_x = widgets.FloatSlider(value=1.05, min=-1.5, max=1.5, step=0.1, description='Action x:')
checkbox_show_robot = widgets.Checkbox(value=True, description='Show Robot (s)')
checkbox_show_obstacle = widgets.Checkbox(value=False, description='Show Obstacle')
checkbox_show_action = widgets.Checkbox(value=False, description='Show Action (x)')
status_label = widgets.Label(value="")

# Output widget to hold the plot
output_plot = widgets.Output()

def update_plot(*args):
    """Clears and redraws the plot based on widget values."""
    with output_plot:
        clear_output(wait=True)

        fig, ax = plt.subplots(figsize=(6, 6))

        s = slider_robot_s.value
        x_action = slider_action_x.value
        show_robot = checkbox_show_robot.value
        show_obstacle = checkbox_show_obstacle.value
        show_action = checkbox_show_action.value

        # Draw elements based on visibility flags and new Y positions
        if show_obstacle:
            draw_square(ax, OBSTACLE_POS, Y_OBSTACLE, color='red', fill=True, label='Obstacle')
        if show_robot:
            draw_square(ax, s, Y_ROBOT, color='green', fill=True, label='Robot (s)')

            # Calculate arrow start point (top-middle of robot)
            arrow_start_x = s
            arrow_start_y = Y_ROBOT + SQUARE_SIZE / 2

            # Determine arrow endpoint and calculate dx, dy
            if show_action:
                # Point to bottom-middle of action square
                arrow_end_x = x_action
                arrow_end_y = Y_ACTION - SQUARE_SIZE / 2
                arrow_dx = arrow_end_x - arrow_start_x
                arrow_dy = arrow_end_y - arrow_start_y
            else:
                # Point straight up with length 1
                arrow_dx = 0
                arrow_dy = 1.0 # Fixed length when action is hidden

            # Draw the arrow
            ax.arrow(arrow_start_x, arrow_start_y, arrow_dx, arrow_dy,
                     head_width=0.1, head_length=0.15, fc='green', ec='green', length_includes_head=True)

        if show_action:
            # Use facecolor='none' for outline, specify edgecolor explicitly
            draw_square(ax, x_action, Y_ACTION, facecolor='none', fill=False, edgecolor='green', label='Action (x)')

            # Check collision and update status ONLY if action is shown
            is_collision = check_collision(x_action)
            status = "Fail!" if is_collision else "Success!"
            status_label.value = f"Status: {status}"
            ax.text(0, PLOT_YLIM[1] * 0.9, status, ha='center', va='top', fontsize=12, color='red' if is_collision else 'darkgreen')
        else:
             status_label.value = ""


        # Plot Formatting
        ax.set_xlim(PLOT_XLIM)
        ax.set_ylim(PLOT_YLIM)
        ax.set_yticks([])
        ax.set_xlabel("Horizontal Position")
        ax.set_title("Interactive Environment Demo")
        # Add horizontal lines for reference
        ax.axhline(Y_OBSTACLE, color='grey', linestyle='--', linewidth=0.5)
        ax.axhline(Y_ROBOT, color='grey', linestyle='--', linewidth=0.5)
        ax.grid(True, axis='x', linestyle=':', linewidth=0.5)
        plt.show()


# Observe changes in widgets and call update_plot
slider_robot_s.observe(update_plot, names='value')
slider_action_x.observe(update_plot, names='value')
checkbox_show_robot.observe(update_plot, names='value')
checkbox_show_obstacle.observe(update_plot, names='value')
checkbox_show_action.observe(update_plot, names='value')

# Arrange widgets vertically
controls = widgets.VBox([
    slider_robot_s,
    slider_action_x,
    checkbox_show_robot,
    checkbox_show_obstacle,
    checkbox_show_action,
    status_label
])

# Arrange plot and controls horizontally
app_layout = widgets.HBox([output_plot, controls])

# Display the combined layout
display(app_layout)

# Initial plot draw
update_plot()

print("Interactive demo created. Run this script in a Jupyter Notebook or similar environment.") 

## Generate Datasets

Generate the clean expert dataset and a noisy version.

In [None]:
def generate_expert_data(num_samples=NUM_EXPERT_SAMPLES,
                         data_params=DEFAULT_DATA_PARAMS):
    """
    Generates simulated expert demonstration data for the 1D jumper scenario.
    Uses parameters from config by default.
    """
    p = data_params
    s_min, s_max = p['s_min'], p['s_max']
    safe_boundary = p['safe_boundary']
    target_mean_delta = p['target_mean_delta']
    lognorm_sigma = p['lognorm_sigma']

    s = np.random.uniform(s_min, s_max, num_samples)
    prob_right = (s + 1) / 2
    prob_right = np.clip(prob_right, 0.0, 1.0)
    jump_right = np.random.rand(num_samples) < prob_right

    lognorm_mu = np.log(target_mean_delta) - (lognorm_sigma**2 / 2)
    delta = lognorm.rvs(s=lognorm_sigma, scale=np.exp(lognorm_mu), size=num_samples)

    x = np.zeros(num_samples)
    x[jump_right] = safe_boundary + delta[jump_right]
    x[~jump_right] = -safe_boundary - delta[~jump_right]

    # Clip x values between -3 and 3
    x = np.clip(x, -3*safe_boundary, 3*safe_boundary)

    df = pd.DataFrame({'s': s, 'x': x})
    return df


expert_data = generate_expert_data(num_samples=NUM_EXPERT_SAMPLES)
print(f"Generated {len(expert_data)} expert samples for training.")

test_data = generate_expert_data(num_samples=N_TEST_SAMPLES)
print(f"Generated {len(test_data)} samples for testing/plotting.")

In [None]:
expert_data

In [None]:
test_data

In [None]:
plot_predictions(test_data=expert_data, title="Training Data Distribution")

In [None]:
plot_predictions(test_data=test_data, title="Test Data Distribution")

## Simple MSE regression

In [None]:
X_train = torch.tensor(expert_data[['s']].values, dtype=torch.float32).to(DEVICE)
y_train = torch.tensor(expert_data[['x']].values, dtype=torch.float32).to(DEVICE)

train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

model_mse = MLP(input_dim=1, output_dim=1).to(DEVICE)

optimizer = optim.Adam(model_mse.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

print(f"Training model for {EPOCHS} epochs...")
for _ in range(EPOCHS):
    model_mse = train_one_epoch(model_mse, train_loader, optimizer, loss_fn)

model_mse.eval()

In [None]:
test_s_tensor = torch.tensor(test_data[['s']].values, dtype=torch.float32).to(DEVICE)

with torch.no_grad():
    x_pred_tensor = model_mse(test_s_tensor)
x_pred_np = x_pred_tensor.cpu().numpy().flatten()

x_pred_np[:10]

In [None]:
print(f"Success rate (beyond {SAFE_BOUNDARY} boundary): {(np.abs(x_pred_np) > SAFE_BOUNDARY).mean()}")

In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="BC (MSE) Predictions vs. Test Data",
    model_label="MLP BC (MSE)"
    )

If we have a set of numbers $x_1, x_2, \dots, x_N$ and want to find the prediction $p$ that minimizes the Mean Squared Error (MSE), we define the loss as:

$$
\text{MSE}(p) = \frac{1}{N} \sum_{i=1}^{N} (x_i - p)^2
$$

The minimum is achieved when $p$ equals the **mean** of the values:

$$
p = \frac{1}{N} \sum_{i=1}^{N} x_i = \text{mean}(x)
$$

So, the best constant prediction under MSE is simply the average of the data.

In [None]:
demo_data = expert_data[(expert_data['s'] > 0.6) & (expert_data['s'] < 0.7)]

print(f"Mean: {demo_data['x'].mean()}")

plt.figure(figsize=(10,6))
plt.hist(demo_data['x'], bins=100, density=True, alpha=0.7)
plt.axvline(demo_data['x'].mean(), color='g', linestyle='--', label=f'Mean={demo_data["x"].mean():.2f}')
plt.xlabel('x')
plt.ylabel('Density')
plt.title('Distribution of x values')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
demo_data = expert_data#[(expert_data['s'] > 0.6) & (expert_data['s'] < 0.7)]

print(f"Mean: {demo_data['x'].mean()}")
print(f"Median: {demo_data['x'].median()}")

plt.figure(figsize=(10,6))
plt.hist(demo_data['x'], bins=100, density=True, alpha=0.7)
plt.axvline(demo_data['x'].mean(), color='g', linestyle='--', label=f'Mean={demo_data["x"].mean():.2f}')
plt.axvline(demo_data['x'].median(), color='r', linestyle='--', label=f'Median={demo_data["x"].median():.2f}')
plt.xlabel('x')
plt.ylabel('Density')
plt.title('Distribution of x values')
plt.legend()
plt.grid(True)
plt.show()

If we want to minimize the Mean Absolute Error (MAE) over a set of values $x_1, x_2, \dots, x_N$, we define the loss as:

$$
\text{MAE}(p) = \frac{1}{N} \sum_{i=1}^{N} |x_i - p|
$$

Unlike MSE, this loss is minimized not by the mean, but by the **median** of the values:

$$
p = \text{median}(x)
$$

So, under MAE, the best constant prediction is the median of the data — which makes it more robust to outliers compared to MSE.

## MAE regression

In [None]:
model_mae = MLP(input_dim=1, output_dim=1).to(DEVICE)

optimizer = optim.Adam(model_mae.parameters(), lr=LEARNING_RATE)
loss_fn = nn.L1Loss()

print(f"Training model for {EPOCHS} epochs...")
for _ in range(EPOCHS):
    model_mae = train_one_epoch(model_mae, train_loader, optimizer, loss_fn)

model_mae.eval()

In [None]:
with torch.no_grad():
    x_pred_tensor = model_mae(test_s_tensor)
x_pred_np = x_pred_tensor.cpu().numpy().flatten()

x_pred_np[:10]

In [None]:
print(f"Success rate (beyond {SAFE_BOUNDARY} boundary): {(np.abs(x_pred_np) > SAFE_BOUNDARY).mean()}")

In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="MAE trained model Predictions vs. Test Data",
    model_label="MLP BC (MAE)"
    )

### Can we use even better loss?

In [None]:
# Plot the 6 functions that can be used as a loss function
x = np.linspace(-4, 4, 1000)

functions = [
    (r"$x^2$", lambda x: x**2),
    (r"$|x|$", lambda x: np.abs(x)),
    (r"$1 - e^{-x^2}$", lambda x: 1 - np.exp(-x**2)),
    (r"$\min(|x|, 2)$", lambda x: np.minimum(np.abs(x), 2)),
    (r"$\log(1 + x^2)$", lambda x: np.log1p(x**2)),
    (r"$1 - e^{-|x|}$", lambda x: 1 - np.exp(-np.abs(x))),
]

fig, axs = plt.subplots(2, 3, figsize=(12, 6))

for ax, (title, func) in zip(axs.flat, functions):
    ax.plot(x, func(x), lw=4, color='orange')
    ax.set_title(title, fontsize=14)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()


## Target tokenization and classification

In [None]:
model_clf = MLP(input_dim=1, output_dim=1).to(DEVICE)

optimizer = optim.Adam(model_clf.parameters(), lr=LEARNING_RATE)
loss_fn = nn.BCEWithLogitsLoss()

y_train = torch.tensor(expert_data[['x']].values > 0, dtype=torch.float32).to(DEVICE)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Training model for {EPOCHS} epochs...")
for _ in range(EPOCHS):
    model_clf = train_one_epoch(model_clf, train_loader, optimizer, loss_fn)

model_clf.eval()

In [None]:
mean_left = expert_data[expert_data['x'] <= 0]['x'].mean()
mean_right = expert_data[expert_data['x'] > 0]['x'].mean()
print(f"Mean left: {mean_left}, Mean right: {mean_right}")

In [None]:
with torch.no_grad():
    x_cls_pred = model_clf(test_s_tensor)
    x_pred_tensor = (x_cls_pred > 0) * mean_right + (x_cls_pred <= 0) * mean_left
    
x_pred_np = x_pred_tensor.cpu().numpy().flatten()

x_pred_np[:10]

In [None]:
print(f"Success rate (beyond {SAFE_BOUNDARY} boundary): {(np.abs(x_pred_np) > SAFE_BOUNDARY).mean()}")

In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="Classification model Predictions vs. Test Data",
    model_label="Classification"
    )

In [None]:
with torch.no_grad():
    x_cls_pred = torch.sigmoid(model_clf(test_s_tensor))
    probs = torch.rand_like(x_cls_pred)
    x_pred_tensor = (probs < x_cls_pred) * mean_right + (probs >= x_cls_pred) * mean_left
    
x_pred_np = x_pred_tensor.cpu().numpy().flatten()

x_pred_np[:10]

### Classification with Random Sampling

To introduce action diversity, instead of always taking the class with the highest probability, we sample the action according to the predicted probability distribution.  
This can be useful if you want the robot to occasionally pick less likely but still valid actions, improving behavior diversity.


In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="Classification model with random sampling vs. Test Data",
    model_label="Classification"
    )

## Diffusion Policy

In [None]:
noisy_steps = get_noisy_steps("media/robot1.png", 4)
plot_noisy_steps(noisy_steps)

In [None]:
noisy_steps = get_noisy_steps("media/robot2.png", 4)
plot_noisy_steps(noisy_steps, True, "Denoising diffusion model inference")

In [None]:
original_np = noisy_steps[0]
noise_img = noisy_steps[-1]

# Plot side-by-side: noise -> image
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(noise_img)
axes[0].set_title("Random noise", fontsize=18)
axes[0].axis("off")

axes[1].imshow(original_np)
axes[1].set_title("Reasonable image", fontsize=18)
axes[1].axis("off")

# Add arrow
fig.text(0.5, 0.5, '→', fontsize=30, ha='center')

plt.tight_layout()
plt.show()

In [None]:
# Prepare data
expert_data_x = expert_data['x'].values

# Plot both distributions
fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)

# Left: standard normal
sns.kdeplot(np.random.randn(100000), fill=True, color='blue', ax=axes[0], label='Random Noise')
axes[0].set_title("Random noise", fontsize=18)
axes[0].set_xlabel("Action (x)")
axes[0].set_ylabel("Density")
axes[0].legend()

# Right: expert data
sns.kdeplot(expert_data['x'], fill=True, color='blue', ax=axes[1], label='Data Distribution', bw_adjust=0.05)
axes[1].set_title("Reasonable actions", fontsize=18)
axes[1].set_xlabel("Action (x)")
axes[1].legend()

# Add arrow between plots
fig.text(0.53, 0.5, '→', fontsize=30, ha='center')

plt.tight_layout()
plt.show()

In the forward process we are adding noise to the data at each step.

There are different ways to add noise. For simplicity, I will use this example:

$$
x_t = \sqrt{a_t} \, x_{t-1} + \sqrt{1 - a_t} \, \varepsilon
$$

where $a_t$ is predefined by us, and $\varepsilon$ is sampled from a normal distribution with mean 0 and standard deviation equal to the data std.



In practice, instead of modeling it step by step, we model how the noisy signal looks after $t$ steps:

$$
x_t = \sqrt{\bar{a}_t} \, x_0 + \sqrt{1 - \bar{a}_t} \, \varepsilon
$$

where $\bar{a}_t$ is the cumulative product of $a_t$.

$$
\bar{a}_t = \prod_{s=1}^{t} a_s
$$


In [None]:
# Noise scheduling is very important for the model to work well.

T = 49 # Number of diffusion steps

# Noise variance at each step
b_t = torch.linspace(0.0001, 0.2, T) # I use pretty high noise as we have a very low number of steps

# Per steps signal retention coefficients 
a_t = 1 - b_t

# Cumulative product of signal retention coefficients
a_cumprod_t = torch.cumprod(a_t, dim=0)

print(a_t)
print(a_cumprod_t)

In [None]:
model_diffusion = MLP(input_dim=3, output_dim=1).to(DEVICE)
optimizer = optim.Adam(model_diffusion.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

x_tensor = torch.tensor(expert_data['x'], device=DEVICE)
s_tensor = torch.tensor(expert_data['s'], device=DEVICE)

print(f"Training model for {EPOCHS} epochs...")
for _ in range(EPOCHS):
    # For transparency I will regenerate the data at each epoch here so it is easier to follow the code

    # Generate a random step from 1 to T for each sample
    # We will use only one step per sample per epoch for simplicity
    t = torch.randint(0, T, (len(expert_data),), device=DEVICE)

    # Generate extra noise
    e = torch.randn(len(expert_data), device=DEVICE)

    # Generate x_t
    x_t = torch.sqrt(a_cumprod_t[t]) * x_tensor + torch.sqrt(1 - a_cumprod_t[t]) * e
    
    # The input for the model will be s, t, and x_t
    X_train = torch.stack([s_tensor, t.float()/T, x_t], dim=1).float()

    # And the target will be e
    y_train = e.unsqueeze(-1)

    # Create a dataset and a dataloader
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Now let't train 1 bath
    model_diffusion = train_one_epoch(model_diffusion, train_loader, optimizer, loss_fn)

model_diffusion.eval()

Now when we have a trained model, we can sample random noise from the normal distribution and restore it by iteratively applying the model $T$ times.  
Each time we use the model to predict the noise and subtract it from the current state:

$$
e_t = \text{model\_diffusion}(s, t, x_t)
$$

There are different ways to reconstruct the action using the model prediction.  
I will show the one with the simplest math: DDIM with $\eta = 0$.

Let's look at the formula for the forward process:

$$
x_t = \sqrt{\bar{a}_t} \, x_0 + \sqrt{1 - \bar{a}_t} \, \varepsilon \tag{1}
$$

For $x_{t-1}$, we can rewrite (1) as:

$$
x_{t-1} = \sqrt{\bar{a}_{t-1}} \, x_0 + \sqrt{1 - \bar{a}_{t-1}} \, \varepsilon \tag{2}
$$

Now let's estimate $x_0$ from (1):

$$
x_0 = \frac{1}{\sqrt{\bar{a}_t}} \left( x_t - \sqrt{1 - \bar{a}_t} \, \varepsilon \right) \tag{3}
$$

Now we can plug (3) into (2), assuming that $\varepsilon$ is the same for both equations:

$$
x_{t-1} = \sqrt{\bar{a}_{t-1}} \, \frac{1}{\sqrt{\bar{a}_t}} \left( x_t - \sqrt{1 - \bar{a}_t} \, \varepsilon \right) + \sqrt{1 - \bar{a}_{t-1}} \, \varepsilon
$$

Simplifying:
- $\bar{a}_{t-1} / \bar{a}_t = 1/a_t$
- Rearranging terms, we get:

$$
x_{t-1} = \frac{1}{\sqrt{a_t}} \left( x_t - \left( \sqrt{1 - \bar{a}_t} - \sqrt{a_t - \bar{a}_t} \right) \varepsilon \right)
$$


In [None]:
predictions = []

noise = np.random.randn(len(test_s_tensor))
predictions.append(noise)
x_t_tensor = torch.tensor(noise, dtype=torch.float32).to(DEVICE).unsqueeze(-1)

for t in range(T-1, -1, -1):
    t_tensor = torch.ones_like(test_s_tensor) * t/T
    X_train = torch.cat([test_s_tensor, t_tensor, x_t_tensor], dim=1)

    with torch.no_grad():
        e_pred = model_diffusion(X_train)

    x_t_tensor = (1/ a_t[t]) ** 0.5 * (x_t_tensor - ((1 - a_cumprod_t[t])** 0.5 - (a_t[t] - a_cumprod_t[t]) ** 0.5 ) * e_pred)

    predictions.append(x_t_tensor.cpu().numpy().flatten())

x_pred_np = np.array(predictions[-1])
x_pred_np[:10]

In [None]:
print(f"Success rate (beyond {SAFE_BOUNDARY} boundary): {(np.abs(x_pred_np) > SAFE_BOUNDARY).mean()}")

In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="Diffusion model Predictions vs. Test Data",
    model_label="Diffusion model"
    )

In [None]:
plot_evolution_distributions(predictions)

## Flow Matching Policy

We can draw an intuitive (but not strictly correct) analogy between Diffusion and Flow Matching using the relationship between summation and integration:

$$
\sum_{n=1}^{N} f\left(\frac{n}{N}\right) \cdot \frac{1}{N} \approx \int_{0}^{1} f(x) \, dx
$$

$$
Diffusion ≈ Flow Matching
$$
This is just a conceptual analogy:  
- Diffusion discretizes time into fixed steps and predicts denoising step-by-step.  
- Flow Matching treats time as continuous and directly learns the vector field.

So, just like Riemann sums approximate integrals, Flow Matching can be seen as a continuous-time simplification of Diffusion.

In [None]:
from IPython.display import HTML
from matplotlib import animation

# Load and preprocess image
img_path = "media/robot2.png"
x0_img = Image.open(img_path).convert("RGB").resize((128, 128))
x0_np = np.array(x0_img).astype(np.float32) / 255.0

# Generate random noise with fixed seed
np.random.seed(42)
e_np = np.random.randn(*x0_np.shape).astype(np.float32)

# Create animation
fig, ax = plt.subplots(figsize=(3, 3))
img_display = ax.imshow(np.zeros_like(x0_np), animated=True)
ax.axis("off")

def update(frame_idx):
    t = np.linspace(0, 1, 50)[frame_idx]
    x_t = t * x0_np + (1 - t) * e_np  # Simplified interpolation
    x_t = np.clip(x_t, 0, 1)
    img_display.set_array(x_t)
    return [img_display]

ani = animation.FuncAnimation(fig, update, frames=50, interval=50, blit=True)

In [None]:
HTML(ani.to_jshtml())

Flow Matching feels like Diffusion stripped of a lot of complexity.

We don't need $a_t$ and $b_t$ anymore.

During training, for each sample we just sample two random numbers:
- $\varepsilon$ — noise from a normal distribution
- $t$ — time from $0$ to $1$ (uniform distribution)

Then we use a simple linear interpolation between the actual $x$ and the noise $\varepsilon$ to get the noisy input $x_t$:

$$
x_t = x \cdot t + \varepsilon \cdot (1 - t)
$$

Finally, similar to the Diffusion model, we use $s$, $t$, and $x_t$ as input to the model, but the training target is simply:

$$
model(s, t, x_t) \approx x - \varepsilon
$$


In [None]:
model_flow_matching = MLP(input_dim=3, output_dim=1).to(DEVICE)
optimizer = optim.Adam(model_flow_matching.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()

x_tensor = torch.tensor(expert_data['x'], dtype=torch.float32, device=DEVICE)
s_tensor = torch.tensor(expert_data['s'], dtype=torch.float32, device=DEVICE)

print(f"Training model for {EPOCHS} epochs...")
for _ in range(EPOCHS):
    # For transparency I will regenerate the data at each epoch here so it is easier to follow the code

    # Generate a random time from 0 to 1 for each sample
    # We will use only one step per sample per epoch for simplicity
    t = torch.rand(len(expert_data), dtype=torch.float32, device=DEVICE)

    # Generate noise
    e = torch.randn(len(expert_data), dtype=torch.float32, device=DEVICE)

    # Generate x_t
    x_t = x_tensor * t + e * (1 - t)

    # The input for the model will be s, t, and x_t
    X_train = torch.stack([s_tensor,t,x_t], dim=1).float()

    # And the target will be e
    y_train = (x_tensor - e).unsqueeze(-1)

    # Create a dataset and a dataloader
    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Now let't train 1 bath
    model_flow_matching = train_one_epoch(model_flow_matching, train_loader, optimizer, loss_fn)

model_flow_matching.eval()

Reconstruction is also straightforward in case of Flow Matching:

$$
x_{t+\tau} = x_t + \tau \cdot \text{model\_flow\_matching}(s, t, x_{t})
$$

Where $\tau$ is a small step that we can select based on the accuracy / speed trade-off.

In [None]:
predictions = []

noise = np.random.randn(len(test_s_tensor))
predictions.append(noise)
x_t_tensor = torch.tensor(noise, dtype=torch.float32).to(DEVICE).unsqueeze(-1)

T = 49
for t in range(0, T):
    t_tensor = torch.ones_like(test_s_tensor) * t/T
    X_train = torch.cat([test_s_tensor, t_tensor, x_t_tensor], dim=1)

    with torch.no_grad():
        e_pred = model_flow_matching(X_train)

    x_t_tensor = x_t_tensor + e_pred * 1/T

    predictions.append(x_t_tensor.cpu().numpy().flatten())

x_pred_np = np.array(predictions[-1])
x_pred_np[:10]

In [None]:
print(f"Success rate (beyond {SAFE_BOUNDARY} boundary): {(np.abs(x_pred_np) > SAFE_BOUNDARY).mean()}")

In [None]:
plot_predictions(
    test_data['s'], x_pred_np, test_data=test_data, 
    title="Flow matching model Predictions vs. Test Data",
    model_label="Flow matching model"
    )

In [None]:
plot_evolution_distributions(predictions)

## Reference

- ACT: Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware  
  https://arxiv.org/abs/2304.13705

- Diffusion Policy: Visuomotor Policy Learning via Action Diffusion  
  https://arxiv.org/abs/2303.04137

- Behavior Transformers: Cloning k Modes with One Stone  
  https://arxiv.org/abs/2206.11251

- VQ-BET: Behavior Generation with Latent Actions  
  https://arxiv.org/abs/2403.03181

- Denoising Diffusion Probabilistic Models  
  https://arxiv.org/abs/2006.11239

- DDIM: Denoising Diffusion Implicit Models  
  https://arxiv.org/abs/2010.02502

- Flow Matching for Generative Modeling  
  https://arxiv.org/abs/2210.02747

- MIT Course on Diffusion & Flow Matching (SDE perspective)  
  https://diffusion.csail.mit.edu/

- My DOT-Policy repo  
  https://github.com/IliaLarchenko/dot_policy