In [17]:
from matplotlib import pyplot as plt
import jax
from jax.lax import scan
import jax.numpy as jnp
import numpy as np
import numpy.random as npr

In [2]:
def arms_race(x0, y0, *, k, a, g, l, b, h,
              num_iterations=10, step_size=0.02):
    def arms_race_up(state, t):
        x, y = state
        dx = k * y - a * x + g
        dy = l * x - b * y + h
        x = x + step_size * dx
        y = y + step_size * dy
        return (x, y), (x, y)
    _, (xs, ys) = scan(arms_race_up, (x0, y0), jnp.arange(step_size, num_iterations + step_size, step_size))
    return xs, ys


In [6]:
num_iterations = 30
step_size = 0.02
gt_k = 10.
gt_a = 20.
gt_g = .1
gt_l = 10.
gt_b = 3.
gt_h = 6.
gt_xs, gt_ys = arms_race(10., 10., k=gt_k, a=gt_a, g=gt_g,
                         l=gt_l, b=gt_b, h=gt_h, num_iterations=num_iterations,
                         step_size=step_size)

In [9]:
gt_xs = jnp.reshape(gt_xs, (num_iterations, -1))[:, 0]
gt_xs

DeviceArray([8.0019999e+00, 3.4185421e+01, 1.7502908e+02, 8.7163947e+02,
             4.3170640e+03, 2.1358072e+04, 1.0564261e+05, 5.2251259e+05,
             2.5843442e+06, 1.2782128e+07, 6.3220196e+07, 3.1268598e+08,
             1.5465393e+09, 7.6491566e+09, 3.7832593e+10, 1.8711940e+11,
             9.2548923e+11, 4.5774547e+12, 2.2640002e+13, 1.1197705e+14,
             5.5383657e+14, 2.7392660e+15, 1.3548361e+16, 6.7009968e+16,
             3.3143015e+17, 1.6392473e+18, 8.1076877e+18, 4.0100469e+19,
             1.9833626e+20, 9.8096781e+20], dtype=float32)

In [10]:
gt_ys = jnp.reshape(gt_ys, (num_iterations, -1))[:, 0]
gt_ys


DeviceArray([1.1520000e+01, 7.4158813e+01, 3.7872485e+02, 1.8851034e+03,
             9.3356289e+03, 4.6185797e+04, 2.2844609e+05, 1.1299026e+06,
             5.5884910e+06, 2.7640592e+07, 1.3670992e+08, 6.7616486e+08,
             3.3442993e+09, 1.6540846e+10, 8.1810735e+10, 4.0463457e+11,
             2.0013154e+12, 9.8984725e+12, 4.8957653e+13, 2.4214371e+14,
             1.1976386e+15, 5.9234997e+15, 2.9297525e+16, 1.4490507e+17,
             7.1669796e+17, 3.5447749e+18, 1.7532391e+19, 8.6714876e+19,
             4.2889035e+20, 2.1212845e+21], dtype=float32)

In [16]:
np.array(gt_xs), np.array(gt_ys)



(array([8.0019999e+00, 3.4185421e+01, 1.7502908e+02, 8.7163947e+02,
        4.3170640e+03, 2.1358072e+04, 1.0564261e+05, 5.2251259e+05,
        2.5843442e+06, 1.2782128e+07, 6.3220196e+07, 3.1268598e+08,
        1.5465393e+09, 7.6491566e+09, 3.7832593e+10, 1.8711940e+11,
        9.2548923e+11, 4.5774547e+12, 2.2640002e+13, 1.1197705e+14,
        5.5383657e+14, 2.7392660e+15, 1.3548361e+16, 6.7009968e+16,
        3.3143015e+17, 1.6392473e+18, 8.1076877e+18, 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20], dtype=float32),
 array([1.1520000e+01, 7.4158813e+01, 3.7872485e+02, 1.8851034e+03,
        9.3356289e+03, 4.6185797e+04, 2.2844609e+05, 1.1299026e+06,
        5.5884910e+06, 2.7640592e+07, 1.3670992e+08, 6.7616486e+08,
        3.3442993e+09, 1.6540846e+10, 8.1810735e+10, 4.0463457e+11,
        2.0013154e+12, 9.8984725e+12, 4.8957653e+13, 2.4214371e+14,
        1.1976386e+15, 5.9234997e+15, 2.9297525e+16, 1.4490507e+17,
        7.1669796e+17, 3.5447749e+18, 1.7532391e+19, 8.671487

In [29]:
num_datapoints = 100
noise_scale = 3
obs_xs = np.array(gt_xs + npr.randn(num_datapoints, *np.shape(gt_xs)) * noise_scale)
obs_ys = np.array(gt_ys + npr.randn(num_datapoints, *np.shape(gt_ys)) * noise_scale)

In [31]:
obs_xs

array([[7.1104202e+00, 3.5948662e+01, 1.6814308e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20],
       [8.7127953e+00, 2.8113756e+01, 1.7368900e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20],
       [5.7510138e+00, 3.4040771e+01, 1.7488641e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20],
       ...,
       [1.4673883e+01, 3.4800049e+01, 1.7454721e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20],
       [1.0552752e+01, 2.9327173e+01, 1.7764212e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20],
       [9.8908739e+00, 3.4463154e+01, 1.7695886e+02, ..., 4.0100469e+19,
        1.9833626e+20, 9.8096781e+20]], dtype=float32)