<a href="https://colab.research.google.com/github/alexandrelacombeLO/python_notebooks/blob/master/traitement%20traction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dash plotly

In [None]:
import base64
import io
from pathlib import Path

import dash
from dash import Dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
from dash import dash_table
from dash_table.Format import Format, Scheme
from dash.exceptions import PreventUpdate

import numpy as np
import pandas as pd

from sklearn.linear_model import LinearRegression
import scipy.stats

import plotly.express as px
import plotly.graph_objects as go
from plotly.validators.scatter.marker import SymbolValidator
raw_symbols = SymbolValidator().values
symbols = [symbol for i, symbol in enumerate(raw_symbols) if i/12 == int(i/12) and symbol < 10]


chart_type = html.Div([
    html.Label("Chart Type"),
    dcc.Dropdown(
        id='chartType-dropdown', 
        value='Scatter',
        options=[{'label': c, 'value': c} for c in ['Boxplot', 'Scatter']],
        clearable=False
    ),
], style={'width':'10%', 'margin-right':'10px'})

title = html.Div([
    html.Label("Title"),
    dcc.Input(
        id='title-input', 
        type="text", 
        size='100',
        debounce=True
#         placeholder=""
        )
], style={'width':'40%'})

x_axis = html.Div([
    html.Label("X-axis"),
    dcc.Dropdown(id='xaxis-dropdown')
], style={'width':'10%'})

y_axis = html.Div([
    html.Label("Y-axis"),
    dcc.Dropdown(id='yaxis-dropdown')
], style={'width':'10%'})

color = html.Div([
    html.Label("Color"),
    dcc.Dropdown(id='color-dropdown')
], style={'width':'10%'})

col = html.Div([
    html.Label("Columns"),
    dcc.Dropdown(id='col-dropdown')
], style={'width':'10%'})

row = html.Div([
    html.Label("Rows"),
    dcc.Dropdown(id='row-dropdown')
], style={'width':'10%'})

size = html.Div([
    html.Label("Marker Size"),
    dcc.Dropdown(id='size-dropdown'),
], style={'width':'15%'})

symbol = html.Div([
    html.Label("Marker Symbol"),
    dcc.Dropdown(id='symbol-dropdown'),
], style={'width':'15%'})

text = html.Div([
    html.Label("Marker Text"),
    dcc.Dropdown(id='text-multidropdown', multi=True, value=[]),
], style={'width':'15%'})

groupby = html.Div([
    html.Label("Groupby"),
    dcc.Dropdown(id='groupby-dropdown'),
], style={'width':'15%'})

hover_data = html.Div([
    html.Label("Hover data"),
    dcc.Dropdown(id='hover-multidropdown', multi=True, value=[])
], style={'width': '40%'})

reg = html.Div([
    html.Label("Linear Regression"),
    dcc.RadioItems(
        id='reg-radioitem',
        value='none',
        options=[{'label': c, 'value': c} for c in ['all', 'by color', 'none']],
        labelStyle={'display': 'inline-block', 'margin-right':'10px'}
    )
], style={'width': '20%'})         

errorbar = html.Div([
    html.Label("show errorbars"),
    dcc.Checklist(
        id='errorbar-checklist',
        options=[{'label': 'Yes', 'value': 'Yes'}],
        value=[]
    )
], style={'width': '12%'})

# keep_cat_col = html.Div([
#     html.Label("keep categorical columns"),
#     dcc.Checklist(
#         id='keep_cat_col',
#         options=[{'label': 'Yes', 'value': 'Yes'}],
#         value=['Yes']
#     )
# ], style={'width': '10%'})

keep_cat_col = html.Div([
    html.Label('keep categorical columns'),
    dcc.Dropdown(id='keep_cat_col', multi=True, value=[])
], style={'width': '20%', 'margin-right':'10px'})

xaxis_range = html.Div(
    children=[
        html.Label("x-axis range"),
        dcc.RangeSlider(
            id='xaxis-range',
            min=0,
            tooltip={'placement': 'bottom'}
        )
    ], style = {'font-size': '80%'}
)

yaxis_range = html.Div(
    children=[
        html.Label("y-axis range"),
        dcc.RangeSlider(
            id='yaxis-range',
            min=0,
            tooltip={'placement': 'bottom'},
            vertical=True
        )
    ], style = {'font-size': '80%'}
)

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

tabs_style = {
    'margin': 'auto',
    'height': '5%',
    'width': '50%'
}

tab_style = {
    'borderBottom': '1px solid #d6d6d6',
    'padding': '6px',
    'fontWeight': 'bold'
}

tab_selected_style = {
    'borderTop': '1px solid #d6d6d6',
    'borderBottom': '1px solid #d6d6d6',
    'backgroundColor': '#119DFF',
    'color': 'white',
    'padding': '6px'
}

app = Dash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div([
    html.Div(id='hidden-div', style={'display':'none'}),
    dcc.Store(id='memory'),
    dcc.Tabs(id='Tab', value='DataTable', children=[
        dcc.Tab(value='DataTable', label='DataTable', children=[
            dcc.Upload(
            id='upload-data',
            children=html.Div([
                'Drag and Drop or ',
                html.A('Select Files')
            ]),
            style={
                'width': '50%',
                'height': '40px',
                'lineHeight': '40px',
                'borderWidth': '1px',
                'borderStyle': 'dashed',
                'borderRadius': '5px',
                'textAlign': 'center',
                'margin': '1% auto'
            },
            multiple=True
        ),
            html.Div([
                html.Div([
                    dcc.Input(
                        id='nomenclature',
                        type='text',
                        placeholder='Split with \'_\' separator',
                    ),
                    html.Button(
                        children='Split',
                        id='Split',
                        n_clicks=0,
                        style={'margin': '0 15px 0 5px'}
                    )
                ]),
                html.Button(
                    children='Download pivot table',
                    id='pivotTable',
                    n_clicks=0,
                )
            ], 
                style={'display': 'flex', 'justify-content': 'flex-end', 'margin-bottom': '5px'}
            ), 
            dash_table.DataTable(id='df')
        ],
        style=tab_style, selected_style=tab_selected_style
        ),
        dcc.Tab(value='Plots', label='Plots', children=[
            html.Div(
                [
                    chart_type,
                    title
                ],
                style={'display':'flex', 'flex-flow': 'row wrap', 'justify-content': 'stretch', 'margin-bottom': '10px', 'margin-top': '10px', 'font-size': '80%'}
            ),
            html.Div(
                [
                    x_axis,
                    y_axis,
                    color,
                    col,
                    row,
                    hover_data,
                ], 
                style={'display':'flex', 'flex-flow': 'row wrap', 'justify-content': 'space-between', 'margin-bottom': '10px', 'margin-top': '10px', 'width': '90%', 'font-size': '80%'}
            ),
            html.Div(
                id='scatter_components', 
                children=[
                    groupby,
                    size,
                    symbol,
                    text,
                    reg,            
                ], style={'display': 'none'}
            ),
             html.Div(
                id='groupby_components', 
                children=[
                    keep_cat_col,
                    errorbar,
                ],
            ),
            html.Div(children=[
                html.Div(
                    id='yaxis-range_div', 
                    children=[
                        yaxis_range
                    ]
                ),
                html.Div([
                    dcc.Graph(id='graph'),
                ]
                )
            ], style={'display':'flex', 'flex-flow': 'row wrap', 'align-items': 'center'}
            ),
            html.Div(
                id='xaxis-range_div', 
                children=[
                    xaxis_range
                ]
            ),
        ],
        style=tab_style, selected_style=tab_selected_style
        )
    ], 
    style=tabs_style
    )
])

@app.callback(
    Output('hidden-div', 'children'),
    [
        Input('pivotTable', 'n_clicks'),
        Input('memory', 'data')
    ]
)
def download_pivot(n_clicks, data):
    if n_clicks > 0 and data:
        pivot = pd.pivot_table(
            pd.DataFrame(data),
            index='ID',
            values=[
                'CROSS-SECTIONAL AREA',
                'ELASTIC EMOD',
                'STRESS 15%',
                'STRESS 25%',
                'POSTYIELD MOD',
                'BREAK EXT',
                'BREAK STRESS',
           ],
            aggfunc='mean'
        )
        pivot.to_excel(Path(r'C:\Users\lacombea\Downloads\pivot_table.xlsx'))
        raise PreventUpdate
    else:
        raise PreventUpdate


def parse_contents(content, filename):
    
    read_csv_params = {
        'sep': '\t',
        'skiprows': [0,1,4],
        'header': 0,
        'index_col': 0,
        'usecols': ['RECORD', 'CROSS-SECTIONAL AREA', 'MEAN DIAMETER', 'MIN DIAMETER', 'MAX DIAMETER', 'ELASTIC EMOD', 'ELASTIC EXT', 'STRESS 15%', 'STRESS 25%', 'PLATEAU STRESS', 'POSTYIELD GRADIENT', 'BREAK EXT', 'BREAK STRESS', 'TOTAL WORK', 'TOUGHNESS'],
    #     'decimal': ','
    }
    
    content_type, content_string = content.split(',')

    decoded = base64.b64decode(content_string)
    try:
        if '.txt' in filename:
                name = filename.replace('_modif', '').replace('.txt', '')
                content = pd.read_csv(io.StringIO(decoded.decode('utf-8')), **read_csv_params)
    except Exception as e:
        print(e)
        return html.Div([
            'There was an error processing this file.'
        ])

    return name, content


@app.callback(
    [
        Output('df', 'columns'),
        Output('df', 'data'),
        Output('df', 'editable'),
        Output('df', 'row_deletable'),
        Output('df', 'row_selectable'),
        Output('df', 'column_selectable'),
        Output('df', 'selected_columns'),
        Output('df', 'filter_action'),
        Output('df', 'sort_action'),
        Output('df', 'sort_mode'),
        Output('df', 'style_data_conditional'),
    ],
    [
        Input('upload-data', 'contents'),
        Input('Split', 'n_clicks')
    ],
    [
        State('upload-data', 'filename'),
        State('df', 'data'),
        State('nomenclature', 'value')
    ]
)
def update_output(list_of_contents, split_button, list_of_names, df_state, split_nomenclature):
    ctx = dash.callback_context
    
    if not ctx.triggered:
        raise PreventUpdate
        
    elif ctx.triggered[0]['prop_id'].split('.')[0] == 'Split':
        df = pd.DataFrame(df_state)
        split = split_nomenclature.split('_')
        df[split] = df.ID.str.split(pat='_', expand=True)
        columns = [{'name': col, 'id': col} for col in df.columns]

    elif ctx.triggered[0]['prop_id'].split('.')[0] == 'upload-data':
        if list_of_contents:
            files = [parse_contents(c, n) for c, n in zip(list_of_contents, list_of_names)]
            dico_fichiers = {}
            for name, content in files:
                dico_fichiers[name] = content

            df = pd.concat(objs=dico_fichiers, names=['ID']).dropna(how='all').reset_index().sort_values(['ID', 'RECORD'])
            # conversion de la colonne RECORD en format int64
            df.astype({'RECORD': 'int64'})

            df['ELASTIC EMOD'] = (df['ELASTIC EMOD'] / 1e6)
            df['ELLIPTICITY'] = df['MAX DIAMETER'] / df['MIN DIAMETER']
            df['MEAN DEVIATION'] = abs(df['MEAN DIAMETER'] - (df['MAX DIAMETER'] + df['MIN DIAMETER']) / 2)
            df['POSTYIELD MOD'] = df['POSTYIELD GRADIENT'] * 30
            df['BREAK RATIO'] = df['BREAK STRESS'] / df['BREAK EXT']

            df = df.drop(columns=['POSTYIELD GRADIENT'])

            df = df if not df_state else pd.concat([pd.DataFrame(df_state), df]).drop_duplicates()

            columns = [{'name': col, 'id': col, 'selectable': True} for col in df.columns]
                       
        else:
            raise PreventUpdate
    
    columns = [{**col, 'type': 'numeric', 'format': Format(scheme=Scheme.decimal_integer)} if col['name'] in ['CROSS-SECTIONAL AREA', 'ELASTIC EMOD', 'POSTYIELD MOD', 'BREAK STRESS'] else {**col} for col in columns]
    columns = [{**col, 'type': 'numeric', 'format': Format(precision=1, scheme=Scheme.fixed)} if col['name']  in ['MEAN DEVIATION', 'STRESS 15%', 'STRESS 25%', 'BREAK EXT'] else {**col} for col in columns]
    columns = [{**col, 'type': 'numeric', 'format': Format(precision=2, scheme=Scheme.fixed)} if col['name']  in ['BREAK RATIO', 'ELLIPTICITY'] else {**col} for col in columns]
    
    style_data_conditional=[
        {
            'if': {
                'filter_query': '{ELASTIC EXT} <= 0.4 or {ELASTIC EXT} >= 4',
                'column_id': 'ELASTIC EXT'
            },
            'backgroundColor': 'tomato',
            'color': 'white'
        },
        {
            'if': {
                'filter_query': '{MEAN DEVIATION} >= 6',
                'column_id': 'MEAN DEVIATION'
            },
            'backgroundColor': 'tomato',
            'color': 'white'
        },
                {
            'if': {
                'filter_query': '{ELLIPTICITY} >= 2.5',
                'column_id': 'ELLIPTICITY'
            },
            'backgroundColor': 'tomato',
            'color': 'white'
        },
        
    ]
        
    return columns, df.to_dict('records'), True, True, 'multi', 'multi', df.columns, 'native', 'native', 'multi', style_data_conditional

@app.callback(
    Output('memory', 'data'), 
    Input('df', 'derived_virtual_data'),
    State('df', 'selected_columns')
)
def store_data(data, columns): 
    if data and columns:
        return pd.DataFrame(data)[columns].to_dict('records')
    else:
        raise PreventUpdate

# Define callback to adapt dropdowns to boxplot or scatter
@app.callback(
    [
        Output('scatter_components', 'style'),
        Output('xaxis-dropdown', 'options'),
    ],
    Input('chartType-dropdown', 'value'),
    Input('memory', 'data'),
)
def set_dropdowns_layout(chartType, data):    
    if data:
        df = pd.DataFrame(data)
        if chartType == 'Boxplot':
            return {'display': 'none'}, \
                    [{'label': c, 'value': c} for c in df.select_dtypes(['object', 'category']).columns.to_list()]

        elif chartType == 'Scatter':
            return {'display': 'flex', 'flex-flow': 'row wrap', 'justify-content': 'space-between', 'margin-bottom': '10px', 'margin-top': '10px', 'width': '67%','font-size': '80%'}, \
                    [{'label': c, 'value': c} for c in df.select_dtypes(['float64', 'int64']).columns.to_list()]
        
    else:
        raise PreventUpdate

#Define callback to adapt xaxis range sliders
@app.callback(
    [
        Output('xaxis-range_div', 'style'),
        Output('xaxis-range', 'max'),
        Output('xaxis-range', 'step'),
        Output('xaxis-range', 'value'),
        Output('xaxis-range', 'marks'),
    ],
    [
        Input('xaxis-dropdown', 'value'),
        Input('memory', 'data'),
    ]
)
def set_xrangesliders(x, data): 
    if data and x:
        df = pd.DataFrame(data)
        if df[x].dtypes in ['float64', 'int64']:
            max = 1.1*df[x].max()
            step = int(max / 100)
            range = [0, max]
            marks = {int(i): str(int(i)) for i in np.linspace(0, max, 11)}
#             mark_values = np.linspace(0, max, 11) # whatever computes the position of the marks
#             mark_labels = {}
#             for mark_val in mark_values:
#                 # work-around bug reported in https://github.com/plotly/dash-core-components/issues/159
#                 # if mark keys happen to fall on integers, cast them to int
#                 if abs(mark_val-round(mark_val)) < 1e-3: # close enough to an int for my use case
#                     mark_val = int(mark_val)
#                 mark_labels[mark_val] = {"label": str(round(mark_val,2))}
                
            return {'width': '80%'}, max, step, range, marks #mark_labels 
        else:
            return {'display': 'none'}, None, None, [None, None], {}
    else:
        return {'display': 'none'}, None, None, [None, None], {}

#Define callback to adapt yaxis range sliders
@app.callback(
    [
        Output('yaxis-range_div', 'style'),
        Output('yaxis-range', 'max'),
        Output('yaxis-range', 'step'),
        Output('yaxis-range', 'value'),
        Output('yaxis-range', 'marks'),
    ],
    [
        Input('yaxis-dropdown', 'value'),
        Input('memory', 'data'),
    ]
)
def set_yrangesliders(y, data): 
    if data and y:
        df = pd.DataFrame(data)
        if df[y].dtypes in ['float64', 'int64']:
            max = 1.1*df[y].max()
            step = int(max / 100)
            range = [0, max]
            marks = {int(i): str(int(i)) for i in np.linspace(0, max, 11)}
#             mark_values = np.linspace(0, max, 11) # whatever computes the position of the marks
#             mark_labels = {}
#             for mark_val in mark_values:
#                 # work-around bug reported in https://github.com/plotly/dash-core-components/issues/159
#                 # if mark keys happen to fall on integers, cast them to int
#                 if abs(mark_val-round(mark_val)) < 1e-3: # close enough to an int for my use case
#                     mark_val = int(mark_val)
#                 mark_labels[mark_val] = {"label": str(round(mark_val,2))}
                
            return {'margin-right': '5%'}, max, step, range, marks #mark_labels 
        else:
            return {'display': 'none'}, None, None, [None, None], {}
    else:
        return {'display': 'none'}, None, None, [None, None], {}
    
    
# Define callback to select all dropdowns options
@app.callback(
    [
        Output('yaxis-dropdown', 'options'),
        Output('color-dropdown', 'options'),
        Output('col-dropdown', 'options'),
        Output('row-dropdown', 'options'),
        Output('hover-multidropdown', 'options'),
        Output('size-dropdown', 'options'),
        Output('symbol-dropdown', 'options'),
        Output('text-multidropdown', 'options'),
        Output('groupby-dropdown', 'options'),
    ],
    [
        Input('chartType-dropdown', 'value'),
        Input('Tab', 'value'),
        Input('memory', 'data'),
    ]
)
def set_dropdowns_options(chartType, tab, data):
    if tab == 'Plots' and data:
            # Set DataFrame
            df = pd.DataFrame(data)
            # options for numeric columns
            num = [{'label': c, 'value': c} for c in df.select_dtypes(['float64', 'int64']).columns.to_list()]
            # options for categorical columns
            cat = [{'label': c, 'value': c} for c in df.select_dtypes(['object', 'category']).columns.to_list()]
            # options for all columns
            return num, cat+num if chartType == 'Scatter' else cat, cat, cat, cat+num, num, cat, cat, cat
    else:
        raise PreventUpdate
    
# Define callback to disable groupby settings if no groupby is selected
@app.callback(
        [
            Output('groupby_components', 'style'),
            Output('keep_cat_col', 'options'),
            Output('keep_cat_col', 'value'),
        ],
        [
            Input('groupby-dropdown', 'value'),
            Input('chartType-dropdown', 'value'),
            Input('memory', 'data'),
        ]
)
def set_groupby_settings(groupby, chartType, data):
    if groupby and chartType == 'Scatter':
        df = pd.DataFrame(data)
        list_cat = df.select_dtypes(['object', 'category']).columns.to_list()
        list_cat.remove(groupby)
        dict_cat = [{'label': c, 'value': c} for c in list_cat]
        return {'display': 'flex', 'flex-flow': 'row wrap', 'justify-content': 'flex-start', 'align-items': 'center', 'margin-bottom': '10px', 'margin-top': '10px', 'font-size': '80%'}, dict_cat, list_cat
    else:
        return {'display': 'none'}, [], []

    
def ci95(data):
    n = len(data.dropna())
    se = scipy.stats.sem(data.dropna())
    h = se * scipy.stats.t.ppf((1 + .95) / 2., n-1)
    return h

# Define callback to update graph
@app.callback(
    [
        Output('graph', 'figure'),
        Output('graph', 'config')
    ],
    [
        Input('chartType-dropdown', 'value'),
        Input('title-input', 'value'),
        Input('xaxis-dropdown', 'value'),
        Input('yaxis-dropdown', 'value'),
        Input('color-dropdown', 'value'),
        Input('col-dropdown', 'value'),
        Input('row-dropdown', 'value'),
        Input('hover-multidropdown', 'value'),
        Input('size-dropdown', 'value'),
        Input('symbol-dropdown', 'value'),
        Input('text-multidropdown', 'value'),
        Input('groupby-dropdown', 'value'),
        Input('reg-radioitem', 'value'),
        Input('errorbar-checklist', 'value'),
        Input('keep_cat_col', 'value'),
        Input('xaxis-range', 'value'),
        Input('yaxis-range', 'value'),
        Input('memory', 'data')
    ]
)
# ___________________________________________________________________ Draw plot function ___________________________________________________________________
def update_figure(chart_type, title, x, y, color, facet_col, facet_row, hover_data, size, symbol, text, groupby, reg, errorbar, keep_cat_col, xaxis_range, yaxis_range,  df_state):

    if y:
#         text = [text] if text else None
        data = pd.DataFrame(df_state)
        
        N_cols = 1 if facet_col is None else data[facet_col].nunique()
        N_rows = 1 if facet_row is None else data[facet_row].nunique()

    # 'Save picture' button properties
        config = {
            'toImageButtonOptions': {
                'format': 'png', # one of png, svg, jpeg, webp
                'filename': 'custom_image',
                'height': 450 * N_rows * .75 if facet_row else 450,
                'width': (450 * N_rows * .75 if facet_row else 450) * 1.6,
                'scale': 1 + max(N_rows, N_cols)  # Multiply title/legend/axis/canvas sizes by this factor
            }
        }
        
        plot_arguments = {
            'x': x,
            'y': y,
            'color': color,
            'facet_col': facet_col,
            'facet_row': facet_row,
            'hover_data': hover_data,
            'height': 600,
            'width': 1066
        }
        
        layout =  {
            'title': {
                'text': title,
                'font_size': 18,
                'x':.5,
                'y': .97,
                'xanchor': 'center',
                'yanchor': 'top',     
            },
            'xaxis': {
                'range': xaxis_range,
                'titlefont_size':16,
                'tickfont_size': 14,
                'hoverformat': '.1f'
            },
            'yaxis': {
                'range': yaxis_range,
                'titlefont_size':16,
                'tickfont_size': 14,
                'hoverformat': '.1f'
            },
            'legend': {
                'bgcolor': 'rgba(0,0,0,0)',
                'title': {'text':'', 'font_size': 16, 'side': 'top left'},
                'font_size': 16, 
                'x': 1.01,
                'y': .5, 
                'xanchor': 'left',
                'yanchor': 'middle', 
                'orientation': 'v'
            },
            'margin': {
                'l': 20,
                'r': 20,
                'b': 20,
                't': 20 if not title else 60
            },
            'paper_bgcolor': 'rgba(0,0,0,0)',
        }

        
        if chart_type == 'Boxplot':
            
            fig = px.box(
                data,
                **plot_arguments,
            )
            
#             fig = px.histogram(
#                 data,
#                 **plot_arguments,
#                 marginal='box',
#             )
        
            fig.update_traces(
                boxpoints='all',
                boxmean=True, #'sd' pour afficher les écart-types,
                pointpos=0,
                jitter=.75,
            )
            
            fig.update_layout(
                **layout
            )
            
            return fig, config
            
        elif chart_type == 'Scatter' and x:
            
            # configure break lines and alert lines for as many column and rows as needed
#             shapes = []
            annotations = []
            for i in range(N_cols * N_rows):
#                 shapes.extend([
#                     {
#                         'type': 'line',
#                         'x0': 400,
#                         'y0': yaxis_range[0],
#                         'x1': 400,
#                         'y1': yaxis_range[1],
#                         'xref': f"x{i + 1}",
#                         'yref': f"y{i + 1}",
#                         'line': {'color': 'red', 'width': 3},
#                     },
#                     {
#                         'type': 'line',
#                         'x0': 600,
#                         'y0': yaxis_range[0],
#                         'x1': 600,
#                         'y1': yaxis_range[1],
#                         'xref': f"x{i + 1}",
#                         'yref': f"y{i + 1}",
#                         'line': {'color': 'rgb(255, 200, 50)', 'width': 3}
#                     }
#                 ])
                annotations.extend([
                    {
                        'text': '<b>Break zone</b>',
#                         'textangle': 270, 
                        'x': xaxis_range[0],
                        'y': 80,
                        'xanchor':'left',
                        'yanchor': 'top',
                        'xref': f"x{i+1}",
                        'yref': f"y{i+1}",
                        'font': {'size': 14, 'color': 'rgb(230,0,0)'},
                        'showarrow': False,
                    },
                    {
                        'text': '<b>Alert zone</b>',
#                         'textangle': 270, 
                        'x': xaxis_range[0],
                        'y': 100,
                        'xanchor': 'left',
                        'yanchor': 'top',
                        'xref': f"x{i+1}",
                        'yref': f"y{i+1}",
                        'font': {'size': 14, 'color': 'rgb(255,170,0)'},
                        'showarrow': False            
                    }
                ])
                
            if data[x].dtypes in ['int64', 'float64']:

                if groupby:
                    group_by = [groupby] + keep_cat_col
#                     group_by = [col for col in group_by if not col is None]
                    columns = [x, y, color, facet_col, facet_row, size, symbol] + hover_data + group_by + text
                    columns = list(set([col for col in columns if not col is None]))
                    data = data[columns].groupby(group_by, as_index=False, observed=True)
                    if errorbar:
                        errorbars = data.agg(ci95)
                    data = data.mean()
#                     print(columns)
#                     display(data)

                fig = px.scatter(
                    data,
                    **plot_arguments,
                    error_x=errorbars[x] if groupby and errorbar else None,
                    error_y=errorbars[y] if groupby and errorbar else None,
                    size=size,
                    symbol=symbol,
                    symbol_sequence=symbols,
#                     text=text,
                    text=[f"{' '.join(val)}" for val in list(zip(*[list(data[x]) for x in text]))] if text else None,
#                     marginal_x='histogram',
#                     marginal_y='histogram'
                )

                fig.update_traces(
                    textposition=['bottom right', 'bottom left', 'bottom right', 'top right', 'bottom right', 'bottom left', 'top right'],
                    textfont_size=16,
                    error_x_thickness=1,
                    error_y_thickness=1
                )
                
                if symbol:
                    fig.for_each_trace(
                        lambda trace: trace.update(showlegend=False) if trace.marker.symbol!=0 else trace.update(name=trace.name.split(',')[0]) 
                    )
                    for i, ele in enumerate(data[symbol].unique()):
                        fig.add_trace(go.Scatter(
                            y=[None],
                            mode='markers',
                            marker=dict(symbol=symbols[i], color='black'),
                            name=ele
                        ))
#                     fig.update_layout(legend_title_text=fig.layout.legend.title.text.split(',')[0])
                                       
                if not size: fig.update_traces(marker_size=10)
                
                fig.update_layout(**layout)
                if y == 'BREAK STRESS':
                    fig.add_trace(go.Scatter(
                        x=[xaxis_range[0], xaxis_range[1]],
                        y=[80, 80],
                        fill='tozeroy', # fill area between trace0 and trace1
                        fillcolor='rgba(230, 0, 0, 0.3)',
                        mode='none',
                        showlegend=False
    #                     mode='lines', line_color='indigo'
                    ))
                    fig.add_trace(go.Scatter(
                        x=[xaxis_range[0], xaxis_range[1]],
                        y=[100, 100],
                        fill='tonexty', # fill area between trace0 and trace1
                        fillcolor='rgba(255, 170, 0, 0.3)',
                        mode='none',
                        showlegend=False
    #                     mode='lines', line_color='indigo'
                    ))
                    fig.update_layout(                    
#                         shapes=shapes,
                        annotations=list(fig.layout.annotations) + annotations
                    )

                # linear regression if asked (and possible)
                if not reg == 'none':

                    model = LinearRegression()
                    if not(facet_col and facet_row):
                        if color and reg == 'by color':

                            x_range = []
                            y_range = []
                            score = []

                            for i, ele in enumerate(data[color].unique()):
                                query = data.query(f"{color} == '{ele}'")[[x, y]].dropna()
                                if query.size > 2:

                                    model.fit(query[[x]], query[y])

                                    x_range.append(np.linspace(query[x].min(), query[x].max(), 100))
                                    y_range.append(model.predict(x_range[i].reshape(-1, 1)))

                                    score.append(model.score(query[[x]], query[y]))

                                    fig.add_traces(go.Scatter(
                                        x=x_range[i],
                                        y=y_range[i],
                                        line_color=px.colors.qualitative.Plotly[i],
                                        opacity=.75,
                                        showlegend=False
                                    ))
                            fig.add_annotation(
                            x=.9*max(fig.layout.xaxis.range),#data[x].max(),
                            xanchor='right',
                            y=data[y].min(),
                            yanchor='top',
                            showarrow=False,
                            text='<br>'.join([f"R² = {'{:.3f}'.format(sc)}" for sc in score])
                            )

                        else:
                            query = data[[x,y]].dropna()
                            model.fit(query[[x]], query[y])
                            x_range = np.linspace(query[x].min(), query[x].max(), 100)
                            y_range = model.predict(x_range.reshape(-1, 1))
                            score = model.score(query[[x]], query[y])

                            fig.add_traces(go.Scatter(
                                x=x_range,
                                y=y_range, 
                                line_color='black',
                                opacity=.75,
                                showlegend=False
                            ))
                            
                            fig.add_annotation(
                                x=.9*max(fig.layout.xaxis.range),#data[x].max(),
                                xanchor='right',
                                y=data[y].min(),
                                yanchor='bottom',
                                showarrow=False,
                                text=f"R² = {'{:.3f}'.format(score)}"
                            )
                                
                return fig, config
            
            else:
                raise PreventUpdate
        else:
            raise PreventUpdate
    else:
        raise PreventUpdate
        
if __name__ == '__main__':
    app.run_server()