In [1]:
from IPython.display import HTML
display(HTML("<head><link rel='stylesheet' type='text/css' href='../custom.css'></head>"))
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
from bqplot import *
import bqplot as bq
import bqplot.marks as bqm
import bqplot.scales as bqs
import bqplot.axes as bqa

import ipywidgets as widgets

import matplotlib as mpl
import matplotlib.colors as mcolors  
import matplotlib.pyplot as plt      

from matplotlib import rc 

In [3]:
scale_x = bqs.LinearScale(min = 0.7, max = 3.0)
scale_y = bqs.LinearScale(min = 0.7, max = 3.0)

axis_x = bqa.Axis(scale=scale_x,
                tick_format='0.2f',
                tick_style={'font-size': '15px'},
                grid_lines = 'none',
                grid_color = '#8e8e8e', 
                label='v [log]',
                label_location='middle',
                label_style={'stroke': 'black', 'default-size': 35},
                label_offset='50px')

axis_y = bqa.Axis(scale=scale_y,
                tick_format='0.2f',
                tick_style={'font-size': '15px'},
                grid_lines = 'none',
                grid_color = '#8e8e8e', 
                orientation='vertical',
                label='p',
                label_location='middle',
                label_style={'stroke': 'red', 'default_size': 35},
                label_offset='50px')

scat = bqm.Scatter(
    name = 'labels',
    x = [1.0, 2.0, 2.5],
    y = [1.0, 2.0, 2.5],
    scales = {'x': scale_x, 'y': scale_y}, 
    #opacities = [1.0],
    visible = True,
    colors = ['black'],
    names = [],
) 


fig = bq.Figure(title='Van der Waals isotherms',
                marks=[scat],
                axes=[axis_x, axis_y],
                animation_duration=500,
                layout = widgets.Layout(align_self='center', width='100%'),
                legend_location='top-right',
                background_style= {'fill': 'white',  'stroke': 'black'},
                #fig_margin=dict(top=80, bottom=80, left=75, right=30)
    )

In [4]:
fig

Figure(animation_duration=500, axes=[Axis(grid_color='#8e8e8e', grid_lines='none', label='v [log]', label_offs…

In [5]:
def get_limits(figure):
    
    data_mark = figure.marks[0]
    
    if data_mark.selected == None or len(data_mark.selected) == 0:
    
        max_x = max(data_mark.x)
        min_x = min(data_mark.x)

        max_y = max(data_mark.y)
        min_y = min(data_mark.y)
        
    else:   
        max_x = max(np.take(data_mark.x, data_mark.selected))
        min_x = min(np.take(data_mark.x, data_mark.selected))

        max_y = max(np.take(data_mark.y, data_mark.selected))
        min_y = min(np.take(data_mark.y, data_mark.selected))
        
        return max_x, min_x, max_y, min_y
    
    for mark in fig.marks:

        if max(mark.x) > max_x:
            max_x = max(mark.x)
            
        if min(mark.x) < min_x:
            min_x = min(mark.x)
            
        if max(mark.y) > max_y:
            max_y = max(mark.y)
            
        if min(mark.y) < min_y:
            min_y = min(mark.y)   
    
    return max_x, min_x, max_y, min_y

In [6]:
def bqplot_to_matplotlib(fig):
    
    plt.ioff()

    axis_x = fig.axes[0]
    axis_y = fig.axes[1]


    plt.rc('text', usetex=True)
    mpl.rcParams['errorbar.capsize'] = 3
    colors_ = list(mcolors.TABLEAU_COLORS.values())


    fig_pdf, ax0 = plt.subplots(nrows=1, ncols=1, figsize=(16, 12), );
    plt.subplots_adjust(left=0.25, bottom=0.25, right=0.9, top=None, wspace=0.0, hspace=0.0)
    ax0.tick_params(axis='both', labelsize=30, pad=10, length=12)

    labels = []

    x_max, x_min, y_max, y_min  = get_limits(fig)


    if type(axis_x.scale) == bqs.LogScale:

        ax0.set_xscale('log')
        ax0.set_xlim(x_min/1.2, x_max*1.2)

    elif type(axis_x.scale) == bqs.LinearScale:

        data_width = x_max - x_min
        ax0.set_xlim(x_min-data_width*0.1, x_max+data_width*0.1)

    if type(axis_y.scale) == bqs.LogScale:

        ax0.set_yscale('log')
        ax0.set_ylim(y_min/1.2, y_max*1.2)

    elif type(axis_y.scale) == bqs.LinearScale:

        data_height = y_max - y_min
        ax0.set_ylim(y_min-data_height*0.1, y_max+data_height*0.1)

    for mark in fig.marks:

        if mark.selected == None:

            data_x = mark.x
            data_y = mark.y

        else:

            if len(mark.selected) < 3:

                data_x = mark.x
                data_y = mark.y

            else:

                data_x = np.take(mark.x, mark.selected)
                data_y = np.take(mark.y, mark.selected)    


        if len(fig.title.split()) > 0: 
            ax0.set_title(r'\textrm{%s}' % fig.title, size=35, pad=20);
        ax0.set_xlabel(r'\textrm{%s}' % axis_x.label, size=30, labelpad=15);
        ax0.set_ylabel(r'\textrm{%s}' % axis_y.label, size=30, labelpad=15);
            
        if type(mark) == bqm.Scatter:
            
            if mark.default_opacities == None:
                default_opacity = 1.0
            elif len(mark.default_opacities) == 0:
                default_opacity = 1.0
            else:
                 default_opacity = mark.default_opacities[0]
            
            
            ax0.scatter(data_x,
                        data_y,
                        label=r'\textrm{%s}' %  mark.name,
                        color=mark.colors[0],
                        s=40*mark.stroke_width,
                        alpha=default_opacity);

        elif type(mark) == bqm.Lines:
            
            if mark.opacities == None:
                opacity = 1.0
            elif len(mark.opacities) == 0:
                opacity = 1.0
            else:
                opacity = mark.opacities[0]
            
            ax0.plot(mark.x,
                     mark.y,
                     label=r'\textrm{%s}' %  mark.name,
                     color=mark.colors[0],
                     linewidth=mark.stroke_width,
                     alpha=opacity
                     );


        handles, labels = plt.gca().get_legend_handles_labels()
        order = [0]
        plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
        ax0.legend(bbox_transform=plt.gcf().transFigure,
              bbox_to_anchor=(0.90, 0.90),
              loc='upper left',
              ncol=1,
              borderaxespad=1.5,
              frameon=False,
              fontsize=25);
        if axis_x.grid_lines != 'none':
                ax0.grid(color='grey', alpha=0.5, linewidth=0.5)
    return fig_pdf
   #if pdf:
   #    save_filename = pdfName.value

   #    if len(save_filename.split()) == 0:
   #        save_filename = 'pdf_file'

   #    if '.pdf' not in save_filename:
   #        save_filename = save_filename + '.pdf'

   #    plt.savefig(save_filename, format='pdf', dpi=300, bbox_inches="tight");
   #    return (save_filename)
   #
   #if png:
   #    save_filename = pngName.value
   #    if len(save_filename.split()) == 0:
   #        save_filename = 'png_file'

   #    if '.png' not in save_filename:
   #        save_filename = save_filename + '.png'
   #    plt.savefig(save_filename, format='png', dpi=300, bbox_inches="tight");
   #    return (save_filename)

In [7]:
def refresh_mpl_figure(a):
    
    plt.close('all')
    mpl_out.clear_output()
    
    fig.title = parse_text(fig.title)
    
    with mpl_out:
        display(bqplot_to_matplotlib(fig))

In [8]:
def parse_text(string):
    if "_" in string:
        
        if string[-1] == "_":
            latex_error_message.value = "<p style='color:red'>Latex could not parse the text in the title/axes. Rewrite it and retry</p>"
            return ''
        else:
            latex_error_message.value = ""
            return "$"+string+"$"
    else:
        latex_error_message.value = ""
        return string

In [9]:
def rename_labels(change):
    fig.title = figure_title.value
    fig.axes[0].label = label_x.value
    fig.axes[1].label = label_y.value

In [10]:
def change_data_name(change):
    obj = change.owner
    
    for i in range(len(middle_block.children)):
        if middle_block.children[i].children[0].children[0] is obj:
            break
            
    fig.marks[i].name = obj.value

In [11]:
def change_data_style(change):
    obj = change.owner
    
    for i in range(len(middle_block.children)):
        if middle_block.children[i].children[0].children[1] is obj:
            break
               
    mark = fig.marks[i]
    marks = [m for m in fig.marks]
    
    if obj.value == "Scatter":
        marks[i] = bqm.Scatter(
                   x = mark.x, 
                   y = mark.y, 
                   scales = mark.scales, 
                   enable_move = False,
                   restrict_x = False,
                   restrict_y = False,
                   selected_style={'opacity': '1'},
                   unselected_style={'opacity': '0.2'},
                   selected = None,
                   colors = mark.colors,
                   name = mark.name
               )
        
    elif obj.value == "Lines":
        
        marks[i] = bqm.Lines(
                    x = sorted(mark.x), 
                    y = sorted(mark.y), 
                    scales = mark.scales, 
                    visible = True,
                    colors = mark.colors,
                    name = mark.name
                )
        
    fig.marks = marks

In [12]:
def change_data_width(change):
    
    obj = change.owner
    
    for i in range(len(middle_block.children)):
        if middle_block.children[i].children[1].children[0] is obj:
            break
            
    fig.marks[i].stroke_width = obj.value

In [13]:
def change_data_opacity(change):
    
    obj = change.owner
    
    for i in range(len(middle_block.children)):
        if middle_block.children[i].children[2].children[0] is obj:
            break
            
    fig.marks[i].opacities = [obj.value]
    fig.marks[i].default_opacities = [obj.value]

In [14]:
def change_data_color(change):
    
    obj = change.owner

    for i in range(len(middle_block.children)):
        if middle_block.children[i].children[3].children[0] is obj:
            break
        
    fig.marks[i].colors = [obj.value]

In [15]:
str(type(fig.marks[0]))[21:-2]

'Scatter'

In [16]:
def return_interface(fig):
    return main_block

In [20]:
def create_interface(fig):
#import sys


#fig = args[0]

#if __name__ == '__main__' and '__file__' in globals():
#    print(True)
#    args = sys.argv
#    fig = args[0]


    #defining the 'blocks' of the interface
    main_block = widgets.VBox([])

    head_block = widgets.VBox([], layout = widgets.Layout(margin='20px 20px 20px 20px', align_items = 'center'))
    body_block = widgets.HBox([])

    left_block = widgets.VBox([], layout = widgets.Layout(width='35%', align_items = 'center'))
    middle_block = widgets.VBox([], layout = widgets.Layout(width='20%',margin='50px 0 0 0'))
    right_block = widgets.VBox([], layout = widgets.Layout(width='45%', align_items = 'center'))


    #head_block's children
    figure_title = widgets.Text(
        value= fig.title,
        placeholder='',
        description='Title:',
        disabled=False
    )

    figure_title.observe(rename_labels, 'value')

    label_x = widgets.Text(
        value= fig.axes[0].label,
        placeholder='',
        description='Axis x:',
        disabled=False
    )

    label_x.observe(rename_labels, 'value')

    label_y = widgets.Text(
        value= fig.axes[1].label,
        placeholder='',
        description='Axis y:',
        disabled=False
    )

    label_y.observe(rename_labels, 'value')

    head_block.children = (figure_title,
                          widgets.HBox([label_x, label_y], layout=widgets.Layout(margin='10px 0 0 0')),

                          )
    #left_block's children
    left_block.children = (fig,)

    #middle_block's children
    for mark in fig.marks:

        data_name_text = widgets.Text(
                value= mark.name
            )

        data_name_text.observe(change_data_name, 'value')

        data_style_dropdown = widgets.Dropdown(
                options = ['Scatter', 'Lines'],
                description='',
                value = str(type(fig.marks[0]))[21:-2]
            )

        data_style_dropdown.observe(change_data_style, 'value')

        width_slider = widgets.FloatSlider(
                value=mark.stroke_width,
                min=0.0,
                max=8.0,
                step=0.2,
                description='Width:',
                disabled=False,
                continuous_update=False,
                orientation='horizontal',
                readout=True,
                readout_format='.2f',
                layout = widgets.Layout(width = '100%')
            )

        width_slider.observe(change_data_width, 'value')

        opacity_slider = widgets.FloatSlider(
                value=2.0,
                min=0.0,
                max=1.0,
                step=0.1,
                description='Opacity:',
                disabled=False,
                continuous_update=False,
                orientation='horizontal',
                readout=True,
                readout_format='.2f',
                layout = widgets.Layout(width = '100%')
            )

        opacity_slider.observe(change_data_opacity, 'value')

        data_color = widgets.ColorPicker(
                concise=False,
                description='',
                value=mark.colors[0],
                disabled=False
            )

        data_color.observe(change_data_color, 'value')

        middle_block.children = middle_block.children + (widgets.VBox([
                                                            widgets.HBox([data_name_text, data_style_dropdown]),
                                                            widgets.HBox([width_slider,]),
                                                            widgets.HBox([opacity_slider,]),
                                                            widgets.HBox([data_color,]),
                                                            widgets.HTML(value = "<hr>",),
                                                        ]),
                                                        )
    #right_block's children

    mpl_out = widgets.Output()

    latex_error_message = widgets.HTML(
        value = "",
    )

    refresh_mpl_figure_button = widgets.Button(
            description='Refresh figure',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Click me',
        )

    refresh_mpl_figure_button.on_click(refresh_mpl_figure)

    right_block.children = (mpl_out,
                            latex_error_message,
                           refresh_mpl_figure_button,
                           )

    body_block.children = (left_block,
                           middle_block,
                           right_block
                          )

    main_block.children = (head_block,
                  body_block)

    #refresh_mpl_figure(fig)

    main_block