In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchdiffeq import odeint_adjoint as odeint  # For ODE-Net (adjoint-based)
from torchdiffeq import odeint as odeint_rk        # For RK-Net (direct backprop)

In [2]:
# -----------------------------
# Common Building Blocks
# -----------------------------

# Standard residual block.
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += x
        out = self.relu(out)
        return out

# Downsampling block: convolution with stride 2.
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [3]:
# -----------------------------
# ODE-Based Blocks
# -----------------------------

# ODE function used for both ODE-Net and RK-Net.
class ODEFunc(nn.Module):
    def __init__(self, channels):
        super(ODEFunc, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, t, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return out

# ODEBlock uses an adjoint-based ODE solver (via odeint).
class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0.0, 1.0]).float()
    
    def forward(self, x):
        t = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, t)[-1]
        return out

# RKBlock uses the standard Runge–Kutta solver (with direct backprop).
class RKBlock(nn.Module):
    def __init__(self, odefunc):
        super(RKBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0.0, 1.0]).float()
    
    def forward(self, x):
        t = self.integration_time.type_as(x)
        out = odeint_rk(self.odefunc, x, t)[-1]
        return out

In [4]:

# -----------------------------
# Model Architectures for MNIST
# -----------------------------

# 1. Standard ResNet for MNIST.
class ResNetMNIST(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetMNIST, self).__init__()
        # Stem: input 1 channel -> 16 channels.
        self.stem = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.down1 = Downsample(16, 32)  # 28x28 -> 14x14.
        self.down2 = Downsample(32, 64)  # 14x14 -> 7x7.
        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(6)])
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)
    
    def forward(self, x):
        out = self.stem(x)
        out = self.down1(out)
        out = self.down2(out)
        out = self.res_blocks(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [5]:

# 2. ODE-Net for MNIST: replaces residual blocks with an ODEBlock.
class ODENetMNIST(nn.Module):
    def __init__(self, num_classes=10):
        super(ODENetMNIST, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.down1 = Downsample(16, 32)
        self.down2 = Downsample(32, 64)
        self.odefunc = ODEFunc(64)
        self.odeblock = ODEBlock(self.odefunc)
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)
    
    def forward(self, x):
        out = self.stem(x)
        out = self.down1(out)
        out = self.down2(out)
        out = self.odeblock(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


In [6]:
# 3. RK-Net for MNIST: same as ODE-Net but uses RKBlock.
class RKNetMNIST(nn.Module):
    def __init__(self, num_classes=10):
        super(RKNetMNIST, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.down1 = Downsample(16, 32)
        self.down2 = Downsample(32, 64)
        self.odefunc = ODEFunc(64)
        self.rkblock = RKBlock(self.odefunc)
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(64, num_classes)
    
    def forward(self, x):
        out = self.stem(x)
        out = self.down1(out)
        out = self.down2(out)
        out = self.rkblock(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


In [7]:
# -----------------------------
# Data Loading & Training on MNIST
# -----------------------------

# Transform: convert images to tensor and normalize.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download MNIST dataset.
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=128, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose one model to train.
# For example, use ODENetMNIST. You can try ResNetMNIST() or RKNetMNIST() instead.
model_1 = ODENetMNIST(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training loop.
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        preds = output.argmax(dim=1)
        correct += (preds == target).sum().item()
        total += data.size(0)
    
    train_loss /= total
    train_acc = correct / total * 100

    # Evaluate on test set.
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item() * data.size(0)
            preds = output.argmax(dim=1)
            correct_test += (preds == target).sum().item()
            total_test += data.size(0)
    test_loss /= total_test
    test_acc = correct_test / total_test * 100
    
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.2f}%, Test Loss {test_loss:.4f}, Test Acc {test_acc:.2f}%")


KeyboardInterrupt: 

In [None]:
# Choose one model to train.
# For example, use ODENetMNIST. You can try ResNetMNIST() or RKNetMNIST() instead.
model_2 = ResNetMNIST(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training loop.
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        preds = output.argmax(dim=1)
        correct += (preds == target).sum().item()
        total += data.size(0)
    
    train_loss /= total
    train_acc = correct / total * 100

    # Evaluate on test set.
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item() * data.size(0)
            preds = output.argmax(dim=1)
            correct_test += (preds == target).sum().item()
            total_test += data.size(0)
    test_loss /= total_test
    test_acc = correct_test / total_test * 100
    
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.2f}%, Test Loss {test_loss:.4f}, Test Acc {test_acc:.2f}%")


In [None]:
# Choose one model to train.
# For example, use ODENetMNIST. You can try ResNetMNIST() or RKNetMNIST() instead.
model_3 = RKNetMNIST(num_classes=10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# Training loop.
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        preds = output.argmax(dim=1)
        correct += (preds == target).sum().item()
        total += data.size(0)
    
    train_loss /= total
    train_acc = correct / total * 100

    # Evaluate on test set.
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item() * data.size(0)
            preds = output.argmax(dim=1)
            correct_test += (preds == target).sum().item()
            total_test += data.size(0)
    test_loss /= total_test
    test_acc = correct_test / total_test * 100
    
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.2f}%, Test Loss {test_loss:.4f}, Test Acc {test_acc:.2f}%")
