In [1]:
%config IPCompleter.greedy=True

from os import path, getcwd
from pathlib import Path
import sys
my_project_path = path.normpath(Path(getcwd()).parent)
sys.path.insert(1,my_project_path)

import numpy as np

from src.processing.processing_methods import *
from src.graphs.interactive_widgets_configurations import *
from src.readers.operations_variable import VariableOperations
from src.graphs.constants.constants import JupyterWidgetsConstants

import ipywidgets as widgets
import matplotlib.pyplot as plt
plt.ion()

CW: JupyterWidgetsConstants = JupyterWidgetsConstants()
PLOT_SIZE_SINGLE = (8, 4)


def create_widgets(
        func_configurations: dict, 
        func_defaults: dict,
        x_available: list = [],
) -> dict:
    """ Switch-case widget generator. 
    
    Generate widgets due to method configuration instruction.
    For more details please refeer to class: JupyterWidgetsConstants
            src/visualization/constants/constants
    Widget documentation: 
            https://ipywidgets.readthedocs.io/en/8.0.2/examples/Widget%20List.html
    Args:
        widget_configuration (dict): Widget defaults. 
        func_defaults (dict): Widgets defaults stored in method definition.
        x_available (list): X-axis column names available in dataset (domain switch).
            <Empty>: Use <data_x_column_name> argument.
            <items>: Generate widget with passed options.

    Returns:
        (dict): Container with generated widget objects.
    """
    widgets_stack = {}
    if x_available:
        widgets_stack = {"x": widgets.Dropdown(options=x_available, value=x_available[0])}
    
    # Create widgets defined in configuration
    for arg, arg_configuration in func_configurations.items():
        
        widget_type = arg_configuration[CW.WIDGET_TYPE]
        configuration = arg_configuration[CW.WIDGET_CONFIGURATION]
        configuration["value"] = func_defaults.get(arg)
        
        if widget_type == CW.INT_SLIDER:
            widget_object_reference = widgets.IntSlider
        elif widget_type == CW.INT_RANGE_SLIDER:
            widget_object_reference = widgets.IntRangeSlider
        elif widget_type == CW.FLOAT_SLIDER:
            widget_object_reference = widgets.FloatSlider
        elif widget_type == CW.FLOAT_LOG_SLIDER:
            widget_object_reference = widgets.FloatLogSlider
        elif widget_type == CW.FLOAT_RANGE_SLIDER:
            widget_object_reference = widgets.FloatRangeSlider
        elif widget_type == CW.DROPDOWN:
            widget_object_reference = widgets.Dropdown
        elif widget_type == CW.CHECKBOX:
            widget_object_reference = widgets.Checkbox
        else:
            raise NotImplementedError(f"Demanded widget <{widget_type}> not supported / recognized.")
        try:
            widgets_stack[arg]=widget_object_reference(**configuration)
        except Exception as e:
            print(f"Arg: {arg}\n Config: {configuration}\nErr msg: {e}")
        
    return widgets_stack


def parse_x_label(label_x: str) -> str:
    """ Return label description by column name.
    Args:
        label_x (str): Column name to be parsed with a diagram x-axis label.
    Return:
        (str): Parsed label.
    """ 
    if label_x == "pixels":
        return "Pixels #"
    elif label_x == "wavenumber":
        return "Wavenumber [cm-1]"
    elif label_x == "wavelength":
        return "Wavelength [nm]"
    elif label_x == "chem_shift":
        return "Chemical shift [ppm]"
    elif label_x == "timestamp":
        return "Time [Unix]"
    elif label_x == "frequency":
        return "Frequency [Hz]"
    else:
        return label_x

    
def unpack_folds(
        dataset,
        folds_names: list,
        data_x_column_name: str,
        data_y_column_name: str,
):
    """
    Args:
        folds_names (list): List of fold to be drawn.
        <Empty>: Draw first enumerated fold.
        <one item>: Draw one fold.
        <multiple items>: Draw multiple folds on one diagram.
    """
    #TODO Rebuild function using DataFormat wrappers pack/unpack
    #TODO Add colouring for multiple fold + add color gen. inside plot function
    
    if folds_names:
        if len(folds_names) == 1:
            x = dataset[folds_names[0]].get(data_x_column_name)
            y = dataset[folds_names[0]].get(data_y_column_name) 
            return x, y
        
        else:
            x_extended = []
            y_extended = []
            for fold_name in folds_names:
                x_extended.append(dataset[fold_name].get(data_x_column_name))
                y_extended.append(dataset[fold_name].get(data_y_column_name))
            return x_extended, y_extended
    
    # If fold names werent passed, process first iterated fold of passed dataset.
    else:
        for fold_name, fold_data in dataset.items():
            x=fold_data.get(data_x_column_name)
            y=fold_data.get(data_y_column_name)
            return x, y

        
def _diagram_draw(
        x: np.ndarray or list,
        y: np.ndarray or list,
        label_x: str = "Wavenumber [cm-1]",
        label_y: str = "Amplitude #",
        title: str = "Spectra"
):
    """ Basic plotting method.

    Args:
        x (np.ndarray): X-axis data.
        y (np.ndarray): Y-axis data.
        label_x (str): X-axis description.
        label_y (str): Y-axis description.
        title (str): Plot title.

    Returns:
        figure, axes objects.
    """
    fig, ax = plt.subplots(figsize=PLOT_SIZE_SINGLE)
    if VariableOperations.is_matrix_1d(x):
        for signal in y:
            ax.plot(x, signal, "-", color="k", linewidth=0.1)
    else:
        for idx, signal in enumerate(y):
            ax.plot(x[idx], signal, "-", color="k", linewidth=0.1)
    ax.set(
        xlabel=label_x,
        ylabel=label_y,
        title=title
    )
    ax.grid()
    fig.tight_layout()
    return fig, ax


def diagram_draw(
        dataset: dict, 
        title: str, 
        data_x_column_name: str,
        data_y_column_name: str,
        folds_names: list = [],
):
    """ Create pyplot objects. 

    Args:
        dataset (dict): Nested dataset.
        title (str): Plot title.
        data_x_column_name (str): X-data column name in dataset container.
        data_y_column_name (str): Y-data column name in dataset container.
        folds_names (list): List of fold to be drawn.
            <Empty>: Draw first enumerated fold.
            <one item>: Draw one fold.
            <multiple items>: Draw multiple folds on one diagram.
    Returns:
        figure, axes objects.
    """

    label_x = parse_x_label(data_x_column_name)
    x, y = unpack_folds(dataset=dataset, data_x_column_name=data_x_column_name, data_y_column_name=data_y_column_name, folds_names=folds_names)
    
    return _diagram_draw(x=x, y=y, title=title, label_x=label_x)
    

def _diagram_update(
        ax,
        x: np.ndarray,
        y: np.ndarray,
        label_x: str = "Wavenumber [cm-1]",
        label_y: str = "Amplitude #",
        title: str = "Spectra"
):
    """ Clear and replot axis container by passed data.

    Args:
        ax: Axis container.
        x (np.ndarray): X-axis data.
        y (np.ndarray): Y-axis data.
        label_x (str): X-axis description.
        label_y (str): Y-axis description.
        title (str): Plot title.

    Returns:
        Updated axis container.
    """
    # Clear the axis container.
    ax.clear()
    # Replot spectra.
    if VariableOperations.is_matrix_1d(x):
        for signal in y:
            ax.plot(x, signal, "-", color="k", linewidth=0.1)
    else:
        for idx, signal in enumerate(y):
            ax.plot(x[idx], signal, "-", color="k", linewidth=0.1)
    # Set plot description.
    ax.set(
        xlabel=label_x,
        ylabel=label_y,
        title=title
    )
    return ax

    
def diagram_update(
        ax, 
        dataset: dict, 
        title: str, 
        data_x_column_name: str,
        data_y_column_name: str,
        folds_names: list = []
) -> None:
    """ Update pyplot axes object.

    Args:
        ax: Axis container.
        dataset (dict): Nested dataset.
        title (str): Plot title.
        label_x (str): X-data column name in dataset container.
        label_y (str): Y-data column name in dataset container.
        folds_names (list): List of fold to be drawn.
            <Empty>: Draw first enumerated fold.
            <one item>: Draw one fold.
            <multiple items>: Draw multiple folds on one diagram.

    Returns:
        -
    """
    label_x = parse_x_label(data_x_column_name)
    x, y = unpack_folds(dataset=dataset, data_x_column_name=data_x_column_name, data_y_column_name=data_y_column_name, folds_names=folds_names)   
    
    _diagram_update(ax=ax, x=x, y=y, title=title, label_x=label_x)


def diagram_interact(
        ax, 
        dataset,     
        processing_func, 
        class_name, 
        func_name, 
        func_instruction,
        func_configurations, 
        data_x_column_name: str = "axis_wavenumber",
        data_y_column_name: str = "spectra",
        folds_names: list = [],
        x_available: list = [],
        **args
) -> None:
    """ Interact plot update method.
    Apply processing method with arguments based on the values set on the widgets, then update the graph with the processed spectra.

    Args:
        ax: Axes object that will be updated.
        dataset: Dataset to process.
        processing_func: Function to be applied.
        class_name: The class name to which the processing method belongs.
        func_name: The processing method name.
        func_instruction: The processing method instruction.
        func_configurations: Widgets generation configurations stored in:
            plotter_interactive_preprocessor_config
        data_x_column_name (str): X-data column name in dataset container.
        data_y_column_name (str): Y-data column name in dataset container.
        folds_names (list): List of fold to be drawn.
            <Empty>: Draw first enumerated fold.
            <one item>: Draw one fold.
            <multiple items>: Draw multiple folds on one diagram.
        x_available (list): X-axis column names available in dataset (domain switch).
            <Empty>: Use <data_x_column_name> argument.
            <items>: Generate widget with passed options.
        **args: Widget dictionary passed as link.

    Returns:
        -
    """
    func_arguments=[v for v in args.values()]
    
    if x_available:
        data_x_column_name=func_arguments.pop(0)
 
    plot_title=f"Processing method: {class_name}.{func_name}\nArgs: {func_arguments}"
    
    kwargs_updated = func_configurations.keys()
    kwargs_updated = {key:func_arguments[idx] for idx, key in enumerate(kwargs_updated)}
    try:
        dataset_processed=processing_func(
            data=dataset, instruction=func_instruction, **kwargs_updated
        )
    except Exception as e:
        print(f"PROCESSING ERR! Error msg:\n{e}")
        return
    diagram_update(
        ax, 
        dataset=dataset_processed, 
        title=plot_title, 
        data_x_column_name=data_x_column_name, 
        data_y_column_name=data_y_column_name, 
        folds_names=folds_names
    )



NameError: name 'np' is not defined