# Physics Simulation Demo

Demonstrates usage of the physics_sim module for generating video data.

In [None]:
# Clone repo and setup path (for Colab)
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI/experiments/physics_prediction

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from physics_sim import (
    Ball, Barrier, PhysicsSimulation,
    generate_trajectory, generate_random_barriers,
    create_random_simulation, generate_dataset
)

## Basic Usage

In [None]:
# Create a simple simulation
ball = Ball(x=10, y=10, vx=1.5, vy=0.8, radius=3)
barrier = Barrier(x=25, y=15, width=15, height=3)

sim = PhysicsSimulation(
    ball=ball,
    barriers=[barrier],
    elasticity=0.9,
    gravity=0.0,
    wrap_boundaries=True,
    seed=42
)

# Generate trajectory
frames = generate_trajectory(sim, num_frames=100)
print(f"Generated {len(frames)} frames, shape: {frames.shape}")
print(f"Values: min={frames.min()}, max={frames.max()}")

In [None]:
# Display frames
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, ax in enumerate(axes):
    idx = i * 20
    ax.imshow(frames[idx], cmap='gray', vmin=0, vmax=1)
    ax.set_title(f't={idx}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## Animation Helper

In [None]:
def animate_trajectory(frames, fps=30, figsize=(6, 6)):
    """Display frames as animation."""
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_aspect('equal')
    ax.axis('off')
    
    im = ax.imshow(frames[0], cmap='gray', vmin=0, vmax=1, animated=True)
    
    def update(frame_idx):
        im.set_array(frames[frame_idx])
        return [im]
    
    anim = animation.FuncAnimation(
        fig, update, frames=len(frames),
        interval=1000//fps, blit=True, repeat=True
    )
    
    plt.close()
    return HTML(anim.to_jshtml())

animate_trajectory(frames, fps=30)

## Gravity Physics

In [None]:
# With gravity
ball_grav = Ball(x=10, y=50, vx=1.2, vy=0, radius=3)
barrier_platform = Barrier(x=40, y=10, width=15, height=3)

sim_grav = PhysicsSimulation(
    ball=ball_grav,
    barriers=[barrier_platform],
    elasticity=0.85,
    gravity=0.15,
    wrap_boundaries=False,
    seed=42
)

frames_grav = generate_trajectory(sim_grav, num_frames=200)

# Display
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, ax in enumerate(axes):
    idx = i * 40
    ax.imshow(frames_grav[idx], cmap='gray', vmin=0, vmax=1)
    ax.set_title(f't={idx}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## Random Terrain

In [None]:
# Generate random simulations
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i, ax in enumerate(axes):
    sim = create_random_simulation(num_barriers=5, with_gravity=False, seed=i*10)
    frame = sim.render()
    ax.imshow(frame, cmap='gray', vmin=0, vmax=1)
    ax.set_title(f'Random Terrain {i+1}')
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Complex trajectory with random terrain
sim_complex = create_random_simulation(num_barriers=7, with_gravity=True, seed=123)
frames_complex = generate_trajectory(sim_complex, num_frames=250)

# Display snapshots
fig, axes = plt.subplots(1, 6, figsize=(18, 3))
for i, ax in enumerate(axes):
    idx = i * 40
    ax.imshow(frames_complex[idx], cmap='gray', vmin=0, vmax=1)
    ax.set_title(f't={idx}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## Determinism Test

In [None]:
# Run same simulation twice
sim1 = create_random_simulation(num_barriers=5, seed=999)
frames1 = generate_trajectory(sim1, num_frames=100)

sim2 = create_random_simulation(num_barriers=5, seed=999)
frames2 = generate_trajectory(sim2, num_frames=100)

print(f"Frames identical: {np.array_equal(frames1, frames2)}")
print(f"Max difference: {np.abs(frames1 - frames2).max()}")

## Dataset Generation

In [None]:
# Generate batch of trajectories
dataset = generate_dataset(
    num_trajectories=10,
    num_frames=100,
    num_barriers=5,
    with_gravity=True,
    base_seed=42
)

print(f"Dataset shape: {dataset.shape}")
print(f"Memory: {dataset.nbytes / 1024 / 1024:.2f} MB")

# Visualize first trajectory
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flatten()):
    idx = i * 10
    ax.imshow(dataset[0, idx], cmap='gray', vmin=0, vmax=1)
    ax.set_title(f't={idx}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## Save/Load

In [None]:
# Save trajectory
np.savez_compressed('results/example_trajectory.npz', frames=frames_complex)
print("Saved trajectory")

# Load
loaded = np.load('results/example_trajectory.npz')['frames']
print(f"Loaded shape: {loaded.shape}")
print(f"Matches original: {np.array_equal(frames_complex, loaded)}")