In [4]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torchdiffeq import odeint_adjoint as odeint

In [11]:
class ODEFunc(nn.Module): 
    def __init__(self, dim): 
        super(ODEFunc, self).__init__() 
        self.net = nn.Sequential(nn.Linear(dim, 50), nn.ReLU(), nn.Linear(50, dim)) 
        self.nfe = 0 # optional counter for function evaluations
    def forward(self, t, h):
        self.nfe += 1
        return self.net(h)

In [12]:
class ODEBlock(nn.Module): 
    def __init__(self, odefunc, integration_time=torch.tensor([0.0, 1.0])): 
        super(ODEBlock, self).__init__() 
        self.odefunc = odefunc 
        self.integration_time = integration_time
    def forward(self, h0):
        # Ensure integration times match the data type of h0
        t = self.integration_time.type_as(h0)
        # odeint returns the solution at all time points; we use the final state.
        h_T = odeint(self.odefunc, h0, t)[-1]
        return h_T

In [13]:
class NeuralODE(nn.Module): 
    def __init__(self, input_dim, hidden_dim, num_classes): 
        super(NeuralODE, self).__init__() 
        self.fc_in = nn.Linear(input_dim, hidden_dim) 
        self.relu = nn.ReLU() 
        self.odeblock = ODEBlock(ODEFunc(hidden_dim)) 
        self.fc_out = nn.Linear(hidden_dim, num_classes)
    def forward(self, x):
        # Map input to hidden state
        h0 = self.relu(self.fc_in(x))
        # Evolve the hidden state over time using the ODE solver (with adjoint backprop)
        h_T = self.odeblock(h0)
        # Map final hidden state to output logits
        out = self.fc_out(h_T)
        return out

In [16]:
model = NeuralODE(2,32,2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
x = torch.randn(128, 2)
y = torch.randint(0, 2, (128,)) 
# Simple training loop.
for epoch in range(100):
    optimizer.zero_grad()
    pred = model(x)
    loss = criterion(pred, y)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")


Epoch 0, Loss: 0.6934
Epoch 10, Loss: 0.6779
Epoch 20, Loss: 0.6699
Epoch 30, Loss: 0.6605
Epoch 40, Loss: 0.6466
Epoch 50, Loss: 0.6272
Epoch 60, Loss: 0.6017
Epoch 70, Loss: 0.5683
Epoch 80, Loss: 0.5327
Epoch 90, Loss: 0.4913


In [20]:
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    pred = model(x)                      # Get the raw output from the model
    predicted_labels = torch.argmax(pred, dim=1)  # Choose the class with highest logit
    accuracy = (predicted_labels == y).float().mean().item()  # Calculate the fraction of correct predictions
    print("Accuracy: {:.2f}%".format(accuracy * 100))

Accuracy: 75.00%
