In [None]:
import pandas as pd
import os
from glob import glob

# Directory containing the CSV files
input_dir = 'comtrade_monthly_hs4_outputs'
output_dir = 'comtrade_monthly_hs4_outputs/merged'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

def read_csv_with_encoding(file_path):
    """Read CSV with correct encoding and handle missing column names"""
    encodings = ['latin-1', 'iso-8859-1', 'windows-1252']
    
    for encoding in encodings:
        try:
            # First read to get the header
            with open(file_path, 'r', encoding=encoding) as f:
                header_line = f.readline().strip()
                data_line = f.readline().strip()
            
            # Count columns in header vs data
            header_cols = header_line.count(',') + 1
            data_cols = data_line.count(',') + 1
            
            if data_cols > header_cols:
                # Add placeholder names for missing columns
                missing_cols = data_cols - header_cols
                print(f"    Adding {missing_cols} placeholder column names")
                
                # Read the original header names
                df_temp = pd.read_csv(file_path, encoding=encoding, nrows=0)
                original_columns = df_temp.columns.tolist()
                
                # Add placeholder names
                new_columns = original_columns + [f'unnamed_col_{i}' for i in range(missing_cols)]
                
                # Read with explicit column names
                df = pd.read_csv(file_path, encoding=encoding, names=new_columns, skiprows=1)
            else:
                # Read normally
                df = pd.read_csv(file_path, encoding=encoding)
            
            return df
            
        except (UnicodeDecodeError, Exception) as e:
            continue
    
    raise Exception(f"Could not read file with any encoding: {file_path}")

# Define the years and trade types
years = [2021, 2022, 2025]
trade_types = {'M': 'import', 'X': 'export'}

# Process each year and trade type
for year in years:
    for trade_code, trade_name in trade_types.items():
        print(f"\nProcessing {year} {trade_name}...")
        
        # Find all matching files for this year and trade type
        pattern = os.path.join(input_dir, f"USA_{trade_code}_{year}*_HS4.csv")
        files = sorted(glob(pattern))
        
        if not files:
            print(f"  No files found for {year} {trade_name}")
            continue
        
        print(f"  Found {len(files)} files")
        
        # Read and concatenate all files
        dfs = []
        for file in files:
            try:
                df = read_csv_with_encoding(file)
                
                # Check if 'qty' column exists
                if 'qty' not in df.columns:
                    print(f"    ⚠️  'qty' column not found. Available columns: {df.columns.tolist()}")
                    continue
                
                # Filter rows where qty >= 0
                df_filtered = df[df['qty'] >= 0]
                dfs.append(df_filtered)
                print(f"    ✓ {os.path.basename(file)}: {len(df)} rows -> {len(df_filtered)} rows after filtering (qty >= 0)")
                
            except Exception as e:
                print(f"    ✗ Error reading {os.path.basename(file)}: {e}")
        
        # Concatenate all dataframes
        if dfs:
            merged_df = pd.concat(dfs, ignore_index=True)
            
            # Verify columns are correct
            print(f"\n  Merged dataframe info:")
            print(f"    Total rows: {len(merged_df):,}")
            print(f"    First 5 columns: {merged_df.columns[:5].tolist()}")
            print(f"    Sample data:")
            print(merged_df[['typeCode', 'freqCode', 'refPeriodId', 'qty']].head())
            
            # Save merged file
            output_file = os.path.join(output_dir, f"USA_{year}_{trade_name}.csv")
            merged_df.to_csv(output_file, index=False, encoding='utf-8')
            print(f"\n  ✓✓ SAVED: {output_file}")
        else:
            print(f"  No data to merge for {year} {trade_name}")

print("\n" + "=" * 60)
print("All files processed successfully!")
print(f"Output files saved in: {output_dir}")
print("=" * 60)

## CHINA DATA SPLITTING AND REPAIRING

In [None]:
import pandas as pd
import os

# File path
input_file = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\combined_df.csv'
output_dir = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\split'

# Create output directory
os.makedirs(output_dir, exist_ok=True)

def read_csv_with_encoding(file_path):
    """Try different encodings"""
    encodings = ['latin-1', 'iso-8859-1', 'windows-1252', 'utf-8']
    for encoding in encodings:
        try:
            df = pd.read_csv(file_path, encoding=encoding)
            return df
        except:
            continue
    raise Exception("Could not read file with any encoding")

print("=" * 60)
print("Reading and fixing China data...")
print("=" * 60)

# Read the CSV
df = read_csv_with_encoding(input_file)

print(f"\nOriginal data:")
print(f"  Total rows: {len(df):,}")
print(f"  Columns: {len(df.columns)}")
print(f"  First 5 column names: {df.columns[:5].tolist()}")
print(f"\nFirst few rows (before fix):")
print(df.head()[df.columns[:5]])

# Fix the column shift issue
# Get the column names
old_columns = df.columns.tolist()

# Shift column names to the right by 1
# Drop the first column name (typeCode) and add a placeholder at the end
new_columns = old_columns[1:] + ['unnamed_extra_col']

# Apply new column names
df.columns = new_columns

print(f"\nAfter fixing column alignment:")
print(f"  First 5 column names: {df.columns[:5].tolist()}")
print(f"\nFirst few rows (after fix):")
print(df.head()[df.columns[:5]])

# Check for required columns
if 'refYear' not in df.columns or 'flowDesc' not in df.columns:
    print(f"\n⚠️  Warning: Required columns not found!")
    print(f"Available columns: {df.columns.tolist()}")
else:
    # Check unique values
    print(f"\nUnique years: {sorted(df['refYear'].unique())}")
    print(f"Unique flows: {df['flowDesc'].unique()}")
    
    # Check if qty column exists
    if 'qty' in df.columns:
        print(f"\nFiltering by qty > 200...")
        df_filtered = df[df['qty'] > 200].copy()
        print(f"  Rows before filtering: {len(df):,}")
        print(f"  Rows after filtering: {len(df_filtered):,}")
    else:
        print(f"\n⚠️  'qty' column not found. Available columns: {df.columns.tolist()}")
        df_filtered = df.copy()
    
    # Split by year and flow
    print(f"\n" + "=" * 60)
    print("Splitting data by year and trade flow...")
    print("=" * 60)
    
    years = sorted(df_filtered['refYear'].unique())
    flows = df_filtered['flowDesc'].unique()
    
    for year in years:
        for flow in flows:
            # Filter data
            mask = (df_filtered['refYear'] == year) & (df_filtered['flowDesc'] == flow)
            df_subset = df_filtered[mask]
            
            if len(df_subset) > 0:
                # Create filename (china_2021_import.csv)
                flow_name = flow.lower()
                filename = f"china_{year}_{flow_name}.csv"
                output_path = os.path.join(output_dir, filename)
                
                # Save file
                df_subset.to_csv(output_path, index=False, encoding='utf-8')
                print(f"✓ Saved: {filename} ({len(df_subset):,} rows)")
            else:
                print(f"  No data for {year} {flow}")
    
    print(f"\n" + "=" * 60)
    print("All files saved successfully!")
    print(f"Output directory: {output_dir}")
    print("=" * 60)

## RENAMING AND NORMALIZATION 

In [None]:
import pandas as pd
import os
from glob import glob

# Define directories
usa_dir = 'comtrade_monthly_hs4_outputs/merged'
china_dir = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\split'

# Column mapping: old_name -> new_name
column_mapping = {
    'period': 'month_id',
    'flowDesc': 'trade_flow_name',
    'partnerISO': 'country_id',
    'partnerDesc': 'country_name',
    'cmdCode': 'product_id_hs4',
    'primaryValue': 'trade_value',
    'qty': 'quantity',
    'cmdDesc': 'product_name_hs4'
}

# Final column order
final_columns = ['month_id', 'trade_flow_name', 'country_id', 'country_name', 
                 'product_id_hs4', 'trade_value', 'quantity', 'nb_product', 'product_name_hs4']

def read_csv_with_encoding(file_path):
    """Try different encodings"""
    encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'windows-1252']
    for encoding in encodings:
        try:
            df = pd.read_csv(file_path, encoding=encoding)
            return df
        except:
            continue
    raise Exception(f"Could not read file: {file_path}")

def process_file(input_path, output_path):
    """Process a single file: select columns, rename, filter, and save"""
    try:
        # Read file
        df = read_csv_with_encoding(input_path)
        
        # Check if all required columns exist
        missing_cols = [col for col in column_mapping.keys() if col not in df.columns]
        if missing_cols:
            print(f"  ⚠️  Missing columns: {missing_cols}")
            print(f"      Available columns: {df.columns.tolist()}")
            return False
        
        # Select only the columns we need
        df_selected = df[list(column_mapping.keys())].copy()
        
        # Rename columns
        df_renamed = df_selected.rename(columns=column_mapping)
        
        # Convert product_id_hs4 to string
        df_renamed['product_id_hs4'] = df_renamed['product_id_hs4'].astype(str)
        
        # Calculate nb_product: count distinct products per month_id and country_id
        print(f"    Calculating nb_product (distinct products per month/country)...")
        df_renamed['nb_product'] = df_renamed.groupby(['month_id', 'country_id'])['product_id_hs4'].transform('nunique')
        
        # Filter: keep only rows where nb_product > 200
        rows_before = len(df_renamed)
        df_filtered = df_renamed[df_renamed['nb_product'] > 200].copy()
        rows_after = len(df_filtered)
        print(f"    Filtered nb_product > 200: {rows_before:,} rows → {rows_after:,} rows")
        
        # Reorder columns
        df_final = df_filtered[final_columns]
        
        # Save
        df_final.to_csv(output_path, index=False, encoding='utf-8')
        
        return True
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()
        return False

print("=" * 60)
print("Processing USA files...")
print("=" * 60)

# Process USA files
usa_files = glob(os.path.join(usa_dir, "USA_*.csv"))
for file_path in usa_files:
    filename = os.path.basename(file_path)
    
    # Skip if already a _final file
    if '_final' in filename:
        continue
    if '_Additional' in filename:
        continue

    
    # Create output filename
    name_without_ext = filename.replace('.csv', '')
    output_filename = f"{name_without_ext}_final.csv"
    output_path = os.path.join(usa_dir, output_filename)
    
    print(f"\n{filename}")
    success = process_file(file_path, output_path)
    
    if success:
        # Check file size
        df_check = pd.read_csv(output_path)
        print(f"  ✓ Saved: {output_filename}")
        print(f"    Final rows: {len(df_check):,}")
        print(f"    Columns: {df_check.columns.tolist()}")
        print(f"    nb_product range: {df_check['nb_product'].min()} to {df_check['nb_product'].max()}")
        print(f"    Sample data:")
        print(df_check[['month_id', 'country_id', 'product_id_hs4', 'nb_product']].head(5))

print("\n" + "=" * 60)
print("Processing China files...")
print("=" * 60)

# Process China files
china_files = glob(os.path.join(china_dir, "china_*.csv"))
for file_path in china_files:
    filename = os.path.basename(file_path)
    
    # Skip if already a _final file
    if '_final' in filename:
        continue
    if '_Additional' in filename:
        continue

    # Create output filename
    name_without_ext = filename.replace('.csv', '')
    output_filename = f"{name_without_ext}_final.csv"
    output_path = os.path.join(china_dir, output_filename)
    
    print(f"\n{filename}")
    success = process_file(file_path, output_path)
    
    if success:
        # Check file size
        df_check = pd.read_csv(output_path)
        print(f"  ✓ Saved: {output_filename}")
        print(f"    Final rows: {len(df_check):,}")
        print(f"    Columns: {df_check.columns.tolist()}")
        print(f"    nb_product range: {df_check['nb_product'].min()} to {df_check['nb_product'].max()}")
        print(f"    Sample data:")
        print(df_check[['month_id', 'country_id', 'product_id_hs4', 'nb_product']].head(5))

print("\n" + "=" * 60)
print("All files processed!")
print("=" * 60)

## ADDITIONAL DATA 

In [None]:
import pandas as pd
import os
from glob import glob

# Define directories and file paths
usa_dir = 'comtrade_monthly_hs4_outputs/merged'
china_dir = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\split'
indicators_file = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\df_long.csv'
reer_file = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\outputs\EER_COUNTRIES.csv'

def read_csv_with_encoding(file_path):
    """Try different encodings"""
    encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'windows-1252']
    for encoding in encodings:
        try:
            df = pd.read_csv(file_path, encoding=encoding)
            return df
        except:
            continue
    raise Exception(f"Could not read file: {file_path}")

print("=" * 60)
print("STEP 1: Loading and preparing indicators data...")
print("=" * 60)

# Load indicators data
df_indicators = read_csv_with_encoding(indicators_file)
print(f"\n✓ Loaded indicators data: {len(df_indicators):,} rows")

# Check what indicators are available
print(f"\nChecking available indicators...")
print(f"Unique indicators in data:")
for ind in df_indicators['INDICATOR'].unique():
    if 'GDP' in ind.upper() or 'GROSS' in ind.upper() or 'CONSUMPTION' in ind.upper() or 'CAPITAL' in ind.upper() or 'INVENTOR' in ind.upper():
        count = len(df_indicators[df_indicators['INDICATOR'] == ind])
        print(f"  - {ind}: {count} rows")

# Let's check filters
print(f"\nUnique PRICE_TYPE: {df_indicators['PRICE_TYPE'].unique()}")
print(f"Unique S_ADJUSTMENT: {df_indicators['S_ADJUSTMENT'].unique()}")

# Filter indicators - let's be more flexible with GDP name
indicator_keywords = ['GDP', 'consumption expenditure', 'capital formation', 'inventories']

df_indicators_filtered = df_indicators[
    (df_indicators['PRICE_TYPE'] == 'Constant prices') &
    (df_indicators['S_ADJUSTMENT'] == 'Seasonally adjusted (SA)') &
    (df_indicators['INDICATOR'].apply(lambda x: any(keyword.lower() in x.lower() for keyword in indicator_keywords)))
].copy()

print(f"\n✓ Filtered indicators: {len(df_indicators_filtered):,} rows")
print(f"Indicators found:")
for ind in df_indicators_filtered['INDICATOR'].unique():
    count = len(df_indicators_filtered[df_indicators_filtered['INDICATOR'] == ind])
    print(f"  - {ind}: {count} rows")

# Pivot indicators to create one column per indicator
print(f"\n  Pivoting indicators to wide format...")
df_indicators_wide = df_indicators_filtered.pivot_table(
    index=['ISO3', 'PERIOD'],
    columns='INDICATOR',
    values='VALUE',
    aggfunc='first'
).reset_index()

# Rename columns
column_rename = {
    'ISO3': 'country_id',
    'PERIOD': 'month_id'
}

# Add renaming for each indicator found
for col in df_indicators_wide.columns:
    if 'GDP' in col or 'Gross domestic product' in col:
        column_rename[col] = 'gdp_constant_sa'
    elif 'Final consumption expenditure' in col:
        column_rename[col] = 'final_consumption_constant_sa'
    elif 'Gross capital formation' in col:
        column_rename[col] = 'gross_capital_formation_constant_sa'
    elif 'Changes in inventories' in col:
        column_rename[col] = 'changes_inventories_constant_sa'

df_indicators_wide = df_indicators_wide.rename(columns=column_rename)

print(f"  Pivoted data shape: {df_indicators_wide.shape}")
print(f"  Columns: {df_indicators_wide.columns.tolist()}")

print("\n" + "=" * 60)
print("STEP 2: Loading and preparing REER data...")
print("=" * 60)

# Load REER data
df_reer = read_csv_with_encoding(reer_file)
print(f"\n✓ Loaded REER data: {len(df_reer)} rows")
print(f"  Columns: {df_reer.columns.tolist()[:15]}...")  # First 15 columns

# Check indicators available
print(f"\nUnique INDICATOR values:")
for ind in df_reer['INDICATOR'].unique():
    print(f"  - {ind}")

# Filter for REER only
df_reer_filtered = df_reer[df_reer['INDICATOR'].str.contains('REER', case=False, na=False)].copy()
print(f"\n✓ Filtered for REER: {len(df_reer_filtered)} rows")

# Get the date columns (format: 2021-M01, 2021-M02, etc.)
date_columns = [col for col in df_reer_filtered.columns if '-M' in col]
print(f"  Found {len(date_columns)} date columns from {date_columns[0]} to {date_columns[-1]}")

# Transform from wide to long format
print(f"\n  Transforming REER data from wide to long format...")

# Melt the dataframe
df_reer_long = df_reer_filtered.melt(
    id_vars=['COUNTRY.ID'],
    value_vars=date_columns,
    var_name='period_str',
    value_name='REER'
)

# Convert period format from "2021-M01" to "202101"
df_reer_long['month_id'] = df_reer_long['period_str'].str.replace('-M', '').astype(int)

# Rename COUNTRY.ID to country_id
df_reer_long = df_reer_long.rename(columns={'COUNTRY.ID': 'country_id'})

# Keep only country_id, month_id, and REER
df_reer_final = df_reer_long[['country_id', 'month_id', 'REER']].copy()

# Remove rows with missing REER values
df_reer_final = df_reer_final.dropna(subset=['REER'])

print(f"  Final REER data shape: {df_reer_final.shape}")
print(f"  Sample:")
print(df_reer_final.head(10))

print("\n" + "=" * 60)
print("STEP 3: Processing USA files...")
print("=" * 60)

# Process USA files
usa_files = glob(os.path.join(usa_dir, "USA_*_final.csv"))
for file_path in usa_files:
    filename = os.path.basename(file_path)
    print(f"\n{filename}")
    
    try:
        # Read file
        df = read_csv_with_encoding(file_path)
        print(f"  Original rows: {len(df):,}")
        
        # Merge with indicators
        df_merged = df.merge(
            df_indicators_wide,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        indicator_cols = [col for col in df_merged.columns if '_constant_sa' in col]
        rows_with_indicators = df_merged[indicator_cols].notna().any(axis=1).sum()
        print(f"  ✓ Added {len(indicator_cols)} indicator columns")
        print(f"    Rows with at least one indicator: {rows_with_indicators:,}")
        
        # Merge with REER
        df_merged = df_merged.merge(
            df_reer_final,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        rows_with_reer = df_merged['REER'].notna().sum()
        print(f"  ✓ Added REER column")
        print(f"    Rows with REER data: {rows_with_reer:,}")
        
        # Save
        output_filename = filename.replace('_final.csv', '_Additional.csv')
        output_path = os.path.join(usa_dir, output_filename)
        df_merged.to_csv(output_path, index=False, encoding='utf-8')
        
        print(f"  ✓ Saved: {output_filename}")
        print(f"    Total columns: {len(df_merged.columns)}")
        print(f"    Indicator columns: {indicator_cols}")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 60)
print("STEP 4: Processing China files...")
print("=" * 60)

# Process China files
china_files = glob(os.path.join(china_dir, "china_*_final.csv"))
for file_path in china_files:
    filename = os.path.basename(file_path)
    print(f"\n{filename}")
    
    try:
        # Read file
        df = read_csv_with_encoding(file_path)
        print(f"  Original rows: {len(df):,}")
        
        # Merge with indicators
        df_merged = df.merge(
            df_indicators_wide,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        indicator_cols = [col for col in df_merged.columns if '_constant_sa' in col]
        rows_with_indicators = df_merged[indicator_cols].notna().any(axis=1).sum()
        print(f"  ✓ Added {len(indicator_cols)} indicator columns")
        print(f"    Rows with at least one indicator: {rows_with_indicators:,}")
        
        # Merge with REER
        df_merged = df_merged.merge(
            df_reer_final,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        rows_with_reer = df_merged['REER'].notna().sum()
        print(f"  ✓ Added REER column")
        print(f"    Rows with REER data: {rows_with_reer:,}")
        
        # Save
        output_filename = filename.replace('_final.csv', '_Additional.csv')
        output_path = os.path.join(china_dir, output_filename)
        df_merged.to_csv(output_path, index=False, encoding='utf-8')
        
        print(f"  ✓ Saved: {output_filename}")
        print(f"    Total columns: {len(df_merged.columns)}")
        print(f"    Indicator columns: {indicator_cols}")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 60)
print("All files processed successfully!")
print("=" * 60)
print("\nNew columns added:")
print("  - gdp_constant_sa (if available in data)")
print("  - final_consumption_constant_sa")
print("  - gross_capital_formation_constant_sa")
print("  - changes_inventories_constant_sa")
print("  - REER (Real Effective Exchange Rate)")
print("=" * 60)

### ADDITIONAL DATA 2023 2024

In [None]:
import pandas as pd
import os
from glob import glob

# Define directories and file paths
input_dir = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\processed_input_data'
indicators_file = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\Data\china and gdp data\china and gdp data\df_long.csv'
reer_file = r'C:\Users\wb636273\OneDrive - WBG\Documents\AI4TRADE\outputs\EER_COUNTRIES.csv'

def read_csv_with_encoding(file_path):
    """Try different encodings"""
    encodings = ['utf-8', 'latin-1', 'iso-8859-1', 'windows-1252']
    for encoding in encodings:
        try:
            df = pd.read_csv(file_path, encoding=encoding)
            return df
        except:
            continue
    raise Exception(f"Could not read file: {file_path}")

print("=" * 60)
print("PREPARATION: Loading reference data...")
print("=" * 60)

# Load indicators data
df_indicators = read_csv_with_encoding(indicators_file)
print(f"\n✓ Loaded indicators data: {len(df_indicators):,} rows")

# Filter indicators
indicator_keywords = ['GDP', 'consumption expenditure', 'capital formation', 'inventories']

df_indicators_filtered = df_indicators[
    (df_indicators['PRICE_TYPE'] == 'Constant prices') &
    (df_indicators['S_ADJUSTMENT'] == 'Seasonally adjusted (SA)') &
    (df_indicators['INDICATOR'].apply(lambda x: any(keyword.lower() in x.lower() for keyword in indicator_keywords)))
].copy()

print(f"✓ Filtered indicators: {len(df_indicators_filtered):,} rows")
print(f"Indicators found: {df_indicators_filtered['INDICATOR'].unique().tolist()}")

# Pivot indicators
df_indicators_wide = df_indicators_filtered.pivot_table(
    index=['ISO3', 'PERIOD'],
    columns='INDICATOR',
    values='VALUE',
    aggfunc='first'
).reset_index()

# Rename columns
column_rename = {
    'ISO3': 'country_id',
    'PERIOD': 'month_id'
}

for col in df_indicators_wide.columns:
    if 'GDP' in col or 'Gross domestic product' in col:
        column_rename[col] = 'gdp_constant_sa'
    elif 'Final consumption expenditure' in col:
        column_rename[col] = 'final_consumption_constant_sa'
    elif 'Gross capital formation' in col:
        column_rename[col] = 'gross_capital_formation_constant_sa'
    elif 'Changes in inventories' in col:
        column_rename[col] = 'changes_inventories_constant_sa'

df_indicators_wide = df_indicators_wide.rename(columns=column_rename)
print(f"✓ Pivoted indicators: {df_indicators_wide.shape}")

# Load REER data
df_reer = read_csv_with_encoding(reer_file)
print(f"\n✓ Loaded REER data: {len(df_reer)} rows")

# Filter for REER
df_reer_filtered = df_reer[df_reer['INDICATOR'].str.contains('REER', case=False, na=False)].copy()

# Get date columns
date_columns = [col for col in df_reer_filtered.columns if '-M' in col]

# Melt to long format
df_reer_long = df_reer_filtered.melt(
    id_vars=['COUNTRY.ID'],
    value_vars=date_columns,
    var_name='period_str',
    value_name='REER'
)

# Convert period format
df_reer_long['month_id'] = df_reer_long['period_str'].str.replace('-M', '').astype(int)
df_reer_long = df_reer_long.rename(columns={'COUNTRY.ID': 'country_id'})
df_reer_final = df_reer_long[['country_id', 'month_id', 'REER']].dropna(subset=['REER'])

print(f"✓ Prepared REER data: {df_reer_final.shape}")

print("\n" + "=" * 60)
print("STEP 1: Splitting files by trade flow...")
print("=" * 60)

# Files to process
files_to_process = [
    'USA_2023_finale.csv',
    'USA_2024_finale.csv',
    'china_2023_finale.csv',
    'china_2024_finale.csv'
]

split_files = []  # Keep track of split files for next step

for filename in files_to_process:
    file_path = os.path.join(input_dir, filename)
    
    if not os.path.exists(file_path):
        print(f"\n⚠️  File not found: {filename}")
        continue
    
    print(f"\n{filename}")
    
    try:
        # Read file
        df = read_csv_with_encoding(file_path)
        print(f"  Original rows: {len(df):,}")
        print(f"  Columns: {df.columns.tolist()}")
        
        # Check if trade_flow_name exists
        if 'trade_flow_name' not in df.columns:
            print(f"  ⚠️  'trade_flow_name' column not found!")
            continue
        
        # Get unique trade flows
        trade_flows = df['trade_flow_name'].unique()
        print(f"  Trade flows found: {trade_flows}")
        
        # Extract country and year from filename
        # e.g., "USA_2023_finale.csv" -> "USA", "2023"
        parts = filename.replace('_finale.csv', '').split('_')
        country = parts[0]
        year = parts[1]
        
        # Split by trade flow
        for flow in trade_flows:
            df_flow = df[df['trade_flow_name'] == flow].copy()
            
            # Create filename
            flow_name = flow.lower()
            output_filename = f"{country}_{year}_{flow_name}_final.csv"
            output_path = os.path.join(input_dir, output_filename)
            
            # Save
            df_flow.to_csv(output_path, index=False, encoding='utf-8')
            print(f"    ✓ {output_filename}: {len(df_flow):,} rows")
            
            # Track for next step
            split_files.append(output_filename)
            
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 60)
print("STEP 2: Adding indicators and REER to create _Additional files...")
print("=" * 60)

for filename in split_files:
    file_path = os.path.join(input_dir, filename)
    print(f"\n{filename}")
    
    try:
        # Read file
        df = read_csv_with_encoding(file_path)
        df["country_id"] = df["country_id"].str.upper()
        print(f"  Original rows: {len(df):,}")
        
        # Check required columns
        if 'country_id' not in df.columns or 'month_id' not in df.columns:
            print(f"  ⚠️  Required columns (country_id, month_id) not found!")
            print(f"      Available columns: {df.columns.tolist()}")
            continue
        
        # Merge with indicators
        df_merged = df.merge(
            df_indicators_wide,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        indicator_cols = [col for col in df_merged.columns if '_constant_sa' in col]
        rows_with_indicators = df_merged[indicator_cols].notna().any(axis=1).sum()
        print(f"  ✓ Added {len(indicator_cols)} indicator columns")
        print(f"    Indicators: {indicator_cols}")
        print(f"    Rows with at least one indicator: {rows_with_indicators:,}")
        
        # Merge with REER
        df_merged = df_merged.merge(
            df_reer_final,
            on=['country_id', 'month_id'],
            how='left'
        )
        
        rows_with_reer = df_merged['REER'].notna().sum()
        print(f"  ✓ Added REER column")
        print(f"    Rows with REER data: {rows_with_reer:,}")
        
        # Save
        output_filename = filename.replace('_final.csv', '_Additional.csv')
        output_path = os.path.join(input_dir, output_filename)
        df_merged.to_csv(output_path, index=False, encoding='utf-8')
        
        print(f"  ✓ Saved: {output_filename}")
        print(f"    Total columns: {len(df_merged.columns)}")
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 60)
print("All files processed successfully!")
print("=" * 60)
print("\nOutput files created:")
print("  Step 1 - Split by trade flow (8 files):")
for f in split_files:
    print(f"    - {f}")
print("\n  Step 2 - Added indicators and REER (8 files):")
for f in split_files:
    print(f"    - {f.replace('_final.csv', '_Additional.csv')}")
print("=" * 60)