# Parsing

In [75]:
import pandas as pd
import numpy as np

file_path = "/Users/leoss/Desktop/base_dataset.xlsx"

all_data = []
all_excluded = []

def track_exclusion(df, reason, sheet_name, resource, metric):
    if len(df) > 0:
        excluded = df.copy()
        excluded['_exclusion_reason'] = reason
        excluded['_sheet'] = sheet_name
        excluded['_resource'] = resource
        excluded['_metric'] = metric
        all_excluded.append(excluded)
        return len(df)
    return 0

#Parses most of the sheets of interest
def parse_simple(sheet_name, resource, metric, skip_rows=2):
    df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=skip_rows,
                       keep_default_na=False, na_values=[''])
    
    #Finds the country from the first col
    first_col = df.columns[0]
    df = df.rename(columns={first_col: 'Country'})
    year_cols = [c for c in df.columns if isinstance(c, (int, float)) and 1800 < c < 2030]
    df = df[['Country'] + year_cols]
    excluded_na = df[df['Country'].isna()]
    track_exclusion(excluded_na, 'Country is NA', sheet_name, resource, metric)
    df = df[df['Country'].notna()]
    
    #Removes notes 
    regex_pattern = r'Source|Note|Please|Less than|n/a|includes|Excludes|Commercial|Differences|Annual|methodology|^USSR'
    mask = df['Country'].astype(str).str.contains(regex_pattern, case=False, na=True, regex=True)
    excluded_regex = df[mask]
    track_exclusion(excluded_regex, f'Regex filter: {regex_pattern}', sheet_name, resource, metric)
    df = df[~mask]
    
    #puts it into long format
    df_long = df.melt(id_vars='Country', var_name='Year', value_name='Value')
    df_long['Year'] = pd.to_numeric(df_long['Year'], errors='coerce')
    df_long['Value'] = pd.to_numeric(df_long['Value'], errors='coerce')
    df_long['Resource'] = resource
    df_long['Metric'] = metric
    excluded_dropna = df_long[df_long[['Year', 'Value', 'Country']].isna().any(axis=1)]
    track_exclusion(excluded_dropna, 'dropna (Year/Value/Country NA after melt)', sheet_name, resource, metric)
    
    return df_long.dropna(subset=['Year', 'Value', 'Country'])

# Minerals are formatted differently than oil and gas; extracts both prod and reserves from them
def parse_minerals(sheet_name, resource):
    df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=2,
                       keep_default_na=False, na_values=[''])
    
    first_col = df.columns[0]
    df = df.rename(columns={first_col: 'Country'})
    year_cols = [c for c in df.columns if isinstance(c, (int, float)) and 1990 < c < 2030]
    
    # Find RESERVES column - "At end of"
    reserves_col = None
    for col in df.columns:
        col_str = str(col).lower()
        if 'at end of' in col_str:
            reserves_col = col
            break
    if not year_cols:
        print(f"WARNING: No year columns in {sheet_name}")
        return pd.DataFrame()
    
    #PRODUCTION DATA 
    df_prod = df[['Country'] + year_cols].copy()

    excluded_na = df_prod[df_prod['Country'].isna()]
    track_exclusion(excluded_na, 'Country is NA', sheet_name, resource, 'Production')
    df_prod = df_prod[df_prod['Country'].notna()]

    #Removes notes
    regex_pattern = r'Source|Note|less than|Rest of World|Total World|^$|Mine|Thousand|n/a'
    mask = df_prod['Country'].astype(str).str.contains(regex_pattern, case=False, na=True, regex=True)
    excluded_regex = df_prod[mask]
    track_exclusion(excluded_regex, f'Regex filter: {regex_pattern}', sheet_name, resource, 'Production')
    df_prod = df_prod[~mask]

    #puts it into long format
    df_prod_long = df_prod.melt(id_vars='Country', var_name='Year', value_name='Value')
    df_prod_long['Year'] = pd.to_numeric(df_prod_long['Year'], errors='coerce')
    df_prod_long['Value'] = pd.to_numeric(df_prod_long['Value'], errors='coerce')
    df_prod_long['Resource'] = resource
    df_prod_long['Metric'] = 'Production'
    excluded_dropna = df_prod_long[df_prod_long[['Year', 'Value', 'Country']].isna().any(axis=1)]
    track_exclusion(excluded_dropna, 'dropna (Year/Value/Country NA after melt)', sheet_name, resource, 'Production')
    
    df_prod_long = df_prod_long.dropna(subset=['Year', 'Value', 'Country'])
    
    #RESERVES DATA
    df_reserves_long = pd.DataFrame()
    if reserves_col is not None:
        df_res = df[['Country', reserves_col]].copy()
        df_res.columns = ['Country', 'Value']
        
        df_res = df_res[df_res['Country'].notna()]
        mask = df_res['Country'].astype(str).str.contains(regex_pattern, case=False, na=True, regex=True)
        df_res = df_res[~mask]
        
        df_res['Value'] = pd.to_numeric(df_res['Value'], errors='coerce')
        df_res = df_res.dropna(subset=['Value'])
        df_res = df_res[df_res['Value'] > 0]
        
        df_res['Year'] = 2024
        df_res['Resource'] = resource
        df_res['Metric'] = 'Reserves'
        
        df_reserves_long = df_res
        print(f"Extracted {len(df_reserves_long)} RESERVES rows for {resource}")
    else:
        print(f"No reserves column found for {resource}")
    
    result = pd.concat([df_prod_long, df_reserves_long], ignore_index=True)
    return result


#Parses reserves time series
def parse_reserves_history(sheet_name, resource):
    df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=4,
                       keep_default_na=False, na_values=[''])
    
    #Extract country name
    first_col = df.columns[0]
    df = df.rename(columns={first_col: 'Country'})
    year_cols = [c for c in df.columns if isinstance(c, (int, float)) and 1800 < c < 2030]
    if not year_cols:
        print(f" WARNING: No year columns in {sheet_name}")
        return pd.DataFrame()
    
    #Remove NAs but check which are getting removed
    df = df[['Country'] + year_cols]
    excluded_na = df[df['Country'].isna()]
    track_exclusion(excluded_na, 'Country is NA', sheet_name, resource, 'Reserves')
    df = df[df['Country'].notna()]
    
    #Remove notes
    regex_pattern = r'Source|Note|Please|Less than|^$|methodology|Reserves-to|Total proved|^USSR|of which'
    mask = df['Country'].astype(str).str.contains(regex_pattern, case=False, na=True, regex=True)
    excluded_regex = df[mask]
    track_exclusion(excluded_regex, f'Regex filter: {regex_pattern}', sheet_name, resource, 'Reserves')
    df = df[~mask]
    
    #Transform into long
    df_long = df.melt(id_vars='Country', var_name='Year', value_name='Value')
    df_long['Year'] = pd.to_numeric(df_long['Year'], errors='coerce')
    df_long['Value'] = pd.to_numeric(df_long['Value'], errors='coerce')
    df_long['Resource'] = resource
    df_long['Metric'] = 'Reserves'
    
    excluded_dropna = df_long[df_long['Value'].isna()]
    track_exclusion(excluded_dropna, 'Value is NA after numeric conversion', sheet_name, resource, 'Reserves')
    
    result = df_long[df_long['Value'].notna()]
    return result

#Parser of coal reserves
def parse_coal_reserves():
   
    sheet_name = 'Coal - Reserves'

    # Read raw to check structure
    df_raw = pd.read_excel(file_path, sheet_name=sheet_name, header=None)
    
    #Row 7 is first data row (Canada)
    df = df_raw.iloc[7:].copy()  # Start from first data row
    df.columns = ['Country', 'Anthracite_Bituminous', 'Subbituminous_Lignite', 
                  'Total', 'Share_of_Total', 'RP_Ratio', 'Extra']
    
    # et Country and Total columns
    df = df[['Country', 'Total']].copy()
    df.columns = ['Country', 'Value']
    df['Country'] = df['Country'].astype(str).str.strip()
    
    #Track: Country is NA or invalid
    excluded_na = df[df['Country'].isna() | (df['Country'] == 'nan') | (df['Country'] == '')]
    track_exclusion(excluded_na, 'Country is NA', sheet_name, 'Coal', 'Reserves')
    df = df[df['Country'].notna() & (df['Country'] != 'nan') & (df['Country'] != '')]
    
    #Remove aggregates and notes
    regex_pattern = r'Source|Note|Please|Less than|methodology|Total |^$|nan'
    mask = df['Country'].str.contains(regex_pattern, case=False, na=True, regex=True)
    excluded_regex = df[mask]
    track_exclusion(excluded_regex, f'Regex filter: {regex_pattern}', sheet_name, 'Coal', 'Reserves')
    df = df[~mask]
    df['Value'] = pd.to_numeric(df['Value'], errors='coerce')
    
    # Track: Value NA
    excluded_value_na = df[df['Value'].isna()]
    track_exclusion(excluded_value_na, 'Value is NA after numeric conversion', sheet_name, 'Coal', 'Reserves')
    df = df.dropna(subset=['Value'])
    
    # Filter out zero/negative values
    df = df[df['Value'] > 0]
    df['Year'] = 2020  # "At end 2020" per sheet title
    df['Resource'] = 'Coal'
    df['Metric'] = 'Reserves'
    
    return df
#OIL PRICES PARSER
def parse_oil_prices():
    sheet_name = 'Oil crude prices since 1861'
    df = pd.read_excel(file_path, sheet_name=sheet_name, skiprows=3,
                       keep_default_na=False, na_values=[''])
    df = df.iloc[:, [0, 2]]
    df.columns = ['Year', 'Value']
    df['Year'] = pd.to_numeric(df['Year'], errors='coerce')
    df['Value'] = pd.to_numeric(df['Value'], errors='coerce')
    
    excluded_dropna = df[df.isna().any(axis=1)]
    track_exclusion(excluded_dropna, 'Year or Value is NA', sheet_name, 'Oil', 'Price')
    
    df = df.dropna()
    df['Country'] = 'Global'
    df['Resource'] = 'Oil'
    df['Metric'] = 'Price'
    return df

# RUN ALL PARSERS
print("LOADING DATA - FIXED PARSER")


simple_sheets = [
    ('Oil Production - barrels', 'Oil', 'Production'),
    ('Oil Consumption - barrels', 'Oil', 'Consumption'),
    ('Gas Production - EJ', 'Natural Gas', 'Production'),
    ('Gas Consumption - EJ', 'Natural Gas', 'Consumption'),
    ('Coal Production - EJ', 'Coal', 'Production'),
    ('Coal Consumption - EJ', 'Coal', 'Consumption'),
]

for sheet, resource, metric in simple_sheets:
    try:
        df = parse_simple(sheet, resource, metric, skip_rows=2)
        all_data.append(df)
        print(f"âœ“ {sheet}: {len(df):,} rows")
    except Exception as e:
        print(f"âœ— {sheet}: {e}")
for sheet, resource in [('Oil - Proved reserves history', 'Oil'), 
                        ('Gas - Proved reserves history', 'Natural Gas')]:
    try:
        df = parse_reserves_history(sheet, resource)
        all_data.append(df)
        print(f"âœ“ {sheet}: {len(df):,} rows")
    except Exception as e:
        print(f"âœ— {sheet}: {e}")

print("MINERALS - PRODUCTION AND RESERVES")
minerals = [
    ('Lithium P-R', 'Lithium'),
    ('Cobalt P-R', 'Cobalt'),
    ('Nickel P-R', 'Nickel'),
    ("Tin P-R", "Tin"),
    ("Bauxite P-R", "Bauxite"),
    ('Natural Graphite P-R', 'Natural Graphite'),
    ('Copper P-R', 'Copper'),
    ('Aluminium P-R', 'Aluminium'),
    ('Zinc P-R', 'Zinc'),
    ('Manganese P-R', 'Manganese'),
    ('Rare Earth metals P-R', 'Rare Earth'),
    ('Platinum Group Metals P-R', 'Platinum Group'),
    ('Vanadium P-R', 'Vanadium'),
]

for sheet, resource in minerals:
    try:
        df = parse_minerals(sheet, resource)
        all_data.append(df)
        prod_count = len(df[df['Metric'] == 'Production'])
        res_count = len(df[df['Metric'] == 'Reserves'])
        print(f"âœ“ {sheet}: {prod_count:,} Production + {res_count:,} Reserves = {len(df):,} total")
    except Exception as e:
        print(f"âœ— {sheet}: {e}")

#Coal Reserves
print("COAL RESERVES (FIXED PARSER)")

try:
    df = parse_coal_reserves()
    all_data.append(df)
    print(f"Coal - Reserves: {len(df):,} rows")
except Exception as e:
    print(f"Coal - Reserves: {e}")
    import traceback
    traceback.print_exc()

try:
    df = parse_oil_prices()
    all_data.append(df)
    print(f"âœ“ Oil Prices: {len(df):,} rows")
except Exception as e:
    print(f"âœ— Oil Prices: {e}")

# COMBINE AND EXPORT
df_combined = pd.concat(all_data, ignore_index=True)

print("\nFINAL SUMMARY")
print(f"\nTotal rows: {len(df_combined):,}")
print(f"\nBy Metric:")
print(df_combined['Metric'].value_counts())
print(f"\nReserves by Resource:")
reserves_df = df_combined[df_combined['Metric'] == 'Reserves']
print(reserves_df['Resource'].value_counts())

# Verify Coal reserves extracted correctly
print(f"\nâœ“ Coal Reserves sample:")
coal_res = df_combined[(df_combined['Resource'] == 'Coal') & (df_combined['Metric'] == 'Reserves')]
print(coal_res.head(10))

output_path = "natural_resources_combined.csv"
df_combined.to_csv(output_path, index=False)
print(f"\nâœ“ Combined data exported to: {output_path}")


LOADING DATA - FIXED PARSER
âœ“ Oil Production - barrels: 4,140 rows
âœ“ Oil Consumption - barrels: 6,120 rows
âœ“ Gas Production - EJ: 3,630 rows
âœ“ Gas Consumption - EJ: 6,120 rows
âœ“ Coal Production - EJ: 2,200 rows
âœ“ Coal Consumption - EJ: 6,120 rows
âœ“ Oil - Proved reserves history: 2,720 rows
âœ“ Gas - Proved reserves history: 2,634 rows
MINERALS - PRODUCTION AND RESERVES
Extracted 8 RESERVES rows for Lithium
âœ“ Lithium P-R: 240 Production + 8 Reserves = 248 total
Extracted 13 RESERVES rows for Cobalt
âœ“ Cobalt P-R: 390 Production + 13 Reserves = 403 total
Extracted 7 RESERVES rows for Nickel
âœ“ Nickel P-R: 84 Production + 7 Reserves = 91 total
Extracted 6 RESERVES rows for Tin
âœ“ Tin P-R: 77 Production + 6 Reserves = 83 total
Extracted 7 RESERVES rows for Bauxite
âœ“ Bauxite P-R: 76 Production + 7 Reserves = 83 total
Extracted 12 RESERVES rows for Natural Graphite
âœ“ Natural Graphite P-R: 360 Production + 12 Reserves = 372 total
Extracted 7 RESERVES rows for Copper
âœ“

# Clustering

In [76]:

from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
import warnings
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)

# CONFIGURATION
@dataclass
class Config:
    """Configuration for the clustering analysis."""
    
    resources_path: str
    gdp_path: str
    output_dir: str = ""
    n_clusters: int = 5  
    random_state: int = 42
    dominance_threshold: float = 15.0
    min_share_display: float = 1.0  # Min % share to show in hover
    
    def __post_init__(self):
        if not Path(self.resources_path).exists():
            raise FileNotFoundError(f"Resources file not found: {self.resources_path}")
        if not Path(self.gdp_path).exists():
            raise FileNotFoundError(f"GDP file not found: {self.gdp_path}")
        if self.output_dir:
            Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        if not 2 <= self.n_clusters <= 10:
            raise ValueError("n_clusters must be between 2 and 10")

#COUNTRY CODE MAPPINGS

COUNTRY_CODE_MAP: dict[str, str] = {
    # North America
    "US": "USA", "United States": "USA", "USA": "USA",
    "Canada": "CAN", "Mexico": "MEX",
    
    # Central & South America
    "Brazil": "BRA", "Brazil1": "BRA", "Argentina": "ARG", "Chile": "CHL",
    "Peru": "PER", "Colombia": "COL", "Venezuela": "VEN", "Ecuador": "ECU",
    "Bolivia": "BOL", "Paraguay": "PRY", "Uruguay": "URY", "Guyana": "GUY",
    "Suriname": "SUR", "Trinidad & Tobago": "TTO", "Trinidad and Tobago": "TTO",
    "Cuba": "CUB", "Dominican Republic": "DOM", "Jamaica": "JAM",
    "Haiti": "HTI", "Panama": "PAN", "Costa Rica": "CRI", "Guatemala": "GTM",
    "Honduras": "HND", "El Salvador": "SLV", "Nicaragua": "NIC",
    
    # Western Europe
    "UK": "GBR", "United Kingdom": "GBR", "Germany": "DEU", "France": "FRA",
    "Italy": "ITA", "Spain": "ESP", "Netherlands": "NLD", "Belgium": "BEL",
    "Norway": "NOR", "Sweden": "SWE", "Finland": "FIN", "Denmark": "DNK",
    "Austria": "AUT", "Switzerland": "CHE", "Ireland": "IRL", "Portugal": "PRT",
    "Greece": "GRC", "Luxembourg": "LUX", "Iceland": "ISL", "Cyprus": "CYP",
    
    # Eastern Europe
    "Poland": "POL", "Czech Republic": "CZE", "Czechia": "CZE",
    "Romania": "ROU", "Hungary": "HUN", "Slovakia": "SVK", "Slovak Republic": "SVK",
    "Bulgaria": "BGR", "Croatia": "HRV", "Slovenia": "SVN", "Serbia": "SRB",
    "Ukraine": "UKR", "Belarus": "BLR", "Lithuania": "LTU", "Latvia": "LVA",
    "Estonia": "EST", "North Macedonia": "MKD", "Bosnia and Herzegovina": "BIH",
    
    # Russia & Central Asia
    "Russia": "RUS", "Russian Federation": "RUS", "Kazakhstan": "KAZ",
    "Uzbekistan": "UZB", "Turkmenistan": "TKM", "Tajikistan": "TJK",
    "Kyrgyzstan": "KGZ", "Kyrgyz Republic": "KGZ", "Azerbaijan": "AZE",
    "Georgia": "GEO", "Armenia": "ARM",
    
    # East Asia
    "China": "CHN", "China ": "CHN", "China Hong Kong SAR": "HKG",
    "Hong Kong": "HKG", "Japan": "JPN", "South Korea": "KOR", "Korea": "KOR",
    "Republic of Korea": "KOR", "North Korea": "PRK", "Taiwan": "TWN",
    "Mongolia": "MNG",
    
    # South & Southeast Asia
    "India": "IND", "India2": "IND", "Indonesia": "IDN", "Philippines": "PHL",
    "Vietnam": "VNM", "Viet Nam": "VNM", "Thailand": "THA", "Malaysia": "MYS",
    "Singapore": "SGP", "Myanmar": "MMR", "Burma": "MMR", "Cambodia": "KHM",
    "Bangladesh": "BGD", "Pakistan": "PAK", "Sri Lanka": "LKA", "Nepal": "NPL",
    "Afghanistan": "AFG", "Brunei": "BRN", "Brunei Darussalam": "BRN", "Laos": "LAO",
    
    # Middle East
    "Iran": "IRN", "Iraq": "IRQ", "Saudi Arabia": "SAU", "UAE": "ARE",
    "United Arab Emirates": "ARE", "Kuwait": "KWT", "Qatar": "QAT", "Oman": "OMN",
    "Yemen": "YEM", "Jordan": "JOR", "Lebanon": "LBN", "Syria": "SYR",
    "Israel": "ISR", "Turkey": "TUR", "TÃ¼rkiye": "TUR", "Bahrain": "BHR",
    
    # Africa
    "South Africa": "ZAF", "Nigeria": "NGA", "Egypt": "EGY", "Algeria": "DZA",
    "Morocco": "MAR", "Tunisia": "TUN", "Libya": "LBY", "Sudan": "SDN",
    "South Sudan": "SSD", "Ethiopia": "ETH", "Kenya": "KEN", "Tanzania": "TZA",
    "Uganda": "UGA", "Ghana": "GHA", "Ivory Coast": "CIV", "Cote d'Ivoire": "CIV",
    "Senegal": "SEN", "Cameroon": "CMR", "Angola": "AGO", "Mozambique": "MOZ",
    "Zambia": "ZMB", "Zimbabwe": "ZWE", "Botswana": "BWA", "Namibia": "NAM",
    "Gabon": "GAB", "Congo": "COG", "Republic of Congo": "COG",
    "DR Congo": "COD", "Democratic Republic of Congo": "COD", "DRC": "COD",
    "Congo Dem. Rep.": "COD", "Mali": "MLI", "Burkina Faso": "BFA", "Niger": "NER",
    "Chad": "TCD", "Mauritania": "MRT", "Madagascar": "MDG", "Malawi": "MWI",
    "Rwanda": "RWA", "Equatorial Guinea": "GNQ", "Benin": "BEN", "Togo": "TGO",
    "Guinea": "GIN", "Liberia": "LBR",
    
    # Oceania
    "Australia": "AUS", "Australia ": "AUS", "New Zealand": "NZL",
    "Papua New Guinea": "PNG", "Fiji": "FJI", "New Caledonia": "NCL",
}

#Exclude aggregates and Notes
AGGREGATES_PATTERN = (
    r"^Total|^Other |OPEC|OECD|European Union|of which|^CIS$|Non-|World|Rest of|"
    r"^Middle East$|^Asia Pacific$|^Africa$|^Europe$|^Americas$|^America$|Global|"
    r"Orinoco Belt|Oil Sands|^Central |Eastern Africa|Western Africa|Middle Africa"
    )

#DATA LOADER
class DataLoader:

    """Handles loading and initial cleaning of resource and GDP data."""
    def __init__(self, config: Config):
        self.config = config
    
    def load_resources(self) -> pd.DataFrame:
        """Load and clean resource production/reserves data."""
        df = (
            pd.read_csv(self.config.resources_path)
            .assign(Country=lambda x: x["Country"].str.strip().replace(
                {"Russian Federation": "Russia", "Brazil1": "Brazil",
                 "India2": "India", "Democratic Republic of Congo": "DR Congo",
                 "China ": "China", "Australia ": "Australia"}
            ))
            .query("~Country.str.contains(@AGGREGATES_PATTERN, case=False, na=False)")
        )
        logger.info(f"Loaded {len(df):,} resource records for {df['Country'].nunique()} countries")
        return df
    
    def load_gdp(self) -> pd.DataFrame:
        #Load GDP data and return latest year with valid GDP values.
        df_raw = pd.read_csv(self.config.gdp_path)
        latest_year = df_raw["Year"].max()
        
        df = (
            df_raw[df_raw["Year"] == latest_year]
            .pivot_table(
                index=["Country Code", "Country Name"],
                columns="Series Name",
                values="Value",
                aggfunc="first"
            )
            .reset_index()
        )
        
        if "GDP (current US$)" not in df.columns:
            raise ValueError("GDP column not found in dataset")
        
        df = (
            df[["Country Code", "Country Name", "GDP (current US$)"]]
            .rename(columns={"GDP (current US$)": "GDP"})
            .query("GDP.notna() & GDP > 0")
        )
        
        logger.info(f"Loaded GDP for {len(df)} countries (year {latest_year})")
        return df
    
#DATA PROCESSOR
class DataProcessor:
    def __init__(self, config: Config):
        self.config = config
        self.energy_resources = ["Oil", "Natural Gas", "Coal"]
    
    def pivot_by_metric(self, df: pd.DataFrame, metric: str, suffix: str) -> pd.DataFrame:
        """Pivot resource data by metric (Production/Reserves) to wide format."""
        df_metric = (
            df[df["Metric"] == metric]
            .sort_values("Year", ascending=False)
            .groupby(["Country", "Resource"])
            .first()
            .reset_index()
            .pivot_table(index="Country", columns="Resource", values="Value", aggfunc="first")
            .reset_index()
        )
        
        # Rename columns with suffix
        cols = {c: f"{c}_{suffix}" for c in df_metric.columns if c != "Country"}
        return df_metric.rename(columns=cols)
    
    def merge_with_gdp(self, df: pd.DataFrame, df_gdp: pd.DataFrame) -> pd.DataFrame:
        #Map country codes and merge with GDP data.
        df = df.copy()
        df["Country Code"] = df["Country"].map(COUNTRY_CODE_MAP)
        mapped = df["Country Code"].notna().sum()
        unmapped = df[df["Country Code"].isna()]["Country"].unique()
        logger.info(f"Country mapping: {mapped} mapped, {len(unmapped)} unmapped")
        if len(unmapped) > 0:
            logger.info(f"  Unmapped: {list(unmapped)[:5]}...")
        
        df = df.merge(df_gdp[["Country Code", "GDP"]], on="Country Code", how="left")
        n_before = len(df)
        df = df.query("GDP.notna() & GDP > 0")
        logger.info(f"Countries with GDP: {len(df)} (dropped {n_before - len(df)})")
        
        return df
    
    def calculate_global_shares(self, df: pd.DataFrame) -> tuple[pd.DataFrame, list[str], list[str]]:

        #Calculate each country's share of global production and reserves
        df = df.copy()
        prod_cols = [c for c in df.columns if c.endswith("_Prod")]
        res_cols = [c for c in df.columns if c.endswith("_Res")]
        
        prod_share_cols, res_share_cols = [], []
        for cols, share_list, label in [
            (prod_cols, prod_share_cols, "Production"),
            (res_cols, res_share_cols, "Reserves")
        ]:
            for col in cols:
                total = df[col].sum()
                if total > 0:
                    share_col = f"{col}_Share"
                    df[share_col] = (df[col] / total) * 100
                    share_list.append(share_col)
        
        logger.info(f"Calculated shares: {len(prod_share_cols)} production, {len(res_share_cols)} reserves")
        return df, prod_share_cols, res_share_cols
    
    def flag_dominant_countries(
        self, df: pd.DataFrame, res_share_cols: list[str]
    ) -> pd.DataFrame:
        #Flag countries with dominant reserve positions (>threshold%).
        df = df.copy().reset_index(drop=True)
        threshold = self.config.dominance_threshold
        
        dominant_reserves = []
        dominant_resources = []
        max_concentrations = []
        dominant_counts = []
        is_dominant = []
        
        for _, row in df.iterrows():
            dominated = []
            resources = []
            max_share = 0
            
            for col in res_share_cols:
                share = row.get(col, 0)
                if pd.notna(share):
                    max_share = max(max_share, share)
                    if share >= threshold:
                        resource = col.replace("_Res_Share", "")
                        dominated.append(f"{resource}: {share:.1f}%")
                        resources.append(resource)
            
            dominant_reserves.append(dominated)
            dominant_resources.append(resources)
            max_concentrations.append(max_share)
            dominant_counts.append(len(dominated))
            is_dominant.append(len(dominated) > 0)
        
        # Assign columns directly to avoid concat alignment issues
        df["Dominant_Reserves"] = dominant_reserves
        df["Dominant_Resources"] = dominant_resources
        df["Max_Concentration"] = max_concentrations
        df["Dominant_Count"] = dominant_counts
        df["Is_Dominant"] = is_dominant
        
        n_dominant = sum(is_dominant)
        logger.info(f"High concentration (>{threshold}% reserves): {n_dominant} countries")
        for _, row in df[df["Is_Dominant"]].iterrows():
            logger.info(f"  ðŸ”´ {row['Country']}: {', '.join(row['Dominant_Reserves'][:3])}")
        
        return df
    
    def create_clustering_features(self, df: pd.DataFrame) -> tuple[pd.DataFrame, list[str]]:
        #Create production/GDP ratio features for clustering
        df = df.copy()
        
        #Base production columns for clustering
        base_prod = [c for c in ["Oil_Prod", "Natural Gas_Prod", "Coal_Prod"] if c in df.columns]
        
        #Aggregate mineral production
        mineral_cols = [
            c for c in df.columns
            if c.endswith("_Prod") and not any(e in c for e in self.energy_resources)
        ]
        if mineral_cols:
            df["Minerals_Prod"] = df[mineral_cols].sum(axis=1, skipna=True)
            base_prod.append("Minerals_Prod")
        
        #Fill missing/zero production with small value to avoid div-by-zero
        for col in base_prod:
            df[col] = df[col].fillna(0.001).replace(0, 0.001)
        
        #Create production/GDP features (GDP in billions)
        clustering_features = []
        for col in base_prod:
            feat = f"{col}_perGDP"
            df[feat] = df[col] / (df["GDP"] / 1e9)
            clustering_features.append(feat)
        
        logger.info(f"Clustering features: {clustering_features}")
        return df, clustering_features

# CLUSTERING ENGINE
class ClusteringEngine:
    #performs K-means clustering
    
    def __init__(self, config: Config):
        self.config = config
        self.scaler = StandardScaler()
        self.kmeans = None
        self.pca = None
        self.loadings = None
    
    def fit(self, df: pd.DataFrame, features: list[str]) -> pd.DataFrame:
        #Fit clustering model and return dataframe with cluster labels
        df = df.copy()
        
        # Prepare data
        df["n_valid"] = df[features].notna().sum(axis=1)
        df = df.query("n_valid >= 1")
        df[features] = df[features].fillna(0)
        X = df[features].values
        X_log = np.log1p(X)
        X_scaled = self.scaler.fit_transform(X_log)
        n_clusters = self.config.n_clusters
        
        # Fit model
        self.kmeans = KMeans(
            n_clusters=n_clusters,
            random_state=self.config.random_state,
            n_init=50
        )
        df["cluster"] = self.kmeans.fit_predict(X_scaled)
        logger.info(f"Clustering complete: k={n_clusters}")
        logger.info(f"Distribution: {df['cluster'].value_counts().sort_index().to_dict()}")
        
        #PCA for visualization
        self.pca = PCA(n_components=2)
        pca_result = self.pca.fit_transform(X_scaled)
        df["PC1"], df["PC2"] = pca_result[:, 0], pca_result[:, 1]
        self.loadings = pd.DataFrame(
            self.pca.components_.T,
            columns=["PC1", "PC2"],
            index=features
        )
        
        return df

# VISUALIZER
class Visualizer:
    #Creates all visualizations for the clustering analysis
    
    def __init__(self, config: Config):
        self.config = config
        self.colors = px.colors.qualitative.Bold
    
    def create_hover_text(self, row: pd.Series, features: list[str], share_cols: list[str]) -> str:
        #Generate rich hover text for choropleth map.
        threshold = self.config.dominance_threshold
        min_share = self.config.min_share_display
        
        lines = [f"<b>{row['Country']}</b>"]
        if row.get("Is_Dominant"):
            lines.append(f"ðŸ”´ HIGH RESERVES CONCENTRATION (>{threshold}%)")
            for dom in row.get("Dominant_Reserves", []):
                lines.append(f"   â€¢ {dom}")
        
        lines.extend([
            f"Cluster: {int(row['cluster'])}",
            f"GDP: ${row['GDP']/1e9:,.0f}B",
            "",
            "<b>Production / GDP:</b>"
        ])
        
        for f in features:
            val = row.get(f, 0)
            if pd.notna(val) and val > 0:
                name = f.replace("_Prod_perGDP", "")
                lines.append(f"  {name}: {val:.2f}")
        
        lines.extend(["", "<b>Global Shares:</b>"])
        
        #Production shares (context only)
        for col in sorted(c for c in share_cols if "_Prod_Share" in c):
            val = row.get(col, 0)
            if pd.notna(val) and val > min_share:
                resource = col.replace("_Prod_Share", "")
                lines.append(f"  ðŸ“ˆ {resource} Prod: {val:.1f}%")
        
        #Reserve shares (determines dominance)
        for col in sorted(c for c in share_cols if "_Res_Share" in c):
            val = row.get(col, 0)
            if pd.notna(val) and val > min_share:
                resource = col.replace("_Res_Share", "")
                if val >= threshold:
                    lines.append(f"ðŸ”´ <b>ðŸª¨ {resource} Res: {val:.1f}%</b>")
                else:
                    lines.append(f"  ðŸª¨ {resource} Res: {val:.1f}%")
        
        return "<br>".join(lines)
    
    def choropleth_map(
        self,
        df: pd.DataFrame,
        features: list[str],
        share_cols: list[str]
    ) -> go.Figure:
        """Create choropleth map with cluster colors and dominance highlighting."""
        df_map = df[df["Country Code"].notna()].copy()
        n_clusters = df_map["cluster"].nunique()
        threshold = self.config.dominance_threshold
        
        #Generate hover text
        df_map["hover_text"] = df_map.apply(
            lambda row: self.create_hover_text(row, features, share_cols), axis=1
        )
        
        fig = go.Figure()
        
        #Add non-dominant countries (white border)
        for cid in range(n_clusters):
            subset = df_map[(df_map["cluster"] == cid) & (~df_map["Is_Dominant"])]
            if len(subset) > 0:
                fig.add_trace(go.Choropleth(
                    locations=subset["Country Code"],
                    z=[cid] * len(subset),
                    colorscale=[[0, self.colors[cid]], [1, self.colors[cid]]],
                    showscale=False,
                    customdata=subset["hover_text"].values,
                    hovertemplate="%{customdata}<extra></extra>",
                    name=f"Cluster {cid}",
                    marker=dict(line=dict(color="white", width=0.5))
                ))
        
        #Add dominant countries (red border)
        for cid in range(n_clusters):
            subset = df_map[(df_map["cluster"] == cid) & (df_map["Is_Dominant"])]
            if len(subset) > 0:
                fig.add_trace(go.Choropleth(
                    locations=subset["Country Code"],
                    z=[cid] * len(subset),
                    colorscale=[[0, self.colors[cid]], [1, self.colors[cid]]],
                    showscale=False,
                    customdata=subset["hover_text"].values,
                    hovertemplate="%{customdata}<extra></extra>",
                    name=f"Cluster {cid} ðŸ”´ >{threshold}% Reserves",
                    marker=dict(line=dict(color="red", width=4))
                ))
        
        fig.update_geos(
            projection_type="winkel tripel",
            showcountries=True, countrycolor="lightgray",
            showcoastlines=True, coastlinecolor="darkgray",
            showocean=True, oceancolor="aliceblue",
            showland=True, landcolor="whitesmoke"
        )
        
        fig.update_layout(
            title=dict(
                text=(f"Resource Production Clustering (k={n_clusters})<br>"
                      f"<sup>ðŸ”´ Red Border = >{threshold}% of global RESERVES</sup>"),
                x=0.5, font=dict(size=18)
            ),
            width=1200, height=700,
            legend=dict(orientation="h", yanchor="bottom", y=0.01, xanchor="center", x=0.5),
            margin=dict(l=0, r=0, t=80, b=0)
        )
        
        fig.add_annotation(
            x=0.02, y=0.02, xref="paper", yref="paper",
            text=f"ðŸ”´ RED BORDER = >{threshold}% of global RESERVES",
            showarrow=False, font=dict(size=12, color="red"),
            bgcolor="white", bordercolor="red", borderwidth=2, borderpad=4
        )
        
        return fig
    
    def pca_biplot(
        self,
        df: pd.DataFrame,
        loadings: pd.DataFrame,
        pca: PCA
    ) -> go.Figure:
        """Create PCA biplot with cluster colors and loading vectors."""
        df_plot = df.copy()
        df_plot["cluster"] = df_plot["cluster"].astype(str)
        threshold = self.config.dominance_threshold
        
        fig = px.scatter(
            df_plot, x="PC1", y="PC2", color="cluster",
            symbol="Is_Dominant",
            symbol_map={True: "star", False: "circle"},
            hover_name="Country",
            color_discrete_sequence=self.colors,
        )
        
        #Style markers
        fig.update_traces(marker=dict(size=10), selector=dict(marker_symbol="circle"))
        fig.update_traces(
            marker=dict(size=18, line=dict(width=2, color="darkred")),
            selector=dict(marker_symbol="star")
        )
        
        #Add loading vectors
        scale = 3
        for feat in loadings.index:
            fig.add_annotation(
                ax=0, ay=0,
                x=loadings.loc[feat, "PC1"] * scale,
                y=loadings.loc[feat, "PC2"] * scale,
                showarrow=True, arrowhead=2, arrowwidth=2,
                text=feat.replace("_Prod_perGDP", "/GDP"),
                font=dict(size=10)
            )
        
        fig.update_layout(
            title=f"PCA Biplot: Production/GDP Clusters (â˜… = >{threshold}% Global Reserves)",
            xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)",
            yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)",
            height=700, width=900
        )
        
        return fig
    
    def cluster_heatmap(self, df: pd.DataFrame, features: list[str]) -> go.Figure:
        """Create heatmap of cluster centroids (z-scored)."""
        centroids = df.groupby("cluster")[features].mean()
        centroid_z = pd.DataFrame(
            StandardScaler().fit_transform(centroids),
            index=[f"Cluster {i}" for i in centroids.index],
            columns=[f.replace("_Prod_perGDP", "/GDP") for f in features]
        )
        
        fig = go.Figure(data=go.Heatmap(
            z=centroid_z.values,
            x=centroid_z.columns,
            y=centroid_z.index,
            colorscale="RdBu_r",
            zmid=0,
            text=np.round(centroid_z.values, 2),
            texttemplate="%{text}",
            textfont=dict(size=12),
            hovertemplate="Cluster: %{y}<br>Feature: %{x}<br>Z-score: %{z:.2f}<extra></extra>"
        ))
        
        fig.update_layout(
            title="Cluster Profiles by Production/GDP (Z-Scores)",
            xaxis_title="Feature",
            yaxis_title="Cluster",
            height=400, width=800
        )
        
        return fig

# MAIN ANALYSIS 
class ResourceClusteringAnalysis:
    """Main analysis pipeline coordinating all components."""
    
    def __init__(self, config: Config):
        self.config = config
        self.loader = DataLoader(config)
        self.processor = DataProcessor(config)
        self.clusterer = ClusteringEngine(config)
        self.visualizer = Visualizer(config)
        
        #Results (populated after run())
        self.df_final: pd.DataFrame = None
        self.clustering_features: list[str] = None
        self.prod_share_cols: list[str] = None
        self.res_share_cols: list[str] = None
    
    def run(self) -> pd.DataFrame:
        """Execute the full analysis pipeline."""
        logger.info("=" * 60)
        logger.info("NATURAL RESOURCES CLUSTERING ANALYSIS")
        logger.info("=" * 60)
        
        #Load data
        logger.info("\n[1] LOADING DATA")
        df_resources = self.loader.load_resources()
        df_gdp = self.loader.load_gdp()
        
        #Process data
        logger.info("\n[2] PROCESSING DATA")
        df_prod = self.processor.pivot_by_metric(df_resources, "Production", "Prod")
        df_res = self.processor.pivot_by_metric(df_resources, "Reserves", "Res")
        
        df = df_prod.merge(df_res, on="Country", how="outer")
        df = self.processor.merge_with_gdp(df, df_gdp)
        
        #Calculate shares and flag dominance
        logger.info("\n[3] CALCULATING GLOBAL SHARES")
        df, self.prod_share_cols, self.res_share_cols = self.processor.calculate_global_shares(df)
        df = self.processor.flag_dominant_countries(df, self.res_share_cols)
        
        #Create clustering features
        logger.info("\n[4] CREATING CLUSTERING FEATURES")
        df, self.clustering_features = self.processor.create_clustering_features(df)
        
        #Perform clustering
        logger.info("\n[5] CLUSTERING")
        self.df_final = self.clusterer.fit(df, self.clustering_features)
        
        #Print cluster summary
        self._print_cluster_summary()
        
        return self.df_final
    
    def _print_cluster_summary(self):
        #Print summary of cluster results.
        logger.info("\n" + "=" * 60)
        logger.info("CLUSTER SUMMARY")
        logger.info("=" * 60)
        
        n_clusters = self.df_final["cluster"].nunique()
        for c in range(n_clusters):
            mask = self.df_final["cluster"] == c
            countries = self.df_final[mask]["Country"].tolist()
            dominant = self.df_final[mask & self.df_final["Is_Dominant"]]["Country"].tolist()
            
            logger.info(f"\nCluster {c} ({len(countries)} countries):")
            logger.info(f"  {', '.join(sorted(countries)[:10])}{'...' if len(countries) > 10 else ''}")
            
            if dominant:
                logger.info(f"  ðŸ”´ High reserves: {', '.join(dominant)}")
            
            means = self.df_final[mask][self.clustering_features].mean()
            prod_str = ", ".join(
                f"{f.replace('_Prod_perGDP', '')}: {means[f]:.2f}"
                for f in self.clustering_features if means[f] > 0.01
            )
            logger.info(f"  Avg prod/GDP: {prod_str}")
    
    def visualize(self) -> dict[str, go.Figure]:
        #Generate all visualizations.
        if self.df_final is None:
            raise RuntimeError("Run analysis first with .run()")
        
        all_share_cols = self.prod_share_cols + self.res_share_cols
        
        figures = {
            "choropleth": self.visualizer.choropleth_map(
                self.df_final, self.clustering_features, all_share_cols
            ),
            "biplot": self.visualizer.pca_biplot(
                self.df_final, self.clusterer.loadings, self.clusterer.pca
            ),
            "heatmap": self.visualizer.cluster_heatmap(
                self.df_final, self.clustering_features
            ),
        }
        
        return figures
    
    def show_figures(self):
        #Display all figures.
        figures = self.visualize()
        for fig in figures.values():
            fig.show()
    
    def export_results(self, output_path: str):
    #Export final dataframe to CSV."""
        if self.df_final is None:
            raise RuntimeError("Run analysis first with .run()")
        
        cols_to_export = [
            "Country", "Country Code", "GDP", "cluster",
            "Is_Dominant", "Dominant_Count", "Max_Concentration",
            "PC1", "PC2"
        ] + self.clustering_features
        
        self.df_final[cols_to_export].to_csv(output_path, index=False)
        logger.info(f"Exported results to {output_path}")
# RUN
def main():
    """Main entry point for the analysis."""
    # Configuration - update paths as needed
    config = Config(
        resources_path="/Users/leoss/Desktop/natural_resources_combined.csv",
        gdp_path="/Users/leoss/Desktop/Capstone/Code/Data/comprehensive_dataset.csv",
        n_clusters=5,  # Use 5-6 clusters
        dominance_threshold=15.0,
    )
    
    # Run analysis
    analysis = ResourceClusteringAnalysis(config)
    df_results = analysis.run()
    # Show visualizations
    analysis.show_figures()
    # Export results (optional)
    # analysis.export_results("clustering_results.csv")
    return analysis
if __name__ == "__main__":
    analysis = main()

NATURAL RESOURCES CLUSTERING ANALYSIS

[1] LOADING DATA
Loaded 27,305 resource records for 110 countries
Loaded GDP for 142 countries (year 2021.0)

[2] PROCESSING DATA
Country mapping: 89 mapped, 0 unmapped
Countries with GDP: 82 (dropped 7)

[3] CALCULATING GLOBAL SHARES
Calculated shares: 16 production, 16 reserves
High concentration (>15.0% reserves): 16 countries
  ðŸ”´ Argentina: Lithium: 15.4%
  ðŸ”´ Australia: Bauxite: 16.8%, Cobalt: 18.8%, Copper: 17.3%
  ðŸ”´ Brazil: Natural Graphite: 22.1%, Rare Earth: 25.2%
  ðŸ”´ Burma: Tin: 21.5%
  ðŸ”´ Chile: Copper: 32.8%, Lithium: 35.7%
  ðŸ”´ China: Aluminium: 73.5%, Manganese: 19.3%, Natural Graphite: 24.2%
  ðŸ”´ DR Congo: Cobalt: 66.3%
  ðŸ”´ Guinea: Bauxite: 35.5%
  ðŸ”´ Indonesia: Nickel: 55.7%
  ðŸ”´ Iran: Natural Gas: 17.9%
  ðŸ”´ Peru: Copper: 17.3%
  ðŸ”´ Russia: Coal: 15.4%, Natural Gas: 20.8%, Platinum Group: 94.7%
  ðŸ”´ Saudi Arabia: Oil: 21.0%
  ðŸ”´ South Africa: Manganese: 38.7%
  ðŸ”´ Turkey: Natural Graphite: 20.6%
 