In [1]:
# 1. Setup -- preview
# Import necessary libraries and define the base path for the data files.
import pandas as pd
import geopandas as gpd
import os

In [2]:
# Define the base path to your raw data directory
# IMPORTANT: Make sure this path is correct for your system.
base_path = '/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw'

In [3]:
# 2. Load GDP Data
# Load the world GDP data from api_ny_gdp_mktp_cd_ds2.csv.
print("--- Loading GDP Data ---")
gdp_file_path = os.path.join(base_path, 'api_ny_gdp_mktp_cd_ds2.csv')

try:
    # We skip the first 4 rows which contain metadata, not the actual data.
    df_gdp = pd.read_csv(gdp_file_path, skiprows=4)
    print("Successfully loaded GDP data.")
    print("First 5 rows of GDP data:")
    print(df_gdp.head())
except FileNotFoundError:
    print(f"Error: GDP file not found at {gdp_file_path}")
except Exception as e:
    print(f"An error occurred: {e}")

print("\n" + "="*50 + "\n")

--- Loading GDP Data ---
Successfully loaded GDP data.
First 5 rows of GDP data:
                  Country Name Country Code     Indicator Name  \
0                        Aruba          ABW  GDP (current US$)   
1  Africa Eastern and Southern          AFE  GDP (current US$)   
2                  Afghanistan          AFG  GDP (current US$)   
3   Africa Western and Central          AFW  GDP (current US$)   
4                       Angola          AGO  GDP (current US$)   

   Indicator Code          1960          1961          1962          1963  \
0  NY.GDP.MKTP.CD           NaN           NaN           NaN           NaN   
1  NY.GDP.MKTP.CD  2.420993e+10  2.496326e+10  2.707802e+10  3.177483e+10   
2  NY.GDP.MKTP.CD           NaN           NaN           NaN           NaN   
3  NY.GDP.MKTP.CD  1.190511e+10  1.270803e+10  1.363092e+10  1.446926e+10   
4  NY.GDP.MKTP.CD           NaN           NaN           NaN           NaN   

           1964          1965  ...          2016          2

In [4]:
# 3. Load Countries Shapefile
# Load the country boundaries from the shapefile.
print("--- Loading Countries Shapefile ---")
shapefile_path = os.path.join(base_path, 'countries_shapefile', 'cn_primary_countries.shp')

try:
    gdf_countries = gpd.read_file(shapefile_path)
    print("Successfully loaded countries shapefile.")
    print("First 5 rows of the shapefile data:")
    print(gdf_countries.head())
except Exception as e:
    # Using a general exception as geopandas can have various backend errors
    print(f"Error loading shapefile: {e}")

print("\n" + "="*50 + "\n")

--- Loading Countries Shapefile ---
Successfully loaded countries shapefile.
First 5 rows of the shapefile data:
  ADM0_A3  abbrev continent                    formal_nam iso_a2 iso_a3  \
0     AFG    Afg.      Asia  Islamic State of Afghanistan     AF    AFG   
1     AGO    Ang.    Africa   People's Republic of Angola     AO    AGO   
2     ALB    Alb.    Europe           Republic of Albania     AL    ALB   
3     AND    And.    Europe       Principality of Andorra     AD    AND   
4     ARE  U.A.E.      Asia          United Arab Emirates     AE    ARE   

   iso_n3                   iso_short                  name  \
0       4                 Afghanistan           Afghanistan   
1      24                      Angola                Angola   
2       8                     Albania               Albania   
3      20                     Andorra               Andorra   
4     784  United Arab Emirates (the)  United Arab Emirates   

              name_sort  ...                      un_fr  

In [5]:
# 4. Load Trade Data
# Load the various trade-related datasets from the `trade_data` directory.
print("--- Loading Trade Data ---")
trade_data_path = os.path.join(base_path, 'trade_data')
trade_files = [
    'baci_hs12_y2016_v202001.csv',
    'baci_hs12_y2017_v202001.csv',
    'baci_hs12_y2018_v202001.csv',
    'country_codes_v202001.csv',
    'product_codes_hs12_v202001.csv'
]

# A dictionary to hold all the loaded trade dataframes
trade_dataframes = {}

for file in trade_files:
    file_path = os.path.join(trade_data_path, file)
    # Use a clean name for the dictionary key (e.g., 'baci_hs12_y2016_v202001')
    df_name = file.split('.')[0]
    try:
        trade_dataframes[df_name] = pd.read_csv(file_path)
        print(f"- Successfully loaded {file}")
    except FileNotFoundError:
        print(f"- Error: File not found at {file_path}")
    except Exception as e:
        print(f"- An error occurred while loading {file}: {e}")

# You can now access each dataframe from the dictionary, for example:
if 'country_codes_v202001' in trade_dataframes:
    print("\nExample: First 5 rows of country_codes_v202001.csv:")
    print(trade_dataframes['country_codes_v202001'].head())

print("\n" + "="*50 + "\n")

--- Loading Trade Data ---
- An error occurred while loading baci_hs12_y2016_v202001.csv: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.
- Successfully loaded baci_hs12_y2017_v202001.csv
- Successfully loaded baci_hs12_y2018_v202001.csv
- An error occurred while loading country_codes_v202001.csv: 'utf-8' codec can't decode byte 0xf4 in position 4141: invalid continuation byte
- Successfully loaded product_codes_hs12_v202001.csv




In [None]:
# 5. Optional: Combine yearly trade data
# If the yearly BACI trade files have the same columns, we can combine them.
print("--- Combining yearly BACI trade data ---")
baci_trade_dfs_to_combine = []
for key, df in trade_dataframes.items():
    if 'baci_hs12_y' in key:
        baci_trade_dfs_to_combine.append(df)

if baci_trade_dfs_to_combine:
    df_baci_combined = pd.concat(baci_trade_dfs_to_combine, ignore_index=True)
    print("Successfully combined yearly BACI trade data into a single DataFrame.")
    print(f"Total rows in combined data: {len(df_baci_combined)}")
    print("\nFirst 5 rows of combined BACI data:")
    print(df_baci_combined.head())
    print("\nLast 5 rows of combined BACI data:")
    print(df_baci_combined.tail())
else:
    print("No BACI dataframes were found to combine.")

In [None]:
import pandas as pd
import os

# --- Q1a. Setup and Data Loading ---
# This version adds the encoding='latin1' parameter to handle the file encoding error.

# Define the base path to your raw data directory
base_path = '/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw'
trade_data_path = os.path.join(base_path, 'trade_data')

# A dictionary to hold the loaded dataframes
dataframes = {}

# List of files to load
files_to_load = [
    'baci_hs12_y2016_v202001.csv',
    'baci_hs12_y2017_v202001.csv',
    'baci_hs12_y2018_v202001.csv',
    'country_codes_v202001.csv'
]

print("Loading data files...")
for file in files_to_load:
    file_path = os.path.join(trade_data_path, file)
    df_name = file.split('.')[0]
    try:
        # THE FIX IS HERE: Added encoding='latin1' to handle non-UTF-8 characters
        dataframes[df_name] = pd.read_csv(file_path, on_bad_lines='skip', encoding='latin1')
        print(f"- Successfully loaded {file}")
    except FileNotFoundError:
        print(f"- Error: File not found at {file_path}")
        exit()
    except Exception as e:
        print(f"An unexpected error occurred while loading {file}: {e}")
        exit()


# Combine the three years of trade data into one DataFrame
baci_dfs = [df for name, df in dataframes.items() if 'baci_hs12_y' in name]
if not baci_dfs:
    print("Error: No BACI trade data files were loaded. Cannot proceed.")
    exit()

df_trade_combined = pd.concat(baci_dfs, ignore_index=True)
print("\nCombined 2016-2018 trade data successfully.\n")

# Load country codes for mapping codes to names
df_country_codes = dataframes.get('country_codes_v202001')
if df_country_codes is None:
    print("Error: country_codes_v202001.csv could not be loaded. Cannot map country names.")
    exit()

# --- 2. Calculate Trading Partners for Each Country ---
# The logic here remains the same as before.

# Get a list of all unique country codes present in the trade data
all_country_codes = pd.unique(df_trade_combined[['i', 'j']].values.ravel('K'))

partner_counts = {}

for code in all_country_codes:
    # Find all countries this country exported to
    exports_to = set(df_trade_combined[df_trade_combined['i'] == code]['j'])
    
    # Find all countries this country imported from
    imports_from = set(df_trade_combined[df_trade_combined['j'] == code]['i'])
    
    # The set of unique partners is the union of the two sets
    unique_partners = exports_to.union(imports_from)
    
    # Count the number of partners and store it
    partner_counts[code] = len(unique_partners)

# Convert the dictionary to a pandas DataFrame for easier sorting and merging
df_partner_counts = pd.DataFrame(list(partner_counts.items()), columns=['country_code', 'partner_count'])

# Merge with the country names for a readable output
if 'country_name_full' in df_country_codes.columns and 'country_code' in df_country_codes.columns:
    df_results = pd.merge(df_partner_counts, df_country_codes[['country_code', 'country_name_full']], on='country_code', how='left')
    df_results = df_results.sort_values(by='partner_count', ascending=False).reset_index(drop=True)
    df_results = df_results.rename(columns={'country_name_full': 'country_name'})
else:
    print("Country code file does not have expected columns 'country_code' and 'country_name_full'. Cannot display full names.")
    df_results = df_partner_counts.sort_values(by='partner_count', ascending=False).reset_index(drop=True)


# --- 3. Display Descriptive Statistics ---

print("\n--- Descriptive Statistics: Trading Partners (2016-2018) ---\n")

# Top 10 Countries with the Most Trading Partners
print("Top 10 Countries with the Most Trading Partners:")
# Fill potential missing names with the code for robustness
df_results['country_name'] = df_results['country_name'].fillna('Unknown')
top_10 = df_results.head(10)
print(top_10[['country_name', 'partner_count']].to_string(index=False))

print("\n------------------------------------------------------\n")

# Bottom 10 Countries with the Fewest Trading Partners
print("Bottom 10 Countries with the Fewest Trading Partners:")
bottom_10 = df_results[df_results['partner_count'] > 0].tail(10)
print(bottom_10[['country_name', 'partner_count']].to_string(index=False))

In [None]:
import pandas as pd
import os

# --- Q1b. Setup and Data Loading (Including Product Codes) ---

# Define the base path to your raw data directory
base_path = '/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw'
trade_data_path = os.path.join(base_path, 'trade_data')

# A dictionary to hold the loaded dataframes
dataframes = {}

# List of all necessary files
files_to_load = [
    'baci_hs12_y2016_v202001.csv',
    'baci_hs12_y2017_v202001.csv',
    'baci_hs12_y2018_v202001.csv',
    'country_codes_v202001.csv',
    'product_codes_hs12_v202001.csv'
]

print("Loading data files...")
for file in files_to_load:
    file_path = os.path.join(trade_data_path, file)
    df_name = file.split('.')[0]
    try:
        # Use encoding='latin1' to prevent UnicodeDecodeError
        dataframes[df_name] = pd.read_csv(file_path, on_bad_lines='skip', encoding='latin1')
        print(f"- Successfully loaded {file}")
    except FileNotFoundError:
        print(f"- Error: File not found at {file_path}")
        exit()

# Combine the yearly trade data
baci_dfs = [df for name, df in dataframes.items() if 'baci_hs12_y' in name]
if not baci_dfs:
    print("Error: No BACI trade data files were loaded. Cannot proceed.")
    exit()
df_trade_combined = pd.concat(baci_dfs, ignore_index=True)
print("\nCombined 2016-2018 trade data successfully.")

# Prepare code-to-name mapping tables
df_country_codes = dataframes.get('country_codes_v202001')
df_product_codes = dataframes.get('product_codes_hs12_v202001')

if df_country_codes is None or df_product_codes is None:
    print("Error: Code description files could not be loaded.")
    exit()

# --- 2. Describe Overall Trade Volume ---

print("\n--- Overall Trade Volume Description (2016-2018) ---")
# The trade value 'v' is in thousands of US dollars.
total_trade_value = df_trade_combined['v'].sum()
print(f"Total Global Trade Value (sum of all flows): ${total_trade_value:,.0f} thousand USD")

print("\nDescriptive Statistics for Trade Value ('v') per flow:")
# Using describe() to get a statistical summary
description = df_trade_combined['v'].describe()
# Format the output for readability
for idx, val in description.items():
    print(f"{idx.capitalize():>8}: {val:,.2f}")


# --- 3. Identify Top 10 Partners for China and the USA ---

print("\n\n--- Top 10 Trading Partners by Total Value (2016-2018) ---")

def get_top_10_partners(country_name_str, country_codes_df, trade_df):
    """Calculates and prints the top 10 trading partners for a given country."""
    try:
        country_code = int(country_codes_df[country_codes_df['country_name_full'].str.contains(country_name_str, na=False)].iloc[0]['country_code'])
    except (IndexError, TypeError):
        print(f"Could not find country code for '{country_name_str}'.")
        return

    # Filter for all trade involving the country (as exporter 'i' or importer 'j')
    country_trade_df = trade_df[(trade_df['i'] == country_code) | (trade_df['j'] == country_code)].copy()
    
    # Determine the partner code for each transaction
    country_trade_df['partner_code'] = country_trade_df.apply(
        lambda row: row['j'] if row['i'] == country_code else row['i'],
        axis=1
    )
    
    # Group by partner and sum the trade value
    partner_trade = country_trade_df.groupby('partner_code')['v'].sum().reset_index()
    
    # Merge with country names to get partner names
    partner_trade = pd.merge(partner_trade, country_codes_df[['country_code', 'country_name_full']], left_on='partner_code', right_on='country_code', how='left')
    
    # Sort to find the top partners
    top_10 = partner_trade.sort_values(by='v', ascending=False).head(10)
    
    print(f"\nTop 10 Partners for {country_name_str}:")
    top_10['v_formatted'] = top_10['v'].apply(lambda x: f"${x:,.0f}K")
    print(top_10[['country_name_full', 'v_formatted']].rename(columns={'country_name_full': 'Partner Country', 'v_formatted': 'Total Trade Value'}).to_string(index=False))

# Run the function for China and the United States
get_top_10_partners("China", df_country_codes, df_trade_combined)
get_top_10_partners("United States of America", df_country_codes, df_trade_combined)


# --- 4. List the Five Highest-Value China Trade Flows ---

print("\n\n--- Five Highest-Value China Trade Flows (2016-2018) ---")

try:
    china_code = int(df_country_codes[df_country_codes['country_name_full'].str.contains("China", na=False)].iloc[0]['country_code'])

    # Filter for all trade involving China
    china_trade_df = df_trade_combined[(df_trade_combined['i'] == china_code) | (df_trade_combined['j'] == china_code)]

    # Group by the specific flow (exporter and importer) and sum the value
    flow_values = china_trade_df.groupby(['i', 'j'])['v'].sum().reset_index()
    
    # Sort to find the highest-value flows
    top_flows = flow_values.sort_values(by='v', ascending=False).head(5)

    # Merge to get exporter names
    top_flows = pd.merge(top_flows, df_country_codes[['country_code', 'country_name_full']], left_on='i', right_on='country_code', how='left')
    top_flows = top_flows.rename(columns={'country_name_full': 'Exporter'})
    
    # Merge to get importer names
    top_flows = pd.merge(top_flows, df_country_codes[['country_code', 'country_name_full']], left_on='j', right_on='country_code', how='left')
    top_flows = top_flows.rename(columns={'country_name_full': 'Importer'})
    
    # Format for printing
    top_flows['Total Value'] = top_flows['v'].apply(lambda x: f"${x:,.0f}K")

    print("Top 5 individual trade flows involving China (sum of all products and years):")
    print(top_flows[['Exporter', 'Importer', 'Total Value']].to_string(index=False))

except (IndexError, TypeError):
    print("Could not find country code for 'China' to analyze trade flows.")

In [None]:
import pandas as pd
import os

# --- Q1c. Improve Display Formatting ---
# Set pandas display options to make tables wider and prevent messy wrapping.
pd.set_option('display.width', 150)
pd.set_option('display.max_colwidth', 80)


# --- 1. Setup and Data Loading ---

# Define the base path to your raw data directory
base_path = '/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw'
trade_data_path = os.path.join(base_path, 'trade_data')

# A dictionary to hold the loaded dataframes
dataframes = {}

# List of all necessary files
files_to_load = [
    'baci_hs12_y2016_v202001.csv',
    'baci_hs12_y2017_v202001.csv',
    'baci_hs12_y2018_v202001.csv',
    'country_codes_v202001.csv',
    'product_codes_hs12_v202001.csv'
]

print("Loading data files...")
for file in files_to_load:
    file_path = os.path.join(trade_data_path, file)
    df_name = file.split('.')[0]
    try:
        dataframes[df_name] = pd.read_csv(file_path, on_bad_lines='skip', encoding='latin1')
        print(f"- Successfully loaded {file}")
    except FileNotFoundError:
        print(f"- Error: File not found at {file_path}")
        exit()

# Combine the yearly trade data
baci_dfs = [df for name, df in dataframes.items() if 'baci_hs12_y' in name]
df_trade_combined = pd.concat(baci_dfs, ignore_index=True)
print("\nCombined 2016-2018 trade data successfully.")

# Prepare code-to-name mapping tables
df_country_codes = dataframes.get('country_codes_v202001')
df_product_codes = dataframes.get('product_codes_hs12_v202001')

if df_country_codes is None or df_product_codes is None:
    print("Error: Code description files could not be loaded.")
    exit()

# --- 2. Calculate Top 10 Export Products (with Clean Formatting) ---

print("\n\n--- Top 10 Export Products by Value (2016-2018) ---")

def get_top_10_exports(country_name, country_codes_df, trade_df, product_codes_df):
    """Calculates and prints the top 10 export products with clean formatting."""
    try:
        country_code = int(country_codes_df[country_codes_df['country_name_full'].str.contains(country_name, na=False)].iloc[0]['country_code'])
    except (IndexError, TypeError):
        print(f"\nCould not find country code for '{country_name}'.")
        return

    exports_df = trade_df[trade_df['i'] == country_code]
    top_products = exports_df.groupby('k')['v'].sum().reset_index()
    top_10 = top_products.sort_values(by='v', ascending=False).head(10)
    
    top_10_with_names = pd.merge(top_10, product_codes_df, left_on='k', right_on='code', how='left')
    
    # FORMATTING FIX: Truncate long descriptions for cleaner output
    top_10_with_names['description'] = top_10_with_names['description'].str.slice(0, 75) + '...'
    
    print(f"\nTop 10 Exports for {country_name}:")
    top_10_with_names['v_formatted'] = top_10_with_names['v'].apply(lambda x: f"${x:,.0f}K")
    
    # Rename columns for final print
    final_df = top_10_with_names[['description', 'v_formatted']].rename(
        columns={'description': 'Product Description', 'v_formatted': 'Total Export Value'}
    )
    print(final_df.to_string(index=False))

# Run the analysis for the three countries
get_top_10_exports("China", df_country_codes, df_trade_combined, df_product_codes)
get_top_10_exports("Japan", df_country_codes, df_trade_combined, df_product_codes)
get_top_10_exports("United States of America", df_country_codes, df_trade_combined, df_product_codes)


# --- 3. Calculate Top 10 Globally Traded Goods (with Clean Formatting) ---

print("\n\n--- Top 10 Globally Traded Goods (2016-2018) ---")

global_product_trade = df_trade_combined.groupby('k').agg(
    total_value=('v', 'sum'),
    total_quantity=('q', 'sum')
).reset_index()

# Top 10 by Value
top_10_value = global_product_trade.sort_values(by='total_value', ascending=False).head(10)
top_10_value_named = pd.merge(top_10_value, df_product_codes, left_on='k', right_on='code', how='left')
# FORMATTING FIX
top_10_value_named['description'] = top_10_value_named['description'].str.slice(0, 75) + '...'
top_10_value_named['value_formatted'] = top_10_value_named['total_value'].apply(lambda x: f"${x:,.0f}K")
print("\nTop 10 Goods with Highest Global Trade Volume by VALUE:")
print(top_10_value_named[['description', 'value_formatted']].rename(columns={'description': 'Product Description', 'value_formatted': 'Total Trade Value'}).to_string(index=False))

# Top 10 by Quantity
top_10_quantity = global_product_trade.sort_values(by='total_quantity', ascending=False).head(10)
top_10_quantity_named = pd.merge(top_10_quantity, df_product_codes, left_on='k', right_on='code', how='left')
# FORMATTING FIX
top_10_quantity_named['description'] = top_10_quantity_named['description'].str.slice(0, 75) + '...'
top_10_quantity_named['quantity_formatted'] = top_10_quantity_named['total_quantity'].apply(lambda x: f"{x:,.0f} Metric Tons")
print("\nTop 10 Goods with Highest Global Trade Volume by QUANTITY:")
print(top_10_quantity_named[['description', 'quantity_formatted']].rename(columns={'description': 'Product Description', 'quantity_formatted': 'Total Trade Quantity'}).to_string(index=False))


In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# --- Q1d. Setup and Configuration ---
# Set pandas display options for cleaner output and plot style.
pd.set_option('display.width', 150)
pd.set_option('display.max_colwidth', 80)
sns.set_style("whitegrid")


# --- 1. Load All Necessary Data ---
print("--- Loading Data ---")
base_path = '/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw'
trade_data_path = os.path.join(base_path, 'trade_data')
shapefile_path = os.path.join(base_path, 'countries_shapefile', 'cn_primary_countries.shp')

# Load shapefile
try:
    gdf = gpd.read_file(shapefile_path)
    print("- Successfully loaded countries shapefile.")
except Exception as e:
    print(f"Error loading shapefile: {e}")
    exit()

# Load trade and country code data
dataframes = {}
files_to_load = [
    'baci_hs12_y2016_v202001.csv', 'baci_hs12_y2017_v202001.csv', 'baci_hs12_y2018_v202001.csv',
    'country_codes_v202001.csv'
]
for file in files_to_load:
    file_path = os.path.join(trade_data_path, file)
    try:
        dataframes[file.split('.')[0]] = pd.read_csv(file_path, encoding='latin1')
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        exit()
print("- Successfully loaded all trade and code data.")

# Combine trade data
baci_dfs = [df for name, df in dataframes.items() if 'baci_hs12_y' in name]
df_trade_combined = pd.concat(baci_dfs, ignore_index=True)
df_country_codes = dataframes.get('country_codes_v202001')


# --- 2. Calculate Geographic Distances from China ---
print("\n--- Calculating Geographic Distances ---")

# The correct column for country names is 'name' based on your output.
country_name_column = 'name' 

gdf_proj = gdf.to_crs(epsg=8857) # Reproject for accurate distance in meters
gdf_proj['centroid'] = gdf_proj.geometry.centroid

try:
    china_centroid = gdf_proj[gdf_proj[country_name_column] == 'China'].centroid.iloc[0]
except IndexError:
    print(f"FATAL ERROR: Could not find 'China' in the shapefile's '{country_name_column}' column.")
    exit()

# Calculate distance in kilometers
gdf['distance_to_china_km'] = gdf_proj.centroid.apply(lambda p: china_centroid.distance(p) / 1000)
print(f"- Calculated distances from China's centroid to {len(gdf)} other countries.")

df_distances = gdf[[country_name_column, 'distance_to_china_km']].rename(columns={country_name_column: 'country_name'})


# --- 3. Calculate China's Export Volume ---
print("\n--- Calculating China's Export Volumes ---")

try:
    china_code = int(df_country_codes[df_country_codes['country_name_full'] == 'China'].iloc[0]['country_code'])
except (IndexError, TypeError):
    print("FATAL ERROR: Could not find country code for 'China' in the country codes file.")
    exit()

china_exports = df_trade_combined[df_trade_combined['i'] == china_code]
export_volumes = china_exports.groupby('j').agg(
    total_export_value=('v', 'sum'),
    total_export_quantity=('q', 'sum')
).reset_index()

export_volumes = pd.merge(export_volumes, df_country_codes, left_on='j', right_on='country_code', how='left')
df_exports = export_volumes[['country_name_full', 'total_export_value', 'total_export_quantity']]
df_exports = df_exports.rename(columns={'country_name_full': 'country_name'})
print(f"- Aggregated export data for {len(df_exports)} partner countries.")

# --- 4. Merge Distance and Trade Data ---
print("\n--- Merging Datasets for Plotting ---")
# Before merging, we must handle potential name mismatches between the two datasets
# For example: 'United States of America' (trade file) vs. 'United States' (shapefile)
name_corrections = {
    'United States of America': 'United States',
    'Republic of Korea': 'South Korea',
    'Viet Nam': 'Vietnam'
}
df_exports['country_name'] = df_exports['country_name'].replace(name_corrections)

# Now, merge the distance data with the export data
df_final = pd.merge(df_distances, df_exports, on='country_name', how='inner')
print(f"- Successfully merged data. Found {len(df_final)} countries with both distance and trade data.")

if len(df_final) == 0:
    print("WARNING: The merge resulted in an empty DataFrame. This is caused by country name mismatches.")


# --- 5. Create and Display Scatterplots ---
print("\n--- Generating Scatterplots ---")

if not df_final.empty:
    plot_data = df_final[(df_final['total_export_value'] > 0) & (df_final['total_export_quantity'] > 0)].copy()
    plot_data['log_value'] = np.log(plot_data['total_export_value'])
    plot_data['log_quantity'] = np.log(plot_data['total_export_quantity'])

    # Plot 1: Distance vs. Log Export Value
    plt.figure(figsize=(12, 7))
    sns.regplot(data=plot_data, x='distance_to_china_km', y='log_value',
                scatter_kws={'alpha':0.6, 's':20}, line_kws={'color':'red', 'linestyle':'--'})
    plt.title('Distance vs. Log of Export Value from China (2016-2018)', fontsize=16, pad=15)
    plt.xlabel('Distance from China (km)', fontsize=12)
    plt.ylabel('Natural Log of Total Export Value (in thousands USD)', fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()

    # Plot 2: Distance vs. Log Export Quantity
    plt.figure(figsize=(12, 7))
    sns.regplot(data=plot_data, x='distance_to_china_km', y='log_quantity',
                scatter_kws={'alpha':0.6, 's':20, 'color':'green'}, line_kws={'color':'red', 'linestyle':'--'})
    plt.title('Distance vs. Log of Export Quantity from China (2016-2018)', fontsize=16, pad=15)
    plt.xlabel('Distance from China (km)', fontsize=12)
    plt.ylabel('Natural Log of Total Export Quantity (in Metric Tons)', fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.show()
else:
    print("\nSkipping plot generation because the merged DataFrame is empty.")