In [1]:
import torch
from torch import nn, optim

Y_TARGET: float = 5.0
SEQ_LEN: int = 1000
INPUT_DIM: int = 5
HIDDEN_DIM: int = 64
LR: float = 1e-3
EPOCHS: int = 500


def loss_fn(y: torch.Tensor) -> torch.Tensor:
    """Calculate the loss as the squared difference from the target value."""
    return (y - Y_TARGET).pow(2)


class RecursiveSystem:
    """A simple second-order recursive system with internal state."""

    def __init__(self, device: torch.device) -> None:
        self.device = device
        self.y_values: list[torch.Tensor] = [
            torch.tensor(0.0, device=self.device),
            torch.tensor(0.0, device=self.device),
        ]
        self.x_values: list[torch.Tensor] = [
            torch.tensor(0.0, device=self.device),
            torch.tensor(0.0, device=self.device),
        ]
        self.z_values: list[torch.Tensor] = [torch.tensor(0.0, device=self.device)]

    def reset_state(
        self, y0: float, y1: float, x0: float, x1: float, z0: float
    ) -> None:
        self.y_values = [
            torch.tensor(y0, device=self.device),
            torch.tensor(y1, device=self.device),
        ]
        self.x_values = [
            torch.tensor(x0, device=self.device),
            torch.tensor(x1, device=self.device),
        ]
        self.z_values = [torch.tensor(z0, device=self.device)]

    def step(self, x: torch.Tensor) -> torch.Tensor:
        y2, y1 = self.y_values
        x2, x1 = self.x_values
        z1 = self.z_values[-1]

        z = z1 + 2.0 * x1 + 0.11
        y = y1 + 0.01 * y2 + 8.0 * x1 - 0.3 * x2 + 0.1 * z1

        # Clamp to avoid runaway values
        y = torch.clamp(y, -1000.0, 1000.0)
        z = torch.clamp(z, -1000.0, 1000.0)

        self.y_values = [y1, y]
        self.x_values = [x1, x]
        self.z_values.append(z)

        return y


class DenseModel(nn.Module):
    """Simple feedforward model with tanh-bounded output."""

    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int = 1
    ) -> None:
        super().__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layer = nn.Linear(hidden_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.ReLU()
        self.output_activation = nn.Tanh()  # bounds output in [-1, 1]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        hidden = self.activation(self.input_layer(x))
        hidden = self.activation(self.hidden_layer(hidden))
        return self.output_activation(self.output_layer(hidden))


def train_one_epoch(
    model: nn.Module,
    system: RecursiveSystem,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
) -> float:
    system.reset_state(0.0, 0.0, 0.0, 0.0, 0.0)
    total_loss = torch.tensor(0.0, device=device)

    for step in range(SEQ_LEN):
        y_vals = system.y_values
        x_vals = system.x_values
        z_val = system.z_values[-1]

        input_tensor = torch.stack(
            [y_vals[0], y_vals[1], x_vals[0], x_vals[1], z_val]
        ).unsqueeze(0)

        x_pred = model(input_tensor).squeeze()  # shape []

        y_tensor = system.step(x_pred)
        loss = loss_fn(y_tensor)
        total_loss += loss

        # Optional debug output for first few steps
        if epoch == 1 and step < 5:
            print(f"[Step {step}] x_pred: {x_pred.item():.4f}, y: {y_tensor.item():.4f}, loss: {loss.item():.4f}")

    avg_loss = total_loss / SEQ_LEN

    optimizer.zero_grad()
    avg_loss.backward()

    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()
    return avg_loss.item()


if __name__ == "__main__":
    device = torch.device("cpu")
    system = RecursiveSystem(device=device)
    model = DenseModel(INPUT_DIM, HIDDEN_DIM).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    for epoch in range(1, EPOCHS + 1):
        avg_loss = train_one_epoch(model, system, optimizer, device, epoch)

        if epoch == 1 or epoch % 50 == 0:
            print(f"Epoch {epoch}/{EPOCHS}, Loss: {avg_loss:.4f}")


[Step 0] x_pred: -0.0063, y: 0.0000, loss: 25.0000
[Step 1] x_pred: -0.0068, y: -0.0396, loss: 25.3971
[Step 2] x_pred: -0.0064, y: -0.0715, loss: 25.7201
[Step 3] x_pred: -0.0062, y: -0.0908, loss: 25.9159
[Step 4] x_pred: -0.0065, y: -0.0990, loss: 26.0002
Epoch 1/500, Loss: 929833.8750
Epoch 50/500, Loss: 0.3279
Epoch 100/500, Loss: 0.3046
Epoch 150/500, Loss: 0.1056
Epoch 200/500, Loss: 0.0359
Epoch 250/500, Loss: 0.0337
Epoch 300/500, Loss: 0.0826
Epoch 350/500, Loss: 0.0658
Epoch 400/500, Loss: 0.0371
Epoch 450/500, Loss: 0.0407
Epoch 500/500, Loss: 0.0299
