In [None]:

# Run this cell to start the Demo GUI
# When switching from one file to another, the loading time can become a few seconds.
root_path = "./Examples/"

%matplotlib widget

if True:
    import sys
    import os
    from pathlib import Path
    sys.path.append(str(Path(os.path.realpath("dummy.ipynb")).parent))
    sys.path.append(os.path.join(str(Path(os.path.realpath("dummy.ipynb")).parent), "FactorRotations", "utils"))

    from scipy.stats import norm
    from ipywidgets import widgets 

    import matplotlib.pyplot as plt
    import warnings
    warnings.filterwarnings( "ignore", module = "matplotlib\..*" )
    import numpy as np
    np.seterr(all="ignore")
    from demo_utils.datahandler import DataHandler
    
    pkl_samples = np.sort([p for p in os.listdir(root_path) if p.endswith(".pkl")])


    for pkl_sample in pkl_samples:
        _ = DataHandler(os.path.join(root_path, pkl_sample))

    data = DataHandler(os.path.join(root_path, pkl_samples[0]))
    
    plt.ioff()
    sample_fig = plt.figure(1,figsize=(3,6))
    sample_fig_ax = [sample_fig.add_subplot(2,1,1), sample_fig.add_subplot(2,1,2)]
    sample_fig.canvas.resizable = False
    sample_fig_ax[0].set_title("Mean Prediction")
    sample_fig_ax[1].set_title("Factor Model Prediction")


    for a in sample_fig_ax:
        a.axis('off')
        a.set_xticklabels([])
        a.set_yticklabels([])
        a.set_aspect('equal')

    plt.ion()
    im_frac = sample_fig_ax[0].imshow(data.cmap(data.get_mean_pred()))
    im_rounded = sample_fig_ax[1].imshow(data.cmap(data.get_prediction()))
    plt.ioff()

    factor_figs = [plt.figure(num=(i+1)*10, figsize=(1.0,1.0)) for i in range(20)]
    for f in factor_figs:
        f.add_subplot(1,1,1)
        f.canvas.resizable=True

    axs_fac = [fig.axes[0] for fig in factor_figs]

    for a in axs_fac:
        a.axis('off')
        a.set_xticklabels([])
        a.set_yticklabels([])
        a.set_aspect('equal')

    plt.subplots_adjust(wspace=0, hspace=0.3)
    plt.ion()

    factors_pos, factors_neg = data.plot_factors()
    for i in range(10):
        axs_fac[2*i].imshow(factors_neg[i])
        axs_fac[2*i+1].imshow(factors_pos[i])
    plt.ioff()
    label_fig = plt.figure(3,(16,2))
    axs_label = [label_fig.add_subplot(1,6, i+1) for i in range(6)]

    label_fig.suptitle('Input and Label(s)', fontsize=12, x=0.18)

    density_text = widgets.Text("Pseudo-Density: 1")
    density_text.disabled = True

    plt.ion()
    data.plot_sample_and_labels(axs_label)
    for a in axs_label:
        a.axis('off')
        a.set_xticklabels([])
        a.set_yticklabels([])
        a.set_aspect('equal')

    def show_sample():
        sample_fig_ax[0].imshow(data.cmap(data.get_mean_pred()))

        sample_fig_ax[1].imshow(data.cmap(data.get_prediction()))
        sample_fig.canvas.draw_idle()

    def update_density():
        n = norm()
        prod = 1
        scale = n.pdf(0)
        for s in slider_list[1:]:
            prod *= n.pdf(s.value) / scale
        density_text.value = f"Pseudo-Density: {np.round(prod, 5)}"

    slider_headings = [widgets.HTML(f"<font size='4'><center><b>Factor {i+1}:</b>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;0.00</center>") for i in range(10)]

    def update_slider_headings():    
        for i in range(len(slider_headings)):
            slider_headings[i].value = f"<font size='4'><center><b>Factor {i+1}:</b>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;" + '{:.2f}'.format(data.factor_weights[i]) +"</center>"

    slider_list = [widgets.FloatSlider(value=0, min=-2, max=2, readout=False)]

    src_dict = {}
    src_dict[slider_list[-1].model_id] = "Covariance"

    for i in range(data.factor_model.rank):
        slider_list.append(widgets.FloatSlider(value=0, min=-2, max=2, readout=False))
        src_dict[slider_list[-1].model_id] = f"Factor {i+1}"

    def update(change):

        src = src_dict[change["owner"].model_id]

        if src == "Covariance":
            data.scale_covariance(change["new"])
        else:
            factor_id = int(src.split(" ")[-1])
            new_value = change["new"]
            data.scale_factor(factor_id, new_value)
        
        update_density()
        show_sample()  
        update_slider_headings()


    slider_list[0].observe(update, names='value')
    for i in range(data.factor_model.rank):
        slider_list[i].observe(update, names='value')


    reset_btn = widgets.Button(description="Reset")

    img_dd = widgets.Dropdown(options=pkl_samples, layout={'width': 'max-content'})

    rotations_dd = widgets.Dropdown(options=data.available_rotations)
    def update_rotation(change):
        if change['type'] == 'change' and change['name'] == 'value':
            data.rotate_factors(change["new"])
            reset(None)
            show_sample()
            factors_pos, factors_neg  = data.plot_factors()
            for i in range(10):
                axs_fac[2*i].imshow(factors_neg[i])
                axs_fac[2*i+1].imshow(factors_pos[i])


    rotations_dd.observe(update_rotation, names="value")

    def update_img(change):
        reset(None)
        data.update_example(os.path.join(root_path, change["new"]))
        rotations_dd.value = "Unrotated"
        factors_pos, factors_neg = data.plot_factors()
        for i in range(10):
            axs_fac[2*i].imshow(factors_neg[i])
            axs_fac[2*i+1].imshow(factors_pos[i])

        data.plot_sample_and_labels(axs_label)
        show_sample()
    img_dd.observe(update_img, names="value")


    def reset(b): 
        data.reset()
        slider_list[0].unobserve_all()
        slider_list[0].value = 1
        slider_list[0].observe(update, names='value')

        for i in range(len(slider_list[1:])):
            slider_list[i+1].unobserve_all()
            slider_list[i+1].value = 0
            slider_list[i+1].observe(update, names='value')

        update_slider_headings()
        show_sample()
        update_density()

    reset_btn.on_click(reset)

    def resample(b):
        reset(None)
        data.resample()
        for i in range(len(slider_list[1:])):
            slider_list[i+1].unobserve_all()
            slider_list[i+1].value = data.factor_weights[i]
            slider_list[i+1].observe(update, names='value')
        
        update_slider_headings()
        show_sample()
        update_density()

    resample_btn = widgets.Button(description="Resample")
    resample_btn.on_click(resample)

    for factor_fig in factor_figs:
        factor_fig.canvas.toolbar_visible = False
        factor_fig.canvas.header_visible = False
        factor_fig.canvas.footer_visible = False

    sample_fig.canvas.toolbar_visible = False
    sample_fig.canvas.header_visible = False
    sample_fig.canvas.footer_visible = False

    label_fig.canvas.toolbar_visible = False
    label_fig.canvas.header_visible = False
    label_fig.canvas.footer_visible = False

    main_v1 = widgets.HBox([rotations_dd, reset_btn, resample_btn, density_text])
    main_v3 = widgets.VBox([sample_fig.canvas])
    main_v2 = widgets.VBox([main_v1, 
        *[widgets.HBox([widgets.HBox([factor_figs[2*i].canvas, widgets.VBox([slider_headings[i], slider_list[i+1]]), factor_figs[2*i+1].canvas],layout=widgets.Layout(border='solid')), 
        widgets.HBox([factor_figs[2*i+2*5].canvas, widgets.VBox([slider_headings[i+5], slider_list[i+5+1]]), factor_figs[2*i+2*5+1].canvas],layout=widgets.Layout(border='solid'))]) for i in range(5)]])

    main = widgets.HBox([main_v2, main_v3])
widgets.VBox([img_dd, label_fig.canvas, main])