## Loading environment variables and libraries

In [8]:
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from webdriver_manager.chrome import ChromeDriverManager
import json, pandas as pd, time, re
from pathlib import Path
from datetime import datetime
import os
from dotenv import load_dotenv
from pymongo import MongoClient
from tqdm import tqdm
import json, re
from pathlib import Path
import pandas as pd
import numpy as np

### Scraper for energy data from africa-energy-portal.org 

In [10]:

# ---------------------------------------------
# Config
# ---------------------------------------------
URL = "https://africa-energy-portal.org/database"

# Match ONLY the API that returns the chart data.
# Your original code matched "get-database-data" — keep that unless
# DevTools shows a different path.
ENDPOINT_PATTERN = re.compile(r"get-database-data", re.I)

OUT_DIR = Path("scraped_json")
OUT_DIR.mkdir(exist_ok=True)

WAIT_SHORT = 3
WAIT_LONG = 25

# XPaths you can tune if the DOM changes
XPATHS = {
    # Year and Country dropdowns (Select2-style). If your IDs differ, update here.
    "year_dropdown": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[1]/div/div[2]/div[1]/div/div[1]/div/a',
    "year_all_option": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[1]/div/div[2]/div[1]/div/div[1]/div/div/div[1]/div/label/span',

    "country_dropdown": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[1]/div/div[2]/div[1]/div/div[3]/div/a',
    "country_all_option": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[1]/div/div[2]/div[1]/div/div[3]/div/div/div[1]/div/label/span',

    # Sector tabs by visible text
    "sector_button_by_text": '//*[normalize-space()="{text}"]',

    # Panel containers (used to ensure we’re on the right panel and to scope Apply)
    "electricity_container": '//*[@id="electricity"]',
    "energy_container": '//*[@id="energy"]',

    # Submenu tabs within each sector container
    "electricity_tab_by_text": '//*[@id="electricity"]//*[normalize-space()="{text}"]',
    "energy_tab_by_text": '//*[@id="energy"]//*[normalize-space()="{text}"]',

    # “All” checkbox per sector (from your notes)
    "electricity_all_checkbox": '//*[@id="electricity"]/div/div[1]/div/div/label/span',
    "energy_all_checkbox": '//*[@id="energy"]/div/div[1]/div/div/label/span',

    # Apply button scoped to the visible container (safer than a global “Apply”)
    "electricity_apply": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[2]/div/div[1]/div/div[3]/div/a[1]',
    "energy_apply": '//*[@id="block-newdatabaseblock"]/div[1]/div/div/div[2]/div/div[1]/div/div[3]/div/a[1]',
}

SECTORS = {
    "Electricity": {
        "container_key": "electricity_container",
        "tab_xpath_tpl": "electricity_tab_by_text",
        "all_checkbox": "electricity_all_checkbox",
        "apply_xpath": "electricity_apply",
        "submenus": ["Access", "Supply", "Technical"],
    },
    "Energy": {
        "container_key": "energy_container",
        "tab_xpath_tpl": "energy_tab_by_text",
        "all_checkbox": "energy_all_checkbox",
        "apply_xpath": "energy_apply",
        "submenus": ["Access", "Efficiency"],
    },
}

# ---------------------------------------------
# Step 1: Setup brave/chrome driver (your working launcher)
# ---------------------------------------------
def setup_brave_driver():
    brave_path = r"c:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"

    chrome_options = Options()
    chrome_options.binary_location = brave_path
    chrome_options.add_argument("--start-maximized")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    # Keep window open during debug (optional): chrome_options.add_experimental_option("detach", True)

    # Enable performance logging for Network.* events
    caps = DesiredCapabilities.CHROME.copy()
    caps["goog:loggingPrefs"] = {"performance": "ALL"}
    chrome_options.set_capability("goog:loggingPrefs", {"performance": "ALL"})

    # Pin to your working Chrome/brave Driver version
    driver_path = ChromeDriverManager(driver_version="142.0.7444.135").install()
    driver = webdriver.Chrome(service=Service(driver_path), options=chrome_options)

    # Enable CDP Network domain for getResponseBody
    driver.execute_cdp_cmd("Network.enable", {})
    driver.execute_cdp_cmd("Network.setCacheDisabled", {"cacheDisabled": True})
    driver.set_page_load_timeout(90)
    return driver

# ---------------------------------------------
# Small helpers
# ---------------------------------------------
def wait_click(driver, xpath, timeout=WAIT_LONG):
    el = WebDriverWait(driver, timeout).until(EC.element_to_be_clickable((By.XPATH, xpath)))
    try:
        el.click()
    except Exception:
        driver.execute_script("arguments[0].click();", el)
    return el

def visible(driver, xpath, timeout=WAIT_LONG):
    return WebDriverWait(driver, timeout).until(EC.visibility_of_element_located((By.XPATH, xpath)))

def safe_click_by_text(driver, template_key, text, timeout=WAIT_LONG):
    return wait_click(driver, XPATHS[template_key].format(text=text), timeout=timeout)

def select_all_in_select2(driver, dropdown_xpath, all_option_xpath):
    wait_click(driver, dropdown_xpath)
    visible(driver, all_option_xpath)
    time.sleep(0.2)
    wait_click(driver, all_option_xpath)
    time.sleep(0.3)

def drain_perf_logs(driver):
    # Clear old logs so we only parse fresh events after “Apply”
    try:
        driver.get_log("performance")
    except Exception:
        pass

def collect_matching_response_events(driver):
    matches = []
    for raw in driver.get_log("performance"):
        msg = json.loads(raw.get("message", "{}")).get("message", {})
        if msg.get("method") == "Network.responseReceived":
            resp = msg.get("params", {}).get("response", {})
            url = resp.get("url", "")
            if re.search(ENDPOINT_PATTERN, url):
                matches.append(msg.get("params", {}))
    return matches

def get_body_by_request_id(driver, request_id):
    try:
        body = driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": request_id})
        return body.get("body", None)
    except Exception:
        return None

def click_apply_and_capture(driver, apply_xpath, sector_name, submenu_name, wait_after=2.8):
    drain_perf_logs(driver)
    wait_click(driver, apply_xpath)
    # allow network calls to complete
    time.sleep(wait_after)

    events = collect_matching_response_events(driver)
    if not events:
        time.sleep(1.5)
        events = collect_matching_response_events(driver)
    if not events:
        print(f"[warn] No matching JSON for {sector_name} → {submenu_name}")
        return None

    req_id = events[-1].get("requestId")
    body = get_body_by_request_id(driver, req_id)
    if not body:
        print(f"[warn] No body for requestId={req_id}")
        return None

    # Parse JSON (handle occasional wrappers)
    try:
        return json.loads(body)
    except json.JSONDecodeError:
        start, end = body.find("{"), body.rfind("}")
        if start != -1 and end != -1 and end > start:
            try:
                return json.loads(body[start:end+1])
            except Exception:
                return None
        return None

def save_json_blob(json_data, sector, submenu):
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    fp = OUT_DIR / f"{sector}_{submenu}_{ts}.json"
    fp.write_text(json.dumps(json_data, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"[saved] {fp}")
    return fp

# ---------------------------------------------
# Step 2: Full flow
# ---------------------------------------------
def scrape_portal():
    driver = setup_brave_driver()
    try:
        print("Opening page…")
        driver.get(URL)

        # Let initial scripts settle a bit
        time.sleep(2)

        # Year=All, Country=All
        print("Selecting Year=All")
        select_all_in_select2(driver, XPATHS["year_dropdown"], XPATHS["year_all_option"])
        print("Selecting Country=All")
        select_all_in_select2(driver, XPATHS["country_dropdown"], XPATHS["country_all_option"])

        # Iterate sectors
        for sector_name, cfg in SECTORS.items():
            print(f"Sector → {sector_name}")
            safe_click_by_text(driver, "sector_button_by_text", sector_name)
            visible(driver, XPATHS[cfg["container_key"]])
            time.sleep(0.3)

            # Iterate submenus
            for submenu in cfg["submenus"]:
                print(f"  Submenu → {submenu}")
                safe_click_by_text(driver, cfg["tab_xpath_tpl"], submenu)
                time.sleep(0.2)

                # Tick All
                wait_click(driver, XPATHS[cfg["all_checkbox"]])
                time.sleep(0.25)

                # Apply and capture JSON
                data = click_apply_and_capture(
                    driver,
                    XPATHS[cfg["apply_xpath"]],
                    sector_name,
                    submenu
                )
                if data is not None:
                    save_json_blob(data, sector_name, submenu)
                else:
                    print(f"[warn] No JSON captured for {sector_name} → {submenu}")

                # Untick All (reset)
                wait_click(driver, XPATHS[cfg["all_checkbox"]])
                time.sleep(0.2)

        print("Done.")
    finally:
        # Comment this out while debugging if you want the window to stay open
        driver.quit()

# ---------------------------------------------
# Step 3: Optional — flatten the saved JSONs to CSV
# (depends on the schema you see in your files)
# ---------------------------------------------
def json_folder_to_csv(folder=OUT_DIR, out_csv="africa_energy_data.csv"):
    records = []
    for fp in sorted(Path(folder).glob("*.json")):
        blob = json.loads(fp.read_text(encoding="utf-8"))
        # Your earlier JSON shape looked like: [{ "_id": indicator, "data": [ {...}, ...] }, ...]
        if isinstance(blob, list):
            for block in blob:
                indicator = block.get("_id")
                for item in block.get("data", []) or []:
                    row = dict(item)
                    row["indicator"] = indicator
                    # try infer sector/submenu from filename
                    name = fp.stem.split("_")
                    if len(name) >= 2:
                        row["sector"] = name[0]
                        row["submenu"] = name[1]
                    records.append(row)
        else:
            # handle single-object responses too
            blob["__file"] = fp.name
            records.append(blob)

    if not records:
        print("No JSON files parsed.")
        return None

    df = pd.DataFrame(records)
    df.to_csv(out_csv, index=False)
    print(f"[csv] Saved {out_csv} ({len(df)} rows)")
    return df

# ---------------------------------------------
# Run
# ---------------------------------------------
if __name__ == "__main__":
    scrape_portal()
    # Optionally combine saved JSONs after a run:
    # json_folder_to_csv()


Opening page…
Selecting Year=All
Selecting Country=All
Sector → Electricity
  Submenu → Access


KeyboardInterrupt: 

### Data preprocessing and cleaning

In [None]:
# converting the scraped JSON data to a preprocessed wide CSV format

# -------------------- CONFIG --------------------
READ_FROM_FLAT_CSV = False   # read JSONs instead of CSV
IN_DIR = Path("scraped_json")     # where the JSON files are
OUT_CSV = "aep_preprocessed_wide_2000_2022.csv"
SOURCE_LINK = "https://africa-energy-portal.org/database"

year_cols = [str(y) for y in range(2000, 2023)]          # 2000..2022
early_year_cols = [str(y) for y in range(2000, 2012)]    # 2000..2011

# -------------------- LOAD ----------------------
def load_from_flat(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path, dtype={"country_code": "string"}, low_memory=False)
    return df

def load_from_json(dir_path: Path) -> pd.DataFrame:
    records = []
    for fp in sorted(dir_path.glob("*.json")):
        blob = json.loads(fp.read_text(encoding="utf-8"))
        if isinstance(blob, list):
            for block in blob:
                metric_full = block.get("_id")  # e.g. "... (Millions of people)"
                for item in block.get("data", []) or []:
                    records.append({
                        "country_code":    item.get("id"),
                        "country_name":    item.get("name"),
                        "year":            item.get("year"),
                        "value":           item.get("score"),
                        "unit":            item.get("unit"),
                        "region_name":     item.get("region_name"),
                        "indicator_topic": item.get("indicator_topic"),   # Access/Supply/Technical/Efficiency
                        "indicator_group": item.get("indicator_group"),   # Electricity/Energy
                        "indicator_name":  item.get("indicator_name"),
                        "indicator_source":item.get("indicator_source"),
                        "metric":          metric_full,                   # keep _id as metric
                        "__file":          fp.name,
                    })
        elif isinstance(blob, dict):
            records.append({**blob, "__file": fp.name})
    return pd.DataFrame.from_records(records)

if READ_FROM_FLAT_CSV:
    df = load_from_flat(FLAT_CSV)
    if "metric" not in df.columns:
        df["metric"] = np.where(
            df["indicator_name"].notna() & df["unit"].notna(),
            df["indicator_name"].astype(str),
            df["indicator_name"].astype(str)
        )
else:
    df = load_from_json(IN_DIR)

# -------------------- CLEAN / TYPES --------------
df.rename(columns={
    "country_name": "country",
    "indicator_group": "sector",
    "indicator_topic": "sub_sector",
    "indicator_source": "source"
}, inplace=True)

df["year"] = pd.to_numeric(df["year"], errors="coerce").astype("Int64")
df["value"] = pd.to_numeric(df.get("value", df.get("score")), errors="coerce")

def strip_unit_parenthetical(s: str) -> str:
    if not isinstance(s, str):
        return s
    return re.sub(r"\s*\([^)]*\)\s*$", "", s).strip()

df["sub_sub_sector"] = df["indicator_name"].apply(strip_unit_parenthetical)

if "unit" not in df.columns or df["unit"].isna().all():
    df["unit"] = df["metric"].str.extract(r"\(([^()]*)\)\s*$", expand=False)

df["source_link"] = SOURCE_LINK

# -------------------- COUNTRY SERIAL -------------
country_order = (df["country"].dropna().drop_duplicates().sort_values().reset_index(drop=True))
serial_map = {name: i+1 for i, name in enumerate(country_order)}
df["country_serial"] = df["country"].map(serial_map).astype("Int64")

# -------------------- BUILD WIDE ------------------
tidy = df[[
    "country", "country_serial", "metric", "unit", "sector", "sub_sector", "sub_sub_sector",
    "source_link", "source", "year", "value"
]].copy()

wide = tidy.pivot_table(
    index=["country", "country_serial", "metric", "unit", "sector", "sub_sector", "sub_sub_sector", "source_link", "source"],
    columns="year",
    values="value",
    aggfunc="first"
).reset_index()

# Ensure all year columns are present and named as strings
for y in year_cols:
    y_int = int(y)
    if y_int in wide.columns:
        wide.rename(columns={y_int: y}, inplace=True)
    if y not in wide.columns:
        wide[y] = np.nan

# -------------------- NULL HANDLING ----------------
# 1) If ALL early years (2000..2011) are null for a row, set those early years to 0
early_all_null_mask = wide[early_year_cols].isna().all(axis=1)
wide.loc[early_all_null_mask, early_year_cols] = 0.0

# 2) For remaining gaps "in between": interpolate across years per row, then ffill/bfill
years_df = wide[year_cols].astype(float)

# interpolate across columns (years). T -> interpolate along index -> T back
years_interp = years_df.T.interpolate(limit_direction="both").T

# As a fallback (if some remained NaN), do forward then backward fill across columns
years_filled = years_interp.T.ffill().bfill().T

wide[year_cols] = years_filled

# -------------------- ORDER / SAVE ----------------
final_cols = ["country", "country_serial", "metric", "unit", "sector", "sub_sector",
              "sub_sub_sector", "source_link", "source"] + year_cols
wide = wide[final_cols].sort_values(["country_serial", "metric"]).reset_index(drop=True)

wide.to_csv(OUT_CSV, index=False)
print(f"Saved -> {OUT_CSV}  (rows: {len(wide)}, cols: {len(wide.columns)})")

# Optional quick check
null_counts = wide[year_cols].isna().sum()
print("Remaining nulls per year (should be few or none):")
print(null_counts[null_counts > 0].to_string())


Saved -> aep_preprocessed_wide_2000_2022.csv  (rows: 1781, cols: 32)
Remaining nulls per year (should be few or none):
Series([], )


In [33]:
df = pd.read_csv(r"D:\LuxDev\Internship\notebooks\aep_preprocessed_wide_2000_2022.csv")
df.head()

Unnamed: 0,country,country_serial,metric,unit,sector,sub_sector,sub_sub_sector,source_link,source,2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022
0,Algeria,1,Electricity export (GWh),GWh,Electricity,Supply,Electricity export,https://africa-energy-portal.org/database,AFREC Database,319.0,196.0,259.0,212.0,197.0,275.0,298.0,273.0,323.0,362.0,803.0,799.0,985.0,384.0,877.0,641.0,507.0,918.764526,833.876404,975.834595,1128.864014,1529.344971,1529.344971
1,Algeria,1,Electricity final consumption (GWh),GWh,Electricity,Supply,Electricity final consumption,https://africa-energy-portal.org/database,AFREC Database,18592.0,19664.0,20739.0,22699.0,23608.0,26656.0,26456.0,27991.0,29953.0,27911.0,33470.0,35867.0,40777.0,40188.0,45751.0,47956.96875,52288.0,56376.101562,58152.601562,59053.726562,60044.328125,62502.121094,62502.121094
2,Algeria,1,Electricity final consumption per capita (KWh),KWh per capita,Electricity,Supply,Electricity final consumption per capita,https://africa-energy-portal.org/database,AFREC Database,596.209656,622.433105,648.194031,700.510559,719.07428,800.758484,783.233521,816.062256,859.219421,786.984375,926.69397,974.128967,1085.480591,1048.239624,1169.704102,1202.787354,1287.689941,1364.439453,1384.320435,1383.671143,1385.640747,1421.494507,1421.494507
3,Algeria,1,Electricity generated from biofuels and waste ...,GWh,Electricity,Supply,Electricity generated from biofuels and waste,https://africa-energy-portal.org/database,AFREC Database,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Algeria,1,Electricity generated from fossil fuels (GWh),GWh,Electricity,Supply,Electricity generated from fossil fuels,https://africa-energy-portal.org/database,AFREC Database,25358.0,26556.0,27591.0,29571.0,30634.0,33360.0,35008.0,36970.0,39753.0,42427.0,42663.0,51062.40625,56776.0,59560.0,63988.0,68576.0,70663.0,75381.96875,75879.992188,79441.46875,83299.226562,84709.507812,84709.507812


In [28]:
# Analyzing the data info
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1781 entries, 0 to 1780
Data columns (total 32 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   country         1781 non-null   object 
 1   country_serial  1781 non-null   int64  
 2   metric          1781 non-null   object 
 3   unit            1781 non-null   object 
 4   sector          1781 non-null   object 
 5   sub_sector      1781 non-null   object 
 6   sub_sub_sector  1781 non-null   object 
 7   source_link     1781 non-null   object 
 8   source          1781 non-null   object 
 9   2000            1781 non-null   float64
 10  2001            1781 non-null   float64
 11  2002            1781 non-null   float64
 12  2003            1781 non-null   float64
 13  2004            1781 non-null   float64
 14  2005            1781 non-null   float64
 15  2006            1781 non-null   float64
 16  2007            1781 non-null   float64
 17  2008            1781 non-null   f

In [29]:
# Displaying the last 5 rows of the dataframe
df.tail()

Unnamed: 0,country,country_serial,metric,unit,sector,sub_sector,sub_sub_sector,source_link,source,2000,2001,2002,2003,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019,2020,2021,2022
1776,Zimbabwe,54,Population with access to electricity-Rural (m...,Millions of people,Electricity,Access,Population with access to electricity-Rural,https://africa-energy-portal.org/database,Tracking SDG7/WBG,0.591027,0.618687,0.588757,0.728449,0.822765,0.912299,0.815137,1.121662,1.243138,1.803378,1.486813,1.358807,2.351611,1.916208,0.800974,1.162808,2.394648,2.710675,3.026958,3.318449,4.444583,3.863523,4.20057
1777,Zimbabwe,54,Population with access to electricity-Urban (m...,Millions of people,Electricity,Access,Population with access to electricity-Urban,https://africa-energy-portal.org/database,Tracking SDG7/WBG,3.527871,3.610542,3.686423,3.706049,3.726027,3.759051,4.06709,3.837064,3.868404,4.190422,3.992765,3.949866,4.121153,4.195922,4.176997,4.154192,4.469254,4.56248,4.651661,4.759478,4.873023,4.986096,5.036577
1778,Zimbabwe,54,Population without access to electricity-Natio...,Millions of people,Electricity,Access,Population without access to electricity-National,https://africa-energy-portal.org/database,Tracking SDG7/WBG,8.103352,8.13694,8.225349,8.199402,8.228717,8.26868,8.242042,8.371184,8.446927,7.8168,8.606742,9.077977,8.238065,8.942379,10.433708,10.46045,9.286457,9.256744,9.23464,9.219563,8.362863,9.21083,9.200273
1779,Zimbabwe,54,Population without access to electricity-Rural...,Millions of people,Electricity,Access,Population without access to electricity-Rural,https://africa-energy-portal.org/database,Tracking SDG7/WBG,7.505235,7.521943,7.588432,7.549385,7.572804,7.613867,7.856543,7.710327,7.76511,7.397531,7.923398,8.278125,7.52911,8.222462,9.601296,9.505162,8.539822,8.490465,8.438792,8.407521,7.535427,8.362817,8.26403
1780,Zimbabwe,54,Population without access to electricity-Urban...,Millions of people,Electricity,Access,Population without access to electricity-Urban,https://africa-energy-portal.org/database,Tracking SDG7/WBG,0.598118,0.614993,0.636913,0.650014,0.655915,0.654815,0.385497,0.660856,0.681817,0.419268,0.683341,0.799851,0.708952,0.719917,0.832404,0.955293,0.74664,0.766286,0.795852,0.812047,0.827437,0.848017,0.936249


In [30]:
# Checking for missing values in each column
df.isna().sum()

country           0
country_serial    0
metric            0
unit              0
sector            0
sub_sector        0
sub_sub_sector    0
source_link       0
source            0
2000              0
2001              0
2002              0
2003              0
2004              0
2005              0
2006              0
2007              0
2008              0
2009              0
2010              0
2011              0
2012              0
2013              0
2014              0
2015              0
2016              0
2017              0
2018              0
2019              0
2020              0
2021              0
2022              0
dtype: int64

In [37]:
# Checking summary statistics of the dataframe
df.describe(include='all').T

Unnamed: 0,count,unique,top,freq,mean,std,min,25%,50%,75%,max
country,1781.0,54.0,Algeria,33.0,,,,,,,
country_serial,1781.0,,,,27.499158,15.594497,1.0,14.0,27.0,41.0,54.0
metric,1781.0,33.0,Electricity export (GWh),54.0,,,,,,,
unit,1781.0,5.0,GWh,648.0,,,,,,,
sector,1781.0,2.0,Electricity,1674.0,,,,,,,
sub_sector,1781.0,4.0,Supply,756.0,,,,,,,
sub_sub_sector,1781.0,33.0,Electricity export,54.0,,,,,,,
source_link,1781.0,1.0,https://africa-energy-portal.org/database,1781.0,,,,,,,
source,1781.0,3.0,AFREC Database,756.0,,,,,,,
2000,1781.0,,,,948.184949,8690.085254,-6492.155762,0.0,1.778905,78.741936,210384.0


### Loading the data from JSON files to a MongoDB database

In [None]:

# -------------------- LOAD ENV VARIABLES --------------------
load_dotenv()

MONGO_URI = os.getenv("MONGO_URI")
MONGO_DB = os.getenv("MONGO_DB", "aep_database")
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "aep_data")
CSV_PATH = "aep_preprocessed_wide_2000_2022.csv"

if not MONGO_URI:
    raise ValueError("Missing MONGO_URI in .env file")

# -------------------- CONNECT TO MONGODB --------------------
client = MongoClient(MONGO_URI)
db = client[MONGO_DB]
collection = db[MONGO_COLLECTION]
print(f"Connected to MongoDB: {MONGO_DB}.{MONGO_COLLECTION}")

# -------------------- READ CSV --------------------
print("Reading CSV file...")
df = pd.read_csv(CSV_PATH)

# Replace NaN with None for MongoDB compatibility
df = df.where(pd.notnull(df), None)

# -------------------- CONVERT & INSERT --------------------
records = df.to_dict("records")

# Optional: drop existing records before insert
collection.delete_many({})
print("Cleared existing collection")

# Insert in batches
batch_size = 1000
print(f"Inserting {len(records)} records in batches of {batch_size}...")
for i in tqdm(range(0, len(records), batch_size)):
    batch = records[i:i + batch_size]
    collection.insert_many(batch)

print(f"Successfully inserted {len(records)} records into MongoDB")

# -------------------- VERIFY --------------------
count = collection.count_documents({})
print(f"Collection now contains {count} documents")

client.close()


Connected to MongoDB: african_energy_data.energy_data
Reading CSV file...
Cleared existing collection
Inserting 1781 records in batches of 1000...


100%|██████████| 2/2 [00:10<00:00,  5.01s/it]


Successfully inserted 1781 records into MongoDB
Collection now contains 1781 documents
