In [None]:
import torch

class HamiltonianSystem(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def hamiltonian(self, q, p):
        """
        Defines the Hamiltonian function H(q, p).
        Here we use a simple harmonic oscillator: H = (p^2 / 2m) + (k q^2 / 2).
        """
        k = 1.0  # Spring constant
        m = 1.0  # Mass
        return (p**2) / (2 * m) + (k * q**2) / 2  # Energy function

    def dynamics(self, q, p):
        """
        Computes dq/dt and dp/dt using Hamilton's equations.
        """
        q.requires_grad_(True)
        p.requires_grad_(True)

        H = self.hamiltonian(q, p)  # Compute Hamiltonian

        # Compute gradients using autodiff
        dH_dq, dH_dp = torch.autograd.grad(H, (q, p), create_graph=True)

        dq_dt = dH_dp  # ∂H/∂p
        dp_dt = -dH_dq  # -∂H/∂q

        return dq_dt, dp_dt

    def trajectory(self, q0, p0, T=10, dt=0.1):
        """
        Simulates the system for T seconds with step size dt using simple Euler integration.
        """
        q, p = q0.clone(), p0.clone()
        trajectory = []

        for _ in range(int(T / dt)):
            dq_dt, dp_dt = self.dynamics(q, p)

            # Euler step
            q = q + dq_dt * dt
            p = p + dp_dt * dt

            trajectory.append((q.clone().detach(), p.clone().detach()))

        return trajectory


# Example usage
torch.manual_seed(42)

# Initial conditions (position and momentum)
q0 = torch.tensor(1.0, requires_grad=True)
p0 = torch.tensor(0.0, requires_grad=True)

# Instantiate the system
hamiltonian_system = HamiltonianSystem()

# Compute the first derivatives (velocity and force)
dq_dt, dp_dt = hamiltonian_system.dynamics(q0, p0)

# Compute second-order derivatives (acceleration)
d2q_dt2 = torch.autograd.grad(dq_dt, q0, create_graph=True, allow_unused=True)[0]

# Simulate trajectory
trajectory = hamiltonian_system.trajectory(q0, p0, T=5, dt=0.1)

# Print results
print("dq/dt (velocity):", dq_dt.item())
print("dp/dt (force):", dp_dt.item())
# print("d²q/dt² (acceleration):", d2q_dt2.item())

# Plot the trajectory
import matplotlib.pyplot as plt
q_values, p_values = zip(*trajectory)

plt.plot(range(len(q_values)), q_values, label="q (position)")
plt.plot(range(len(p_values)), p_values, label="p (momentum)")
plt.xlabel("Time step")
plt.ylabel("Value")
plt.legend()
plt.show()