Skip to content

Commit 7c35995

Browse files
Add Differentiable Physics: Mass-Spring System example (#1359)
* Add Differentiable Physics: Mass-Spring System example * Add differentiable_physics to run_all() in test script * Add visualization and update training code in mass_spring.py * Finalize differentiable_physics with visualization and CI integration * Finalize differentiable_physics with visualization and CI integration * Finalize differentiable_physics with the updates * Update requirements.txt for differentiable_physics * Update run_python_examples.sh to test differentiable_physics in CI * Add mass spring example and update requirements * Add mass spring example and update requirements * Updated README and visualization from corporate ID (abhitorch81) * Update readme.md --------- Co-authored-by: Abhishek Nandy <abhishek.nandy81@gmail.com>
1 parent 6f61614 commit 7c35995

File tree

5 files changed

+231
-0
lines changed

5 files changed

+231
-0
lines changed

differentiable_physics/mass_spring.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import argparse
5+
import matplotlib.pyplot as plt
6+
import os
7+
8+
9+
class MassSpringSystem(nn.Module):
10+
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"):
11+
super().__init__()
12+
self.device = device
13+
self.mass = mass
14+
self.springs = springs
15+
self.dt = dt
16+
self.gravity = gravity
17+
18+
# Particle 0 is fixed at the origin
19+
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device)
20+
21+
# Remaining particles are trainable
22+
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device))
23+
24+
# Velocities
25+
self.velocities = torch.zeros(num_particles, 2, device=device)
26+
27+
def forward(self, steps):
28+
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0)
29+
velocities = self.velocities
30+
31+
for _ in range(steps):
32+
forces = torch.zeros_like(positions)
33+
34+
# Compute spring forces
35+
for (i, j, rest_length, stiffness) in self.springs:
36+
xi, xj = positions[i], positions[j]
37+
dir_vec = xj - xi
38+
dist = dir_vec.norm()
39+
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6)
40+
forces[i] += force
41+
forces[j] -= force
42+
43+
# Apply gravity
44+
forces[:, 1] -= self.gravity * self.mass
45+
46+
# Semi-implicit Euler integration
47+
acceleration = forces / self.mass
48+
velocities = velocities + acceleration * self.dt
49+
positions = positions + velocities * self.dt
50+
51+
# Fix particle 0 at origin
52+
positions[0] = self.initial_position_0
53+
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device)
54+
55+
return positions
56+
57+
58+
def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"):
59+
plt.figure(figsize=(6, 4))
60+
plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x')
61+
plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o')
62+
plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*')
63+
plt.title("Mass-Spring System Positions")
64+
plt.xlabel("X")
65+
plt.ylabel("Y")
66+
plt.legend()
67+
plt.grid(True)
68+
plt.tight_layout()
69+
plt.savefig(save_path)
70+
print(f"Saved visualization to {os.path.abspath(save_path)}")
71+
plt.close()
72+
73+
74+
def train(args):
75+
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76+
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
77+
print(f"Using device: {device}")
78+
system = MassSpringSystem(
79+
num_particles=args.num_particles,
80+
springs=[(0, 1, 1.0, args.stiffness)],
81+
mass=args.mass,
82+
dt=args.dt,
83+
gravity=args.gravity,
84+
device=device,
85+
)
86+
87+
optimizer = optim.Adam(system.parameters(), lr=args.lr)
88+
target_positions = torch.tensor(
89+
[[0.0, 0.0], [1.0, 0.0]], device=device
90+
)
91+
92+
for epoch in range(args.epochs):
93+
optimizer.zero_grad()
94+
final_positions = system(args.steps)
95+
loss = (final_positions - target_positions).pow(2).mean()
96+
loss.backward()
97+
optimizer.step()
98+
99+
if (epoch + 1) % args.log_interval == 0:
100+
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}")
101+
102+
# Visualization
103+
initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy()
104+
visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy())
105+
106+
print("\nTraining completed.")
107+
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}")
108+
print(f"Target positions:\n{target_positions.cpu().numpy()}")
109+
110+
111+
def evaluate(args):
112+
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113+
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
114+
print(f"Using device: {device}")
115+
system = MassSpringSystem(
116+
num_particles=args.num_particles,
117+
springs=[(0, 1, 1.0, args.stiffness)],
118+
mass=args.mass,
119+
dt=args.dt,
120+
gravity=args.gravity,
121+
device=device,
122+
)
123+
124+
with torch.no_grad():
125+
final_positions = system(args.steps)
126+
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}")
127+
128+
129+
def parse_args():
130+
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System")
131+
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs")
132+
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass")
133+
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
134+
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration")
135+
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle")
136+
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant")
137+
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system")
138+
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval")
139+
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs")
140+
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength")
141+
return parser.parse_args()
142+
143+
144+
def main():
145+
args = parse_args()
146+
147+
if args.mode == "train":
148+
train(args)
149+
elif args.mode == "eval":
150+
evaluate(args)
151+
152+
153+
if __name__ == "__main__":
154+
main()
17.3 KB
Loading

differentiable_physics/readme.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Differentiable Physics: Mass-Spring System
2+
3+
This example demonstrates a simple differentiable **mass-spring system** using PyTorch.
4+
5+
A set of particles is connected via springs and evolves over time under the influence of:
6+
- **Spring forces** (via Hooke’s Law)
7+
- **Gravity** (acting in the negative Y-direction)
8+
9+
The system is fully differentiable, enabling **gradient-based optimization** of the **initial positions** of the particles so that their **final positions** match a desired **target configuration**.
10+
11+
This idea is inspired by differentiable simulation frameworks such as those presented in recent research (see reference below).
12+
13+
---
14+
15+
## Files
16+
17+
- `mass_spring.py` — Implements the simulation, training loop, and evaluation logic.
18+
- `README.md` — Description, instructions, and visualization output.
19+
- `mass_spring_viz.png` — Output visualization of the final vs target configuration.
20+
21+
---
22+
23+
## Key Concepts
24+
25+
| Term | Description |
26+
|-------------------|-----------------------------------------------------------------------------|
27+
| Initial Position | Learnable 2D coordinates (x, y) of each particle before simulation begins. |
28+
| Target Position | Desired final 2D position after simulation. Used to compute loss. |
29+
| Gravity | Constant force `[0, -9.8]` pulling particles downward in Y direction. |
30+
| Spring Forces | Modeled using Hooke’s Law. Particles connected by springs exert forces. |
31+
| Dimensionality | All particle positions and forces are 2D vectors. |
32+
33+
---
34+
35+
## Requirements
36+
37+
- Python 3.8+
38+
- PyTorch ≥ 2.0
39+
40+
Install requirements (if needed):
41+
42+
pip install -r requirements.txt
43+
44+
45+
## Usage
46+
47+
First, ensure PyTorch is installed.
48+
49+
#### Train the system
50+
51+
52+
python mass_spring.py --mode train
53+
54+
55+
![Mass-Spring System Visualization](mass_spring_viz.png)
56+
57+
*Mass-Spring System Visualization comparing final vs target positions.*
58+
59+
60+
61+
## References
62+
63+
[1] Sanchez-Gonzalez, A. et al. (2020).
64+
Learning to Simulate Complex Physics with Graph Networks.
65+
arXiv preprint arXiv:2002.09405.
66+
Available: https://arxiv.org/abs/2002.09405
67+
68+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch>=2.6
2+
matplotlib
3+

run_python_examples.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ function gat() {
170170
uv run main.py --epochs 1 --dry-run || error "graph attention network failed"
171171
}
172172

173+
function differentiable_physics() {
174+
uv run mass_spring.py --mode train --epochs 5 --steps 3 || error "differentiable_physics example failed"
175+
}
176+
177+
173178
eval "base_$(declare -f stop)"
174179

175180
function stop() {
@@ -223,6 +228,7 @@ function run_all() {
223228
run fx
224229
run gcn
225230
run gat
231+
run differentiable_physics
226232
}
227233

228234
# by default, run all examples

0 commit comments

Comments
 (0)