In [None]:
import torch
import torch.nn as nn

# state space approximation of the underdamped spring-mass-damper system
class PhysicsInformedNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, m, c, k):
        super(PhysicsInformedNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
        # State space matrices for the underdamped spring
        self.A = torch.tensor([[0, 1], [-k/m, -c/m]], dtype=torch.float32)
        self.B = torch.tensor([[0], [1]], dtype=torch.float32)  # Assuming a zero input (free response)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def physics_loss(self, x):
        # x is the state vector
        dx = torch.matmul(self.A, x.T).T  # Calculate the state derivative
        return dx

# Example usage
m = 1.0  # Mass
c = 0.2  # Damping coefficient
k = 1.0  # Spring constant

model = PhysicsInformedNN(input_dim=2, hidden_dim=50, output_dim=2, m=m, c=c, k=k)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Dummy data (initial state and zero input for simplicity)
x = torch.randn(10, 2)

# Forward pass
output = model(x)

# Compute physics-based loss
loss = model.physics_loss(x).mean()  # Mean of the physics loss over the batch

loss
