<a href="https://colab.research.google.com/github/Tiru-Kaggundi/Trade_AI/blob/main/China_import_2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Install dependencies (first run only)
!pip -q install pycountry openpyxl

import re
import pycountry
import pandas as pd
import numpy as np
from pathlib import Path

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m3.2/6.3 MB[0m [31m96.3 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.3/6.3 MB[0m [31m125.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
#@title Mount Google Drive and set base paths
from google.colab import drive
drive.mount('/content/drive')

BASE_DIR = Path('/content/drive/MyDrive/ai4trade')
RAW_PATH = BASE_DIR / 'data' / 'raw' / 'imports_2025_chn.xlsx'
OUT_PATH = BASE_DIR / 'data' / 'interim' / 'CHN_imports_2025.parquet'

print(f"Reading from: {RAW_PATH}")
print(f"Will write to: {OUT_PATH}")

Mounted at /content/drive
Reading from: /content/drive/MyDrive/ai4trade/data/raw/imports_2025_chn.xlsx
Will write to: /content/drive/MyDrive/ai4trade/data/interim/CHN_imports_2025.parquet


In [None]:
#@title Utility functions

def clean_hs6(x) -> str:
    """
    Coerce HS6 into a clean 6-char zero-padded string.
    Removes commas, apostrophes, spaces, dots; strips; keeps digits; zfill(6).
    """
    if pd.isna(x):
        return np.nan
    s = str(x)
    s = s.replace(",", "").replace("’", "").replace("'", "").replace(" ", "").replace(".", "")
    s = re.sub(r"[^0-9]", "", s)  # keep only digits
    if s == "":
        return np.nan
    return s.zfill(6)[:6]

def derive_hs4(hs6: str) -> str:
    if pd.isna(hs6):
        return np.nan
    return str(hs6)[:4]

def parse_month_header(colname: str):
    """
    Try multiple patterns commonly seen in Trademap: 'Jan 2025', 'Jan-2025', '2025-01', '2025 Jan', 'January 2025'
    Returns pd.Timestamp (first day of month) or None.
    """
    # fast paths with pandas
    for fmt in [None, '%b %Y', '%b-%Y', '%Y-%m', '%Y %b', '%B %Y', '%B-%Y']:
        try:
            dt = pd.to_datetime(colname, format=fmt, errors='raise')
            return pd.Timestamp(year=dt.year, month=dt.month, day=1)
        except Exception:
            pass
    # last resort: month order fallback handled later
    return None

def is_iso3(code: str) -> bool:
    try:
        return pycountry.countries.get(alpha_3=code.upper()) is not None
    except Exception:
        return False

def extract_partner_from_sheet(sheet_name: str, prefix="China-") -> str:
    """
    Sheet names are said to be 'China-XXX'. Extract XXX and validate ISO3.
    """
    if not sheet_name.startswith(prefix):
        # allow slight variations like 'China – XXX' with spaces/dashes
        m = re.match(r"China[\s\-–—_]+([A-Za-z]{3})$", sheet_name.strip())
        code = m.group(1).upper() if m else sheet_name.split("-")[-1].strip().upper()
    else:
        code = sheet_name.replace(prefix, "").strip().upper()
    return code

In [None]:
#@title Read Excel (all sheets) and validate partner ISO3 codes
xlsx = pd.read_excel(RAW_PATH, sheet_name=None, header=0, engine='openpyxl')

sheet_partners = {sheet: extract_partner_from_sheet(sheet) for sheet in xlsx.keys()}
invalid = [ (sheet, code) for sheet, code in sheet_partners.items() if not is_iso3(code) ]

print("Detected sheets and partners:")
for s, c in sheet_partners.items():
    print(f"  {s}  ->  {c} (ISO3={'OK' if is_iso3(c) else 'INVALID'})")

if invalid:
    msg = "Some sheet names don't map to valid ISO3 partners:\n" + \
          "\n".join([f"  - Sheet '{s}' → '{c}' (INVALID)" for s, c in invalid]) + \
          "\n\nPlease rename those sheets to 'China-XXX' with a valid ISO3 code and rerun."
    raise ValueError(msg)

Detected sheets and partners:
  China-USA  ->  USA (ISO3=OK)
  China-HKG  ->  HKG (ISO3=OK)
  China-JPN  ->  JPN (ISO3=OK)
  China-VNM  ->  VNM (ISO3=OK)
  China-KOR  ->  KOR (ISO3=OK)
  China-IND  ->  IND (ISO3=OK)
  China-RUS  ->  RUS (ISO3=OK)
  China-DEU  ->  DEU (ISO3=OK)
  China-NLD  ->  NLD (ISO3=OK)
  China-MYS  ->  MYS (ISO3=OK)
  China-MEX  ->  MEX (ISO3=OK)
  China-THA  ->  THA (ISO3=OK)
  China-SGP  ->  SGP (ISO3=OK)
  China-GBR  ->  GBR (ISO3=OK)
  China-AUS  ->  AUS (ISO3=OK)
  China-TWN  ->  TWN (ISO3=OK)
  China-IDN  ->  IDN (ISO3=OK)
  China-BRA  ->  BRA (ISO3=OK)
  China-ARE  ->  ARE (ISO3=OK)
  China-PHL  ->  PHL (ISO3=OK)
  China-SAU  ->  SAU (ISO3=OK)
  China-CAN  ->  CAN (ISO3=OK)
  China-ITA  ->  ITA (ISO3=OK)
  China-FRA  ->  FRA (ISO3=OK)
  China-KAZ  ->  KAZ (ISO3=OK)
  China-CHE  ->  CHE (ISO3=OK)
  China-CHL  ->  CHL (ISO3=OK)
  China-IRQ  ->  IRQ (ISO3=OK)
  China-ZAF  ->  ZAF (ISO3=OK)
  China-OMN  ->  OMN (ISO3=OK)
  China-PER  ->  PER (ISO3=OK)
  China-Q

In [None]:
#@title Function to normalize a single partner sheet
def normalize_partner_sheet(df_raw: pd.DataFrame, partner_iso3: str) -> pd.DataFrame:
    """
    1) Drop obvious 'Total' line if present.
    2) Clean HS6; derive HS4.
    3) Identify month columns and melt to long.
    4) Attach identifiers.
    5) Remove zero rows after hs4 derivation.
    """
    df = df_raw.copy()

    # Drop rows where the first column is a 'Total' row or clearly not a code
    first_col = df.columns[0]
    df = df[~df[first_col].astype(str).str.contains(r'^total$', case=False, na=False)]

    # Clean HS6
    df[first_col] = df[first_col].map(clean_hs6)
    df = df[~df[first_col].isna()]

    # Build month mapping from headers
    month_cols = [c for c in df.columns if c != first_col]
    parsed = {c: parse_month_header(str(c)) for c in month_cols}
    # Any unparsed? fall back to ordered Jan..Aug 2025 assumption
    if any(v is None for v in parsed.values()):
        # Keep original order; assign Jan..Aug 2025 in order
        months = pd.period_range('2025-01', '2025-08', freq='M')
        fallback = {c: pd.Timestamp(m.start_time) for c, m in zip(month_cols, months)}
        # use parsed if available, else fallback
        for c in month_cols:
            if parsed[c] is None:
                parsed[c] = fallback[c]

    # Melt
    df_long = df.melt(id_vars=[first_col], value_vars=month_cols,
                      var_name='month_col', value_name='value')
    # Map months
    df_long['month'] = df_long['month_col'].map(parsed)
    df_long.drop(columns=['month_col'], inplace=True)

    # Coerce numeric; Trademap sometimes brings commas
    df_long['value'] = pd.to_numeric(df_long['value'], errors='coerce').fillna(0.0)
    # Identifiers
    df_long = df_long.rename(columns={first_col: 'hs6'})
    df_long['hs4'] = df_long['hs6'].map(derive_hs4)
    df_long['origin'] = 'CHN'
    df_long['destination'] = partner_iso3
    df_long['trade_flow'] = 'Import'

    # Remove zero rows *after* hs4 derivation to avoid bloat per project rule
    df_long = df_long[df_long['value'] > 0]

    # Final select & dtypes
    df_long = df_long[['origin', 'destination', 'hs6', 'hs4', 'trade_flow', 'month', 'value']]
    df_long['month'] = pd.to_datetime(df_long['month']).dt.to_period('M').dt.to_timestamp()
    df_long['hs6'] = df_long['hs6'].astype(str)
    df_long['hs4'] = df_long['hs4'].astype(str)

    return df_long

In [None]:
#@title Normalize all partner sheets and concatenate
all_parts = []
for sheet_name, df_raw in xlsx.items():
    partner = sheet_partners[sheet_name]
    part = normalize_partner_sheet(df_raw, partner)
    all_parts.append(part)

df_all = pd.concat(all_parts, axis=0, ignore_index=True)

print("Preview:")
display(df_all.head(10))
print(df_all.dtypes)
print(f"Rows after zero-drop: {len(df_all):,}")

Preview:


Unnamed: 0,origin,destination,hs6,hs4,trade_flow,month,value
0,CHN,USA,854231,8542,Import,2025-01-01,1091959000
1,CHN,USA,854239,8542,Import,2025-01-01,168780000
2,CHN,USA,880240,8802,Import,2025-01-01,329769000
3,CHN,USA,841191,8411,Import,2025-01-01,313125000
4,CHN,USA,271112,2711,Import,2025-01-01,890427000
5,CHN,USA,300490,3004,Import,2025-01-01,105188000
6,CHN,USA,848620,8486,Import,2025-01-01,255541000
7,CHN,USA,870323,8703,Import,2025-01-01,124109000
8,CHN,USA,300215,3002,Import,2025-01-01,115496000
9,CHN,USA,854233,8542,Import,2025-01-01,50731000


origin                 object
destination            object
hs6                    object
hs4                    object
trade_flow             object
month          datetime64[ns]
value                   int64
dtype: object
Rows after zero-drop: 287,162


In [None]:
#@title Partner filter: ≥100 unique HS4s (positive) on average per month
# Ensure months are only Jan..Aug 2025 (in case file had extras)
mask_2025_jan_aug = (df_all['month'] >= '2025-01-01') & (df_all['month'] <= '2025-08-31')
df_ja = df_all.loc[mask_2025_jan_aug].copy()

# Count unique hs4 per (destination, month) with value>0
monthly_counts = (
    df_ja[df_ja['value'] > 0]
    .groupby(['destination', 'month'])['hs4']
    .nunique()
    .reset_index(name='uniq_hs4')
)

avg_counts = (
    monthly_counts
    .groupby('destination')['uniq_hs4']
    .mean()
    .reset_index(name='avg_monthly_uniq_hs4')
)

eligible_partners = avg_counts.loc[avg_counts['avg_monthly_uniq_hs4'] >= 100, 'destination'].tolist()
print("Eligible partners (avg ≥ 100 HS4s):", eligible_partners)

df_filtered = df_all[df_all['destination'].isin(eligible_partners)].copy()
print(f"Rows after partner filter: {len(df_filtered):,}  (from {len(df_all):,})")

Eligible partners (avg ≥ 100 HS4s): ['ARE', 'AUS', 'BRA', 'CAN', 'CHE', 'CHL', 'DEU', 'FRA', 'GBR', 'HKG', 'IDN', 'IND', 'IRL', 'ITA', 'JPN', 'KOR', 'MEX', 'MYS', 'NLD', 'NZL', 'PER', 'PHL', 'RUS', 'SGP', 'THA', 'TWN', 'USA', 'VNM', 'ZAF']
Rows after partner filter: 283,772  (from 287,162)


In [None]:
#@title Sanity checks
assert set(df_filtered.columns) == {'origin','destination','hs6','hs4','trade_flow','month','value'}

# Only China imports in 2025 Jan..Aug
assert (df_filtered['origin'] == 'CHN').all()
assert (df_filtered['trade_flow'] == 'Import').all()
assert df_filtered['month'].min() >= pd.Timestamp('2025-01-01')
assert df_filtered['month'].max() <= pd.Timestamp('2025-08-31')

# HS code sanity
assert df_filtered['hs6'].str.len().eq(6).all()
assert df_filtered['hs4'].str.len().eq(4).all()

print("All checks passed.")

All checks passed.


In [None]:
#@title Write parquet to ai4trade/data/interim
OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
df_filtered.to_parquet(OUT_PATH, index=False)
print(f"Saved: {OUT_PATH}")

Saved: /content/drive/MyDrive/ai4trade/data/interim/CHN_imports_2025.parquet


In [None]:
#@title Calculate average monthly exports and unique HS4s for eligible partners
avg_exports = (
    df_filtered.groupby(['destination', 'month'])['value'].sum().reset_index(name='monthly_export_value')
    .groupby('destination')['monthly_export_value'].mean().reset_index(name='avg_monthly_exports_to_chn')
)

avg_unique_hs4 = (
    df_filtered.groupby(['destination', 'month'])['hs4'].nunique().reset_index(name='monthly_unique_hs4')
    .groupby('destination')['monthly_unique_hs4'].mean().reset_index(name='avg_monthly_unique_hs4')
)

# Merge the two dataframes
partner_summary = pd.merge(avg_exports, avg_unique_hs4, on='destination')

# Sort by average monthly exports in descending order
partner_summary_sorted = partner_summary.sort_values(by='avg_monthly_exports_to_chn', ascending=False)

print("Average monthly exports to China and average number of unique HS4s for eligible partners:")
display(partner_summary_sorted)

Average monthly exports to China and average number of unique HS4s for eligible partners:


Unnamed: 0,destination,avg_monthly_exports_to_chn,avg_monthly_unique_hs4
25,TWN,18392100000.0,727.5
15,KOR,14587270000.0,797.75
14,JPN,12939970000.0,905.25
26,USA,12037450000.0,906.375
1,AUS,10236570000.0,431.875
22,RUS,9564752000.0,341.5
2,BRA,9002591000.0,329.375
17,MYS,7744482000.0,526.375
6,DEU,7560687000.0,848.875
27,VNM,7484745000.0,586.625


In [None]:
# Calculate total export value for each HS4 per destination
hs4_exports = (
    df_filtered.groupby(['destination', 'hs4'])['value'].sum().reset_index(name='total_export_value_to_chn_by_hs4')
)

# Rank HS4 codes within each destination by total export value
hs4_exports['rank'] = hs4_exports.groupby('destination')['total_export_value_to_chn_by_hs4'].rank(method='first', ascending=False)

# Filter for the top 3 HS4 codes for each destination
top_hs4_exports = hs4_exports[hs4_exports['rank'] <= 3]

print("Top 3 HS4 items of export to China for each eligible partner country:")
display(top_hs4_exports.sort_values(by=['destination', 'rank']))

Top 3 HS4 items of export to China for each eligible partner country:


Unnamed: 0,destination,hs4,total_export_value_to_chn_by_hs4,rank
44,ARE,2709,12252413000,1.0
46,ARE,2711,3298616000,2.0
96,ARE,3901,1432401000,3.0
495,AUS,2601,46690047000,1.0
512,AUS,2711,7600697000,2.0
...,...,...,...,...
18885,VNM,8471,7209714000,2.0
18887,VNM,8473,5604994000,3.0
19233,ZAF,7108,5940300000,1.0
19078,ZAF,2610,2908908000,2.0
