##### Imports:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
# deep learning

import numpy as np
# linear algebra

from scipy.integrate import solve_ivp
# dataset generaion solver

from torch.utils.data import DataLoader
from torch.utils.data import random_split
# data handling

import matplotlib.pyplot as plt
# plotting

from tqdm import tqdm
# progrees tracking

#### Datset Generation (Synthetic But Correct as per SciPy)

- Sample generation functions

In [None]:
def generate_ode_params(n_samples=1000):
    # generate coeffs a,b,c and y_0, for dy/dx = ax + by + c, initial condition y_0.

    np.random.seed(42)
    
    params = []
    for _ in range(n_samples):

        '''a = np.random.uniform(-2, 2)
        b = np.random.uniform(-2, 2)
        c = np.random.uniform(-2, 2)
        y0 = np.random.uniform(-2, 2)'''

        while True:
            a = np.random.uniform(-2, 2)
            b = np.random.uniform(-5, -0.5)  # force exponential decay
            c = np.random.uniform(-2, 2)
            y0 = np.random.uniform(-10, 10)
            # Solve and check stability
            _, y = solve_ode(a, b, c, y0)
            if np.max(np.abs(y)) < 1000:  # Reject explosive solutions
                break

        # why (-2,2): numerical stability
        # otherwise, solutions to the diff eqns may vary wildly which will cause
        # problems in computing MSE loss.

        # eg. if params: (a=0, b=10, c=0, y_0=1) dy/dx = 0x + 10y + 0
        # solution: y = e^(10x), so as x -> infty, major changes

        # LIMITATION: less generalizable to larger magnitude solutions
        # neural net may not learn everything

        params.append((a, b, c, y0))
    return params


def solve_ode(a, b, c, y0, x_span=(0, 5), num_points=50):

    # solve dy/dx = ax + by + c with initial condition y_0 by enumeration

    # essentially, numerically solving and "plotting" the continuous curve at 50 points, 
    # in the interval (0,5)

    def ode_func(x, y):
        return a*x + b*y + c

    x_eval = np.linspace(x_span[0], x_span[1], num_points)
    # 50 points of a curve between 0 to 5

    sol = solve_ivp(ode_func, x_span, [y0], t_eval=x_eval, method='RK45')

    # sol.y[0] == solution for y at x_eval
    return x_eval, sol.y[0]
