In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchdiffeq import odeint_adjoint as odeint 

In [2]:
class ODEFunc(nn.Module):
    def __init__(self, dim):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, 50),
            nn.ReLU(),
            nn.Linear(50, dim)
        )
        self.nfe = 0  # optional counter for function evaluations

    def forward(self, t, h):
        self.nfe += 1
        return self.net(h)

In [3]:
class ODEBlock(nn.Module):
    def __init__(self, odefunc, integration_time=torch.tensor([0.0, 1.0])):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = integration_time

    def forward(self, h0):
        # Ensure the integration time tensor is of the same type as h0.
        t = self.integration_time.type_as(h0)
        # Solve the ODE and use the final state as the block output.
        h_T = odeint(self.odefunc, h0, t)[-1]
        return h_T

In [4]:
class NeuralODE(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(NeuralODE, self).__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.odeblock = ODEBlock(ODEFunc(hidden_dim))
        self.fc_out = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        h0 = self.relu(self.fc_in(x))
        h_T = self.odeblock(h0)
        out = self.fc_out(h_T)
        return out

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transform to convert MNIST images to flattened vectors.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # flatten the 28x28 image into a 784-dim vector
])

# Load the MNIST dataset.
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Instantiate the model.
model = NeuralODE(input_dim=784, hidden_dim=64, num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [00:01<00:00, 9.75MB/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 368kB/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:00<00:00, 3.30MB/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 1.05MB/s]


In [6]:
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # Evaluate accuracy on the test set.
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            preds = output.argmax(dim=1)
            correct += (preds == target).sum().item()
            total += target.size(0)
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

Epoch 1/5, Loss: 0.3991
Test Accuracy: 94.12%
Epoch 2/5, Loss: 0.1615
Test Accuracy: 95.57%
Epoch 3/5, Loss: 0.1163
Test Accuracy: 96.53%
Epoch 4/5, Loss: 0.0933
Test Accuracy: 97.04%
Epoch 5/5, Loss: 0.0752
Test Accuracy: 96.50%
