In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchdiffeq import odeint

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Generate random training data (100 points in [0,1]x[0,1])
n_samples = 100
x_train = torch.rand(n_samples, 1)
y_train = torch.rand(n_samples, 1)

# Plot the training data
plt.figure(figsize=(10, 5))
plt.scatter(x_train.numpy(), y_train.numpy(), alpha=0.6, label='Training data')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Training Data Points')
plt.legend()
plt.grid(True)
plt.show()

# Define the neural network for the ODE
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 20),
            nn.Tanh(),
            nn.Linear(20, 1)
        )
        
    def forward(self, t, u):
        return self.net(u)

# Define the Neural ODE model
class NeuralODE(nn.Module):
    def __init__(self, odefunc):
        super(NeuralODE, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0.0, 1.0]).float()
        
    def forward(self, x):
        # Initial condition is zero (as per the problem statement)
        u0 = torch.zeros_like(x)
        
        # Concatenate x to u at each time step (to make it available to the ODE function)
        # Here we use a simple approach where we pass x as an additional input
        # In practice, you might need a more sophisticated approach
        out = odeint(self.odefunc, u0, self.integration_time, 
                    method='dopri5', rtol=1e-3, atol=1e-4)
        return out[-1] + x  # Residual connection similar to ResNet

# Create and train the model
odefunc = ODEFunc()
model = NeuralODE(odefunc)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Training loop
n_epochs = 500
losses = []

for epoch in range(n_epochs):
    optimizer.zero_grad()
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if epoch % 50 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

# Generate test points for plotting the learned function
x_test = torch.linspace(0, 1, 100).view(-1, 1)
with torch.no_grad():
    y_pred = model(x_test)

# Plot the learned function
plt.figure(figsize=(10, 5))
plt.scatter(x_train.numpy(), y_train.numpy(), alpha=0.6, label='Training data')
plt.plot(x_test.numpy(), y_pred.numpy(), 'r-', linewidth=2, label='Learned function')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Training Data and Learned Function')
plt.legend()
plt.grid(True)
plt.show()

: 