<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Neural_Differential_Equations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torchdiffeq

In [None]:
import torch
import torch.nn as nn
import torchdiffeq

# Define the ODE function
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.fc1 = nn.Linear(2, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 2)

    def forward(self, t, y):
        """
        Args:
            t: Current time (not used here but needed for compatibility).
            y: Input state (batch of vectors).
        Returns:
            Derivative of y.
        """
        out = self.fc1(y)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Define the ODE Block
class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc

    def forward(self, x):
        """
        Args:
            x: Initial state (batch of vectors).
        Returns:
            Final state after solving the ODE.
        """
        t = torch.tensor([0, 1], dtype=torch.float32)  # Time range
        return torchdiffeq.odeint(self.odefunc, x, t)[-1]  # Take the last step

# Example usage
if __name__ == "__main__":
    # Instantiate the ODE function and block
    odefunc = ODEFunc()
    odeblock = ODEBlock(odefunc)

    # Input tensor
    x = torch.tensor([[2.0, 0.0]], dtype=torch.float32)  # Batch size 1, dimension 2

    # Forward pass
    out = odeblock(x)
    print("Output:", out)