# Import

In [1]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns
import matplotlib.pyplot as plt
import sqlite3

# Read SQL function

In [2]:
def map_tech_to_mode(tech):
    class_mapping = {
    'T_HDV_AJ': 'Air jet',
    'T_HDV_B': 'Bus',
    'T_HDV_R': 'Rail',
    'T_HDV_T': 'HD Truck',
    'T_HDV_W': 'Marine vessel',
    'T_MDV_T': 'MD Truck',
    'T_LDV_C_': 'LD Car',
    'T_LDV_LT': 'LD Truck',
    'T_LDV_M': 'Motorcycle',
    'T_IMP_': 'Fuel use',
    'H2_COMP_100_700': 'Fuel use'
    }

    for prefix, class_name in class_mapping.items():
        if tech.startswith(prefix):
            return class_name
    return 'Other'

def map_tech_to_fuel(tech):
    carrier_mapping = {
        'BEV': 'Battery electric',
        'GSL': 'Gasoline',
        'DSL': 'Diesel',
        'CNG': 'Compressed NG',
        'LNG': 'Liquified NG',
        'JTF': 'Jet Fuel',
        'SPK': 'Synth. Jet Fuel',
        'HFO': 'Heavy Fuel Oil',
        'MDO': 'Marine Diesel Oil',
        'ELC': 'Electricity',
        'ETH': 'Ethanol',	
        'RDSL': 'Ren. Diesel',
    }
    # Order matters 
    if 'PHEV35' in tech:
        return 'PHEV (35-mile AER)'
    if 'PHEV50' in tech:
        return 'PHEV (50-mile AER)'
    if 'PHEV' in tech:
        return 'Plug-in hybrid'
    if 'BEV150' in tech:
        return 'BEV (150-mile AER)'
    if 'BEV200' in tech:
        return 'BEV (200-mile AER)'
    if 'BEV300' in tech:
        return 'BEV (300-mile AER)'
    if 'BEV400' in tech:
        return 'BEV (400-mile AER)'
    if 'FC' in tech:
        return 'Fuel-cell electric'
    if 'HEV' in tech:
        return 'Hybrid'
    if 'H2' in tech:
        return 'Hydrogen'
    if 'BEV_CHRG' in tech:
        return 'LD BEV charger'
    if 'CHRG' in tech:
        return 'Other charger'
    for prefix, carrier in carrier_mapping.items():
        if prefix in tech:
            return carrier 
    return 'Other'

In [3]:
def import_db_capacity_data(db_variant='vanilla4', scenario='vanilla4', path='C:/Users/rashi/ESM_databases/temoa/data_files/'):
    db_file = path + f'canoe_on_12d_{db_variant}.sqlite'
    conn = sqlite3.connect(db_file)

    # filter db tables
    query_net_capacity = f"SELECT * FROM OutputNetCapacity WHERE sector = 'Transport' AND scenario = '{scenario}'"
    query_built_capacity = f"SELECT * FROM OutputBuiltCapacity WHERE sector = 'Transport' AND scenario = '{scenario}'"
    query_existing_capacity = "SELECT * FROM ExistingCapacity WHERE tech like 'T_%'"

    net_cap = pd.read_sql_query(query_net_capacity, conn).drop(columns=['region', 'sector'])
    new_cap = pd.read_sql_query(query_built_capacity, conn).drop(columns=['region', 'sector'])
    ex_cap =  pd.read_sql_query(query_existing_capacity, conn)[['tech', 'vintage', 'capacity', 'units']]
    conn.close()

    # label mode and fuel classes
    net_cap['mode'] = net_cap['tech'].apply(map_tech_to_mode)
    new_cap['mode'] = new_cap['tech'].apply(map_tech_to_mode)
    ex_cap['mode'] = ex_cap['tech'].apply(map_tech_to_mode)

    net_cap['fuel'] = net_cap['tech'].apply(map_tech_to_fuel)
    new_cap['fuel'] = new_cap['tech'].apply(map_tech_to_fuel)
    ex_cap['fuel'] = ex_cap['tech'].apply(map_tech_to_fuel)

    # prepare capacity dfs
    net_cap_group = net_cap.groupby(['mode', 'fuel', 'period'], as_index=False).sum('capacity').drop(columns='vintage').rename(columns={'period': 'vintage'})
    ex_cap_group = ex_cap.copy()
    ex_cap_group['vintage'] = 2021
    ex_cap_group = ex_cap_group.groupby(['mode', 'fuel', 'vintage', 'units'], as_index=False).sum('capacity')
    new_cap_group = new_cap.groupby(['mode', 'fuel', 'vintage'], as_index=False).sum('capacity')

    # brand capacity types and merge, filling empty values with 0
    net_cap_group = net_cap_group[['mode', 'fuel', 'vintage', 'capacity']].rename(columns={'capacity': 'net capacity'})
    new_cap_group = new_cap_group[['mode', 'fuel', 'vintage', 'capacity']].rename(columns={'capacity': 'new capacity'})
    ex_cap_group = ex_cap_group[['mode', 'fuel', 'vintage', 'capacity']].rename(columns={'capacity': 'ex capacity'})
    merged_cap = net_cap_group.merge(new_cap_group, on=['mode', 'fuel', 'vintage'], how='left').merge(ex_cap_group, on=['mode', 'fuel', 'vintage'], how='left').fillna(0.)

    # sort and calculate retired capacity
    merged_cap = merged_cap.sort_values(by=['mode', 'fuel', 'vintage'])
    merged_cap['retired cap'] = 0.

    for _, group in merged_cap.groupby(['mode', 'fuel']):
        for i in range(1, len(group)):
            current_idx = group.index[i]
            previous_idx = group.index[i-1]

            # calculate retired capacity as netcap_i + newcap_i - netcap_{i-1}, where excap_i = netcap_{i-1}
            merged_cap.loc[current_idx, 'retired cap'] = (
                merged_cap.loc[previous_idx, 'net capacity'] + 
                merged_cap.loc[current_idx, 'new capacity'] - 
                merged_cap.loc[current_idx, 'net capacity']
            )

    for _, group in merged_cap.groupby(['mode', 'fuel']):
        for i in range(1, len(group)):
            current_idx = group.index[i]
            previous_idx = group.index[i-1]

            # fill the "ex capacity" for years after 2021 using the "net capacity" from the previous year
            merged_cap.loc[current_idx, 'ex capacity'] = merged_cap.loc[previous_idx, 'net capacity']

    merged_cap['retired cap'] = -merged_cap['retired cap']  # negative values for retired capacity
    merged_cap = merged_cap.round(2)
    return merged_cap

In [4]:
color_fuel_map = {
    'Gasoline': 'red',
    'Diesel': 'brown',
    'Compressed NG': 'tomato',
    'Hybrid': 'darkorange',
    'PHEV (35-mile AER)': 'dodgerblue',
    'PHEV (50-mile AER)': 'darkblue',
    'BEV (150-mile AER)': 'limegreen',
    'BEV (200-mile AER)': 'seagreen',
    'BEV (300-mile AER)': 'olive',
    'BEV (400-mile AER)': 'darkgreen',
    'Plug-in hybrid': 'blue',
    'Battery electric': 'green',
    'Fuel-cell electric': 'mediumvioletred'
}

In [5]:
def color_fuel_map_rgba(alpha=0.4):
    return {
        'Gasoline':          f'rgba(255,   0,   0,   {alpha})',    # red
        'Diesel':            f'rgba(165,  42,  42,  {alpha})',    # brown
        'Compressed NG':     f'rgba(255,  99,  71,  {alpha})',    # tomato
        'Hybrid':            f'rgba(255, 140,   0,  {alpha})',    # darkorange
        'PHEV (35-mile AER)': f'rgba(30, 144, 255,  {alpha})',   # dodgerblue
        'PHEV (50-mile AER)': f'rgba(0,    0,   139, {alpha})',   # darkblue
        'BEV (150-mile AER)': f'rgba(50,  205,  50,  {alpha})',   # limegreen
        'BEV (200-mile AER)': f'rgba(46,  139,  87,  {alpha})',   # seagreen
        'BEV (300-mile AER)': f'rgba(128, 128,   0,  {alpha})',   # olive
        'BEV (400-mile AER)': f'rgba(0,   100,   0,  {alpha})',   # darkgreen
        'Plug-in hybrid':    f'rgba(0,    0,   255,  {alpha})',   # blue
        'Battery electric':  f'rgba(0,   128,   0,   {alpha})',   # green
        'Fuel-cell electric':f'rgba(199,  21,  133,  {alpha})'    # mediumvioletred
    }

In [6]:
def plot_capacity_flow(db_variant='vanilla4', scenario='vanilla4', scenario_title=None,
                       modes=['LD Car', 'LD Truck', 'MD Truck', 'HD Truck'], color_map=color_fuel_map):
    df = import_db_capacity_data(db_variant=db_variant, scenario=scenario)
    df = df.melt(id_vars=['mode', 'fuel', 'vintage'],
                       value_vars=['net capacity', 'ex capacity', 'new capacity', 'retired cap'],
                       var_name='cap type', value_name='capacity')
    df['fuel'] = pd.Categorical(df['fuel'], categories=list(color_map.keys()), ordered=True)

    # df['capacity'] = df['capacity'] / 1E3     # convert to million units
    df['vintage'] = df['vintage'].astype('str')
    df_filtered = df[(abs(df['capacity']) > 1e-3) & (df['mode'].isin(modes)) & (df['cap type'] != 'ex capacity')].reset_index(drop=True)
    
    fig = px.bar(df_filtered, x='cap type', y='capacity', color='fuel', 
            pattern_shape='cap type', pattern_shape_sequence=["", ".", "/"],
            facet_col='vintage', facet_col_spacing=2E-2, 
            facet_row='mode', facet_row_spacing=1.5E-2,
            category_orders={'cap type': ["net capacity", "new capacity", "retired cap"],
                             'vintage': sorted(df['vintage'].unique()),
                             'fuel': list(color_map.keys()),
                             'mode': modes},
            template='plotly_white', orientation='v', 
            color_discrete_map = color_map,
            # color_discrete_sequence=px.colors.qualitative.G10_r + px.colors.qualitative.Bold[5:], 
            text_auto='.2s', width=1200, height=900
            )

    fig.update_layout(
        margin=dict(
            t=65, b=30),
        title=dict(
            text=f'<b>Vehicle fleet stock and flow in ON by vehicle class and powertrain ({scenario_title})</b>',
            x=0.5, y=0.98, xanchor='center', yanchor='top'),
        yaxis_title_standoff=0,
        legend_title=dict(
            text='<b>Fuel/powertrain type</b>', font=dict(size=15)),
        bargap=0.1,
        legend=dict(
            traceorder='grouped', orientation='v', yanchor='top', y=0.8, xanchor='center', x=1.15),
        font=dict(
            size=15)
        )

    fig.for_each_trace(lambda trace: trace.update(textfont=dict(size=11)))
    fig.for_each_xaxis(lambda axis: axis.update(title_text='', showticklabels=False))
    fig.for_each_yaxis(lambda axis: axis.update(title_text=''))

    unique_xdomains = sorted(set(fig.layout[axis].domain[0] for axis in fig.layout if axis.startswith('yaxis')))
    n_vintages = df['vintage'].nunique()
    for row_i, _ in enumerate(unique_xdomains, start=1):
        # fig.for_each_yaxis(lambda axis: axis.update(matches=f"y{i}") if axis.domain[0] == domain else None)
        anchor_i = (row_i - 1) * n_vintages + 1   # left-most col in that row
        match_anchor = 'y' if row_i == 1 else f'y{anchor_i}'
        fig.update_yaxes(matches=match_anchor, row=row_i, col=None, nticks=8, zeroline=True, zerolinecolor='black', tickformat='~s')
    
    for annotation in fig.layout.annotations:           # Fix facet cols and facet row annotations
        if 'mode' in annotation.text:
            annotation.text = f"<b>{annotation.text.split('=')[1]}</b>"
            annotation.font.size = 16
            annotation.x = 1.01
            annotation.xanchor = 'center'
        else:
            annotation.text = f"<b>{annotation.text.split('=')[1]}</b>"
            annotation.font.size = 16
            annotation.y = 0
            annotation.yanchor = 'top' 

    shown_legends = set()
    for trace in fig.data:                              # Show only one legend for each fuel type
        trace.name = trace.name.split(",")[0]
        if trace.name not in shown_legends:
            trace.showlegend = True
            shown_legends.add(trace.name)
        else:
            trace.showlegend = False

    # enforce fuel order in legend
    _order = list(color_map.keys())
    fig.data = tuple(
        sorted(fig.data, key=lambda t: (_order.index(t.name) if t.name in _order else len(_order)))
    )
    
    # add capacity type legends separated by a blank entry
    fig.add_trace(go.Bar(
        x=[None], y=[None],
        showlegend=True,
        legendgroup='Blank',
        legendgrouptitle=None,
        name='',
        marker_color='rgba(0,0,0,0)'
    ))

    for i, (name, pat) in enumerate([
        ("Net capacity",   ""),
        ("New capacity",   "."),
        ("Retired capacity", "/"),
    ]):
        fig.add_trace(go.Bar(
            x=[None], y=[None],
            name=name,
            showlegend=True,
            legendgroup='Capacity type',
            legendgrouptitle=dict(text="<b>Capacity type</b>") if i == 0 else None,
            marker=dict(
                color='rgba(0,0,0,0)',
                line=dict(color='black', width=1),
                pattern_shape=pat
            )
        ))
    
    fig.add_annotation(
        text='<b>Fleet capacity (k vehicles)</b>',
        x=-0.08, y=0.5, xref="paper", yref="paper", showarrow=False, textangle=-90, font=dict(size=17))
        
    fig.show()

In [7]:
plot_capacity_flow(
    db_variant='vanilla4', 
    scenario='vanilla4', 
    scenario_title='vanilla model',
)

  grouped = df.groupby(required_grouper, sort=False)  # skip one_group groupers


In [8]:
plot_capacity_flow(
    db_variant='lowgrowth', 
    scenario='lowgrowth',
    scenario_title='Low growth scenario', 
)





In [9]:
plot_capacity_flow(
    db_variant='medgrowth', 
    scenario='medgrowth', 
    scenario_title='medgrowth model',
)





In [10]:
plot_capacity_flow(
    db_variant='evgrowth', 
    scenario='evgrowth',
    scenario_title='Norway EV growth rates',
    # modes=['MD Truck', 'HD Truck'],
)





In [11]:
plot_capacity_flow(
    db_variant='highgrowth', 
    scenario='highgrowth',
    scenario_title='High growth model',
    # modes=['MD Truck', 'HD Truck'],
)





In [12]:
def load_capacity_flow_scenarios(
        scenarios= {
        # db variant : scenario
        'vanilla4' : 'vanilla4',
        'baseline' : 'baseline',
        "lowgrowth" : "lowgrowth",
        "medgrowth"   : "medgrowth",            # used as 2025 reference
        "evgrowth" : "evgrowth",
        "highgrowth": "highgrowth"
        },
        modes=['LD Car', 'LD Truck', 'MD Truck', 'HD Truck']) -> pd.DataFrame:
    dfs = []
    for db_variant, scenario in scenarios.items():
        df = import_db_capacity_data(db_variant, scenario)
        df = df[df['mode'].isin(modes)].copy()  # filter for modes
        df = df.drop(columns=['ex capacity'])   # filter out ex capacity
        df = df[df['vintage'] != 2021].copy()   # filter out 2021 vintage
        df['net capacity'], df['new capacity'], df['retired cap'] = df['net capacity'] * 1E3, df['new capacity'] * 1E3, df['retired cap'] * 1E3     # convert to normal units
        df["scenario"] = scenario
        dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

fleet_data = load_capacity_flow_scenarios()
# fleet_data.to_csv('fleet_data.csv', index=False)
fleet_data

Unnamed: 0,mode,fuel,vintage,net capacity,new capacity,retired cap,scenario
0,HD Truck,Battery electric,2040,0.0,0.0,-0.0,vanilla4
1,HD Truck,Battery electric,2045,0.0,0.0,-0.0,vanilla4
2,HD Truck,Battery electric,2050,7140.0,7140.0,-0.0,vanilla4
3,HD Truck,Diesel,2025,124820.0,0.0,-9730.0,vanilla4
4,HD Truck,Diesel,2030,98060.0,0.0,-26760.0,vanilla4
...,...,...,...,...,...,...,...
1107,MD Truck,Plug-in hybrid,2030,11540.0,11120.0,-0.0,highgrowth
1108,MD Truck,Plug-in hybrid,2035,11870.0,330.0,-0.0,highgrowth
1109,MD Truck,Plug-in hybrid,2040,12200.0,330.0,-0.0,highgrowth
1110,MD Truck,Plug-in hybrid,2045,11990.0,210.0,-420.0,highgrowth


In [13]:
def plot_capacity_flow_multicat(
        future_periods=[2030, 2040, 2050],
        capacity_type='net capacity',
        baseline_scenario='baseline',
        scenarios={
            # db variant : scenario
            # 'vanilla4' : 'vanilla4',
            'baseline' : 'baseline',
            "lowgrowth" : "lowgrowth",
            "medgrowth"   : "medgrowth",
            "evgrowth" : "evgrowth",
            "highgrowth": "highgrowth"
        },
        scenario_labels={
            # 'vanilla4' : 'Vanilla',
            'baseline'          : 'Baseline',
            'lowgrowth'       : 'Low growth',
            'medgrowth'        : 'Med. growth',
            "evgrowth"        : "Norway EVs",
            'highgrowth'      : 'High growth',
        },
        modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
        color_map=color_fuel_map
    ):

    # --------------------------------------------------------------------- #
    #  Keep only 2025 / 2035 / 2050 (with baseline rules for 2025)
    # --------------------------------------------------------------------- #
    df = load_capacity_flow_scenarios(scenarios=scenarios, modes=modes_order)
    df = df[df['vintage'].isin([2025] + future_periods)].copy()

    df_2025 = df[(df['vintage'] == 2025) &
                 (df['scenario'] == baseline_scenario)].copy()

    df_future = df[df['vintage'] != 2025]
    df = pd.concat([df_2025, df_future], ignore_index=True)

    # --------------------------------------------------------------------- #
    #  Build explicit multicategory axis order
    # --------------------------------------------------------------------- #
    period_layer = ['2025']
    scen_layer   = [scenario_labels[baseline_scenario]]
    for period in future_periods:
        for scen in tuple(scenarios.values()):
            period_layer.append(str(period))
            scen_layer.append(scenario_labels[scen])
    
    # make subsequent Baseline vintages different from 2025 so that category order is followed
    base_lbl = scenario_labels[baseline_scenario]
    for i, (p, s) in enumerate(zip(period_layer, scen_layer)):
        if s == base_lbl and p != '2025': scen_layer[i] = base_lbl + '\u200B'
    
    # insert unique blanks between vintages to separate scenarios
    sep_count = 0
    new_p, new_s = [], []
    for i in range(len(period_layer)):
        new_p.append(period_layer[i])
        new_s.append(scen_layer[i])
        if i < len(period_layer)-1 and period_layer[i] != period_layer[i+1]:
            sep_count += 1
            new_p.append('\u200B' * sep_count)  # zero‑width string increases per break
            new_s.append('') 

    period_layer, scen_layer = new_p, new_s
    axis_keys = list(zip(period_layer, scen_layer))

    # --------------------------------------------------------------------- #
    #  Set up multi‑row figure
    # --------------------------------------------------------------------- #
    fig = make_subplots(
        rows=len(modes_order), cols=1,
        shared_xaxes=True,
        vertical_spacing=0.03,
        row_heights=[1/len(modes_order)]*len(modes_order)
    )

    # --------------------------------------------------------------------- #
    #  Populate each subplot
    # --------------------------------------------------------------------- #
    for r, mode in enumerate(modes_order, 1):

        mode_df = df[df['mode'] == mode].copy()
        mode_df['period_str']  = mode_df['vintage'].astype(str)          
        mode_df['scen_lbl']  = mode_df['scenario'].map(scenario_labels)

        # make subsequent Baseline vintages different from 2025 so that category order is followed
        mask = (mode_df['scen_lbl'] == base_lbl) & (mode_df['period_str'] != '2025')
        mode_df.loc[mask, 'scen_lbl'] += '\u200B'  

        pivot = (mode_df.pivot_table(index=['period_str', 'scen_lbl'],         # use period_str
                                    columns='fuel',
                                    values=capacity_type,
                                    aggfunc='sum')
                .reindex(index=pd.MultiIndex.from_tuples(axis_keys),
                        columns=list(color_map.keys()),
                        fill_value=0))
                # .reindex(pd.MultiIndex.from_tuples(axis_keys))         # reindex to axis_keys
                # .reindex(columns=color_map.keys(), fill_value=0))      # add all fuels
        
        for fuel in color_map:            # palette order
            fig.add_bar(
                x=[period_layer, scen_layer],
                y=pivot[fuel],            
                name=fuel,
                marker_color=color_map[fuel],
                legendgroup=fuel,
                showlegend=(r == 1),
                text=pivot[fuel],
                texttemplate='%{text:.3s}',  
                textposition='inside',
                insidetextanchor='end',
                row=r, col=1
            )

    # --------------------------------------------------------------------- #
    #  Shared labels and aesthetic tweaks
    # --------------------------------------------------------------------- #
    fig.update_yaxes(title_text='Fleet size (vehicles units)',
                     row=3, col=1)
    
    fig.update_yaxes(nticks=5, tickformat='~s')
    fig.update_yaxes(tickformat='~s',
                     dtick=2500 * 1E3, 
                     row=2, col=1)
    
    fig.update_xaxes(
        # type='multicategory',
        # categoryorder='array',
        # categoryarray=[period_layer, scen_layer],
        showdividers=True,
        dividercolor='lightgrey',
        dividerwidth=1,
        tickangle=45,
        # ticklabelstandoff=2,
    )

    fig.update_layout(
        width=375*len(future_periods), height=210*len(modes_order),
        margin=dict(
            t=40, b=30),
        title=dict(
            text=f'<b>Vehicle fleet stock and flow in ON by vehicle class, powertrain and scenario</b>',
            x=0.5, y=0.985, xanchor='center', yanchor='top'),
        yaxis_title_standoff=1,
        legend_title=dict(
            text='<b>Fuel/powertrain type</b>', font=dict(size=16)),
        barmode='stack', bargap=0.15,
        legend=dict(
            orientation='v', yanchor='top', y=0.8, xanchor='center', x=1.18, traceorder='normal'),
        font=dict(size=15),
        template='plotly_white'
        )
    
    fig.for_each_trace(lambda trace: trace.update(textfont=dict(size=11)))

    # Right‑hand row titles
    for r, mode in enumerate(modes_order, start=1):
        fig.add_annotation(
            text=f'<b>{mode}</b>',
            x=1.005, xref='paper',
            y=1 - (r - 0.5) / 4, yref='paper',
            showarrow=False,
            textangle=90,
            xanchor='left', yanchor='middle',
            font=dict(size=16)
        )

    fig.show()

In [14]:
plot_capacity_flow_multicat(
                            capacity_type='net capacity', 
                            baseline_scenario='baseline',
                            )

In [15]:
def plot_capacity_flow_multicat_area(
        title_text='Vehicle fleet stock in Ontario by vehicle class, powertrain and scenario',
        capacity_type='net capacity',
        scenarios={
            # db variant : scenario
            'vanilla4' : 'vanilla4',
            "baseline"   : "baseline",
            "lowgrowth" : "lowgrowth",
            'medgrowth'   : "medgrowth",
            "evgrowth" : "evgrowth",
            "highgrowth": "highgrowth"
        },
        scenario_labels={
            'vanilla4' : 'Vanilla',
            'baseline'  : 'Baseline',
            'lowgrowth'  : 'Low growth',
            'medgrowth'   : 'Med. growth',
            "evgrowth" : "Norway EVs",
            'highgrowth' : 'High growth',
        },
        modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
        color_map=color_fuel_map
    ):

    df = load_capacity_flow_scenarios(scenarios=scenarios, modes=modes_order)

    # 2) Build the list of bold subplot_titles:
    n_rows = len(modes_order)
    n_cols = len(tuple(scenarios.values()))
    subplot_titles = []
    for r, mode in enumerate(modes_order, start=1):
        for c, scen in enumerate(tuple(scenarios.values()), start=1):
            if r == 1:
                # Top row shows the scenario label (bolded here)
                subplot_titles.append(f'<b>{scenario_labels.get(scen, scen)}</b>')
            else:
                subplot_titles.append("")

    # 3) Create the figure with shared axes (as per your revision):
    fig = make_subplots(
        rows=n_rows, cols=n_cols,
        shared_xaxes=True,
        shared_yaxes=True,
        horizontal_spacing=0.02,
        vertical_spacing=0.03,
        subplot_titles=subplot_titles
    )

    # 4) Loop over each (mode, scenario), pivot, then add stacked‐area traces:
    for r_idx, mode in enumerate(modes_order, start=1):
        for c_idx, scen in enumerate(tuple(scenarios.values()), start=1):
            df_cell = df[(df['mode'] == mode) & (df['scenario'] == scen)].copy()
            if df_cell.empty:
                continue

            # Pivot so index = vintage, columns = fuel, values = capacity_type
            pivot = (
                df_cell
                .groupby(['vintage', 'fuel'])[capacity_type]
                .sum()
                .unstack(fill_value=0)
            )
            # Reindex to force all fuels in the desired order (even if zero):
            pivot = pivot.reindex(columns=list(color_map.keys()), fill_value=0)
            pivot = pivot.sort_index()
            vintages = pivot.index.astype(str).tolist()

            # Now add one stacked‐area trace per fuel:
            # Use stackgroup="one" so that each subplot stacks internally.
            for i, fuel in enumerate(color_map.keys()):
                yvals = pivot[fuel].values.tolist()
                fill_mode = 'tozeroy' if i == 0 else 'tonexty'

                fig.add_trace(
                    go.Scatter(
                        x=vintages,
                        y=yvals,
                        mode='none',    # other options: 'lines', 'markers', 'lines+markers'
                        name=fuel,
                        stackgroup="one",
                        fill=fill_mode,
                        fillcolor=color_map[fuel],       # <-- exact color here
                        line=dict(color=color_map[fuel], width=0),  # no border line
                        legendgroup=fuel,
                        showlegend=(r_idx == 1 and c_idx == 1),  # legend only once
                        hoverinfo='x+y+name'
                    ),
                    row=r_idx, col=c_idx
                )

    # 5) Tweak the x‐axis and y‐axis for this subplot only:
    fig.update_yaxes(title_text='Fleet size (vehicles units)',
                     row=2, col=1)
    
    fig.update_yaxes(nticks=5, tickformat='~s')
    fig.update_yaxes(tickformat='~s',
                     dtick=2500 * 1E3,
                     row=2, col=1)
    
    fig.for_each_xaxis(lambda axis: axis.update(title_text='', showticklabels=False))
    fig.update_xaxes(
        showticklabels=True,
        tickangle=45,
        # ticklabelstandoff=2,
        row=4, col=1,
    )

    fig.update_layout(
        width=250 * n_cols + 200, height=200 * n_rows + 150,
        margin=dict(
            t=70, b=10),
        title=dict(
            text=f'<b>{title_text}</b>',
            x=0.5, y=0.985, xanchor='center', yanchor='top'),
        yaxis_title_standoff=1,
        legend_title=dict(
            text='<b>Fuel/powertrain type</b>', font=dict(size=16)),
        barmode='stack', bargap=0.15,
        legend=dict(
            orientation='v', yanchor='top', y=0.8, xanchor='center', x=1.12, traceorder='normal'),
        font=dict(size=15),
        template='plotly_white'
        )

    # 8) Finally, add a bold annotation on the right of each row to label the mode:
    for r, mode in enumerate(modes_order, start=1):
        fig.add_annotation(
            text=f'<b>{mode}</b>',
            x=1.005, xref='paper',
            y=1 - (r - 0.5) / 4, yref='paper',
            showarrow=False,
            textangle=90,
            xanchor='left', yanchor='middle',
            font=dict(size=17)
        )

    fig.show()


In [16]:
plot_capacity_flow_multicat_area(
    title_text='Vehicle fleet stock in Ontario by vehicle class, powertrain and scenario',
    capacity_type='net capacity',
    scenarios={
        # db variant : scenario
        'vanilla4' : 'vanilla4',
        'baseline' : 'baseline',
        "lowgrowth" : "lowgrowth",
        "medgrowth"   : "medgrowth",
        'medgrowth_aeo' : 'medgrowth_aeo',
        "evgrowth" : "evgrowth",
        # "highgrowth": "highgrowth"
    },
    scenario_labels={
        'vanilla4' : 'Vanilla',
        'baseline'  : 'Baseline',
        'lowgrowth'  : 'Low growth',
        "medgrowth" : "Med growth",
        'medgrowth_aeo' : 'Med growth (AEO)',
        "evgrowth" : "Norway EVs",
        # 'highgrowth' : 'High growth',
    },
    modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
    color_map=color_fuel_map_rgba(0.8)
)

In [17]:
def plot_capacity_flow_multicat_area_share(
        title_text='Vehicle fleet stocks and sales market shares in Ontario by vehicle class, powertrain and scenario',
        capacity_type='net capacity',
        scenarios={
        # db variant : scenario
        'vanilla4' : 'vanilla4',
        'baseline' : 'baseline',
        "lowgrowth" : "lowgrowth",
        "medgrowth"   : "medgrowth",
        "evgrowth" : "evgrowth",
        # "highgrowth": "highgrowth"
        },
        scenario_labels={
            'vanilla4' : 'Vanilla',
            'baseline'  : 'Baseline',
            'lowgrowth'  : 'Low growth',
            "medgrowth" : "Med growth",
            "evgrowth" : "Norway EVs",
            # 'highgrowth' : 'High growth',
        },
            modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
            color_map=color_fuel_map
        ):

    # ─── 1) Load & filter your fleet_data for the chosen scenarios & modes ───
    df = load_capacity_flow_scenarios(scenarios=scenarios, modes=modes_order)

    # ─── 2) Build bold subplot titles (only on row 1) ─────────────────────────
    n_rows = len(modes_order)
    n_cols = len(tuple(scenarios.values()))
    subplot_titles = []
    for r, mode in enumerate(modes_order, start=1):
        for c, scen in enumerate(tuple(scenarios.values()), start=1):
            if r == 1:
                subplot_titles.append(f"<b>{scenario_labels.get(scen, scen)}</b>")
            else:
                subplot_titles.append("")

    # ─── 3) Create subplots WITH secondary_y=True for each cell, and share primary y by row ───
    specs = [[{"secondary_y": True} for _ in range(n_cols)] for _ in range(n_rows)]
    fig = make_subplots(
        rows=n_rows, cols=n_cols,
        shared_xaxes=True,
        shared_yaxes="rows",   # primary y shared by row
        horizontal_spacing=0.005,
        vertical_spacing=0.03,
        subplot_titles=subplot_titles,
        specs=specs
    )

    # ─── 4) Populate each subplot (mode × scenario) ─────────────────────────────
    for r_idx, mode in enumerate(modes_order, start=1):
        for c_idx, scen in enumerate(tuple(scenarios.values()), start=1):
            df_cell = df[(df['mode'] == mode) & (df['scenario'] == scen)].copy()
            if df_cell.empty:
                continue

            # 4a) Pivot “capacity_type” (net capacity by vintage & fuel)
            pivot = (
                df_cell
                .groupby(['vintage', 'fuel'])[capacity_type]
                .sum()
                .unstack(fill_value=0)
            )
            pivot = pivot.reindex(columns=list(color_map.keys()), fill_value=0).sort_index()
            vintages = pivot.index.astype(str).tolist()

            # 4b) Add one stacked‐area trace per fuel on PRIMARY y‐axis
            for i, fuel in enumerate(color_map.keys()):
                yvals = pivot[fuel].values.tolist()
                fill_mode = 'tozeroy' if i == 0 else 'tonexty'
                fig.add_trace(
                    go.Scatter(
                        x=vintages,
                        y=yvals,
                        mode='none',
                        name=fuel,
                        opacity=0.5,
                        stackgroup="one",
                        fill=fill_mode,
                        fillcolor=color_map[fuel],
                        line=dict(color=color_map[fuel], width=0),
                        legendgroup=fuel,
                        showlegend=(r_idx == 1 and c_idx == 1),
                        hoverinfo='x+y+name'
                    ),
                    row=r_idx, col=c_idx, secondary_y=False
                )

            # 4c) Pivot “new capacity” to compute ZEV share as a FRACTION (0–1)
            pivot_new = (
                df_cell
                .groupby(['vintage', 'fuel'])['new capacity']
                .sum()
                .unstack(fill_value=0)
            )
            pivot_new = pivot_new.reindex(columns=list(color_map.keys()), fill_value=0).sort_index()

            total_new = pivot_new.sum(axis=1)
            zev_fuels = [
                f for f in pivot_new.columns
                if ('PHEV' in f) 
                   or ('BEV' in f) 
                   or (f == 'Plug-in hybrid') 
                   or (f == 'Battery electric') 
                   or (f == 'Fuel-cell electric')
            ]
            zev_new = pivot_new[zev_fuels].sum(axis=1)
            share_frac = zev_new / total_new.replace(0, 1)
            share_vals = share_frac.tolist()

            # 4d) Add ZEV share as a black dashed line on SECONDARY y‐axis
            fig.add_trace(
                go.Scatter(
                    x=vintages,
                    y=share_vals,
                    mode='lines+markers',
                    name='ZEV (%)',
                    line=dict(color='black', width=1),    # dash='dot'
                    marker=dict(symbol='circle', size=4, color='black'),
                    legendgroup='ZEV share',
                    showlegend=(r_idx == 1 and c_idx == 1),
                    hoverinfo='x+y'
                ),
                row=r_idx, col=c_idx, secondary_y=True
            )

    # ─── 5) Configure PRIMARY y‑axes (fleet size) ────────────────────────────────
    # 5a) Only the subplot at (row=2, col=1) gets the title; others hide ticklabels
    fig.update_yaxes(
        title_text='Fleet size (vehicles units)',
        dtick= 2500 * 1E3,  
        tickformat='~s',
        showgrid=True,
        secondary_y=False,
        row=2, col=1
    )
    # Hide primary y ticklabels in columns 2…n_cols
    for r in range(1, n_rows + 1):
        for c in range(2, n_cols + 1):
            fig.update_yaxes(
                showticklabels=False,
                secondary_y=False,
                row=r, col=c
            )

    # ─── 6) Configure SECONDARY y‑axes (ZEV %) ───────────────────────────────────
    # 6a) Set the same 0–1 range + percent format + no gridlines for each secondary y
    fig.update_yaxes(
        range=[0, 1.05],      # extend to 1.05 so "1.00" sits lower
        tickvals=[0, 0.25, 0.5, 0.75, 1.0],  # keep ticks at the nice round percentages
        tickformat='.0%',
        showgrid=False,
        secondary_y=True
    )
    # 6b) Only (row=1, col=n_cols) shows the title “ZEV %”; hide ticklabels elsewhere
    fig.update_yaxes(
        # title_text='ZEV market share',
        # title_standoff=10,     # a smaller standoff pulls it closer to the axis
        title_text=None,
        showticklabels=False,
        secondary_y=True,
        row=1, col=n_cols
    )
    for r in range(1, n_rows + 1):
        for c in range(1, n_cols):
            fig.update_yaxes(
                showticklabels=False,
                secondary_y=True,
                row=r, col=c
            )

    # ─── 7) Configure X‑axes – only bottom row shows ticklabels ───────────────────
    fig.for_each_xaxis(lambda ax: ax.update(title_text='', showticklabels=False))
    fig.update_xaxes(
        showticklabels=True,
        tickangle=45,
        # ticklabelstandoff=2,
        row=n_rows, col=1
    )

    # ─── 8) Overall layout / legend / title formatting ──────────────────────────
    fig.update_layout(
        width=220 * n_cols + 100,
        height=180 * n_rows + 100,
        margin=dict(t=70, b=100, l=20, r=20),
        title=dict(
            text=f'<b>{title_text}</b>',
            x=0.5, y=0.985,
            xanchor='center', yanchor='top'
        ),
        legend_title=dict(text='<b>Fuel/powertrain type</b>', font=dict(size=16)),
        legend_title_side='top right',
        legend=dict(
            orientation='h',
            yanchor='bottom',
            y=-0.23,
            xanchor='center',
            x=0.45,
            traceorder='normal'
        ),
        # legend=dict(
        #     orientation='v',
        #     yanchor='top',
        #     y=0.8,
        #     xanchor='center',
        #     x=1.1,
        #     traceorder='normal'
        # ),
        font=dict(size=15),
        template='plotly_white'
    )

    # ─── 9) Add bold row labels on the right side (“LD Car”, “LD Truck”, etc.) ───
    for r_idx, mode in enumerate(modes_order, start=1):
        y_pos = 1 - ((r_idx - 0.5) / n_rows)
        fig.add_annotation(
            text=f'<b>{mode}</b>',
            x=0.98, xref='paper',
            y=y_pos, yref='paper',
            showarrow=False,
            textangle=90,
            xanchor='left', yanchor='middle',
            font=dict(size=17)
        )

    row_height = 1 / n_rows                # n_rows is 4 in your call
    y_annot    = 1 - row_height / 2        # halfway down the first row

    fig.add_annotation(
        text='ZEV market share',
        x=0.945, xref='paper',             # a hair to the right of the plotting area
        y=y_annot + 0.02, yref='paper',
        showarrow=False,
        textangle=90,                     # 90° → 270° == “flip 180°” vertically
        font=dict(size=16),                # slightly smaller
        xanchor='left', yanchor='middle'
    )

    fig.show()
    fig.write_image("fleet_stocks_medgrowth.svg", engine="kaleido", 
                    # scale=0.5
                    )


In [18]:
plot_capacity_flow_multicat_area_share(
    title_text='Vehicle fleet stocks and zero-emission vehicle (ZEV) sales market shares in Ontario by vehicle class, powertrain and scenario',
    capacity_type='net capacity',
    scenarios={
        # db variant : scenario
        'vanilla4' : 'vanilla4',
        'baseline' : 'baseline',
        # 'lowgrowth' : 'lowgrowth',
        "medgrowth"   : "medgrowth",
        'medgrowth_life_7' : 'medgrowth_life_7',
        'medgrowth_aeo' : 'medgrowth_aeo',
        # "evgrowth" : "evgrowth",
        # "highgrowth": "highgrowth"
        },
    scenario_labels={
        'vanilla4' : 'Vanilla',
        'baseline'  : 'Baseline',
        # 'lowgrowth'  : 'Low growth',
        "medgrowth" : "Med. growth",
        'medgrowth_life_7' : 'Med. growth (7 life)',
        'medgrowth_aeo' : 'Med. growth (AEO)',
        # "evgrowth" : "Norway EVs",
        # 'highgrowth' : 'High growth',
        },
    modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
    color_map=color_fuel_map_rgba(0.8)
    )

In [19]:
plot_capacity_flow_multicat_area_share(
    title_text='Vehicle fleet stocks and zero-emission vehicle (ZEV) sales market shares in Ontario by vehicle class, powertrain and scenario',
    capacity_type='net capacity',
    scenarios={
        # db variant : scenario
        'vanilla4' : 'vanilla4',
        'vanilla4_master' : 'vanilla4_master',
        'baseline_car_shift' : 'baseline_car_shift',
        'baseline_truck_usage' : 'baseline_truck_usage',
        'baseline_usage_trend' : 'baseline_usage_trend',
        },
    scenario_labels={
        'vanilla4' : 'Vanilla',
        'vanilla4_master'  : 'Master',
        'baseline_car_shift' : 'Base (car shift)',
        'baseline_truck_usage' : 'Base (truck usage)',
        'baseline_usage_trend' : 'Base (usage trend)',
        },
    modes_order=('LD Car', 'LD Truck', 'MD Truck', 'HD Truck'),
    color_map=color_fuel_map_rgba(0.8)
    )