In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Input
from sklearn.metrics import r2_score

# ----------------- CONFIG -----------------
PARQUET_FILE   = "Exports-by-branches-of-processing-and-countries-2015-2025.parquet"

LOOKBACK       = 12    # months
EPOCHS         = 42
BATCH_SIZE     = 12
VAL_FRACTION   = 0.2
MIN_SERIES_LEN = 40    # minimum months of data per (country, branch)
MIN_NONZERO    = 12    # require at least 12 non-zero months (real exports)

np.random.seed(42)

# ----------------- HELPERS -----------------
def make_sequences(X, y, lookback=12):
    Xs, ys = [], []
    for i in range(len(X) - lookback):
        Xs.append(X[i:i+lookback])
        ys.append(y[i+lookback])
    return np.array(Xs), np.array(ys)

def calculate_wape(y_true, y_pred):
    y_true = np.array(y_true, dtype=float)
    y_pred = np.array(y_pred, dtype=float)
    denom = np.sum(np.abs(y_true))
    if denom == 0:
        return np.nan
    return np.sum(np.abs(y_true - y_pred)) / denom * 100.0

def build_lstm_model(input_shape):
    # input_shape = (lookback, num_features)
    model = Sequential()
    model.add(Input(shape=input_shape))
    model.add(LSTM(64))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

# ----------------- DATA LOADING -----------------
def load_all_series(country_filter="United States"):
    """
    Load the parquet and build a long DataFrame:
    Country, Branches, date, value (Fob value only).
    If country_filter is None, uses ALL countries.
    """
    df = pd.read_parquet(PARQUET_FILE)
    print("Parquet columns:", df.columns)

    # Keep only Fob value rows
    df = df[df["Unit"] == "Fob value"].copy()
    df["DATA"] = pd.to_numeric(df["DATA"], errors="coerce")
    df = df.dropna(subset=["DATA"])

    # Month "2015M01" -> datetime
    df["date"] = pd.to_datetime(
        df["Month"].str.replace("M", "-") + "-01",
        format="%Y-%m-%d",
        errors="coerce",
    )
    df = df.dropna(subset=["date"])

    if country_filter is not None:
        df = df[df["Country"] == country_filter]

    # Aggregate per (Country, Branches, date)
    df = (
        df.groupby(["Country", "Branches", "date"], as_index=False)["DATA"]
          .sum()
          .rename(columns={"DATA": "value"})
    )

    print(f"\nTotal rows after filtering: {len(df)}")
    print("Example rows:")
    print(df.head())

    return df

# ----------------- TRAIN ONE SERIES -----------------
def train_lstm_for_series(sales_series):
    """
    sales_series: 1D array-like of Fob value over time (monthly, sorted).
    Returns dict with r2, wape, model, etc., or None if not enough data.
    """
    sales = np.asarray(sales_series, dtype=float)

    # length check
    if len(sales) < MIN_SERIES_LEN:
        return None

    # non-zero export requirement
    nonzero_count = np.count_nonzero(sales)
    if nonzero_count < MIN_NONZERO:
        return None

    # constant or almost-constant series -> skip
    if np.allclose(sales, sales[0]):
        return None

    # Features = just the sales series
    features = sales.reshape(-1, 1)
    targets  = sales.reshape(-1, 1)

    # Normalization (like your example)
    y_mean = sales.mean()
    y_std  = sales.std()
    if y_std == 0:
        return None

    f_mean = features.mean(axis=0)
    f_std  = features.std(axis=0)
    f_std[f_std == 0] = 1.0

    features_n = (features - f_mean) / f_std
    targets_n  = (targets  - y_mean) / y_std

    # Sequences
    X, y = make_sequences(features_n, targets_n, lookback=LOOKBACK)
    if len(X) < 10:
        return None

    # Time-based split (no shuffle)
    n_total = len(X)
    n_train = int(np.floor(n_total * (1.0 - VAL_FRACTION)))
    if n_train < 1 or n_train >= n_total:
        return None

    X_train, X_val = X[:n_train], X[n_train:]
    y_train, y_val = y[:n_train], y[n_train:]

    # Build & train model
    model = build_lstm_model(input_shape=(LOOKBACK, X.shape[2]))
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=0  # quiet
    )

    # Predict & denormalize
    y_pred_n = model.predict(X_val, verbose=0)
    y_true = y_val * y_std + y_mean
    y_pred = y_pred_n * y_std + y_mean

    # Metrics
    r2   = r2_score(y_true, y_pred)
    wape = calculate_wape(y_true, y_pred)

    return {
        "model":   model,
        "r2":      float(r2),
        "wape":    float(wape),
        "y_mean":  float(y_mean),
        "y_std":   float(y_std),
        "n_points": len(sales),
    }

# ----------------- RUN ACROSS CATEGORIES -----------------
def train_lstm_across_categories(country_filter="United States"):
    """
    Trains one LSTM per (Country, Branches) time series.
    Returns:
      - results_df: summary DataFrame with metrics
      - best_info: dict containing the best model and meta-data
    """
    df = load_all_series(country_filter=country_filter)

    results = []
    best_info = None

    # Get unique countries if no filter
    if country_filter is None:
        countries = df["Country"].unique()
        print(f"\nProcessing {len(countries)} countries:")
        for country in countries:
            print(f"- {country}")

    grouped = df.sort_values("date").groupby(["Country", "Branches"])
    total_groups = len(grouped)
    current_country = None

    print(f"\nNumber of (Country, Branch) series: {total_groups}")

    for idx, ((country, branch), sub) in enumerate(grouped, start=1):
        # Track country changes
        if current_country != country:
            current_country = country
            print(f"\nProcessing country: {country}")

        sales_series = sub["value"].values

        info = train_lstm_for_series(sales_series)
        if info is None:
            continue

        results.append({
            "country":   country,
            "branch":    branch,
            "n_points":  info["n_points"],
            "r2":        info["r2"],
            "wape":      info["wape"],
        })

        # Track best model by R²
        if (best_info is None) or (info["r2"] > best_info["r2"]):
            best_info = {
                "country": country,
                "branch":  branch,
                **info,
            }

        # light progress indicator
        if idx % 100 == 0:
            print(f"Processed {idx}/{total_groups} series...")

    if not results:
        print("\nNo valid series found (not enough data / mostly zero).")
        return pd.DataFrame(), None

    results_df = pd.DataFrame(results)

    # Optionally drop suspicious perfect fits (R² == 1)
    results_df = results_df[results_df["r2"] < 0.9999]

    # Summary: top 20 by R²
    print("\n===== TOP 20 CATEGORIES BY R² (LSTM, Fob value only) =====")
    print(
        results_df.sort_values("r2", ascending=False)
                  .head(20)
                  .to_string(index=False)
    )

    # Best model info
    best_row = results_df.sort_values("r2", ascending=False).iloc[0]
    print("\n===== BEST CATEGORY (HIGHEST R²) =====")
    print(f"Country: {best_row['country']}")
    print(f"Branch:  {best_row['branch']}")
    print(f"R²:      {best_row['r2']:.4f}")
    print(f"WAPE:    {best_row['wape']:.2f}%")
    print(f"Points:  {best_row['n_points']}")

    return results_df, best_info

# ----------------- RUN IT -----------------
# Start with one country (cheaper):
results_df, best_info = train_lstm_across_categories(country_filter=None)

# For ALL countries (very heavy), you can later do:
# results_df, best_info = train_lstm_across_categories(country_filter=None)

2025-11-08 15:51:31.867113: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Parquet columns: Index(['Branches', 'Country', 'Month', 'Unit', 'DATA'], dtype='object')

Total rows after filtering: 1542288
Example rows:
        Country                                 Branches       date  value
0  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-01-01    0.0
1  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-02-01    0.0
2  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-03-01    0.0
3  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-04-01    0.0
4  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-05-01    0.0

Processing 253 countries:
- Afghanistan 
- Albania 
- Algeria 
- American Samoa
- Andorra 
- Angola 
- Anguilla
- Antarctica
- Antigua and Barbuda
- Argentina 
- Armenia 
- Aruba
- Australia
- Austria 
- Azerbaijan 
- Bahamas
- Bahrain 
- Bangladesh 
- Barbados
- Belarus 
- Belgium 
- Belize 
- Benin 
- Bermuda
- Bhutan 
- Bolivia 
- Bonaire, Sint Eustatius and Saba
- Bosnia and Herzegovina 
- Botswana 
- 

I0000 00:00:1762613499.899312  218248 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5562 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070, pci bus id: 0000:01:00.0, compute capability: 8.6
2025-11-08 15:51:42.438647: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91500



Processing country: Argentina 

Processing country: Armenia 

Processing country: Aruba

Processing country: Australia

Processing country: Austria 

Processing country: Azerbaijan 

Processing country: Bahamas

Processing country: Bahrain 

Processing country: Bangladesh 

Processing country: Barbados

Processing country: Belarus 

Processing country: Belgium 
Processed 1000/12144 series...

Processing country: Belize 

Processing country: Benin 

Processing country: Bermuda

Processing country: Bhutan 

Processing country: Bolivia 

Processing country: Bonaire, Sint Eustatius and Saba

Processing country: Bosnia and Herzegovina 

Processing country: Botswana 

Processing country: Bouvet Island

Processing country: Brazil 

Processing country: British Indian Ocean Territory

Processing country: Brunei Darussalam

Processing country: Bulgaria 

Processing country: Burkina Faso 

Processing country: Burundi 

Processing country: Cabo Verde

Processing country: Cambodia 

Processing cou