## Final aggregation of gkg data

In [None]:
import os
import re
import numpy as np
import pandas as pd
pd.set_option("display.max_columns", None)

# ============================================================
# 1) Deduplication helper
# ============================================================

def dedupe_parallel_arrays(df, list_cols):
    """
    Deduplicate primary list column (e.g. SOURCEURL)
    and apply the same index filtering to all columns in list_cols.
    Supports numpy arrays and Python lists.
    """
    primary = list_cols[0]

    def normalize(x):
        if isinstance(x, list):
            return x
        if isinstance(x, np.ndarray):
            return x.tolist()
        return x

    def dedupe_row(row):
        urls = normalize(row[primary])
        if not isinstance(urls, list):
            return row

        seen = set()
        keep_idx = []

        for i, u in enumerate(urls):
            if u not in seen:
                seen.add(u)
                keep_idx.append(i)

        # Apply same filtering to all parallel list columns
        for col in list_cols:
            colval = normalize(row[col])
            if isinstance(colval, list):
                row[col] = [colval[i] for i in keep_idx]

        return row

    return df.apply(dedupe_row, axis=1)

# ============================================================
# 2) Load mapped GDELT dataset
# ============================================================

mapped_path = "../data/gdelt/gkg/4_aggregated/gkg_final.parquet"
if not os.path.exists(mapped_path):
    raise FileNotFoundError(f"Missing mapped GDELT file: {mapped_path}")

df_final = pd.read_parquet(mapped_path)
print("✅ Loaded GDELT mapped dataset:", len(df_final))

# ============================================================
# 3) Fill missing ADMIN0 using lookup
# ============================================================

locations = df_final[['ADMIN0', 'ADMIN1', 'ADMIN2']]
locations = locations[locations['ADMIN0'].notna()].drop_duplicates()

admin_lookup = {
    (row.ADMIN1, row.ADMIN2): row.ADMIN0
    for _, row in locations.iterrows()
}

def fill_admin0(row):
    if pd.isna(row['ADMIN0']):
        return admin_lookup.get((row['ADMIN1'], row['ADMIN2']), None)
    return row['ADMIN0']

df_final['ADMIN0'] = df_final.apply(fill_admin0, axis=1)

# ============================================================
# 4) Deduplicate SOURCEURL + parallel array columns
# ============================================================

list_columns = [
    "DATE",
    "V2Themes",
    "DocumentIdentifier",
    "Amounts"
]

df_final = dedupe_parallel_arrays(df_final, list_columns)
print("✅ Deduplication complete")

# ==================
# 5) Load scraped flat dataset & build feature maps
# ============================================================

df_scraped = pd.read_parquet(
    "../data/gdelt/gkg/5_modelled/gkg_exploded_with_counts_and_topics.parquet"
)
print("✅ Loaded scraped URLs dataset:", len(df_scraped))

# Build set of **valid scraped URLs**
scraped_urls = set(df_scraped["url"].tolist())

# Add column showing which URLs survived filtering
df_final["valid_DocumentIdentifier"] = df_final["DocumentIdentifier"].apply(
    lambda urls: [u for u in urls if u in scraped_urls]
)

# ——— Features to remap ———
feature_cols = [
    'clean_text',
    'NER_admin0', 'NER_admin1', 'NER_admin2',
    'compound_score', 'neg_score', 'neu_score', 'pos_score',
    'fatalities_freq', 'displaced_freq', 'detained_freq',
    'injured_freq', 'sexual_violence_freq', 'torture_freq',
    'economic_shocks_freq', 'agriculture_freq', 'weather_freq',
    'food_insecurity_freq',
    'fatalities_count', 'displaced_count', 'detained_count',
    'injured_count', 'sexual_violence_count', 'torture_count',
    'sentiment.compound', 'sentiment.neg', 'sentiment.neu', 'sentiment.pos',
    'pred_impact_type', 'pred_urgency',
    'pred_resource_food', 'pred_resource_water', 'pred_resource_cash_aid',
    'pred_resource_healthcare', 'pred_resource_shelter',
    'pred_resource_livelihoods', 'pred_resource_education',
    'pred_resource_infrastructure', 'pred_resource_none'
]

# Build dicts for URL → feature
feature_maps = {
    col: df_scraped.set_index("url")[col].to_dict()
    for col in feature_cols
}

def map_list(url_list, mapping):
    if not isinstance(url_list, (list, np.ndarray)):
        return []
    return [mapping[u] for u in url_list if u in mapping]

# ============================================================
# 6) Map all scraped features back into events
# ============================================================

for col, mp in feature_maps.items():
    df_final[col + "_list"] = df_final["DocumentIdentifier"].apply(
        lambda urls: map_list(urls, mp)
    )

print("✅ All article-level features mapped back into event dataset")

# ==========================================
# Only keep event rows where at least one valid clean_text exists
df_final = df_final[
    df_final.clean_text_list.apply(lambda x: isinstance(x, list) and len(x) > 1)
    | df_final.CS_score.isin([1,2,3,4,5])
]

# Cap counts columns to remove unreasonable values
CAPS = {
    "fatalities_count_list": 50,
    "injured_count_list": 100,
    "displaced_count_list": 500,
    "detained_count_list": 50,
    "sexual_violence_count_list": 10,
    "torture_count_list": 10
}

for col, cap in CAPS.items():
    df_final[col] = df_final[col].apply(
        lambda lst: [min(x, cap) for x in lst] if isinstance(lst, list) else lst
    )

print(len(df_final['ADMIN1'].unique()))
print(len(df_final['ADMIN2'].unique()))

df_final.to_parquet("../data/gdelt/gkg/6_final/gkg_dataset_v2.parquet")
df_final.head()

## Validation and Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------
# FEWSNET PERIOD LIST
# -----------------------
fewsnet_periods = [
    '201602','201606','201610','201702','201706','201710','201802','201806','201810',
    '201902','201906','201910','202002','202006','202010','202102','202106','202110',
    '202202','202206','202210','202302','202306','202310','202402', '202406'
]

# -----------------------
# 1. BUILD df_heat PIVOT TABLE
# -----------------------
print("Number of unique articles in df_final:")
print(len(df_final.explode("valid_DocumentIdentifier")['valid_DocumentIdentifier'].unique()))

df_heat = (
    df_final
        .explode("valid_DocumentIdentifier")
        .dropna(subset=["valid_DocumentIdentifier"])
        .groupby(["ADMIN0", "period"])["valid_DocumentIdentifier"]
        .nunique()
        .unstack(fill_value=0)
        .sort_index()
        .sort_index(axis=1)
)
df_heat.index = df_heat.index.str.title()
# -----------------------
# 2. PLOT HEATMAP
# -----------------------
plt.figure(figsize=(40, 18))
ax = sns.heatmap(
    df_heat,
    cmap="YlOrRd",
    linewidths=0.2,
    linecolor="gray",
    annot=True,
    fmt="d",
    annot_kws={"size": 8},
    cbar_kws={"shrink": 0.5}
)

# -----------------------
# 3. DRAW THICK OUTLINES AROUND FEWSNET COLUMNS
# -----------------------
x_labels = df_heat.columns.tolist()

for p in fewsnet_periods:
    if p in x_labels:
        col_idx = x_labels.index(p)
        
        # Draw a thick rectangle around the column
        ax.add_patch(plt.Rectangle(
            (col_idx, 0),                # (x,y) lower-left corner in data coords
            1,                           # width (1 column)
            df_heat.shape[0],            # height (# of rows)
            fill=False,
            edgecolor='black',            # outline color
            linewidth=3                  # thickness
        ))

# -----------------------
# 4. FORMATTING & LABELS
# -----------------------
ax.set_title(
    "Unique Article Counts by Country and Period (GDELT GKG)\nwith FEWSNET Periods Outlined",
    fontsize=26, fontweight='bold', pad=30
)

ax.set_xlabel("Period (YYYYMM)", fontsize=20, labelpad=20)
ax.set_ylabel("Country (ADMIN0)", fontsize=20, labelpad=20)

ax.set_xticklabels(x_labels, fontsize=10, rotation=90)
ax.set_yticklabels(df_heat.index, fontsize=14)

plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt

# -------------------------------------------------------
# 1. AGGREGATE ARTICLE COUNTS BY COUNTRY (ADMIN0)
# -------------------------------------------------------
df_country = (
    df_final
        .explode("valid_DocumentIdentifier")
        .dropna(subset=["valid_DocumentIdentifier"])
        .groupby("ADMIN0")["valid_DocumentIdentifier"]
        .nunique()
        .reset_index()
)

df_country.rename(columns={"valid_DocumentIdentifier": "article_count"}, inplace=True)
df_country["ADMIN0"] = df_country["ADMIN0"].str.strip().str.title()

# -------------------------------------------------------
# 2. NATURAL EARTH NAME MAPPING (INCLUDES S. Sudan FIX)
# -------------------------------------------------------
ne_mapping = {
    "Burkina Faso": "Burkina Faso",
    "Burundi": "Burundi",
    "Cameroon": "Cameroon",
    "Central African Republic": "Central African Rep.",
    "Chad": "Chad",
    "Democratic Republic Of The Congo": "Dem. Rep. Congo",
    "Ethiopia": "Ethiopia",
    "Guinea": "Guinea",
    "Kenya": "Kenya",
    "Liberia": "Liberia",
    "Madagascar": "Madagascar",
    "Malawi": "Malawi",
    "Mali": "Mali",
    "Mauritania": "Mauritania",
    "Mozambique": "Mozambique",
    "Niger": "Niger",
    "Nigeria": "Nigeria",
    "Sierra Leone": "Sierra Leone",
    "Somalia": "Somalia",

    # ⭐ Correct name in Natural Earth
    "South Sudan": "S. Sudan",

    "Sudan": "Sudan",
    "Uganda": "Uganda",
    "Yemen": "Yemen",
    "Zambia": "Zambia",
    "Zimbabwe": "Zimbabwe",
}

df_country["NE_name"] = df_country["ADMIN0"].map(ne_mapping)

# -------------------------------------------------------
# 3. LOAD WORLD SHAPEFILE (Africa + Yemen)
# -------------------------------------------------------
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
world["name"] = world["name"].str.strip().str.title()

target = list(world[world["continent"] == "Africa"]["name"]) + ["Yemen"]
subset = world[world["name"].isin(target)].copy()

# -------------------------------------------------------
# 4. MERGE GEOMETRY WITH YOUR DATA
# -------------------------------------------------------
geo_df = subset.merge(df_country, how="left", left_on="name", right_on="NE_name")
geo_df["has_data"] = ~geo_df["article_count"].isna()

# -------------------------------------------------------
# 5. PLOT (GREY = NO DATA)
# -------------------------------------------------------
fig, ax = plt.subplots(figsize=(18, 16))

# Grey countries first
geo_df[geo_df["has_data"] == False].plot(
    color="lightgrey",
    edgecolor="black",
    linewidth=0.7,
    ax=ax
)

# Heatmap for countries with data
geo_df[geo_df["has_data"] == True].plot(
    column="article_count",
    cmap="YlOrRd",
    edgecolor="black",
    linewidth=0.7,
    legend=True,
    legend_kwds={"label": "Number of Articles", "shrink": 0.6},
    ax=ax
)

ax.set_title(
    "Geographic Heatmap of GDELT GKG Article Counts by Country (Africa + Yemen)\nGrey = No Articles in Dataset",
    fontsize=18, fontweight="bold",
    pad=20
)

ax.axis("off")

# -------------------------------------------------------
# 6. LABELS
# -------------------------------------------------------
for idx, row in geo_df.iterrows():
    centroid = row.geometry.centroid
    ax.annotate(
        row["name"],
        xy=(centroid.x, centroid.y),
        ha="center",
        fontsize=9,
        color="black"
    )

plt.show()

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
import statsmodels.formula.api as smf

# ======================================================
# 0. EXTRACT UNIQUE ENTITIES
# ======================================================
countries = df_final["ADMIN0"].unique()
periods = sorted(df_final["period"].astype(int).unique())

# Create full country–period panel
df_panel = pd.MultiIndex.from_product(
    [countries, periods], names=["ADMIN0", "period"]
).to_frame(index=False)

# ======================================================
# 1. ARTICLE COUNTS PER COUNTRY-PERIOD
# ======================================================
df_articles_period = (
    df_final
        .explode("valid_DocumentIdentifier")
        .dropna(subset=["valid_DocumentIdentifier"])
        .groupby(["ADMIN0", "period"])["valid_DocumentIdentifier"]
        .nunique()
        .reset_index(name="article_count")
)

# Ensure period is numeric
df_articles_period["period"] = df_articles_period["period"].astype(int)

# ======================================================
# 2. FOOD SECURITY PER COUNTRY-PERIOD
# ======================================================
df_fs_period = (
    df_final
        .groupby(["ADMIN0", "period"])["CS_score"]
        .max()
        .reset_index()
)

df_fs_period["period"] = df_fs_period["period"].astype(int)

df_fs_period_lag = df_fs_period.copy()
df_fs_period_lag = df_fs_period_lag[df_fs_period_lag['CS_score'].isin([1,2,3,4,5])]
df_fs_period_lag['prev_cs'] = df_fs_period_lag.groupby('ADMIN0')['CS_score'].shift(1)
df_fs_period_lag.drop('CS_score', axis=1, inplace=True)
# ======================================================
# 3. MERGE BOTH INTO THE COMPLETE PANEL
# ======================================================
df_period = (
    df_panel
        .merge(df_fs_period, on=["ADMIN0", "period"], how="left")
        .merge(df_articles_period, on=["ADMIN0", "period"], how="left")
        .merge(df_fs_period_lag, on=["ADMIN0", "period"], how="left")
)

def compute_articles_between(g):
    # g = one country's dataframe, already sorted by period
    fs_idx = g.index[g["CS_score"].notna()].tolist()   # indices of FEWSNET periods
    
    results = pd.Series(index=g.index, dtype="float")

    # For each pair of consecutive FEWSNET periods
    for i in range(1, len(fs_idx)):
        prev_i = fs_idx[i-1]
        curr_i = fs_idx[i]
        
        # sum article_count BETWEEN these periods (exclusive of prev, inclusive of curr)
        s = g.loc[prev_i:curr_i, "article_count"].sum() - g.loc[prev_i, "article_count"]

        # assign ONLY to the current FEWSNET period
        results[curr_i] = s

    return results

df_period["articles_between"] = (
    df_period
    .sort_values(["ADMIN0","period"])
    .groupby("ADMIN0", group_keys=False)
    .apply(compute_articles_between)
)


df_valid = df_period.copy()
df_valid = df_valid[df_valid["CS_score"].isin([1,2,3,4,5])]

# ======================================================
# 6. USE ACTUAL FOOD SECURITY LEVEL (CS_score)
# ======================================================
# Drop rows where article sums are missing or CS_score is missing
df_corr = df_valid.dropna(subset=["articles_between", "CS_score"])

# ======================================================
# 7. CORRELATION (raw, without controls)
# ======================================================
print("\nRaw correlation between CS_score and articles_between:")
print(df_corr[["CS_score", "articles_between"]].corr())

# ======================================================
# 8. SCATTERPLOT WITH REGRESSION LINE (raw)
# ======================================================
plt.figure(figsize=(10,7))
sns.regplot(
    data=df_corr,
    x="CS_score",
    y="articles_between",
    scatter_kws={"s": 60, "alpha": 0.7},
    line_kws={"color": "red"}
)
plt.title("Articles Between FEWSNET Periods vs Food Security Level (Raw Relationship)")
plt.xlabel("Food Security Level (CS_score)")
plt.ylabel("Articles Between Periods")
plt.tight_layout()
plt.show()

# ======================================================
# 9. OLS REGRESSION WITH FIXED EFFECTS
# ======================================================
# Country FE + Period FE
model_fe = smf.ols(
    formula="articles_between ~ CS_score + C(ADMIN0) + C(period)",
    data=df_corr
).fit()

print("\nOLS with Country & Period Fixed Effects (articles_between ~ CS_score + FE):")
print(model_fe.summary())

## Admin1 distribution

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# --- Compute unique article counts per ADMIN1 ---
admin1_counts = (
    df_final
        .explode("valid_DocumentIdentifier")
        .groupby("ADMIN1")["valid_DocumentIdentifier"]
        .nunique()
        .reset_index(name="article_count")
)

# --- FIX: ensure article_count is integer ---
admin1_counts["article_count"] = admin1_counts["article_count"].astype(int)

# --- Correct bucket definitions ---
bins   = [-1, 0, 5, 10, 20, 50, float("inf")]
labels = ["0", "1–5", "6–10", "11–20", "21–50", "50+"]

admin1_counts["bucket"] = pd.cut(
    admin1_counts["article_count"],
    bins=bins,
    labels=labels,
    include_lowest=True
)

# --- Correct distribution: count ADMIN1 per bucket ---
distribution = (
    admin1_counts["bucket"]
    .value_counts()
    .reindex(labels)
    .reset_index()
)

distribution.columns = ["bucket", "num_admin1"]
print(distribution)
# --- Plot ---
plt.figure(figsize=(12,6))
sns.barplot(data=distribution, x="bucket", y="num_admin1", color="steelblue")
plt.title("ADMIN1 Distribution by Article Count Buckets (GDELT GKG)")
plt.xlabel("Article Count Bucket")
plt.ylabel("Number of ADMIN1 Regions")
plt.tight_layout()
plt.show()

## ADMIN2 check

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# --- Compute unique article counts per ADMIN1 ---
admin2_counts = (
    df_final
        .explode("valid_DocumentIdentifier")
        .groupby("ADMIN2")["valid_DocumentIdentifier"]
        .nunique()
        .reset_index(name="article_count")
)

# --- FIX: ensure article_count is integer ---
admin2_counts["article_count"] = admin2_counts["article_count"].astype(int)

# --- Correct bucket definitions ---
bins   = [-1, 0, 5, 10, 20, 50, float("inf")]
labels = ["0", "1–5", "6–10", "11–20", "21–50", "50+"]

admin2_counts["bucket"] = pd.cut(
    admin2_counts["article_count"],
    bins=bins,
    labels=labels,
    include_lowest=True
)

# --- Correct distribution: count ADMIN1 per bucket ---
distribution = (
    admin2_counts["bucket"]
    .value_counts()
    .reindex(labels)
    .reset_index()
)

distribution.columns = ["bucket", "num_admin1"]

# --- Plot ---
plt.figure(figsize=(12,6))
sns.barplot(data=distribution, x="bucket", y="num_admin1", color="steelblue")
plt.title("ADMIN2 Distribution by Article Count Buckets (GDELT GKG)")
plt.xlabel("Article Count Bucket")
plt.ylabel("Number of ADMIN2 districts")
plt.tight_layout()
plt.show()


## Heatmap for each topic metric

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# ---------------------------------------------------------
# 0. CHOOSE OUTPUT FOLDER
# ---------------------------------------------------------
output_folder = "/Users/marco.bertetti/Desktop/charts/nlp/gkg"
os.makedirs(output_folder, exist_ok=True)

# ---------------------------------------------------------
# 1. LIST OF ALL METRICS YOU WANT TO INCLUDE
# ---------------------------------------------------------
metric_list_columns = [
    'fatalities_freq_list', 'fatalities_count_list', 
    'displaced_freq_list', 'displaced_count_list',
    'detained_freq_list', 'detained_count_list',
    'injured_freq_list', 'injured_count_list',
    'sexual_violence_freq_list', 'sexual_violence_count_list',
    'torture_freq_list', 'torture_count_list',
    'economic_shocks_freq_list', 'agriculture_freq_list',
    'weather_freq_list', 'food_insecurity_freq_list',
    'pred_resource_food_list', 'pred_resource_water_list', 'pred_resource_cash_aid_list',
    'pred_resource_healthcare_list', 'pred_resource_shelter_list',
    'pred_resource_livelihoods_list', 'pred_resource_education_list',
    'pred_resource_infrastructure_list'
]

# ---------------------------------------------------------
# 2. SAFE SUM FUNCTION
# ---------------------------------------------------------
def safe_sum(x):
    if isinstance(x, list):
        clean = []
        for v in x:
            try:
                clean.append(float(v))
            except:
                clean.append(0)
        return sum(clean)
    return 0

# ---------------------------------------------------------
# 3. APPLY SAFE SUM TO ALL LIST COLUMNS
# ---------------------------------------------------------
for col in metric_list_columns:
    new_col = col.replace("_list", "")
    df_final[new_col] = df_final[col].apply(safe_sum)

all_metrics = [c.replace("_list", "") for c in metric_list_columns]

print("Total metrics:", len(all_metrics))
print(all_metrics)

# ---------------------------------------------------------
# 4. GENERATE HEATMAP FOR EACH METRIC (SAVE FIRST, THEN SHOW)
# ---------------------------------------------------------
for metric in all_metrics:

    df_heat = (
        df_final
            .groupby(["ADMIN0", "period"])[metric]
            .sum()
            .unstack(fill_value=0)
            .sort_index()
            .sort_index(axis=1)
    )

    df_heat.index = df_heat.index.str.title()

    plt.figure(figsize=(35, 16))
    ax = sns.heatmap(
        df_heat,
        cmap="YlOrRd",
        linewidths=0.2,
        linecolor="gray",
        annot=True,
        fmt="g",
        annot_kws={"fontsize": 7, "color": "black"},
        cbar_kws={"shrink": 0.5}
    )

    ax.set_title(f"{metric.upper()} by Country and Period", fontsize=26, fontweight='bold', pad=20)
    ax.set_xlabel("Period (YYYYMM)", fontsize=20)
    ax.set_ylabel("Country (ADMIN0)", fontsize=20)

    ax.set_xticklabels(df_heat.columns, fontsize=10, rotation=90)
    ax.set_yticklabels(df_heat.index, fontsize=14)

    plt.tight_layout()

    # ---- SAVE BEFORE SHOW ----
    file_path = os.path.join(output_folder, f"{metric.upper()}.png")
    plt.savefig(file_path, dpi=300)
    print(f"Saved: {file_path}")

    # ---- NOW SHOW ----
    plt.show()

    plt.close()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os

# ---------------------------------------------------------
# 0. CHOOSE OUTPUT FOLDER
# ---------------------------------------------------------
output_folder = "/Users/marco.bertetti/Desktop/charts/nlp/gkg/impact_urgency"
os.makedirs(output_folder, exist_ok=True)

# ---------------------------------------------------------
# 1. Convert pred_impact_type_list into category frequency columns
# ---------------------------------------------------------
def count_categories(list_value):
    if isinstance(list_value, list):
        return Counter(list_value)
    return Counter()

# Impact type categories → expand to columns
impact_counts = df_final["pred_impact_type_list"].apply(count_categories)
impact_df = pd.DataFrame(impact_counts.tolist()).fillna(0)
impact_df.columns = [f"impact_{c}" for c in impact_df.columns]

# ---------------------------------------------------------
# 2. Convert pred_urgency_list into category frequency columns
# ---------------------------------------------------------
urgency_counts = df_final["pred_urgency_list"].apply(count_categories)
urgency_df = pd.DataFrame(urgency_counts.tolist()).fillna(0)
urgency_df.columns = [f"urgency_{c}" for c in urgency_df.columns]

# ---------------------------------------------------------
# 3. Merge back into df_final
# ---------------------------------------------------------
df_final2 = pd.concat([df_final, impact_df, urgency_df], axis=1)

# All expanded metric names
categorical_metrics = list(impact_df.columns) + list(urgency_df.columns)

print("Extracted categorical metrics:")
print(categorical_metrics)

# ---------------------------------------------------------
# 4. Generate heatmaps for each categorical metric
# ---------------------------------------------------------
for metric in categorical_metrics:

    df_heat = (
        df_final2
            .groupby(["ADMIN0", "period"])[metric]
            .sum()
            .unstack(fill_value=0)
            .sort_index()
            .sort_index(axis=1)
    )

    # Capitalize country names
    df_heat.index = df_heat.index.str.title()

    # ---- PLOT ----
    plt.figure(figsize=(35, 16))
    ax = sns.heatmap(
        df_heat,
        cmap="Blues",
        linewidths=0.2,
        linecolor="gray",
        annot=True,                      # SHOW NUMBERS
        fmt="g",
        annot_kws={"fontsize": 7, "color": "black"},   # ALWAYS BLACK TEXT
        cbar_kws={"shrink": 0.5}
    )

    ax.set_title(f"{metric.upper()} by Country and Period",
                 fontsize=26, fontweight='bold', pad=20)
    ax.set_xlabel("Period (YYYYMM)", fontsize=20)
    ax.set_ylabel("Country (ADMIN0)", fontsize=20)

    ax.set_xticklabels(df_heat.columns, fontsize=10, rotation=90)
    ax.set_yticklabels(df_heat.index, fontsize=14)

    plt.tight_layout()

    # ---- SAVE BEFORE SHOW ----
    file_path = os.path.join(output_folder, f"{metric.upper()}.png")
    plt.savefig(file_path, dpi=300)
    print(f"Saved: {file_path}")

    # ---- SHOW IN NOTEBOOK ----
    plt.show()

    plt.close()