# Population pyramide  over time 

In [20]:
import re
from modelclass import model 

In [21]:
mpopulation,baseline = model.modelload('models/population',run=True)

Zipped file read:  models\population.pcim


## Code to produce population pyramide

In [22]:
def sort_key(s):
    """
    Sort columns like:
    MORTALITY__FEMALE__AGE_0, MORTALITY__MALE__AGE_0, MORTALITY__FEMALE__AGE_1, ...
    after age
    """
    age = int(re.search(r'AGE_(\d+)', s).group(1))
    return age

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import re
import pandas as pd
import IPython
import numpy as np

def animate_population_pyramid(mmodel, interval=300, save_path=None, auto_show=True, age_limits=[20, 65]):
    """
    Animated population pyramid (0–100) colored by 3 age groups defined by age_limits.
    Displays population shares in legend, vertical divider, and 'Male'/'Female' labels above chart.

    Parameters
    ----------
    mmodel : ModelFlow model instance
        Must contain population variables 'pop__male*__AGE_*' and 'pop__female*__AGE_*'.
    interval : int
        Delay between frames (ms).
    save_path : str, optional
        If provided, saves animation (.gif or .mp4).
    auto_show : bool
        If True, displays inline in Jupyter.
    age_limits : list of two ints, optional
        The two cutoff ages between groups. Example: [20, 65] → groups (0–20), (21–65), (65+)
    """

    df = mmodel.basedf
    years = list(df.index)

    # --- ModelFlow pattern filters ---
    male_cols_sorted    = sorted(mmodel['pop__male*__AGE_*'].names, key=sort_key)
    female_cols_sorted  = sorted(mmodel['pop__female*__AGE_*'].names, key=sort_key)

    if not male_cols_sorted or not female_cols_sorted:
        raise ValueError("No population columns found for male/female in model data.")

    # --- Extract ages ---
    ages_sorted = [int(re.search(r'AGE_(\d+)', c).group(1)) for c in male_cols_sorted]

    # --- Define age groups dynamically ---
    limit1, limit2 = age_limits
    def age_group(age):
        if age <= limit1:
            return 0
        elif age <= limit2:
            return 1
        else:
            return 2

    group_idx = np.array([age_group(a) for a in ages_sorted])
    group_colors = ["#1f77b4", "#2ca02c", "#d62728"]  # blue, green, red
    male_colors = [group_colors[g] for g in group_idx]
    female_colors = [group_colors[g] for g in group_idx]

    # --- Helper to get population arrays ---
    def get_values(year):
        m = df.loc[year, male_cols_sorted].to_numpy()
        f = df.loc[year, female_cols_sorted].to_numpy()
        return m, f

    male_init, female_init = get_values(years[0])

    # --- Compute population shares for legend ---
    def get_group_shares(year):
        m, f = get_values(year)
        total = m.sum() + f.sum()
        shares = []
        for i in range(3):
            group_total = (m[group_idx == i].sum() + f[group_idx == i].sum())
            shares.append(group_total / total * 100)
        return shares

    shares = get_group_shares(years[0])

    # --- Figure setup ---
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.set_xlabel('Population')
    ax.set_ylabel('Age')
    ax.set_title('Population Pyramid by Age Group', fontsize=14, pad=25)
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{abs(int(x)):,}"))

    legend_labels = [
        f"0–{limit1} ({shares[0]:.1f}%)",
        f"{limit1+1}–{limit2} ({shares[1]:.1f}%)",
        f"{limit2+1}+ ({shares[2]:.1f}%)"
    ]
    legend_handles = [
        plt.Rectangle((0,0),1,1, color=group_colors[i], label=legend_labels[i]) for i in range(3)
    ]
    legend = ax.legend(handles=legend_handles, title="Age Groups", loc="upper right")

    male_bar = ax.barh(ages_sorted, -male_init, color=male_colors)
    female_bar = ax.barh(ages_sorted, female_init, color=female_colors)
    year_text = ax.text(0.02, 0.95, str(years[0]), transform=ax.transAxes,
                        fontsize=16, fontweight='bold', ha='left', va='top')

    max_pop = max(df[male_cols_sorted + female_cols_sorted].max())
    ax.set_xlim(-max_pop * 1.1, max_pop * 1.1)

    # --- Add vertical line at x = 0 ---
    ax.axvline(0, color="black", linewidth=1.2)

    # --- Add 'Male' and 'Female' labels ABOVE plot ---
    # Use figure coordinates to ensure fixed position
    fig.text(0.25, 0.94, "Male", ha="center", va="bottom",
             fontsize=14, fontweight='bold', color="steelblue")
    fig.text(0.75, 0.94, "Female", ha="center", va="bottom",
             fontsize=14, fontweight='bold', color="lightcoral")

    # --- Animation update ---
    def update(frame):
        year = years[frame]
        male_values, female_values = get_values(year)

        for bar, val in zip(male_bar, -male_values):
            bar.set_width(val)
        for bar, val in zip(female_bar, female_values):
            bar.set_width(val)
        year_text.set_text(str(year))

        shares = get_group_shares(year)
        for i, txt in enumerate(legend.get_texts()):
            txt.set_text(f"{legend_labels[i].split('(')[0]}({shares[i]:.1f}%)")
        return []

    anim = FuncAnimation(fig, update, frames=len(years), interval=interval, blit=False, repeat=True)

    # --- Save animation if requested ---
    if save_path:
        if save_path.endswith(".gif"):
            anim.save(save_path, writer="pillow", fps=5)
        elif save_path.endswith(".mp4"):
            anim.save(save_path, writer="ffmpeg", fps=5)
        print(f"✅ Animation saved to: {save_path}")

    # --- Inline display (Notebook 7 compatible) ---
    if auto_show and IPython.get_ipython() is not None:
        plt.close(fig)
        display(HTML(anim.to_jshtml()))

    return anim


#anim = animate_population_pyramid(mpopulation,save_path='graph/test1.mp4')

## Now make the chart 

In [25]:
anim = animate_population_pyramid(mpopulation)

✅ Animation saved to: graph/test1.mp4
