In [0]:
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.ar_model import AutoReg
from ortools.linear_solver import pywraplp
from copy import deepcopy
from math import isclose
from scipy import stats
import pandas as pd
import numpy as np
from operator import truediv as divide
from datetime import datetime


site_alias = {
    'southernops_vc': {
        'amrun_loadandhaul': 'p1',
        'amrun_rom_stp': 's1',
        'amrun_bene': 'p2',
        'amrun_p_stp': 's2',
        'amrun_shiploader': 'p3'
    },
    'gove_vc': {
        'gove_loadandhaul': 'p1',
        'gove_rom_stp_lmb': 's1',
        'gove_overland': 'p2',
        'gove_p_stp_lmb': 's2',
        'gove_shiploader': 'p3'
    }
}

site_alias_inv = {i: {v: k for k, v in j.items()} for i, j in site_alias.items()}


def max_attr_value(group):
    """
    Determine the appropriate attribute type for calculating maximum value based on group type.

    Parameters:
    -----------
    group : pandas.DataFrame or pandas.GroupBy
        A DataFrame or grouped data containing an 'attr_name' column

    Returns:
    --------
    str
        'unconstrained' if the group's 'attr_name' is not 'stp_size',
        'stp_size' otherwise

    Notes:
    ------
    This helper function is used when creating histograms to determine
    the appropriate attribute type to query for maximum values. For normal
    production nodes, we use 'unconstrained' values, while for stockpile
    nodes ('stp_size'), we use their own max values.
    """
    return 'unconstrained' if group['attr_name'].iloc[0] != 'stp_size' else 'stp_size'


def classify_seasons(x):
    """
    Categorizes timestamps into wet and dry seasons based on month.

    Parameters:
    -----------
    x : pandas.DataFrame
        DataFrame containing a 'timestamp' column with datetime values

    Returns:
    --------
    numpy.ndarray
        Array of strings with season labels:
        - 'Wet season': Months from October to April (inclusive)
        - 'Dry season': Months from May to September (inclusive)
        - None: Any other case (should not occur with valid data)

    Notes:
    ------
    Northern Australia experiences two distinct seasons:
    - Wet season (October to April): Higher rainfall and humidity
    - Dry season (May to September): Lower rainfall and humidity
    """
    period = np.where(
        (x.timestamp.dt.month >= 10) | (x.timestamp.dt.month <= 4), 'Wet season',
        np.where(
            (x.timestamp.dt.month >= 5) & (x.timestamp.dt.month <= 9), 'Dry season',
            None
        )
    )

    return period


def round_floats(df, decimals=2):
    """
    Rounds all float-type columns in a Pandas DataFrame to a specified number of decimal places.

    Parameters:
        df (pd.DataFrame): The input DataFrame containing numeric columns.
        decimals (int, optional): The number of decimal places to round to. Defaults to 2.

    Returns:
        pd.DataFrame: A DataFrame with all float-type columns rounded to the specified decimal places.
    """

    df[df.select_dtypes(include=['float']).columns] = df.select_dtypes(include=['float']).round(decimals)
    return df


def calc_blocked_starved_days(df):
    """
    Calculates the number of days a node was fully or partially blocked/starved
    based on production constraints and stockpile levels.

    Parameters:
    -----------
    df : pd.DataFrame
        A dataframe containing production and stockpile size data for different nodes.

    Returns:
    --------
    pd.DataFrame
        A dataframe with the following columns:
        - "node_id": Name of the node.
        - "days_fully_blocked": Days when production was fully blocked.
        - "days_partially_blocked": Days when production was partially blocked.
        - "days_fully_starved": Days when production was fully starved.
        - "days_partially_starved": Days when production was partially starved.
        - "total_num_days": Total number of days in the dataset.

    Notes:
    ------
    - A node is **fully blocked** when its constrained production is ≤ 1 and its
      associated stockpile is at maximum capacity.
    - A node is **partially blocked** when constrained production is reduced
      but still > 1, while the stockpile remains full.
    - A node is **fully starved** when its constrained production is ≤ 1 and its
      associated stockpile is empty.
    - A node is **partially starved** when its constrained production is reduced
      but still > 1, while the stockpile is empty.
    - The function evaluates three nodes: 'amrun_loadandhaul', 'amrun_bene', and
      'amrun_shiploader'.
    """
    col = []
    row = {}

    for period in ['Wet season', 'Dry season', 'Whole year']:

        row['node_id'] = 'amrun_loadandhaul'
        row['period'] = period
        row['days_fully_blocked'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            constrained_amrun_loadandhaul <= 1 and\
            stp_size_amrun_rom_stp == stp_size_amrun_rom_stp.max()
            """
            ).shape[0]
        )
        row['days_partially_blocked'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            unconstrained_amrun_loadandhaul != constrained_amrun_loadandhaul and\
            constrained_amrun_loadandhaul > 1 and\
            stp_size_amrun_rom_stp == stp_size_amrun_rom_stp.max()
            """
            ).shape[0]
        )
        row['days_fully_starved'] = 0
        row['days_partially_starved'] = 0
        row['total_num_days'] = df.query("period == @period" if period != 'Whole year' else "period == period").shape[0]
        col.append(deepcopy(row))

        row['node_id'] = 'amrun_bene'
        row['period'] = period
        row['days_fully_blocked'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            constrained_amrun_bene <= 1 and\
            stp_size_amrun_p_stp == stp_size_amrun_p_stp.max()
            """
            ).shape[0]
        )
        row['days_partially_blocked'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            unconstrained_amrun_bene != constrained_amrun_bene and\
            constrained_amrun_bene > 1 and\
            stp_size_amrun_p_stp == stp_size_amrun_p_stp.max()
            """
            ).shape[0]
        )
        row['days_fully_starved'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            constrained_amrun_bene <= 1 and\
            stp_size_amrun_rom_stp == 0
            """
            ).shape[0]
        )
        row['days_partially_starved'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            unconstrained_amrun_bene != constrained_amrun_bene and\
            constrained_amrun_bene > 1 and\
            stp_size_amrun_rom_stp == 0
            """
            ).shape[0]
        )
        row['total_num_days'] = df.query("period == @period" if period != 'Whole year' else "period == period").shape[0]
        col.append(deepcopy(row))

        row['node_id'] = 'amrun_shiploader'
        row['period'] = period
        row['days_fully_blocked'] = 0
        row['days_partially_blocked'] = 0
        row['days_fully_starved'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            constrained_amrun_shiploader <= 1 and\
            stp_size_amrun_p_stp == 0
            """
            ).shape[0]
        )
        row['days_partially_starved'] = (
            df
            .query("period == @period" if period != 'Whole year' else "period == period")
            .query(
            """
            unconstrained_amrun_shiploader != constrained_amrun_shiploader and\
            constrained_amrun_shiploader > 1 and\
            stp_size_amrun_p_stp == 0
            """
            ).shape[0]
        )
        row['total_num_days'] = df.query("period == @period" if period != 'Whole year' else "period == period").shape[0]
        col.append(deepcopy(row))

    return pd.DataFrame(col)


def post_process_sim_result(sim_result, attrs_dict, use_case_id='southernops_vc'):
    """
    Processes simulation results to generate a structured DataFrame with calculated metrics.

    This function extracts data from `sim_result`, organizes it into a Pandas DataFrame, and
    computes various statistical summaries, including annual means, impact analysis, and
    blocked/starved days for different nodes.

    Parameters:
        sim_result (dict): A dictionary containing simulation output data for multiple nodes.
        use_case_id (str, optional): An identifier used to label the resulting impact DataFrame.
                                     Defaults to 'southernops_vc'.

    Returns:
        tuple:
            - pd.DataFrame: A DataFrame containing the raw simulation results structured with
                            attributes, timestamps, and node identifiers.
            - pd.DataFrame: A DataFrame summarizing impacts, including mean values, confidence
                            intervals, and annual metrics.
    """
    df = pd.DataFrame()
    for node_name in sim_result.keys():
        if 'stp' in node_name:
            df = pd.concat([
                df,
                pd.DataFrame(
                    sim_result[node_name]['sp_size'], columns=['attr_value']
                ).assign(
                    node_id=node_name,
                    attr_name='stp_size',
                    timestamp=pd.date_range('2022-01-01', periods=50000)
                ).reset_index()
            ])
        else:
            for attr_name in ['unconstrained', 'constrained', 'ratio']:
                df = pd.concat([
                    df,
                    pd.DataFrame(
                        sim_result[node_name]['production'][attr_name] if attr_name == 'ratio' else (
                            sim_result[node_name]['production'][attr_name] *
                            sim_result[node_name]['production']['ratio']
                        ), columns=['attr_value']
                    ).assign(
                        node_id=node_name,
                        attr_name=attr_name,
                        timestamp=pd.date_range('2022-01-01', periods=50000)
                    ).reset_index()
                ])

    p_stp = (
        df
        .assign(period=lambda x: classify_seasons(x))
        .query(" node_id == 'amrun_p_stp' & attr_name == 'stp_size' ")
    )

    shiploader = (
        df
        .assign(period=lambda x: classify_seasons(x))
        .query(" node_id == 'amrun_shiploader' & attr_name == 'constrained' ")
    )

    bene = (
        df
        .assign(period=lambda x: classify_seasons(x))
        .query(" node_id == 'amrun_bene' & attr_name == 'constrained' ")
    )

    inventory_summary = pd.DataFrame(
        {
            'node_id': ['amrun_p_stp', 'amrun_p_stp', 'amrun_p_stp'],
            'value_chain_node_id': [
                'southernops_vc_amrun_p_stp', 'southernops_vc_amrun_p_stp', 'southernops_vc_amrun_p_stp'
            ],
            'period': ['Dry season', 'Wet season', 'Whole year'],
            'mean_days_to_empty': [
                divide(
                    p_stp.query("period == 'Dry season'").attr_value.mean(),
                    shiploader.query("period == 'Dry season'").attr_value.mean()
                ),
                divide(
                    p_stp.query("period == 'Wet season'").attr_value.mean(),
                    shiploader.query("period == 'Wet season'").attr_value.mean()
                ),
                divide(
                    p_stp.attr_value.mean(),
                    shiploader.attr_value.mean()
                )
            ],
            'mean_days_to_full': [
                divide(
                    (
                        attrs_dict['sp_attr']['amrun_p_stp']['live_size'] -
                        p_stp.query("period == 'Dry season'").attr_value.mean()
                    ),
                    bene.query("period == 'Dry season'").attr_value.mean()
                ),
                divide(
                    (
                        attrs_dict['sp_attr']['amrun_p_stp']['live_size'] -
                        p_stp.query("period == 'Wet season'").attr_value.mean()
                    ),
                    bene.query("period == 'Wet season'").attr_value.mean()
                ),
                divide(
                    (attrs_dict['sp_attr']['amrun_p_stp']['live_size'] - p_stp.attr_value.mean()),
                    bene.attr_value.mean()
                )
            ]
        }
    )

    full_year_annual_means = (
        df
        .assign(
            period="Whole year",
            year_index=lambda x: np.ceil((x['index'] + 1) / 365)
        ).query("~attr_name.isin(['stp_size', 'ratio'])")
        .groupby(
            ['node_id', 'attr_name', 'period', 'year_index'],
            as_index=False,
            group_keys=False
        ).agg(
            mean=pd.NamedAgg(column='attr_value', aggfunc='sum'),
            num_days=pd.NamedAgg(column='attr_value', aggfunc='count'),
        ).query(" num_days == 365 ")
        .groupby(['node_id', 'attr_name', 'period'], as_index=False)
        .agg(
            annual_mean=pd.NamedAgg(column='mean', aggfunc='mean'),
            annual_mean_upper=pd.NamedAgg(column='mean', aggfunc=lambda x: np.mean(x) + 2 * np.std(x)),
            annual_mean_lower=pd.NamedAgg(column='mean', aggfunc=lambda x: np.mean(x) - 2 * np.std(x))
        ).pivot(
            index=['node_id', 'period'],
            columns='attr_name',
            values=['annual_mean', 'annual_mean_upper', 'annual_mean_lower']
        ).reset_index()
        .pipe(
            lambda d: d.set_axis(
                [
                    f"{col[0]}_{col[1]}" if isinstance(col, tuple) and '' not in col else col[0]
                    for col in d.columns
                ],
                axis=1
            )
        )
    )

    seasonal_annual_means = (
        df
        .assign(
            period=lambda x: classify_seasons(x),
            year_index=lambda x: np.ceil((x['index'] + 1) / 365)
        ).query("~attr_name.isin(['stp_size', 'ratio'])")
        .groupby(
            ['node_id', 'attr_name', 'period', 'year_index'],
            as_index=False,
            group_keys=False
        ).agg(
            mean=pd.NamedAgg(column='attr_value', aggfunc='sum'),
            num_days=pd.NamedAgg(column='attr_value', aggfunc='count'),
        ).assign(
            median_num_days=lambda x: x.groupby(['period'])['num_days'].transform('median')
        ).query(" num_days >= median_num_days*0.95 ")
        .groupby(['node_id', 'attr_name', 'period'], as_index=False)
        .agg(
            annual_mean=pd.NamedAgg(column='mean', aggfunc='mean'),
            annual_mean_upper=pd.NamedAgg(column='mean', aggfunc=lambda x: np.mean(x) + 2 * np.std(x)),
            annual_mean_lower=pd.NamedAgg(column='mean', aggfunc=lambda x: np.mean(x) - 2 * np.std(x))
        ).pivot(
            index=['node_id', 'period'],
            columns='attr_name',
            values=['annual_mean', 'annual_mean_upper', 'annual_mean_lower']
        ).reset_index()
        .pipe(
            lambda d: d.set_axis(
                [
                    f"{col[0]}_{col[1]}" if isinstance(col, tuple) and '' not in col else col[0]
                    for col in d.columns
                ],
                axis=1
            )
        )
    )

    annual_means = pd.concat([full_year_annual_means, seasonal_annual_means], ignore_index=True)

    starved_blocked_days = (
        df.assign(
            period=lambda x: classify_seasons(x)
        ).pivot(
            index=['index', 'period'],
            columns=['attr_name', 'node_id'],
            values=['attr_value']
        ).reset_index()
        .pipe(
            lambda d: d.set_axis(
                [
                    f"{col[1]}_{col[2]}" if isinstance(col, tuple) and '' not in col else col[0]
                    for col in d.columns
                ],
                axis=1
            )
        ).assign(
            stp_size_amrun_rom_stp=lambda x: x.stp_size_amrun_rom_stp.shift(1).fillna(0),
            stp_size_amrun_p_stp=lambda x: x.stp_size_amrun_p_stp.shift(1).fillna(0)
        ).pipe(
            calc_blocked_starved_days
        )
    )

    fullyear_summary = (
        df
        .assign(
            period="Whole year"
        ).groupby(
            ['node_id', 'attr_name', 'period'],
            as_index=False
        ).agg(
            mean=pd.NamedAgg(column='attr_value', aggfunc='mean'),
            std=pd.NamedAgg(column='attr_value', aggfunc='std'),
            first_quartile=pd.NamedAgg(column='attr_value', aggfunc=lambda x: np.percentile(x, 25)),
            median=pd.NamedAgg(column='attr_value', aggfunc='median'),
            third_quartile=pd.NamedAgg(column='attr_value', aggfunc=lambda x: np.percentile(x, 75)),
            max=pd.NamedAgg(column='attr_value', aggfunc='max'),
            sem=pd.NamedAgg(column='attr_value', aggfunc=stats.sem)
        )
    )

    seasonal_summary = (
        df
        .assign(
            period=lambda x: classify_seasons(x)
        ).groupby(
            ['node_id', 'attr_name', 'period'],
            as_index=False
        ).agg(
            mean=pd.NamedAgg(column='attr_value', aggfunc='mean'),
            std=pd.NamedAgg(column='attr_value', aggfunc='std'),
            first_quartile=pd.NamedAgg(column='attr_value', aggfunc=lambda x: np.percentile(x, 25)),
            median=pd.NamedAgg(column='attr_value', aggfunc='median'),
            third_quartile=pd.NamedAgg(column='attr_value', aggfunc=lambda x: np.percentile(x, 75)),
            max=pd.NamedAgg(column='attr_value', aggfunc='max'),
            sem=pd.NamedAgg(column='attr_value', aggfunc=stats.sem)
        )
    )

    seasonal_bins = (
        df
        .query("attr_name in ['constrained', 'unconstrained', 'stp_size']")
        .assign(period=lambda x: classify_seasons(x))
        .groupby(['node_id', 'period', 'attr_name'], as_index=False)
        [['node_id', 'period', 'attr_name', 'attr_value']]
        .apply(
            lambda group: pd.DataFrame({
                'node_id': group['node_id'].iloc[0],
                'period': group['period'].iloc[0],
                'type': group['attr_name'].iloc[0],

                'count': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[0],
                'bins_start': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[1][:-1],
                'bins_end': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[1][1:]
            })
        )
    )

    fullyear_bins = (
        df
        .query("attr_name in ['constrained', 'unconstrained', 'stp_size']")
        .assign(period="Whole year")
        .groupby(['node_id', 'period', 'attr_name'], as_index=False)
        [['node_id', 'period', 'attr_name', 'attr_value']]
        .apply(
            lambda group: pd.DataFrame({
                'node_id': group['node_id'].iloc[0],
                'period': group['period'].iloc[0],
                'type': group['attr_name'].iloc[0],

                'count': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[0],
                'bins_start': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[1][:-1],
                'bins_end': np.histogram(
                    group['attr_value'],
                    bins=np.arange(
                        0, df.query(f"attr_name == '{max_attr_value(group)}'").attr_value.max(),
                        20000 if group['attr_name'].iloc[0] == 'stp_size' else 10000
                    )
                )[1][1:]
            })
        )
    )

    bins_all = (
        pd.concat([seasonal_bins, fullyear_bins], ignore_index=True)
        .groupby(
            ['node_id', 'period', 'type'],
            as_index=False,
            group_keys=False
        )[['node_id', 'period', 'type', 'count', 'bins_start', 'bins_end']]
        .apply(
            lambda group: (
                group
                .sort_values('bins_start', ascending=True)
                .query("~(count == 0 & count.shift() == 0)")
            )
        )
    )

    impact = (
        pd.concat(
            [seasonal_summary, fullyear_summary],
            ignore_index=True
        ).query(" attr_name != 'stp_size' ")
        .pivot(
            index=['node_id', 'period'],
            columns='attr_name',
            values=['mean', 'first_quartile', 'median', 'third_quartile', 'max', 'std', 'sem']
        ).reset_index()
        .pipe(
            lambda d: d.set_axis(
                [
                    f"{col[0]}_{col[1]}" if isinstance(col, tuple) and '' not in col else col[0]
                    for col in d.columns
                ],
                axis=1
            )
        ).assign(
            value_chain_node_id=lambda x: use_case_id + '_' + x.node_id,

            impact_tons=lambda x: x.mean_constrained - x.mean_unconstrained,

            impact_tons_upper=lambda x: (
                (x.mean_constrained - x.mean_unconstrained)
                + 2 * np.sqrt(x.sem_constrained**2 + x.sem_unconstrained**2)
            ),

            impact_tons_lower=lambda x: (
                (x.mean_constrained - x.mean_unconstrained)
                - 2 * np.sqrt(x.sem_constrained**2 + x.sem_unconstrained**2)
            ),

            impact=lambda x: (x.mean_constrained - x.mean_unconstrained) / x.mean_unconstrained,

            impact_upper=lambda x: (
                (x.mean_constrained - x.mean_unconstrained)
                + 2 * np.sqrt(x.sem_constrained**2 + x.sem_unconstrained**2)
            ) / x.mean_unconstrained,

            impact_lower=lambda x: (
                (x.mean_constrained - x.mean_unconstrained)
                - 2 * np.sqrt(x.sem_constrained**2 + x.sem_unconstrained**2)
            ) / x.mean_unconstrained
        ).merge(
            annual_means, on=['node_id', 'period'], how='left'
        ).merge(
            starved_blocked_days, on=['node_id', 'period'], how='left'
        )
    )

    impact = pd.concat([impact, inventory_summary])

    return df, impact, bins_all


def calc_constrained_prod(**kwargs):
    """
    Solves a linear optimization problem to determine constrained production levels.

    This function uses Google's OR-Tools linear solver ('GLOP') to maximize production
    while adhering to given constraints on resources and capacities. It is specifically
    designed for use cases 'southernops_vc' and 'gove_vc'.

    Parameters:
    **kwargs: Arbitrary keyword arguments containing the following keys:
        - use_case_id (str): The identifier for the use case.
        - p1_max, p2_max, p3_max (float): Maximum allowable production levels for different
          production stages.
        - r1, r2 (float): Conversion ratios between production stages.
        - s1, s2 (float): Initial stock levels for different stages.
        - s1_max, s2_max (float): Maximum allowable stock levels.

    Returns:
    dict or bool:
        - A dictionary with production variables as keys, each containing:
            - 'value': The computed optimal value of the variable.
            - 'is_constrained': Boolean indicating whether the variable is at its upper limit.
        - False if the use case is not supported.
    """
    solver = pywraplp.Solver.CreateSolver('GLOP')

    if kwargs['use_case_id'] in ['southernops_vc', 'gove_vc']:

        p1 = solver.NumVar(lb=0, ub=kwargs['p1_max'], name='p1')
        p11 = solver.NumVar(lb=0, ub=kwargs['p1_max'], name='p11')
        p12 = solver.NumVar(lb=0, ub=kwargs['p1_max'], name='p12')
        p13 = solver.NumVar(lb=0, ub=kwargs['p2_max'], name='p13')
        p2 = solver.NumVar(lb=0, ub=kwargs['p2_max'], name='p2')
        p3 = solver.NumVar(lb=0, ub=kwargs['p3_max'], name='p3')

        solver.Maximize(p1 + p2 + p3)

        solver.Add(p11 + p12 == p1 * kwargs['r1'])
        solver.Add(p12 + p13 == p2)
        solver.Add(p11 + kwargs['s1'] - p13 <= kwargs['s1_max'])
        solver.Add(p11 + kwargs['s1'] - p13 >= 0)
        solver.Add(p2 * kwargs['r2'] + kwargs['s2'] - p3 <= kwargs['s2_max'])
        solver.Add(p2 * kwargs['r2'] + kwargs['s2'] - p3 >= 0)

        solver.Solve()

        return_dict = {
            var.name(): {'value': var.solution_value()} for var in solver.variables()
        }

        return_dict['s1'] = return_dict['p11']['value'] + kwargs['s1'] - return_dict['p13']['value']
        return_dict['s2'] = return_dict['p2']['value'] * kwargs['r2'] + kwargs['s2'] - return_dict['p3']['value']

        keys_to_delete = ['p11', 'p12', 'p13']
        for del_key in keys_to_delete:
            del return_dict[del_key]

        for ret_key in return_dict.keys():
            if ret_key.startswith('p'):
                return_dict[ret_key]['is_constrained'] = not isclose(
                    kwargs[ret_key + '_max'],
                    return_dict[ret_key]['value']
                )

        return return_dict

    return False


def calc_constrained_prod_norom(**kwargs):
    """
    Solves a linear optimization problem to determine constrained production levels.

    This function uses Google's OR-Tools linear solver ('GLOP') to maximize production
    while adhering to given constraints on resources and capacities. It is specifically
    designed for use cases 'southernops_vc' and 'gove_vc'.

    Parameters:
    **kwargs: Arbitrary keyword arguments containing the following keys:
        - use_case_id (str): The identifier for the use case.
        - p1_max, p2_max, p3_max (float): Maximum allowable production levels for different
          production stages.
        - r1, r2 (float): Conversion ratios between production stages.
        - s1, s2 (float): Initial stock levels for different stages.
        - s1_max, s2_max (float): Maximum allowable stock levels.

    Returns:
    dict or bool:
        - A dictionary with production variables as keys, each containing:
            - 'value': The computed optimal value of the variable.
            - 'is_constrained': Boolean indicating whether the variable is at its upper limit.
        - False if the use case is not supported.
    """
    solver = pywraplp.Solver.CreateSolver('GLOP')

    p1 = solver.NumVar(lb=0, ub=kwargs['p1_max'], name='p1')
    p2 = solver.NumVar(lb=0, ub=kwargs['p2_max'], name='p2')
    p3 = solver.NumVar(lb=0, ub=kwargs['p3_max'], name='p3')

    solver.Maximize(p1 + p2 + p3)

    solver.Add(p1 * kwargs['r1'] == p2)
    solver.Add(p2 * kwargs['r2'] + kwargs['s2'] - p3 <= kwargs['s2_max'])
    solver.Add(p2 * kwargs['r2'] + kwargs['s2'] - p3 >= 0)

    solver.Solve()

    return_dict = {
        var.name(): {'value': var.solution_value()} for var in solver.variables()
    }

    return_dict['s2'] = return_dict['p2']['value'] * kwargs['r2'] + kwargs['s2'] - return_dict['p3']['value']

    for ret_key in return_dict.keys():
        if ret_key.startswith('p'):
            return_dict[ret_key]['is_constrained'] = not isclose(
                kwargs[ret_key + '_max'],
                return_dict[ret_key]['value']
            )

    return return_dict


def calc_attributes(
    num_iters: int,
    nodes: pd.DataFrame,
    gt_value_chain_timeseries: pd.DataFrame
) -> dict:
    """
    Constructs a dictionary of attributes for nodes and stockpiles in a value chain simulation.

    Parameters:
    -----------
    num_iters : int
        The number of iterations for which attributes will be generated.
    nodes : pd.DataFrame
        A dataframe containing metadata about nodes, including `node_id` and `node_type`.
    gt_value_chain_timeseries : pd.DataFrame
        A dataframe containing time-series data for nodes, including production metrics
        (`input_tons`, `output_tons`), capacity (`max_capacity`, `min_capacity`), and inventory.

    Returns:
    --------
    dict
        A dictionary with the following structure:
        - "size": Number of iterations.
        - "num_nodes": Total number of transformation nodes.
        - "node_names": List of transformation node IDs.
        - "sp_names": List of stockpile node IDs.
        - "node_attr": Dictionary mapping transformation nodes to their attributes:
            - "smoothed_trend": Median of recent trend component from seasonal decomposition.
            - "residual_mean": Mean of residual component from decomposition.
            - "residual_std": Standard deviation of residuals.
            - "seasonal_pattern": Seasonal component as a NumPy array.
            - "ratio_mean": Mean efficiency ratio (output/input).
            - "ratio_std": Standard deviation of efficiency ratio.
            - "generated_downtime": Simulated downtime probabilities using an autoregressive model.
            - "size": Number of iterations.
        - "sp_attr": Dictionary mapping stockpile nodes to their attributes:
            - "live_size": Median full capacity (`max_capacity - min_capacity`).
            - "sp_size": Initial stockpile size (set to zero).

    Notes:
    ------
    - The function models downtime using an autoregressive time series model.
    - Seasonal decomposition is applied to `input_tons` to extract trend, seasonality, and residuals.
    - The function ensures that output does not exceed input for transformation nodes.
    """
    attrs_dict = {
        "size" : num_iters,
        "num_nodes" : None,
        "node_names" : [],
        "sp_names" : [],
        "node_attr": {},
        "sp_attr": {}
    }

    for idx, row in nodes.iterrows():
        if row['node_type'] == 'Transformation':
            input_output = (
                gt_value_chain_timeseries
                .query(f" node_id == '{row['node_id']}' and metric in ['input_tons', 'output_tons'] ")
                .pivot(index='timestamp', columns='metric', values='value')
                .reset_index()
                .query("output_tons <= input_tons")
                .assign(
                    output_tons=lambda x: np.where(
                        (x.output_tons == 0) & (x.input_tons > 0), x.input_tons, x.output_tons
                    )
                )
            )

            # input_tons = (
            #     input_output
            #     .replace(0, np.nan)
            #     .assign(
            #         input_tons=lambda x: x.input_tons.interpolate(limit_direction='both'),
            #         output_tons=lambda x: x.output_tons.interpolate(limit_direction='both')
            #     ).filter(['timestamp', 'input_tons'])
            #     .set_index('timestamp')
            # )

            input_tons = (
                input_output
                .query(" input_tons > 0 ")
                .set_index('timestamp')
            )

            result = seasonal_decompose(input_tons['input_tons'], model='additive', period=365, extrapolate_trend=90)
            trend = result.trend
            seasonal = result.seasonal
            residuals = result.resid

            prev_year = datetime.today().year - 1

            mean_trend_mean_ratio = (
                input_tons
                .assign(
                    moving_avg_360d=lambda x: x['input_tons'].rolling(window=360, min_periods=1).mean(),
                    trend=trend
                ).query(f'timestamp.dt.year >= {prev_year}')
                .assign(trend_mean_ratio=lambda x: x.trend / x.moving_avg_360d)
                .trend_mean_ratio
                .mean()
            )

            mean_max_trend_ratio = (
                input_tons
                .assign(
                    moving_max_360d=lambda x: x['input_tons'].rolling(window=360, min_periods=1).max(),
                    moving_trend_360d=trend.rolling(window=360, min_periods=1).mean()
                ).query(f'timestamp.dt.year >= {prev_year}')
                .assign(max_trend_ratio=lambda x: x.moving_max_360d / x.moving_trend_360d)
                .max_trend_ratio
                .mean()
            )

            ratio = (
                input_output
                .query(" input_tons > 0 ")
                .assign(
                    ratio=lambda x: x.output_tons / x.input_tons
                )
                .ratio
            )

            downtime_data = (
                input_output
                .assign(shutdown=lambda x: np.where(x.input_tons == 0, 1, 0))
                .shutdown
                .to_numpy()
            )
            downtime_model = AutoReg(downtime_data, lags=2, trend='n', seasonal=True, period=365).fit()
            forecasted_probability = downtime_model.forecast(num_iters)

            attrs_dict['node_attr'][row['node_id']] = {
                'hist_max_prod': input_tons['input_tons'].max(),
                'smoothed_trend': trend[-60:].median(),
                'mean_trend_mean_ratio': mean_trend_mean_ratio,
                'mean_max_trend_ratio': mean_max_trend_ratio,
                'residual_mean': residuals.mean(),
                'residual_std': residuals.std(),
                'seasonal_pattern': seasonal[:365].to_numpy(),
                'ratio_mean' : ratio.mean(),
                'ratio_std': ratio.std(),
                'generated_downtime': np.random.binomial(1, np.clip(forecasted_probability, 0, 1)),
                'size': num_iters
            }

            attrs_dict['node_names'].append(row['node_id'])

        else:
            max_cap = (
                gt_value_chain_timeseries
                .query(f" node_id == '{row['node_id']}' and metric in ['max_capacity', 'min_capacity'] ")
                .pivot(index='timestamp', columns='metric', values='value')
                .reset_index()
                .assign(
                    full_capacity=lambda x: x.max_capacity - x.min_capacity
                )
                .full_capacity
                .median()
            )

            sp_size = (
                gt_value_chain_timeseries
                .query(f" node_id == '{row['node_id']}' and metric == 'inventory' ")
                .value
                .median()
            )

            attrs_dict['sp_attr'][row['node_id']] = {
                'live_size' : max_cap,
                'sp_size' : sp_size
            }

            attrs_dict['sp_names'].append(row['node_id'])

    attrs_dict['num_nodes'] = len(attrs_dict['node_names'])
    return attrs_dict


class SimCapacity:
    """
    A class to simulate production capacity and constraints for various nodes in a supply chain.

    This class generates unconstrained production values, applies constraints, and computes
    simulation results based on provided attributes.

    Attributes:
    use_case_id (str): The identifier for the specific use case.
    exclude_rom_stp (bool): Flag indicating whether ROM-STP constraints should be excluded.
    attrs_dict (dict): A deep copy of the input attributes dictionary.
    sim_result (dict): A dictionary containing the simulation results.

    Methods:
    gen_unconstrained_prod(size, smoothed_trend, seasonal_pattern, generated_downtime,
                           ratio_mean, ratio_std, residual_mean, residual_std) -> np.array:
        Generates unconstrained production values based on a trend, seasonal pattern,
        and stochastic variations.

    gen_sim_result(attrs_dict) -> dict:
        Generates a full simulation result, computing both unconstrained and constrained
        production values across multiple nodes.

    rerun_sim_with_scenario(params) -> dict:
        Creates a modified simulation scenario by adjusting target attributes in the
        attributes dictionary and regenerating simulation results.
    """
    def __init__(self, attrs_dict):
        self.use_case_id = attrs_dict['use_case_id']
        self.exclude_rom_stp = attrs_dict['exclude_rom_stp']
        self.attrs_dict = deepcopy(attrs_dict)
        self.sim_result = self.gen_sim_result(self.attrs_dict)

    def gen_unconstrained_prod(
        self, size: int, smoothed_trend: float, mean_trend_mean_ratio: float, mean_max_trend_ratio: float,
        seasonal_pattern: np.array, generated_downtime: np.array, ratio_mean: float, ratio_std: float,
        residual_mean: float, residual_std: float, hist_max_prod: float
    ) -> np.array:
        generated_trend = np.tile(smoothed_trend, size)
        generated_season = np.tile(seasonal_pattern, int(size / 365) + 1)[:size]
        generated_residual = stats.norm.rvs(residual_mean, residual_std, size)
        unconstrained_prod = np.clip(
            generated_trend + generated_season + generated_residual,
            a_min=0, a_max=hist_max_prod
        )

        ratio = stats.truncnorm.rvs(
            a=(0 - ratio_mean) / (1e-100 if ratio_std == 0 else ratio_std),
            b=(1 - ratio_mean) / (1e-100 if ratio_std == 0 else ratio_std),
            loc=ratio_mean,
            scale=ratio_std,
            size=size
        )

        return np.where(generated_downtime[:size] == 1, 0, unconstrained_prod), ratio

    def gen_sim_result(self, attrs_dict):
        sim_result = {}
        for node_name in attrs_dict['node_names']:
            sim_result[node_name] = {'production': {}}

            unconstrained_prod, ratio = self.gen_unconstrained_prod(
                **attrs_dict['node_attr'][node_name]
            )

            sim_result[node_name]['production']['unconstrained'] = unconstrained_prod
            sim_result[node_name]['production']['ratio'] = ratio
            sim_result[node_name]['production']['constrained'] = unconstrained_prod.copy()

        for sp_name in attrs_dict['sp_names']:
            sim_result[sp_name] = {
                'current_sp': attrs_dict['sp_attr'][sp_name]['sp_size'],
                'sp_size': np.zeros(attrs_dict['size']),
                'live_size': attrs_dict['sp_attr'][sp_name]['live_size']
            }

        for i in range(attrs_dict['size']):
            init_val = {}
            for key in sim_result.keys():
                if 'production' in sim_result[key].keys():
                    init_val[
                        site_alias[self.use_case_id][key] + '_max'
                    ] = sim_result[key]['production']['unconstrained'][i]

                    init_val[
                        site_alias[self.use_case_id][key].replace('p', 'r')
                    ] = sim_result[key]['production']['ratio'][i]
                else:
                    init_val[site_alias[self.use_case_id][key]] = sim_result[key]['current_sp']
                    init_val[site_alias[self.use_case_id][key] + '_max'] = sim_result[key]['live_size']

            init_val['use_case_id'] = self.use_case_id

            if attrs_dict['exclude_rom_stp']:
                constrained_prod = calc_constrained_prod_norom(**init_val)
            else:
                constrained_prod = calc_constrained_prod(**init_val)

            for key in constrained_prod.keys():
                if key.startswith('p'):
                    sim_result[
                        site_alias_inv[self.use_case_id][key]
                    ]['production']['constrained'][i] = constrained_prod[key]['value']

                else:
                    sim_result[
                        site_alias_inv[self.use_case_id][key]
                    ]['current_sp'] = constrained_prod[key]

                    sim_result[
                        site_alias_inv[self.use_case_id][key]
                    ]['sp_size'][i] = constrained_prod[key]

        for sp_name in attrs_dict['sp_names']:
            for key_to_del in ['live_size', 'current_sp']:
                del sim_result[sp_name][key_to_del]

        return sim_result

    def rerun_sim_with_scenario(
        self,
        params=[{
            'attr2change': 'node_attr',
            'node_id': 'amrun_shiploader',
            'target_var': 'smoothed_trend',
            'is_numeric': True,
            'taget_value': 102489.89,
            'exclude_rom_stp': False
        }]
    ):
        attrs_dict = deepcopy(self.attrs_dict)

        for param in params:
            attr2change, node_id, target_var, is_numeric, taget_value = (
                param['attr2change'], param['node_id'], param['target_var'],
                param['is_numeric'], param['taget_value']
            )

            attrs_dict['exclude_rom_stp'] = param['exclude_rom_stp']

            if attr2change == 'node_attr' and is_numeric:
                attrs_dict[attr2change][node_id][target_var] = taget_value

            elif attr2change == 'node_attr':
                attrs_dict[attr2change][node_id][target_var] += abs(
                    attrs_dict[attr2change][node_id][target_var]
                ) * taget_value

            else:
                attrs_dict[attr2change][node_id][target_var] *= (
                    1 +
                    taget_value
                )

        sim_result = self.gen_sim_result(attrs_dict)

        return sim_result


In [0]:
def load_arrow_table_from_foundry(foundry_client, dataset_rid, branch_name=None, columns=None):
    
    arrow_buffer = foundry_client.datasets.Dataset.read_table(
        dataset_rid,
        format='ARROW',
        branch_name=branch_name,
        columns=columns
    )
    buffer_reader = pa.BufferReader(arrow_buffer)
    reader = pa.ipc.RecordBatchStreamReader(buffer_reader)

    return reader.read_all()

In [0]:
node_status = pl.from_arrow(
    load_arrow_table_from_foundry(
        foundry_client=foundry_client,
        dataset_rid=node_status_gdi_rid,
        branch_name="master",
        columns=['node_id', 'timestamp', 'is_starved', 'is_blocked', 'is_deferred_bottleneck']
    )
)

In [0]:
import warnings
import os
warnings.filterwarnings("ignore")
os.environ["POLARS_VERBOSE"] = "0"
pl.Config.set_fmt_str_lengths(1000)
pl.Config.set_tbl_rows(100)

In [0]:
def create_predicates(node_links, use_case_id = 'gudai_darri_vc'):
    """
    Generates a list of predicates and corresponding column names based on parent-child 
    relationships in a given dataset.

    This function processes a dataset of node links to determine dependencies between nodes 
    and constructs conditional expressions (predicates) using the Polars library. These predicates 
    help analyze whether nodes are blocked, starved, or in specific states based on their parents and children.

    Args:
        node_links (pl.DataFrame): A Polars DataFrame containing columns:
            - 'parent_value_chain_node_id': The parent node identifier.
            - 'child_value_chain_node_id': The child node identifier.
            - 'use_case_id': The identifier for filtering the dataset.
        use_case_id (str, optional): The identifier used to filter and process the dataset. 
            Defaults to 'gudai_darri_vc'.

    Returns:
        tuple: A tuple containing:
            - list of str: Polars expressions defining conditions for each node (e.g., blocked, starved).
            - list of str: Corresponding column names associated with the predicates.
    """

    node_link_dict = (
        node_links
        .with_columns(
            pl.col('parent_value_chain_node_id').str.replace(f'{use_case_id}_', '').alias('parent'),
            pl.col('child_value_chain_node_id').str.replace(f'{use_case_id}_', '').alias('child'),
        )
        .filter(pl.col('use_case_id') == 'gudai_darri_vc')
        .select(['parent', 'child'])
    #    .collect()
        .to_dicts()
    )

    all_nodes = list(set(itertools.chain.from_iterable([list(row.values()) for row in node_link_dict])))
    
    all_predicates = []
    columns_to_select = []
    for node in all_nodes:
        children = [rel_dict['child'] for rel_dict in node_link_dict if rel_dict['parent'] == node]
        parents = [rel_dict['parent'] for rel_dict in node_link_dict if rel_dict['child'] == node]

        columns_to_select.append(
            f"is_deferred_bottleneck#{node}"
        )
        
        if len(parents) > 0:
            all_parents_blocked_predicate = "(" + (" & ").join([f"pl.col('is_blocked#{parent}')" for parent in parents]) + ")"
            all_parents_free_predicate = "(" + (" & ").join([f"~pl.col('is_blocked#{parent}')" for parent in parents]) + ")"
            some_not_all_parents_blocked_predicate = "(" + "(" + (" | ").join([f"pl.col('is_blocked#{parent}')" for parent in parents]) +")" + " & " + "~" + all_parents_blocked_predicate + ")"
            some_or_all_parents_free_predicate = "(" + some_not_all_parents_blocked_predicate + " | " + all_parents_free_predicate + ")"
        else:
            all_parents_blocked_predicate = "pl.lit(True)"
            all_parents_free_predicate = "pl.lit(False)"
            some_not_all_parents_blocked_predicate = "pl.lit(True)"
            some_or_all_parents_free_predicate = "pl.lit(False)"
            
        if len(children) > 0:
            all_children_starved_predicate = "(" + (" & ").join([f"pl.col('is_starved#{child}')" for child in children]) + ")"
            all_children_free_predicate = "(" + (" & ").join([f"~pl.col('is_starved#{child}')" for child in children]) + ")"
            some_not_all_children_starved_predicate = "(" + "(" + (" | ").join([f"pl.col('is_starved#{child}')" for child in children]) +")"  + " & " + "~" + all_children_starved_predicate + ")"
            some_or_all_children_free_predicate = "(" + some_not_all_children_starved_predicate + " | " + all_children_free_predicate + ")"
        else:
            all_children_starved_predicate = "pl.lit(True)"
            all_children_free_predicate = "pl.lit(False)"
            some_not_all_children_starved_predicate = "pl.lit(True)"
            some_or_all_children_free_predicate = "pl.lit(False)"

        if len(parents) > 0 and len(children) > 0:
            all_predicates.append(
                f"pl.when((pl.col('is_deferred_bottleneck#{node}') == pl.lit(1, dtype=pl.Int64)) & {all_parents_blocked_predicate} & {all_children_starved_predicate}).then(1).otherwise(0).alias('blocked_starved#{node}')"
            )
            columns_to_select.append(
                f'blocked_starved#{node}'
            )

        
            all_predicates.append(
                f"pl.when((pl.col('is_deferred_bottleneck#{node}') == pl.lit(1, dtype=pl.Int64)) & {all_parents_free_predicate} & {all_children_free_predicate}).then(1).otherwise(0).alias('notblocked_notstarved#{node}')"
            )
            columns_to_select.append(
                f'notblocked_notstarved#{node}'
            )
        
                
        all_predicates.append(
            f"pl.when((pl.col('is_deferred_bottleneck#{node}') == pl.lit(1, dtype=pl.Int64)) & ({some_not_all_parents_blocked_predicate} | ({all_parents_blocked_predicate} & {all_children_free_predicate}))).then(1).otherwise(0).alias('blocked#{node}')"
        )

        columns_to_select.append(
            f'blocked#{node}'
        )


        all_predicates.append(
            f"pl.when((pl.col('is_deferred_bottleneck#{node}') == pl.lit(1, dtype=pl.Int64)) & ({some_not_all_children_starved_predicate} | ({all_children_starved_predicate} & {all_parents_free_predicate}))).then(1).otherwise(0).alias('starved#{node}')"
        )

        columns_to_select.append(
            f'starved#{node}'
        )
    
    return all_predicates, columns_to_select

In [0]:
node_status_raw = (
 node_status
    .filter(
        pl.col('timestamp').dt.date() == datetime(2025, 3, 5).date(),
    )
    # .collect()
    .pivot(
        on='node_id', 
        index='timestamp',
        values=['is_starved', 'is_blocked', 'is_deferred_bottleneck'],
        separator = '#'
    )
    .with_row_index("id")
)

In [0]:
df_agg = (
    node_status
    # .collect()
    .pivot(
        on='node_id', 
        index='timestamp',
        values=['is_starved', 'is_blocked', 'is_deferred_bottleneck'],
        separator = '#'
    ).with_columns(
       [eval(predicate) for predicate in all_predicates]
    ).select(['timestamp'] + columns_to_select)
    .unpivot(
        index='timestamp',
        variable_name = 'attr_names',
        value_name = 'attr_values'
    ).with_columns(
        pl.col("attr_names").str.splitn("#", 2).struct.rename_fields(["attr_name", "node_id"]).alias("fields")
    ).unnest("fields")
    .pivot(
        on='attr_name', 
        index=['timestamp', 'node_id'],
        values='attr_values'
    ).with_columns(
        pl.col('timestamp').dt.date().alias('date')
    ).group_by(['date', 'node_id'])
    .agg(
        pl.col('is_deferred_bottleneck').sum(),
        pl.col('blocked_starved').sum(),
        pl.col('notblocked_notstarved').sum(),
        pl.col('blocked').sum(),
        pl.col('starved').sum()
    )
)

In [0]:
from copy import deepcopy

class DataTypeNotExists(Exception):
    pass

class OperatorNotExists(Exception):
    pass

def evaluate_rule_truth_values(dataset, rules):
    """
    Evaluates the truth values of rules against a dataset.

    Args:
        dataset (list): A list of data rows, where each row is a dictionary containing data attributes.
        rules (list): A list of rules, where each rule is a dictionary containing rule attributes such as
                      'data_type', 'operator', 'min_value', 'max_value', 'evaluation_node', and 'metric_id'.

    Returns:
        tuple: A tuple containing:
            - rule_truth_values (dict): A dictionary mapping rule IDs to their truth values (True/False).
            - rule_values (dict): A dictionary mapping rule IDs to the evaluated metric values.
            - rule_ref_values (dict): A dictionary mapping rule IDs to their reference values (min and max).
            - rule_operator (dict): A dictionary mapping rule IDs to their operators.
    """

    dataset = [row.asDict() for row in dataset]
    rules = [
        rule
        for rule in rules
        if rule['active_window_start'] <= dataset[0]['timestamp']
        and rule['active_window_end'] > dataset[0]['timestamp']
    ]

    rule_truth_values = {}
    rule_values = {}
    rule_ref_values = {}
    rule_operator = {}
    for rule in rules:
        try:
            if rule['data_type'] == 'NUMERIC':
                ref_val_min = float(rule['min_value'])
                ref_val_max = float(rule['max_value'])
                val_list = [
                    data
                    for data in dataset
                    if data['node_id'] == rule['evaluation_node']
                    and data['metric_id'] == rule['metric_id']
                ]

                if len(val_list) > 0:
                    if val_list[0]['metric_value'] is not None:
                        val = float(val_list[0]['metric_value'])
                    else:
                        val = None
                else:
                    val = None
            elif rule['data_type'] == 'SET':
                ref_val_min = rule['min_value'].strip().split(';')
                ref_val_max = rule['max_value'].strip().split(';')
                val_list = [
                    data
                    for data in dataset
                    if data['node_id'] == rule['evaluation_node']
                    and data['metric_id'] == rule['metric_id']
                ]

                if len(val_list) > 0:
                    if val_list[0]['metric_value'] is not None:
                        val = val_list[0]['metric_value']
                    else:
                        val = None
                else:
                    val = None
            elif rule['data_type'] == 'TEXT':
                ref_val_min = rule['min_value']
                ref_val_max = rule['max_value']
                val_list = [
                    data
                    for data in dataset
                    if data['node_id'] == rule['evaluation_node']
                    and data['metric_id'] == rule['metric_id']
                ]

                if len(val_list) > 0:
                    if val_list[0]['metric_value'] is not None:
                        val = val_list[0]['metric_value']
                    else:
                        val = None
                else:
                    val = None
            else:
                raise DataTypeNotExists(
                    f"{rule['data_type']} type doesn't exist. Allowed types TEXT, NUMERIC, SET"
                )
        except ValueError:
            raise ValueError(
                f"""Error: Cannot convert '{rule['min_value']}' or '{rule['min_value']}'
                    on dataset: '{rule['dataset']}' and metric_id: '{rule['metric_id']}' """
            )

        if val is not None:
            if rule['operator'] == 'RANGE':
                truth_val = val >= ref_val_min and val <= ref_val_max
            elif rule['operator'] == 'GTE':
                truth_val = val >= ref_val_min
            elif rule['operator'] == 'LTE':
                truth_val = val <= ref_val_min
            elif rule['operator'] == 'GT':
                truth_val = val > ref_val_min
            elif rule['operator'] == 'LT':
                truth_val = val < ref_val_min
            elif rule['operator'] == 'EQUAL':
                truth_val = val == ref_val_min
            elif rule['operator'] == 'NOT EQUAL':
                truth_val = val != ref_val_min
            elif rule['operator'] == 'STARTS_WITH':
                truth_val = val.startswith(ref_val_min)
            elif rule['operator'] == 'NOT STARTS_WITH':
                truth_val = not val.startswith(ref_val_min)
            elif rule['operator'] == 'ENDS_WITH':
                truth_val = val.endswith(ref_val_min)
            elif rule['operator'] == 'NOT ENDS_WITH':
                truth_val = not val.endswith(ref_val_min)
            elif rule['operator'] == 'CONTAINS':
                truth_val = ref_val_min in val
            elif rule['operator'] == 'NOT CONTAINS':
                truth_val = ref_val_min not in val
            elif rule['operator'] == 'IN':
                truth_val = val in ref_val_min
            elif rule['operator'] == 'NOT IN':
                truth_val = val not in ref_val_min
            else:
                raise OperatorNotExists(
                    f"""{rule['operator']} operator doesn't exist. Allowed operators
                        RANGE, GTE, LTE, GT, LT, EQUAL, NOT EQUAL, STARTS_WITH, NOT STARTS_WITH,
                        ENDS_WITH, NOT ENDS_WITH, CONTAINS, NOT CONTAINS, IN, NOT IN"""
                )
        else:
            truth_val = False

        rule_truth_values[rule['rule_id']] = truth_val
        rule_values[rule['rule_id']] = val
        rule_ref_values[rule['rule_id']] = {'ref_val_min': ref_val_min, 'ref_val_max': ref_val_max}
        rule_operator[rule['rule_id']] = rule['operator']
    
    return rule_truth_values, rule_values, rule_ref_values, rule_operator

def evaluate_node_rules(rule_truth_values, rule_values, rule_ref_values, rule_operator, node_config_dict, rules):
    """
    Evaluates rules for each node and determines their states (e.g., BLOCKED, STARVED).

    Args:
        rule_truth_values (dict): A dictionary mapping rule IDs to their truth values.
        rule_values (dict): A dictionary mapping rule IDs to their evaluated metric values.
        rule_ref_values (dict): A dictionary mapping rule IDs to their reference values.
        rule_operator (dict): A dictionary mapping rule IDs to their operators.
        node_config_dict (list): A list of dictionaries containing node configuration details.
        rules (list): A list of rules, where each rule is a dictionary containing rule attributes.

    Returns:
        list: A list of dictionaries, where each dictionary contains the evaluation results for a node.
    """

    for i, j in rule_truth_values.items():
        exec(f"{i} = {j}")

    all_nodes_rules = {(rule['node'], rule['rule_type']) for rule in rules}

    all_nodes = {
        node['parent_node_id'] for node in node_config_dict} | {node['child_node_id'] for node in node_config_dict
    }

    results = []

    for node_id in all_nodes:
        node_level_result = {}
        node_level_result['node_id'] = node_id
        for rule_type in ['BLOCKED', 'STARVED']:
            if (node_id, rule_type) in all_nodes_rules:
                truth_exp = [
                    rule['truth_value']
                    for rule in rules
                    if rule['node'] == node_id
                    and rule['rule_type'] == rule_type
                ][0]

                rules_to_save = [
                    rule['rule_id']
                    for rule in rules
                    if rule['node'] == node_id
                    and rule['rule_type'] == rule_type
                ]

                node_level_result[rule_type] = eval(truth_exp)

                node_level_result[rule_type + '_RULES'] = str(
                    {i[0]: i[1] for i in rule_truth_values.items() if i[0] in rules_to_save}
                )

                node_level_result[rule_type + '_OPERATORS'] = str(
                    {i[0]: i[1] for i in rule_operator.items() if i[0] in rules_to_save}
                )

                node_level_result[rule_type + '_VALUES'] = str(
                    {i[0]: i[1] for i in rule_values.items() if i[0] in rules_to_save}
                )

                node_level_result[rule_type + '_REF_VALUES'] = str(
                    {i[0]: i[1] for i in rule_ref_values.items() if i[0] in rules_to_save}
                )

                node_level_result[rule_type + '_RULE_DESC'] = [
                    {
                        'rule_id': rule['rule_id'],
                        'long_rule_desc': rule['rule_explanation'],
                        'short_rule_desc': rule['short_rule_desc']
                    }
                    for rule in rules
                    if (
                        rule['node'] == node_id
                        and rule['rule_type'] == rule_type
                        and rule_truth_values[rule['rule_id']]
                    )
                ]
                node_level_result[rule_type + '_TRUTH'] = truth_exp
            else:
                node_level_result[rule_type] = False
                node_level_result[rule_type + '_RULES'] = None
                node_level_result[rule_type + '_OPERATORS'] = None
                node_level_result[rule_type + '_VALUES'] = None
                node_level_result[rule_type + '_REF_VALUES'] = None
                node_level_result[rule_type + '_RULE_DESC'] = [
                    {'rule_id': None, 'long_rule_desc': None, 'short_rule_desc': None}
                ]
                node_level_result[rule_type + '_TRUTH'] = None

        if node_level_result['BLOCKED'] and node_level_result['STARVED']:
            node_level_result['BLOCKED'] = False

        results.append(node_level_result)

    return results

def resolve_state_conflict(results, node_config_dict):
    """
    Resolves state conflicts between parent and child nodes based on their activity and state.

    Args:
        results (list): A list of dictionaries containing node evaluation results.
        node_config_dict (list): A list of dictionaries containing node configuration details.

    Returns:
        list: A list of dictionaries with updated node states after resolving conflicts.
    """

    inactive_parents = deepcopy([node for node in node_config_dict if node['is_parent_active'] == 0])
    inactive_children = deepcopy([node for node in node_config_dict if node['is_child_active'] == 0])

    for inactive_node in inactive_parents:
        parent_idx = [i for i, j in enumerate(results) if j['node_id'] == inactive_node['parent_node_id']][0]
        parent_node = results.pop(parent_idx)

        child_idx = [i for i, j in enumerate(results) if j['node_id'] == inactive_node['child_node_id']][0]
        child_node = results.pop(child_idx)

        if child_node['STARVED']:
            parent_node['STARVED'] = False
            parent_node['BLOCKED'] = False
        else:
            parent_node['BLOCKED'] = True
            parent_node['STARVED'] = False

        results.append(child_node)
        results.append(parent_node)

    for inactive_node in inactive_children:
        parent_idx = [i for i, j in enumerate(results) if j['node_id'] == inactive_node['parent_node_id']][0]
        parent_node = results.pop(parent_idx)

        child_idx = [i for i, j in enumerate(results) if j['node_id'] == inactive_node['child_node_id']][0]
        child_node = results.pop(child_idx)

        if parent_node['BLOCKED']:
            child_node['STARVED'] = False
            child_node['BLOCKED'] = False
        else:
            child_node['BLOCKED'] = False
            child_node['STARVED'] = True

        results.append(child_node)
        results.append(parent_node)

    return results

def identify_bottlenecks(results, branches):
    """
    Identifies bottlenecks in a value chain by analyzing node states and their relationships.

    Args:
        results (list): A list of dictionaries containing node evaluation results.
        branches (list): A list of branch paths, where each branch is a dictionary containing
                         path details such as 'path_array' and 'path_weight'.

    Returns:
        list: A list of dictionaries with updated node states, including bottleneck identification.
    """

    original_node_states = {}
    for branch in branches:
        branch = branch.asDict()
        is_bottleneck_found = False
        processed = set()
        for parent_node_id, child_node_id in list(zip(branch['path_array'][-2::-1], branch['path_array'][::-1][:-1])):

            if parent_node_id not in original_node_states.keys():
                parent_node_idx = [
                    i for i, j in enumerate(results)
                    if j['node_id'] == parent_node_id
                ][0]

                parent_node_result_original = results.pop(parent_node_idx)
                parent_node_result = deepcopy(parent_node_result_original)
                original_node_states[parent_node_id] = deepcopy(parent_node_result_original)
                processed.add(parent_node_id)
            elif parent_node_id in processed:
                parent_node_idx = [
                    i for i, j in enumerate(results)
                    if (j['node_id'] == parent_node_id)
                    & (j.get('vc_path_id', 'empty') == branch['path_id'])
                ][0]

                parent_node_result_original = results.pop(parent_node_idx)
                parent_node_result = deepcopy(parent_node_result_original)
            else:
                parent_node_result = deepcopy(original_node_states[parent_node_id])
                processed.add(parent_node_id)

            if child_node_id not in original_node_states.keys():
                child_node_idx = [i for i, j in enumerate(results) if j['node_id'] == child_node_id][0]
                child_node_result_original = results.pop(child_node_idx)
                child_node_result = deepcopy(child_node_result_original)
                original_node_states[child_node_id] = deepcopy(child_node_result_original)
                processed.add(child_node_id)
            elif child_node_id in processed:
                child_node_idx = [
                    i for i, j in enumerate(results)
                    if (j['node_id'] == child_node_id)
                    & (j.get('vc_path_id', 'empty') == branch['path_id'])
                ][0]

                child_node_result_original = results.pop(child_node_idx)
                child_node_result = deepcopy(child_node_result_original)
            else:
                child_node_result = deepcopy(original_node_states[child_node_id])
                processed.add(child_node_id)

            parent_node_result['vc_path_id'] = branch['path_id']
            parent_node_result['vc_path_weight'] = branch['path_weight']
            parent_node_result['is_bottleneck'] = parent_node_result.get('is_bottleneck', False)
            parent_node_result['starving_downstream'] = parent_node_result.get('starving_downstream', False)
            parent_node_result['blocking_upstream'] = parent_node_result.get('blocking_upstream', False)
            parent_node_result['bottleneck_desc'] = parent_node_result.get(
                'bottleneck_desc',
                [{'rule_id': None, 'long_rule_desc': None, 'short_rule_desc': None}]
            )

            child_node_result['vc_path_id'] = branch['path_id']
            child_node_result['vc_path_weight'] = branch['path_weight']
            child_node_result['is_bottleneck'] = child_node_result.get('is_bottleneck', False)
            child_node_result['starving_downstream'] = child_node_result.get('starving_downstream', False)
            child_node_result['blocking_upstream'] = child_node_result.get('blocking_upstream', False)
            child_node_result['bottleneck_desc'] = child_node_result.get(
                'bottleneck_desc',
                [{'rule_id': None, 'long_rule_desc': None, 'short_rule_desc': None}]
            )

            if parent_node_result['BLOCKED'] and child_node_result['STARVED']:
                parent_node_result['BLOCKED'] = False

            if child_node_result['STARVED']:
                parent_node_result['starving_downstream'] = True
                if parent_node_result['bottleneck_desc'] == [
                    {'rule_id': None, 'long_rule_desc': None, 'short_rule_desc': None
                }]:
                    parent_node_result['bottleneck_desc'] = child_node_result['STARVED_RULE_DESC']
                else:
                    parent_node_result['bottleneck_desc'] += child_node_result['STARVED_RULE_DESC']

            if parent_node_result['BLOCKED']:
                child_node_result['blocking_upstream'] = True
                if child_node_result['bottleneck_desc'] == [
                    {'rule_id': None, 'long_rule_desc': None, 'short_rule_desc': None
                }]:
                    child_node_result['bottleneck_desc'] = parent_node_result['BLOCKED_RULE_DESC']
                else:
                    child_node_result['bottleneck_desc'] += parent_node_result['BLOCKED_RULE_DESC']

            if not is_bottleneck_found:
                if parent_node_result['starving_downstream'] and not parent_node_result['STARVED']:
                    parent_node_result['is_bottleneck'] = True
                    is_bottleneck_found = True
                elif parent_node_result['blocking_upstream'] and not parent_node_result['BLOCKED']:
                    parent_node_result['is_bottleneck'] = True
                    is_bottleneck_found = True
                elif child_node_result['starving_downstream'] and not child_node_result['STARVED']:
                    child_node_result['is_bottleneck'] = True
                    is_bottleneck_found = True
                elif child_node_result['blocking_upstream'] and not child_node_result['BLOCKED']:
                    child_node_result['is_bottleneck'] = True
                    is_bottleneck_found = True
            results.append(parent_node_result)
            results.append(child_node_result)

    return results

def evaluate_and_identify_bottlenecks(dataset, rules, node_config_dict, branches):
    """
    Evaluates rules against a dataset, resolves state conflicts, and identifies bottlenecks in a value chain.

    This function combines the evaluation of truth values for rules, the resolution of state conflicts 
    between parent and child nodes, and the identification of bottlenecks in the value chain.

    Args:
        dataset (list): A list of data rows, where each row is a dictionary containing data attributes.
        rules (list): A list of rules, where each rule is a dictionary containing rule attributes such as
                      'data_type', 'operator', 'min_value', 'max_value', 'evaluation_node', and 'metric_id'.
        node_config_dict (list): A list of dictionaries containing node configuration details.
        branches (list): A list of branch paths, where each branch is a dictionary containing
                         path details such as 'path_array' and 'path_weight'.

    Returns:
        list: A list of dictionaries containing the evaluation results for nodes, including bottleneck identification.
    """

    rule_truth_values, rule_values, rule_ref_values, rule_operator = evaluate_rule_truth_values(dataset, rules)
    results = evaluate_node_rules(rule_truth_values, rule_values, rule_ref_values, rule_operator, node_config_dict, rules)
    results = resolve_state_conflict(results, node_config_dict)

    return identify_bottlenecks(results, branches)