In [1]:
import tensorflow as tf
from models import utils as mutils
import jax
from jax import numpy as jnp

from matplotlib import pyplot as plt
from matplotlib.widgets import Button, Slider



In [11]:
points = []
while len(points) < 10000:
    p = tf.random.uniform((1, 1, 2), -1, 1)
    if tf.math.reduce_sum(p**2) <= 1:
        points.append(p)
dataset = {'image': points, 'label': [1 if tf.math.reduce_sum(p**2) <= 0.5 else 0 for p in points]}
ds = tf.data.Dataset.from_tensor_slices(dataset)
tf.data.experimental.save(ds, 'disk/data')
ds.element_spec

{'image': TensorSpec(shape=(1, 1, 2), dtype=tf.float32, name=None),
 'label': TensorSpec(shape=(), dtype=tf.int32, name=None)}

In [2]:
spec = {'image': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.float32, name=None),
            'label': tf.TensorSpec(shape=(), dtype=tf.int32, name=None)}
new_ds = tf.data.experimental.load('disk/data', spec)

In [3]:
data = [d['image'] for d in new_ds.take(1000)]
data_np = tf.stack(data, axis=0).numpy()
data_np = jnp.squeeze(data_np)



In [4]:
%matplotlib

Using matplotlib backend: Qt5Agg


In [12]:
fig, ax = plt.subplots()
fig.set_figheight(5)
fig.set_figwidth(5)

circle = plt.Circle((0, 0), 1, color='r', fill=False)

xs = lambda r: data_np[:, 0] * r
ys = lambda r: data_np[:, 1] * r

ax.add_patch(circle)
line, = ax.plot(xs(0.6), ys(0.6), 'r+')
ax.grid(True)

axrad = fig.add_axes([0.1, 0.25, 0.0225, 0.63])

rad_slider = Slider(
    ax=axrad,
    label="Radius",
    valmin=0.,
    valmax=1.,
    valinit=0.6,
    orientation="vertical"
)

def update(val):
    line.set_ydata(ys(rad_slider.val))
    line.set_xdata(xs(rad_slider.val))
    fig.canvas.draw_idle()


rad_slider.on_changed(update)

plt.show()