In [9]:
import os 
import sys 
from typing import List 

cur_path = os.path.abspath("../..")
if cur_path not in sys.path: 
    sys.path.append(cur_path)

from functools import cache 
import numpy as np 
import pandas as pd 
import altair as alt 
from IPython.display import clear_output
from altair import datum
from dotenv import load_dotenv
from subgrounds.subgrounds import Subgrounds, Subgraph
from subgrounds.pagination import ShallowStrategy

# Required when developing in a jupyter-notebook environment 
load_dotenv('../../../.env')

from utils_notebook.utils import ddf, load_subgraph, remove_prefix
from utils_notebook.vega import (
    output_chart, 
    apply_css, 
    stack_order_expr, 
    wide_to_longwide, 
    chart_stack_area_overlay_line_timeseries,
)
from utils_notebook.queries import adjust_precision, QueryManager
from utils_notebook.testing import validate_season_series
from utils_notebook.css import css_tooltip_timeseries_multi_colored
from utils_notebook.vega import condition_union, XAXIS_DEFAULTS

In [2]:
sg: Subgrounds
bs: Subgraph
sg, bs = load_subgraph()

In [3]:
q = QueryManager(sg, bs) 

In [4]:
df = q.query_silo_daily_snapshots()

In [5]:
df['totalBeanMints'] = df.dailyBeanMints.cumsum()

In [6]:
df.tail()

Unnamed: 0,season,dailyBeanMints,totalBeanMints
2814,7929,3033.972532,76265740.0
2815,7953,0.0,76265740.0
2816,7976,4913.763723,76270660.0
2817,8001,5066.463488,76275720.0
2818,8004,1061.324981,76276790.0


In [7]:
def possibly_override(data = None, defaults = None, override = False):
    defaults = defaults or {}
    data = data or {} 
    # Mix by default, override optionally 
    return {**defaults, **data} if not override else data 


def chart(
    df: pd.DataFrame, 
    timestamp_col: str, 
    lmetrics: List[str], 
    rmetrics: List[str] = None, 
    title: str = '', 
    xaxis_kwargs = None, 
    xaxis_kwargs_override: bool = False, 
    yaxis_left_kwargs: dict = None, 
    yaxis_left_kwargs_override: bool = False, 
    yaxis_right_kwargs: dict = None, 
    yaxis_right_kwargs_override: bool = False, 
    color_map = None,      
    tooltip_formats = None, 
    separate_y_axes: bool = False, 
    show_exploit_rule: bool = True, 
    exploit_day: int = 17, # must be either 16 or 17
    width: int = 700, 
): 
    """Creates a chart with a shared time axis and up to two y axes 
        
    Assumes that data is in long-wide format (i.e. df was processed with function wide_to_longwide)
    """
    rmetrics = rmetrics or []
    assert not set(lmetrics).intersection(set(rmetrics)), "Same metric on two axes"
    metrics = lmetrics + rmetrics
    tooltip_formats = tooltip_formats or {}
    xaxis_kwargs = possibly_override(xaxis_kwargs, XAXIS_DEFAULTS, override=xaxis_kwargs_override)
    yaxis_left_kwargs = possibly_override(yaxis_left_kwargs, None, override=yaxis_left_kwargs_override)
    yaxis_right_kwargs = possibly_override(yaxis_right_kwargs, None, override=yaxis_right_kwargs_override)
    
    # construct axes 
    xaxis = alt.Axis(**xaxis_kwargs)
    yaxis_left = alt.Axis(**yaxis_left_kwargs)
    yaxis_right = alt.Axis(**yaxis_right_kwargs) 
    
    # Shared x encoding channel (time axis)
    x = alt.X(f"{timestamp_col}:O", axis=xaxis)

    # Optional custom color scale 
    if color_map: 
        color_scale = alt.Scale(domain=metrics, range=[color_map[m] for m in metrics])
    else: 
        color_scale = alt.Scale(domain=metrics)
        
    # Tooltips
    tooltips = (
        [alt.Tooltip(f'{timestamp_col}:O', timeUnit="yearmonthdate", title="date")] + 
        [alt.Tooltip(f'{m}:Q', format=tooltip_formats.get(m, ",d")) for m in metrics]
    )
    
    base = (
        alt.Chart(df)
        .encode(x=x)
        .properties(title=title, width=width)
    )
    
    assert exploit_day in [16, 17]
    rule_exploit = (
        # selection captures nearest timestamp (for current mouse position) 
        # tooltip rendered uses this data point (pivoted, so we have all data for this timestamp) 
        base
        .transform_pivot('variable', value='value', groupby=[timestamp_col])
        .transform_filter(f"""
            year(datum['{timestamp_col}']) === 2022 && 
            month(datum['{timestamp_col}']) === 3 && 
            date(datum['{timestamp_col}']) === {exploit_day} 
        """) # && warn(datetime(datum['{timestamp_col}']))
        .mark_rule(opacity=1, color='#474440', strokeDash=[2.5,1])
    )
        
    cbase = (
        base
        # Stack order matters when we are using an area chart 
        .transform_calculate(stack_order=stack_order_expr("variable", metrics))
        .encode(
            color=alt.Color("variable:N", scale=color_scale, legend=alt.Legend(title=None)), 
            order=alt.Order('stack_order:Q', sort='ascending'),
        )
    )

    class Strategies: 

        @staticmethod
        def area(base, axis):
            return (
                base 
                .transform_calculate(sort_col=stack_order_expr("variable", metrics))
                .mark_area(point='transparent')
                .encode(y=alt.Y("value:Q", axis=axis), tooltip=tooltips) 
            )

        @staticmethod
        def line(base, axis):
            return (
                base 
                .mark_line()
                .encode(y=alt.Y("value:Q", axis=axis))
            )
            
    left = Strategies.area(
        cbase.transform_filter(condition_union("==", "|", lmetrics)), 
        yaxis_left, 
    )

    right = Strategies.line(
        cbase.transform_filter(condition_union("==", "|", rmetrics)), 
        yaxis_right, 
    )
    
    if show_exploit_rule: 
        # Rule doesn't show up unless layered with line or area base. 
        c = left + alt.layer(right, rule_exploit)
    else: 
        c = left + right 
    if separate_y_axes: 
        c = (
            c
            .resolve_scale(y="independent")
            .resolve_axis(y="independent")
        )
    return c 

In [8]:
df.head()

Unnamed: 0,season,dailyBeanMints,totalBeanMints
0,3,31.65067,31.65067
1,4,17.76026,49.41093
2,5,43.709604,93.120534
3,18,0.342173,93.462707
4,21,676.195254,769.657961


In [8]:
# colors = {
#     'pod listing vol': '#B5E48C', # light green yellow 
#     'pod order vol': '#52B69A', # light blue green 
#     'total bean vol': '#168AAD', # darker blue
#     'total pod vol': '#184E77', # mid blue 
# }
# c = chart_stack_area_overlay_line_timeseries(
#     df_snaps, 
#     "timestamp", 
#     value_vars, 
#     ['pod listing vol', 'pod order vol', 'total bean vol',], 
#     "Farmer's Market Volume", 
#     yaxis_area_kwargs=dict(title="Volume", format=".3~s"), 
#     color_map=colors, 
# )

# css_lines = css_tooltip_timeseries_multi_colored(value_vars, colors) 
# css = "\n".join(css_lines)

# apply_css("")
# # apply_css(css)

# c

In [9]:
# output_chart(c, css=css)