In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2

## Theory

In [None]:
def plot_rdk_hist(fraction_subset, p_subset, p_background, num_dots, num_directions):
    p = {"s": p_subset, "b": p_background}
    num_group = dict(s=int(fraction_subset * num_dots))
    num_group["b"] = num_dots - num_group["s"]
    mu = {}
    sigma2 = {}
    mu_dir = {}
    sigma2_dir = {}
    titles = {"s": "subset", "b": "background", "t": "total"}
    dirn = {"s": -2, "b": num_directions // 2}
    for k in "sb":
        for d in "ci":
            if d == "c":
                p_cur = p[k]
            else:
                p_cur = 1 / num_directions
            mu[k, d] = num_group[k] * p_cur
            sigma2[k, d] = num_group[k] * p_cur * (1 - p_cur)
        for param_dir, param_dict in [(mu_dir, mu), (sigma2_dir, sigma2)]:
            param_dir[k] = np.ones(num_directions + 1) * param_dict[k, "i"]
            param_dir[k][dirn[k]] = param_dict[k, "c"]
            param_dir[k][-1] = param_dir[k][0]
    mu_dir["t"] = fraction_subset * mu_dir["s"] + (1 - fraction_subset) * mu_dir["b"]
    sigma2_dir["t"] = (
        fraction_subset * sigma2_dir["s"] + (1 - fraction_subset) * sigma2_dir["b"]
    )
    theta = np.linspace(0, 2 * np.pi, num_directions + 1)
    plt.figure(figsize=(16, 8))
    for i, k in enumerate("sbt"):
        normfac = 1 / np.amin(mu_dir[k])
        plt.subplot(1, 3, i + 1, polar=True)
        plt.polar(theta, normfac * mu_dir[k], color="C" + str(i))
        s = np.sqrt(sigma2_dir[k]) * 1.96
        plt.fill_between(
            theta,
            normfac * (mu_dir[k] - s),
            normfac * (mu_dir[k] + s),
            color="C" + str(i),
            alpha=0.2,
        )
        plt.title(titles[k] + "\n")
    plt.tight_layout()

In [None]:
import ipywidgets as widgets


@widgets.interact(
    fraction_subset=widgets.FloatSlider(0.1, min=0.0, max=1.0, step=0.1),
    p_subset=widgets.FloatSlider(0.9, min=0.0, max=1.0, step=0.1),
    p_background=widgets.FloatSlider(0.2, min=0.0, max=1.0, step=0.1),
    num_dots=widgets.IntSlider(500, min=10, max=1000, step=10),
    num_directions=widgets.IntSlider(180, min=10, max=360, step=10),
)
def plot_interactive(fraction_subset, p_subset, p_background, num_dots, num_directions):
    plot_rdk_hist(fraction_subset, p_subset, p_background, num_dots, num_directions)


## From RDK actual sampling

In [None]:
from rdktools.rdk_params import Params, get_random_params
from rdktools.rdk_stimuli import RDK
from rdktools.rdk_experiment import other_angle, Experiment

In [None]:
import ipywidgets as widgets
from IPython.display import display


def get_hist(theta, nbins=50):

    radii, thetas = [list(d) for d in np.histogram(theta, bins=nbins)[:2]]

    thetas = [(r1 + r2) / 2 for r1, r2 in zip(thetas[:-1], thetas[1:])]
    thetas.append(thetas[0]), radii.append(radii[0])

    thetas = np.deg2rad(thetas)
    radii = np.array(radii)  # / np.max(radii)

    return thetas, radii


def plot_rdk_motions_and_launch(
    fraction_subset,
    p_subset,
    p_background,
    diffusion,
    diffuse_subset,
    name,
    num_batches,
    num_examples_per_batch,
):

    params = Params(N_TRIALS_PER_BATCH=num_examples_per_batch, N_BATCH=num_batches)

    # angles = (angle_s, angle_b)
    angles = [0, 90]

    params.SUBSET_RATIO = fraction_subset
    params.DOT_COHERENCE = [p_background, p_subset]
    params.DIFFUSION_SCALE = diffusion
    params.DIFFUSE_SUBSET = diffuse_subset
    params.NAME = name

    rdk = RDK(None, params)
    thetas = [0, 0]
    theta = []
    rdk.new_sample(angles)
    angles = np.deg2rad(angles)

    rdk.sample_dots(rdk.max_radius, rdk.ndots)

    for _ in range(100):
        rdk.update()
        if rdk.rand < p_subset:
            thetas[0] = rdk.dot_motiondirs
        else:
            thetas[1] = rdk.dot_motiondirs

        theta.append(rdk.dot_motiondirs)

    fig, axs = plt.subplots(1, 3, subplot_kw={"projection": "polar"}, figsize=(10, 5))

    axs[0].plot(*get_hist(thetas[0]), label="dots")
    axs[0].set_title("Subset Moving")
    axs[1].plot(*get_hist(thetas[1]), label="dots")
    axs[1].set_title("Subset Random")
    axs[2].plot(*get_hist(np.concatenate(theta)), label="dots")
    axs[2].set_title("Average")

    [ax.get_yaxis().set_visible(False) for ax in axs]
    # [[ax.vlines([a], [0], [1], color=c, label=l, linestyle='--') for a, c, l in zip(angles, ['b', 'r'], ['global', 'subset'])] for ax in axs]
    # [ax.legend() for ax in axs]

    return params


controls = {
    "fraction_subset": widgets.FloatSlider(0.1, min=0.0, max=1.0, step=0.05),
    "p_subset": widgets.FloatSlider(0.9, min=0.0, max=1.0, step=0.05),
    "p_background": widgets.FloatSlider(0.2, min=0.0, max=1.0, step=0.1),
    "num_batches": widgets.IntSlider(10, min=1, max=30, step=1),
    "num_examples_per_batch": widgets.IntSlider(3, min=1, max=10, step=1),
    # angle_s = widgets.IntSlider(90, min=0, max=360, step=10),
    # angle_b = widgets.IntSlider(0, min=0, max=360, step=10),
    "diffusion": widgets.IntSlider(0, min=0, max=360, step=1),
    "diffuse_subset" : widgets.ToggleButton(value=False, description='Diffuse Subset ?'),
    "name": widgets.Text(
        value="gabriel", placeholder="type your name", description="Name"
    ),
}


In [None]:
interactive_exp = widgets.interactive(plot_rdk_motions_and_launch, **controls)

display(interactive_exp)

exp = Experiment(interactive_exp.result,
 None, save_data=False, save_gif=False, randomize=False
)

button = widgets.Button(description="Lauch Experiment!")
output = widgets.Output()

display(button, output)

def on_button_clicked(button):
    exp.__init__(interactive_exp.result, None, save_data=False, save_gif=False, randomize=False)
    with output:
        exp.run()

    return exp

button.on_click(on_button_clicked)

In [None]:
exp.results_pd.describe()[['absolute_error_global', 'absolute_error_subset']]

In [None]:
exp.params