In [None]:
import sys
sys.path.append('../')

In [None]:

from Scripts.pinn_model import TrafficFlowForLWR_PINN, TrafficFlowForARZ_PINN
from Scripts.physics import lwr_pde_residual, physics_loss_calculator
import torch
import matplotlib.pyplot as plt

In [None]:
# Create meshgrid for proper tensor shapes
x_range = torch.linspace(0, 10000, 100)
t_range = torch.linspace(0, 3600, 50)
X, T = torch.meshgrid(x_range, t_range, indexing='ij')
x = X.reshape(-1, 1)
t = T.reshape(-1, 1)

# Correct ground truth functions
ground_truth_rho = lambda x, t: torch.sin(x / 1000) * torch.exp(-t / 3600)
ground_truth_u = lambda x, t: torch.cos(x / 1000) * torch.exp(-t / 3600)

In [None]:
""" ### LWR Model Approximation using PINNs ###"""

# Initialize model
model = TrafficFlowForLWR_PINN(hidden_layers=8, neurons_per_layer=20, outputs=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

alpha = 1
beta = 1
# Training loop (start simple)
for epoch in range(1000):
    optimizer.zero_grad()

    data_loss = torch.mean((model(x, t) - ground_truth_rho(x, t))**2)
    # Compute physics loss
    physics_loss = torch.mean(lwr_pde_residual(model, x, t)**2)

    # Add boundary/initial conditions (implement these next)
    total_loss = alpha * data_loss + beta * physics_loss
    
    total_loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss.item():.6f}")


In [None]:
# Initialize ARZ model
model = TrafficFlowForARZ_PINN(hidden_layers=8, neurons_per_layer=20)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1000):
    optimizer.zero_grad()
    
    # Get predictions
    pred = model(x, t)  # shape (N, 2)
    rho_pred, u_pred = pred[:, 0:1], pred[:, 1:2]  # each (N, 1)
    
    # Ground truth
    rho_true = ground_truth_rho(x, t).reshape(-1, 1)
    u_true = ground_truth_u(x, t).reshape(-1, 1)
    
    # Data loss
    data_loss_rho = torch.mean((rho_pred - rho_true)**2)
    data_loss_u = torch.mean((u_pred - u_true)**2)
    
    # Physics loss
    physics_loss = physics_loss_calculator(x, t, model, beta1=1.0, beta2=1.0,
                                         gamma1=1.0, gamma2=1.0, tau=0.02)
    
    total_loss = data_loss_rho + data_loss_u + physics_loss
    
    total_loss.backward()
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss.item():.6f}")
