Demo taken from:
https://ipywidgets.readthedocs.io/en/stable/examples/Lorenz%20Differential%20Equations.html


In [2]:
%matplotlib widget

from ipywidgets import interact, interactive
from IPython.display import clear_output, display, HTML

import numpy as np
from scipy import integrate

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import cnames
from matplotlib import animation

import math

UsageError: Line magic function `%` not found.


In [None]:
def solve_lorenz(N=10, angle=0.0, max_time=8.0, sigma=10.0, beta=8./3, rho=28.0):

    fig = plt.figure()
    ax = fig.add_axes([0, 0, 1, 1], projection='3d')
    ax.axis('off')

    # prepare the axes limits
    ax.set_xlim((-25, 25))
    ax.set_ylim((-35, 35))
    ax.set_zlim((5, 55))

    def lorenz_deriv(x_y_z, t0, sigma=sigma, beta=beta, rho=rho):
        """Compute the time-derivative of a Lorenz system."""
        x, y, z = x_y_z
        return [sigma * (y - x), x * (rho - z) - y, x * y - beta * z]

    # Choose random starting points, uniformly distributed from -15 to 15
    np.random.seed(1)
    x0 = -15 + 30 * np.random.random((N, 3))

    # Solve for the trajectories
    t = np.linspace(0, max_time, int(250*max_time))
    x_t = np.asarray([integrate.odeint(lorenz_deriv, x0i, t)
                      for x0i in x0])

    # choose a different color for each trajectory
    colors = plt.cm.viridis(np.linspace(0, 1, N))

    for i in range(N):
        x, y, z = x_t[i,:,:].T
        lines = ax.plot(x, y, z, '-', c=colors[i])
        plt.setp(lines, linewidth=2)

    ax.view_init(30, angle)
    plt.show()

    return t, x_t

In [None]:
t, x_t = solve_lorenz(angle=0, N=5)


In [None]:
# linear model
f_x_ab = lambda x, a,b: (a*x + b)

# error, squared error, and gradients of sqErr
err_i = lambda x, y, a,b: (y - f_x_ab(x, a, b) )
err2_i = lambda x, y, a,b: err_i(x, y, a,b)**2
de_da = lambda x, y, a,b: -2.0*err_i(x, y, a,b)*x
de_db = lambda x, y, a,b: -2.0*err_i(x, y, a,b)

In [None]:
err_i

In [None]:
# data:
X = [0, 1, 2, 3, 6, 10]
Y = [2.1, 4.2, 5.8, 8, 14, 18]
N = len(X)
N

In [None]:
plt.plot(X, Y, '-o')

In [None]:
a = 1.8
b = 2.2

for i in range(N):
    xi = X[i]
    yi = Y[i]
    e = err2_i(xi, yi, a, b)
    print(e)

In [None]:
## init params:
a = - 1.0
b = 1.0
alpha = 0.001
errMax = 1e-3

for k in range(200):
    grad_a = 0.
    grad_b = 0.
    err_sq_batch = 0.
    for i in range(N):
        xi = X[i]
        yi = Y[i]
        err_sq_batch += err2_i(xi, yi, a, b)
        grad_a += de_da(xi, yi, a, b)
        grad_b += de_db(xi, yi, a, b)
        
#     lr_a = grad_a/err_sq_batch
#     lr_b = grad_b/err_sq_batch
    

    a -= alpha * grad_a
    b -= alpha * grad_b
    if k%20 == 0:
        print("")
        print("a = " + str(a))
        print("b = " + str(b))
        print("grad_a = " + str(grad_a))
        print("grad_b = " + str(grad_b))
        print("err_sq_batch = "  + str(err_sq_batch))
    if err_sq_batch < errMax:
        print("break. err_batch="+str(err_sq_batch))
        break
print("")
print("a = " + str(a))
print("b = " + str(b))



In [None]:
sum_err  = 0
for i in range(N):
    xi = X[i]
    yi = Y[i]
    e = err_i(xi, yi, a, b)
    sum_err +=e
    print(e)
print("abs:")
print(sum_err)