In [None]:
import plotly.graph_objs as go
import numpy as np

In [None]:
# Takes a list/numpy array of pairs of XY vectors to plot
class XYChart:
    
    def __init__(self, x=None, y=None, x_label='x', y_label='y', title='Plot', names=None):
        
        self.chart = go.FigureWidget()
        self.update(x=x, y=y, x_label=x_label, y_label=y_label, title=title, names=names)
        self.display()
    
    def update(self, x=None, y=None, x_label='x', y_label='y', title='Plot', names=None):
        # Clear current data (needed if num datasets goes down to clear extra legend entries)
        self.chart.data = []
        
        # Check if y is none
        if y is None:
            y = [np.zeroes(1,1)]
        
        # Format the x and y data into lists of 1D numpy arrays
        (x_format, y_format) = format_xy_data(x, y)
        
        # Pad list of names with y0, y1, etc
        num_datasets = len(y_format)
        names = pad_names(names, num_datasets)
        
        # Make data dict from formatted x and y data
        data = []
        for ii in range(num_datasets):
            data.append(dict(type = 'scatter',
                             x = x_format[ii],
                             y = y_format[ii],
                             mode='lines',
                             name=names[ii]
                            )
                       )
        
        layout = dict(
                        title = title,

                        xaxis=dict(
                            title=x_label,
                        ),

                        yaxis=dict(
                            title=y_label,
                        )
                     )
        
        #self.chart = go.FigureWidget( data=data, layout=layout )
        self.chart.update(dict(data=data, layout=layout))
        
    def display(self):
        display(self.chart)
    

In [None]:
# Check if 1D by checking len() of each element in list/array
def check_if_1d(x):
    # Axis along which plotting vector is stored. Assume plotting is along longest dimension
    plot_axis = 0
    all_elements_same_length = True
    is_1D = True
    last_length = 0
    if x is not None:
        for ii in range(len(x)):
            element = x[ii]
            try:
                length = len(element)
                # If any of the elements are longer than the number of elements, then plot axis is along elements
                if length > len(x):
                    plot_axis = 1
            except TypeError:
                length = 1

            if ii > 0 and last_length != length:
                all_elements_same_length = False
            last_length = length

            if length > 1:
                is_1D = False

        if plot_axis == 0 and all_elements_same_length:
            plot_axis = 0
        else:
            plot_axis = 1
    else:
        is_1D = True
        plot_axis = 0
        
    return (is_1D, plot_axis)

In [None]:
def format_xy_data(x_in, y_in):
        (y_data_1d, plot_axis) = check_if_1d(y_in)
        x_data_1d = check_if_1d(x_in)[0]
        
        ### Convert to list of 1D numpy arrays ###
        # 1D
        if y_data_1d:
            y = [np.array(y_in).reshape(len(y_in))]
            if x_in is not None:
                x = [np.array(x_in).reshape(len(x_in))]
            else:
                x = np.arange(len(y[0]))
                
        # Plot axis is along columns
        elif plot_axis == 0:
            y = list(np.array(y_in).transpose())
            if x_in is not None:
                if x_data_1d:
                    x = [np.array(x_in).reshape(len(x_in))] * len(y)
                else:
                    x = list(np.array(x_in).transpose())
            else:
                x = [np.arange(len(y[0])) for ii in range(len(y))]
        
        # Plot axis is along rows
        elif plot_axis == 1:
            y = [np.array(e).reshape(len(e)) for e in y_in]
            if x_in is not None:
                if x_data_1d:
                    x = [np.array(x_in).reshape(len(x_in))] * len(y)
                else:
                    x = [np.array(e).reshape(len(e)) for e in x_in]
            else:
                x = [np.arange(len(y[ii])) for ii in range(len(y))]
        return (x, y)

In [None]:
def pad_names(names, num_datasets):
    if names is not None:
        num_names = len(names)
        names = list(names)
    else:
        num_names = 0
        names = []

    for ii in range(num_datasets - num_names):
        names.append('y' + str(ii))
    return names

In [None]:
#if __name__ == '__main__':
# x = np.array([[1, 2, 5, 6],[1, 2, 3, 4, 5, 6],[1, 3, 4, 5, 6]]).transpose()
# y = [[1, 2, 5, 6],[-1, -3, -4, -5, -6],[-0.9, -1, -2, -3, -4]]
# xy_chart = XYChart(x, y, x_label='bad', y_label='worse', title='worst', names=['1', 'oreo'])

In [None]:
# y = [np.random.randint(0,5,10),np.random.randint(-6,-3,15)]
# xy_chart.update(y=y, x_label='Idx', y_label='Rand', title='RANDER', names=['rand', 'oreo'])
# xy_chart.update(y=y, x_label='Idx', y_label='Rand', title='RANDER', names=['rand', 'oreo'])
# xy_chart.update(y=y, x_label='Idx', y_label='Rand', title='RANDER', names=['rand', 'oreo'])
# xy_chart.update(y=y, x_label='Idx', y_label='Rand', title='RANDER', names=['rand', 'oreo'])