<a href="https://colab.research.google.com/github/Tiru-Kaggundi/Trade_AI/blob/main/data_aggregation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os, pandas as pd

BASE_DIR     = "/content/drive/MyDrive/ai4trade"
RAW_DIR      = f"{BASE_DIR}/data/raw"       # where CHN_2023.parquet ... USA_2025.parquet are
INTERIM_DIR  = f"{BASE_DIR}/data/interim"   # national-level outputs will be saved here
os.makedirs(INTERIM_DIR, exist_ok=True)

# EXACT files we produced earlier
FILE_MAP = [
    ("CHN_2023.parquet", "CHN"),
    ("CHN_2024.parquet", "CHN"),
    ("CHN_2025.parquet", "CHN"),
    ("USA_2023.parquet", "USA"),
    ("USA_2024.parquet", "USA"),
    ("USA_2025.parquet", "USA"),
]

print("RAW_DIR:", RAW_DIR)
print("INTERIM_DIR:", INTERIM_DIR)
print("Available:", sorted(os.listdir(RAW_DIR)))


Mounted at /content/drive
RAW_DIR: /content/drive/MyDrive/ai4trade/data/raw
INTERIM_DIR: /content/drive/MyDrive/ai4trade/data/interim
Available: ['CHN_2023.parquet', 'CHN_2024.parquet', 'CHN_2025.parquet', 'USA_2023.parquet', 'USA_2024.parquet', 'USA_2025.parquet', 'parquets_old', 'trade_s_chn_m_hs_2023.csv.zip', 'trade_s_chn_m_hs_2024.csv.zip', 'trade_s_chn_m_hs_2025.csv.zip', 'trade_s_usa_state_m_hs_2023.csv.zip', 'trade_s_usa_state_m_hs_2024.csv.zip', 'trade_s_usa_state_m_hs_2025.csv.zip']


In [6]:
import numpy as np
from IPython.display import display

# -------------------------------------------------------
# Canonical output columns:
# origin, destination, hs6, hs4, trade_flow, month, value
# -------------------------------------------------------

def _parse_month_from_month_id(series):
    """
    Converts month_id like 202503 or '2025-03' into a Timestamp
    representing the first day of that month.
    """
    s = series.astype(str).str.strip()

    # Ensure pandas Series stays intact (avoid numpy array)
    mask = s.str.match(r"^\d{6}$")
    s = s.where(~mask, s.str.slice(0, 4) + "-" + s.str.slice(4, 6))

    month = pd.to_datetime(s, errors="coerce")
    # Convert to first day of month
    month = month.dt.to_period("M").dt.to_timestamp()
    return month


def _final_type_enforce(df):
    """Enforce datatypes, padding, and sanity checks."""
    # HS codes as zero-padded strings
    df["hs6"] = df["hs6"].astype(str).str.replace(r"\.0+$", "", regex=True).str.zfill(6)
    df["hs4"] = df["hs4"].astype(str).str.replace(r"\.0+$", "", regex=True).str.zfill(4)

    # Normalize month to month-start Timestamp
    df["month"] = pd.to_datetime(df["month"]).dt.to_period("M").dt.to_timestamp()

    # Trade flow normalization (title-case)
    df["trade_flow"] = df["trade_flow"].astype(str).str.strip().str.title()

    # Destination upper-case
    df["destination"] = df["destination"].astype(str).str.upper()

    # Value numeric
    df["value"] = pd.to_numeric(df["value"], errors="coerce")

    # --- Assertions ---
    assert df["hs6"].str.len().eq(6).all(), "HS6 must be 6 chars."
    assert df["hs4"].str.len().eq(4).all(), "HS4 must be 4 chars."
    assert df["month"].notna().all(), "Month parse failed."
    assert df["trade_flow"].isin(["Export","Import"]).all(), f"Unexpected trade_flow: {df['trade_flow'].unique()}"
    return df


def standardize_CHN(df_raw):
    """
    China schema (from README):
    month_id, province_id, province_name, trade_flow_name,
    country_id, country_name, product_id, product_name, trade_value
    """
    req = ["month_id","trade_flow_name","country_id","product_id","trade_value"]
    missing = [c for c in req if c not in df_raw.columns]
    if missing:
        raise KeyError(f"Missing columns in CHN file: {missing}")

    out = pd.DataFrame({
        "origin":      "CHN",
        "destination": df_raw["country_id"],
        "hs6":         df_raw["product_id"],
        "hs4":         df_raw["product_id"].astype(str).str.replace(r"\\.0+$", "", regex=True).str[:4],
        "trade_flow":  df_raw["trade_flow_name"],
        "month":       _parse_month_from_month_id(df_raw["month_id"]),
        "value":       df_raw["trade_value"],
    })

    # Normalize trade_flow plural/singular
    flow_map = {"Exports": "Export", "Export": "Export",
                "Imports": "Import", "Import": "Import"}
    out["trade_flow"] = out["trade_flow"].map(flow_map)

    out = out.dropna(subset=["destination","month","value"])
    out = _final_type_enforce(out)
    return out


def standardize_USA(df_raw):
    """
    USA schema (from README):
    month_id, state_id, state_name, trade_flow_name,
    country_id, country_name, product_id, product_name, trade_value
    """
    req = ["month_id","trade_flow_name","country_id","product_id","trade_value"]
    missing = [c for c in req if c not in df_raw.columns]
    if missing:
        raise KeyError(f"Missing columns in USA file: {missing}")

    out = pd.DataFrame({
        "origin":      "USA",
        "destination": df_raw["country_id"],
        "hs6":         df_raw["product_id"],
        "hs4":         df_raw["product_id"].astype(str).str.replace(r"\\.0+$", "", regex=True).str[:4],
        "trade_flow":  df_raw["trade_flow_name"],
        "month":       _parse_month_from_month_id(df_raw["month_id"]),
        "value":       df_raw["trade_value"],
    })

    flow_map = {"Exports": "Export", "Export": "Export",
                "Imports": "Import", "Import": "Import"}
    out["trade_flow"] = out["trade_flow"].map(flow_map)

    out = out.dropna(subset=["destination","month","value"])
    out = _final_type_enforce(out)
    return out


def monthly_totals(df_std):
    """Monthly export/import totals for verification."""
    return (df_std.groupby(["trade_flow","month"], as_index=False)["value"]
                 .sum()
                 .sort_values(["trade_flow","month"]))


def national_aggregate(df_std):
    """Aggregate provinces/states → national totals."""
    return (df_std.groupby(["origin","destination","hs6","hs4","trade_flow","month"], as_index=False)["value"]
                 .sum())


In [7]:
verification_csvs = []
cleaned_paths = []

for fname, country in FILE_MAP:
    in_path = os.path.join(RAW_DIR, fname)
    year    = "".join([c for c in fname if c.isdigit()])[:4]
    out_nat = os.path.join(INTERIM_DIR, f"{country}_{year}_national.parquet")
    out_tot = os.path.join(INTERIM_DIR, f"{country}_{year}_monthly_totals_raw.csv")

    print(f"\n=== {fname} ({country}) ===")
    df_raw = pd.read_parquet(in_path, engine="pyarrow")
    print("Raw shape:", df_raw.shape)

    # Standardize by dataset
    if country == "CHN":
        df_std = standardize_CHN(df_raw)
    else:
        df_std = standardize_USA(df_raw)

    print("Standardized shape:", df_std.shape)
    display(df_std.head(3))

    # A) Monthly totals BEFORE aggregation
    totals_before = monthly_totals(df_std)
    display(totals_before.head(6))
    totals_before.to_csv(out_tot, index=False)
    print("Saved monthly totals (raw) →", out_tot)
    verification_csvs.append(out_tot)

    # B) Aggregate to national level (provinces/states collapsed)
    df_nat = national_aggregate(df_std)
    print("National shape:", df_nat.shape)

    # C) Re-verify totals AFTER aggregation (must match)
    totals_after = monthly_totals(df_nat)
    chk = totals_before.merge(totals_after, on=["trade_flow","month"], how="outer", suffixes=("_before","_after")).fillna(0)
    mism = (chk["value_before"].round(2) != chk["value_after"].round(2))
    if mism.any():
        print("⚠️ Mismatch in monthly totals after aggregation. Showing rows:")
        display(chk[mism].head(20))
        raise AssertionError("Post-aggregation monthly totals do not equal pre-aggregation totals.")
    print("✅ Monthly totals verified equal (pre vs post).")

    # D) Save the national-level parquet with canonical columns
    df_nat = df_nat[["origin","destination","hs6","hs4","trade_flow","month","value"]]
    df_nat.to_parquet(out_nat, index=False, engine="pyarrow")
    print("✅ Wrote:", out_nat)
    cleaned_paths.append(out_nat)

print("\nAll verification CSVs:")
for p in verification_csvs:
    print("  ", p)



=== CHN_2023.parquet (CHN) ===
Raw shape: (19121752, 13)
Standardized shape: (19121752, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,CHN,JPN,710410,7104,Export,2023-01-01,646
1,CHN,JPN,710491,7104,Export,2023-01-01,3426
2,CHN,JPN,711311,7113,Export,2023-01-01,26


Unnamed: 0,trade_flow,month,value
0,Export,2023-01-01,292275367741
1,Export,2023-02-01,214025743000
2,Export,2023-03-01,315588909730
3,Export,2023-04-01,295417464107
4,Export,2023-05-01,283483634249
5,Export,2023-06-01,285321612403


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/CHN_2023_monthly_totals_raw.csv
National shape: (4407542, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/CHN_2023_national.parquet

=== CHN_2024.parquet (CHN) ===
Raw shape: (19624760, 13)
Standardized shape: (19624760, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,CHN,BGD,30199,3019,Import,2024-01-01,93356
1,CHN,BGD,190590,1905,Import,2024-01-01,800
2,CHN,BGD,382319,3823,Import,2024-01-01,739529


Unnamed: 0,trade_flow,month,value
0,Export,2024-01-01,307734572740
1,Export,2024-02-01,220279290654
2,Export,2024-03-01,279681311209
3,Export,2024-04-01,292453428286
4,Export,2024-05-01,302347868345
5,Export,2024-06-01,307854344985


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/CHN_2024_monthly_totals_raw.csv
National shape: (4439253, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/CHN_2024_national.parquet

=== CHN_2025.parquet (CHN) ===
Raw shape: (4916482, 13)
Standardized shape: (4916482, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,CHN,IRL,710491,7104,Import,2025-01-01,22635
1,CHN,ITA,710692,7106,Import,2025-01-01,3650
2,CHN,ITA,711311,7113,Import,2025-01-01,733


Unnamed: 0,trade_flow,month,value
0,Export,2025-01-01,324767552883
1,Export,2025-02-01,215172210554
2,Export,2025-03-01,313911773459
3,Import,2025-01-01,185971625273
4,Import,2025-02-01,183453687281
5,Import,2025-03-01,211269326485


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/CHN_2025_monthly_totals_raw.csv
National shape: (1095572, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/CHN_2025_national.parquet

=== USA_2023.parquet (USA) ===
Raw shape: (22451773, 9)
Standardized shape: (22451773, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,USA,CHN,852869,8528,Import,2023-01-01,264068
1,USA,CHN,852910,8529,Import,2023-01-01,65531
2,USA,CHN,852990,8529,Import,2023-01-01,1023090


Unnamed: 0,trade_flow,month,value
0,Export,2023-01-01,165486160894
1,Export,2023-02-01,158941789023
2,Export,2023-03-01,184458472459
3,Export,2023-04-01,162387696038
4,Export,2023-05-01,166677068442
5,Export,2023-06-01,167232234117


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/USA_2023_monthly_totals_raw.csv
National shape: (4216054, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/USA_2023_national.parquet

=== USA_2024.parquet (USA) ===
Raw shape: (22467146, 9)
Standardized shape: (22467146, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,USA,CAN,540753,5407,Import,2024-01-01,25721
1,USA,CAN,540773,5407,Import,2024-01-01,1900
2,USA,CAN,540793,5407,Import,2024-01-01,19092


Unnamed: 0,trade_flow,month,value
0,Export,2024-01-01,160582317040
1,Export,2024-02-01,167420050869
2,Export,2024-03-01,179325914652
3,Export,2024-04-01,171673234988
4,Export,2024-05-01,173123964953
5,Export,2024-06-01,174359922265


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/USA_2024_monthly_totals_raw.csv
National shape: (4194566, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/USA_2024_national.parquet

=== USA_2025.parquet (USA) ===
Raw shape: (2036113, 9)
Standardized shape: (2036113, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,USA,JPN,880720,8807,Import,2025-01-01,50188
1,USA,JPN,880730,8807,Import,2025-01-01,123876704
2,USA,JPN,900211,9002,Import,2025-01-01,2875


Unnamed: 0,trade_flow,month,value
0,Export,2025-01-01,164025679587
1,Export,2025-02-01,167608804810
2,Import,2025-01-01,317306228983
3,Import,2025-02-01,288187133875


Saved monthly totals (raw) → /content/drive/MyDrive/ai4trade/data/interim/USA_2025_monthly_totals_raw.csv
National shape: (458523, 7)
✅ Monthly totals verified equal (pre vs post).
✅ Wrote: /content/drive/MyDrive/ai4trade/data/interim/USA_2025_national.parquet

All verification CSVs:
   /content/drive/MyDrive/ai4trade/data/interim/CHN_2023_monthly_totals_raw.csv
   /content/drive/MyDrive/ai4trade/data/interim/CHN_2024_monthly_totals_raw.csv
   /content/drive/MyDrive/ai4trade/data/interim/CHN_2025_monthly_totals_raw.csv
   /content/drive/MyDrive/ai4trade/data/interim/USA_2023_monthly_totals_raw.csv
   /content/drive/MyDrive/ai4trade/data/interim/USA_2024_monthly_totals_raw.csv
   /content/drive/MyDrive/ai4trade/data/interim/USA_2025_monthly_totals_raw.csv


In [9]:
frames = []
for p in cleaned_paths:
    df = pd.read_parquet(p, engine="pyarrow")
    # Re-enforce types (belt & suspenders)
    df["hs6"] = df["hs6"].astype(str).str.zfill(6)
    df["hs4"] = df["hs4"].astype(str).str.zfill(4)
    df["month"] = pd.to_datetime(df["month"]).dt.to_period("M").dt.to_timestamp()
    df["trade_flow"] = df["trade_flow"].astype(str).str.title()
    df["destination"] = df["destination"].astype(str).str.upper()
    df["value"] = pd.to_numeric(df["value"], errors="coerce")
    frames.append(df)

combined = pd.concat(frames, ignore_index=True)

# Final checks
assert combined["hs6"].str.len().eq(6).all()
assert combined["hs4"].str.len().eq(4).all()
assert combined["trade_flow"].isin(["Export","Import"]).all()
assert combined["month"].notna().all()

combined_path = os.path.join(INTERIM_DIR, "harmonized_trade_data.parquet")
combined.to_parquet(combined_path, index=False, engine="pyarrow")
print("✅ Harmonized dataset →", combined_path)
print("Shape:", combined.shape)
display(combined.sample(min(5, len(combined))))


✅ Harmonized dataset → /content/drive/MyDrive/ai4trade/data/interim/harmonized_trade_data.parquet
Shape: (18811510, 7)


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
13486339,USA,SGP,731519,7315,Export,2023-05-01,0
14954871,USA,CHN,291816,2918,Import,2024-09-01,101959
5340368,CHN,DEU,740940,7409,Import,2024-12-01,2686563
1458266,CHN,GMB,392490,3924,Export,2023-11-01,103987
18042923,USA,TUR,690390,6903,Export,2024-07-01,0


In [10]:
# Totals by origin-year-flow
overview = (combined
            .assign(year = combined["month"].dt.year)
            .groupby(["origin","year","trade_flow"], as_index=False)["value"].sum()
            .sort_values(["origin","year","trade_flow"]))
display(overview.head(20))

# First few monthly totals per origin (just to eyeball trends)
monthly = (combined.groupby(["origin","trade_flow","month"], as_index=False)["value"].sum()
                    .sort_values(["origin","trade_flow","month"]))
display(monthly.head(24))


Unnamed: 0,origin,year,trade_flow,value
0,CHN,2023,Export,3414265469240
1,CHN,2023,Import,2563583755105
2,CHN,2024,Export,3580262274749
3,CHN,2024,Import,2587149291237
4,CHN,2025,Export,853851536896
5,CHN,2025,Import,580694639039
6,USA,2023,Export,2019159258568
7,USA,2023,Import,3084110146752
8,USA,2024,Export,2064516641219
9,USA,2024,Import,3267388705988


Unnamed: 0,origin,trade_flow,month,value
0,CHN,Export,2023-01-01,292275367741
1,CHN,Export,2023-02-01,214025743000
2,CHN,Export,2023-03-01,315588909730
3,CHN,Export,2023-04-01,295417464107
4,CHN,Export,2023-05-01,283483634249
5,CHN,Export,2023-06-01,285321612403
6,CHN,Export,2023-07-01,281761164276
7,CHN,Export,2023-08-01,284790921653
8,CHN,Export,2023-09-01,299129465971
9,CHN,Export,2023-10-01,274826898113
