# Differentiable Simulator for Multi-particle Dynamics

We will use the package PhiFlow (https://tum-pbs.github.io/PhiFlow/) to simulate billiard balls and collisions. Follow the README instructions to correctly install the package.

In [1]:
from phi.torch.flow import (
    Tensor, vec, batch, instance, stack, expand, rename_dims, math,
    Sphere, PointCloud, Box, Field, dual, field,
    iterate, plot, jit_compile
)

Let’s create a cue ball and a standard billiard triangle. Our table will have dimensions 1x1.

In [None]:
def create_balls(cue_pos=vec(x=.1, y=.5), billiard_layers=3, radius=0.05):
    coords = [cue_pos]
    for i in range(billiard_layers):
        for j in range(i + 1):
            coords.append(vec(x=i * 2 * radius + 0.5, y=j * 2 * radius + 0.5 - i * radius * 0.7))
    return Sphere(stack(coords, instance('ball')), radius=radius)

domain = Box(x=1, y=1)
balls = create_balls(billiard_layers=3, radius=0.05)
plot([domain, balls], overlay='list')

Next, we define the dynamics consisting of linear movement and collisions. We store the velocities in the `values` of the field. The impact dynamics depend on the relative distances and velocities of the balls, which we get using `math.pairwise_differences`. 

**Task 1: Implement the Euler integrator**
* Use the Euler update for new positions

**Task 2: Update the points with the new positions**


In [None]:
def euler(data: Field, dt: float) -> Tensor:
    """ Euler integrator. """
    # Velocities
    v0 = data.values
    # TODO: Return the sphere elements with updated positions
    # Input `data` is a Field (https://tum-pbs.github.io/PhiFlow/Fields.html)
    # Hint: Use center points of elements as initial positions
    new_positions = ...
    return new_positions

def advect_points(points: Field, dt: float, integrator=euler) -> Field:
    """
    Advects (transports) the sample points of a point cloud using a simple Euler step.
    """
    # Obtain new positions using ODE integrator for solving the movement
    # dt is the Euler step time increment
    new_positions = integrator(points, dt)
    # TODO: Create a copy of the points Geometry with the center at the new positions
    # Hint: Look at methods listed in https://tum-pbs.github.io/PhiFlow/phi/geom/#phi.geom.Geometry
    new_elements = ...
    # Now returns a copy of this field with `elements` replaced
    result = points.with_elements(new_elements)

    return result

@jit_compile
def apply_bounds(v: Field, damping_coef=-0.50) -> Field:
    """
    Applies boundary conditions to keep balls within a specified domain. 
    If a ball moves out of bounds, it is clipped to the boundary, and its velocity is damped.
    """
    # Extract the radius of the balls
    radius = v.geometry.radius 
    clipped = math.clip(v.points, domain.lower + radius, domain.upper - radius)
    v = field.where(clipped == v.points, v, v * damping_coef)
    # Shift the balls' positions to the clipped values (within the valid boundary)
    return v.shifted_to(clipped)

@jit_compile
def physics_step(v: PointCloud, dt: float, elasticity=0.5) -> PointCloud:
    """
    Simulates one step of point cloud physics, including advection and collision handling.
        v: PointCloud representing the moving objects
        dt: Time step for the simulation
        elasticity: Coefficient of restitution to model the bounciness of the objects (1.0 = perfectly elastic, 0 = inelastic)
    """

    # Advect the points to their next positions based on their current velocities
    v_next = advect_points(v, dt)

    """ Compute pairwise differences for collision detection """
    # Calculate pairwise differences between all points to detect potential collisions
    # x_diff gives the displacement vectors between every pair of points
    x_diff = math.pairwise_differences(v_next.points)

    # Compute the distance between each pair of points using the displacement vectors
    # eps is used to avoid NaN errors during gradient calculations with sqrt
    dist = math.vec_length(x_diff, eps=1e-4)

    # Calculate relative velocities between the pairs of points
    rel_v = -math.pairwise_differences(v.values)

    # Normalize the displacement vectors to obtain the direction of separation (or impact)
    dist_dir = -math.safe_div(x_diff, dist)

    """ Compute pairwise differences for collision detection """

    # Project the relative velocity onto the direction of the displacement (impact direction)
    # This helps identify if the points are moving toward or away from each other
    projected_v = dist_dir.vector * rel_v.vector

    # Determine if there is an impact:
    # - Impact occurs when the projected velocity is negative (points are moving toward each other)
    # - The distance between the points must be less than twice their radius (indicating contact)
    has_impact = (projected_v < 0) & (dist < 2 * v.geometry.radius)

    # Calculate the impulse resulting from the collision using the elasticity coefficient
    # The impulse adjusts the velocity of the colliding objects to simulate bouncing
    impulse = -(1 + elasticity) * 0.5 * projected_v * dist_dir

    # Compute the combined radii of the two colliding objects.
    radius_sum = v.geometry.radius + rename_dims(v.geometry.radius, instance, dual)

    # Estimate the time of impact between the points based on their projected velocities and current distances
    impact_time = math.safe_div(dist - radius_sum, projected_v)

    # Calculate the contribution of the impulse to the position change
    #     The positions are adjusted based on the time remaining after the detected impact
    x_inc_contrib = math.sum(
        math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), dual
    )

    # Update the positions of the points to account for any collisions
    v = v.with_elements(v.geometry.shifted(x_inc_contrib))

    # Apply the calculated impulses to the velocities of the points to reflect the collision effects
    v += math.sum(math.where(has_impact, impulse, 0), dual)

    # Advect the points again with the updated velocities to move them forward in time
    return advect_points(v, dt)

# Now, let's give the cue ball a starting velocity and run the simulation.
v_x = 7 
v_y = 0
v0 = math.scatter(math.zeros(balls.shape), indices=vec(ball=0), values=vec(x=v_x, y= v_y))
initial_state = Field(balls, v0, 0)
trj = iterate(lambda v: apply_bounds(physics_step(v=v, dt=0.003)), batch(t=128), initial_state, substeps=2)
plot([domain, trj.geometry], overlay='list', animate='t')

## Learning with a Differentiable Physics

We will learn how to make the initial shot on the cue ball to maximize the dispersion of the other balls across the table. Therefore, let us first define a loss function to quantify the dispersion.

**Task 3: Implement the loss function**
* Use L2 loss on the pairwise distances
* The loss value for the initial position should be around 4

In [None]:
def billiards_triangle(billiard_layers=3, radius=0.05):
    coords = []
    for i in range(billiard_layers):
        for j in range(i + 1):
            coords.append(vec(x=i * 2 * radius + 0.5, y=j * 2 * radius + 0.5 - i * radius * 0.7))
    return Sphere(stack(coords, instance('balls')), radius=radius)

def loss_function(x0: Tensor, v0: Tensor, goal=vec(x=1.0, y=1.0), steps=64):

    triangle_balls = PointCloud(billiards_triangle()) * (0, 0)
    controllable_ball = PointCloud(Sphere(expand(x0, instance(triangle_balls).with_size(1)), radius=triangle_balls.geometry.radius)) * v0
    all_balls = controllable_ball & triangle_balls
    
    # Simulate the physics for the specified number of steps using `physics_step`
    trj = iterate(lambda v: apply_bounds(physics_step(v=v, dt=0.003)), batch(t=steps), all_balls)

    """ Our loss function will quantify the dispersion of the balls at the end of the simulation """
    balls_end_positions = trj.t[-1].points
    # TODO: Use as loss function the l2 loss of pairwise distances for the last points
    # Hint: Leverage math module functions in https://tum-pbs.github.io/PhiML/phiml/math/index.html
    loss = ...
    
    return loss, trj, all_balls

output = loss_function(x0=vec(x=.1, y=.5), v0=vec(x=15, y=0))
print('loss:', output[0])


Finally, let's do gradient descent over the differentiable simulator.

**Task 4: Implement the gradient step**
* Note that we want to maximize dispersion

In [None]:
grad_fun = math.gradient(loss_function, 'x0,v0')
learning_rate = .01

x0 = vec(x=.1, y=.5)
v0 = vec(x=15, y=0)

for iter in range(50):
    (loss, trj, balls), (_, dv0) = grad_fun(x0, v0)
    print(f"Iter={iter} loss={loss:.3f}  v0={v0}  ∇={dv0}")
    # TODO: Apply a gradient step to maximize the loss function
    v0 += ...

final_loss, trj, balls = loss_function(x0, v0)
print(f"Final loss: {final_loss}")
plot([domain, trj.geometry], overlay='list', animate='t')