In [None]:
import numexpr as ne
import numpy as np
import pandas as pd
from xarray import open_dataset

from mosartwmpy import Model

from mosartwmpy.config.parameters import Parameters
from mosartwmpy.grid.grid import Grid
from mosartwmpy.state.state import State


In [None]:
grid = Grid.from_files('./speed_test_grid.zip')
x = open_dataset('./speed_test_state.nc')
state = State.from_dataframe(x.to_dataframe())
x.close()
parameters = Parameters()
delta_t = 3 * 60 * 60 / 5 / 3


In [None]:
# full
# def full(state, grid):
#     mosart = Model()
#     mosart.initialize('../config.yaml', state=state, grid=grid)
#     mosart.update_until(mosart.get_end_time())


In [None]:
%timeit -n 10 -r 7 full(state, grid)

In [None]:
# numexpr
def numexpr(state: State, grid: Grid, parameters: Parameters, delta_t: float) -> None:
    
    has_reservoir = np.isfinite(grid.reservoir_id)
    
    flow_volume = calculate_flow_volume(has_reservoir, parameters.reservoir_flow_volume_ratio, delta_t, state.channel_outflow_downstream)
    
    state.channel_outflow_downstream = remove_flow(has_reservoir, state.channel_outflow_downstream, flow_volume, delta_t)
    
    cells = pd.DataFrame({'id': grid.id[state.grid_cell_unmet_demand > 0]}).set_index('id')
    cells['supply'] = 0
    
    # join grid cell demand, then drop where no demand
    demand = grid.reservoir_to_grid_mapping.join(pd.DataFrame(state.grid_cell_unmet_demand, columns=['grid_cell_demand']))
    demand = demand[demand.grid_cell_demand.gt(0)]
    
    # aggregate demand to each reservoir and join to flow volume
    reservoir_demand_flow = demand.groupby('reservoir_id')[['grid_cell_demand']].sum().rename(columns={'grid_cell_demand': 'reservoir_demand'}).join(pd.DataFrame({'flow_volume': flow_volume, 'reservoir_id': grid.reservoir_id}).dropna().set_index('reservoir_id'))
    
    for _ in np.arange(parameters.reservoir_supply_iterations):
        
        if _ == 0:
            case = reservoir_demand_flow
        else:
            # subset reservoir list to speed up calculation
            case = reservoir_demand_flow[np.isin(reservoir_demand_flow.index.astype(int).values, demand.reservoir_id.unique())]
            case.loc[:, 'reservoir_demand'] = case.join(demand.groupby('reservoir_id')[['grid_cell_demand']].sum()).grid_cell_demand.fillna(0)
        
        # ratio of flow to total demand
        case.loc[:, 'demand_fraction'] = divide(case.flow_volume.values, case.reservoir_demand.values)
        
        # case 1
        if case.demand_fraction.gt(1).any():
            case = demand[np.isin(demand.reservoir_id.values, case[case.demand_fraction.gt(1)].index.astype(int).values)]
            case.loc[:, 'condition_count'] = case.groupby(case.index)['reservoir_id'].transform('count')
            case.loc[:, 'supply'] = divide(case.grid_cell_demand, case.condition_count)
            taken_from_reservoir = reservoir_demand_flow.join(case.groupby('reservoir_id').supply.sum()).supply.fillna(0).values
            reservoir_demand_flow.loc[:, 'reservoir_demand'] -= taken_from_reservoir
            reservoir_demand_flow.loc[:, 'flow_volume'] -= taken_from_reservoir
            # all demand was supplied to these cells
            cells.loc[:, 'supply'] += cells.join(case.groupby(case.index)[['grid_cell_demand']].first()).grid_cell_demand.fillna(0)
            demand = demand[~demand.index.isin(case.index.unique())]
        
        else:
            # sum demand fraction
            case = demand.merge(case, how='left', left_on='reservoir_id', right_index=True)
            case.loc[:, 'demand_fraction_sum'] = case.groupby(case.index).demand_fraction.transform('sum').fillna(0).values
            
            # case 2
            if case.demand_fraction_sum.ge(1).any():
                case = case[case.demand_fraction_sum.ge(1)]
                case.loc[:, 'supply'] = case.grid_cell_demand.values  * case.demand_fraction.values / case.demand_fraction_sum.values
                taken_from_reservoir = reservoir_demand_flow.join(case.groupby('reservoir_id')['supply'].sum()).supply.fillna(0).values
                reservoir_demand_flow.loc[:, 'reservoir_demand'] = subtract(reservoir_demand_flow.reservoir_demand.values, taken_from_reservoir)
                reservoir_demand_flow.loc[:, 'flow_volume'] = subtract(reservoir_demand_flow.flow_volume.values, taken_from_reservoir)
                # all demand was supplied to these cells
                cells.loc[:, 'supply'] += cells.join(case.groupby(case.index)[['grid_cell_demand']].first()).grid_cell_demand.fillna(0)
                demand = demand[~demand.index.isin(case.index.unique())]
                
            else:
                case = case[case.demand_fraction_sum.gt(0)]
                case.loc[:, 'supply'] = case.grid_cell_demand.values * case.demand_fraction.values
                taken_from_reservoir = reservoir_demand_flow.join(case.groupby('reservoir_id')['supply'].sum()).supply.fillna(0).values
                reservoir_demand_flow.loc[:, 'reservoir_demand'] -= taken_from_reservoir
                reservoir_demand_flow.loc[:, 'flow_volume'] -= taken_from_reservoir
                # not all demand was supplied to these cells
                supplied = cells[[]].join(case.groupby(case.index)[['supply']].sum()).supply.fillna(0)
                cells.loc[:, 'supply'] += supplied
                demand.loc[:, 'grid_cell_demand'] -= demand[[]].join(supplied).fillna(0).supply.values
    
    # merge the supply back in and update demand
    supplied = pd.DataFrame(grid.id).join(cells).supply.fillna(0).values
    state.grid_cell_supply = add(state.grid_cell_supply, supplied)
    state.grid_cell_unmet_demand = subtract(state.grid_cell_unmet_demand, supplied)
    
    # add the residual flow volume back
    state.channel_outflow_downstream[:] -= pd.DataFrame(grid.reservoir_id, columns=['reservoir_id']).merge(reservoir_demand_flow.flow_volume, how='left', left_on='reservoir_id', right_index=True).flow_volume.fillna(0).values / delta_t


calculate_flow_volume = ne.NumExpr(
    'where('
        'has_reservoir,'
        '-(reservoir_flow_volume_ratio * delta_t * channel_outflow_downstream),'
        '0'
    ')',
    (('has_reservoir', bool), ('reservoir_flow_volume_ratio',  np.float64), ('delta_t', np.float64), ('channel_outflow_downstream', np.float64))
)

remove_flow = ne.NumExpr(
    'where('
        'has_reservoir,'
        'channel_outflow_downstream + flow_volume / delta_t,'
        'channel_outflow_downstream'
    ')',
    (('has_reservoir', bool), ('channel_outflow_downstream',  np.float64), ('flow_volume', np.float64), ('delta_t', np.float64))
)

divide = ne.NumExpr(
    'a / b',
    (('a', np.float64), ('b', np.float64))
)

subtract = ne.NumExpr(
    'a - b',
    (('a', np.float64), ('b', np.float64))
)

add = ne.NumExpr(
    'a + b',
    (('a', np.float64), ('b', np.float64))
)


In [None]:
%timeit -n 10 -r 7 numexpr(state, grid, parameters, delta_t)
