In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import seaborn as sns

In [None]:
seed = 111
np.random.seed(seed)
sns.set(style="white")

color_names = ["red",
               "windows blue",
               "amber",
               "faded green",
               "dusty purple",
               "orange",
               "clay",
               "pink",
               "greyish",
               "light cyan",
               "steel blue",
               "pastel purple",
               "mint",
               "salmon"]

colors = sns.xkcd_palette(color_names)

In [None]:
def _plot_segments(ax, segmentation, ymin, ymax):
    s_seq = np.squeeze(segmentation)
    z_cps = np.concatenate(
        ([0], np.where(np.diff(s_seq))[0]+1, [s_seq.size]))
    for start, stop in zip(z_cps[:-1], z_cps[1:]):
        stop = min(s_seq.size, stop+1)
        ax.axvspan(
            start, stop-1, ymin=ymin, ymax=ymax,
            alpha=.8, facecolor=colors[s_seq[start]])

In [None]:
def make_seq(h=10, T=100):
    y = np.random.randint(h)
    seq = []
    z_seq = []
    v = np.random.uniform(-0.5, 0.5)
    for i in range(T):
        z_seq.append(0 if v < 0 else 1)
        seq.append(y + np.random.randn(1) * 0.1)
        y += v
        if y > h:
            v = -v
        elif y < 0:
            v = -v
    return np.array(seq), np.array(z_seq)

In [None]:
y_seq, z_seq = make_seq()
fig = plt.figure(figsize=(12, 4))
ax = fig.gca()
ax.scatter(np.arange(100), y_seq)
ax.set_ylim([-1, 11.])
_plot_segments(ax, z_seq, 0., 0.05)
plt.show()

In [None]:
data_y = []
data_z = []
for i in tqdm(range(100000)):
    dy, dz = make_seq()
    data_y.append(dy)
    data_z.append(dz)
data_y = np.asarray(data_y)
data_z = np.asarray(data_z)
data_y.shape, data_z.shape

In [None]:
idx = np.random.randint(data_y.shape[0])
y_seq, z_seq = data_y[idx], data_z[idx]
fig = plt.figure(figsize=(12, 4))
ax = fig.gca()
ax.scatter(np.arange(100), y_seq, s=25, marker='+')
ax.plot(np.arange(100), y_seq)
ax.set_ylim([-1, 11.])
_plot_segments(ax, z_seq, 0., 0.05)
plt.show()

In [None]:
np.savez('bouncing_ball.npz', y=data_y, z=data_z)

In [None]:
data_y = []
data_z = []
for i in tqdm(range(1000)):
    dy, dz = make_seq(T=150)
    data_y.append(dy)
    data_z.append(dz)
data_y = np.asarray(data_y)
data_z = np.asarray(data_z)
data_y.shape, data_z.shape

In [None]:
np.savez('bouncing_ball_test.npz', y=data_y, z=data_z)