In [1]:
import pandas as pd
import plotly.express as px
import numpy as np

In [26]:
# --- Load and clean ---
gni_df = pd.read_csv("labelled_gni_data.csv")
pop_df = pd.read_csv("population_data.csv")

# Pick the most recent available GDP per capita value
gni_df["gni"] = gni_df["2024"].fillna(gni_df["2023"]).fillna(gni_df["2022"])

# Pick the most recent available population value
pop_df["population"] = pop_df["2024"].fillna(pop_df["2023"]).fillna(pop_df["2022"])

#calculate gni per capita
df = pd.merge(gni_df[["Country Name", "Country Code", "gni"]], pop_df[["Country Code", "population"]], on="Country Code")
df["gni_per_capita"] = df["gni"] / df["population"]

# convert to monthly as the LLM cost is monthly
df["gni_per_capita"] = df["gni_per_capita"] / 12

# Keep only valid ISO3 country codes (drop aggregates)
df = df[df["Country Code"].str.len() == 3]

# Drop countries with no GDP data
df = df.dropna(subset=["gni_per_capita"])

# --- Compute LLM cost burden ---
df["LLM_cost_burden_percent"] = (20 / df["gni_per_capita"]) * 100

# Log transform (add small constant to avoid log(0))
df["LLM_cost_burden_log"] = np.log10(df["LLM_cost_burden_percent"] + 1e-6)

In [27]:
df.head()

Unnamed: 0,Country Name,Country Code,gni,population,gni_per_capita,LLM_cost_burden_percent,LLM_cost_burden_log
0,Aruba,ABW,4409839000.0,107624.0,3414.541096,0.58573,-0.232302
1,Africa Eastern and Southern,AFE,3358020000000.0,769294618.0,363.755307,5.498202,0.740221
2,Afghanistan,AFG,91710010000.0,42647492.0,179.201656,11.160611,1.047688
3,Africa Western and Central,AFW,2833010000000.0,521764076.0,452.473019,4.420153,0.645437
4,Angola,AGO,290360000000.0,37885849.0,638.672943,3.131493,0.495752


In [30]:
# --- Plot ---
fig = px.choropleth(
    df,
    locations="Country Code",
    color="LLM_cost_burden_log",
    hover_name="Country Name",
    hover_data={"LLM_cost_burden_percent": ":.2f"},
    color_continuous_scale="Viridis",   # good for log data
    title="Log-scaled LLM Cost Monthly Burden (% of GNI per Capita Monthly for $20)",
    width=1800,
    height=1080,
)

fig.update_layout(
    geo=dict(showframe=False, showcoastlines=True, projection_type="equirectangular")
)

fig.show()