In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import r2_score

# ----------------- CONFIG -----------------
PARQUET_FILE = "Exports-by-branches-of-processing-and-countries-2015-2025.parquet"

LOOKBACK       = 12    # months
EPOCHS         = 42
BATCH_SIZE     = 1
VAL_FRACTION   = 0.2
MIN_SERIES_LEN = 40    # minimum months of data per series

np.random.seed(42)

# ----------------- HELPERS -----------------
def make_sequences(X, y, lookback=12):
    """
    Turn a 2D feature array and 1D/2D target into
    (num_samples-lookback, lookback, num_features).
    """
    Xs, ys = [], []
    for i in range(len(X) - lookback):
        Xs.append(X[i:i+lookback])
        ys.append(y[i+lookback])
    return np.array(Xs), np.array(ys)

def calculate_wape(y_true, y_pred):
    y_true = np.array(y_true, dtype=float)
    y_pred = np.array(y_pred, dtype=float)
    denom = np.sum(np.abs(y_true))
    if denom == 0:
        return np.nan
    return np.sum(np.abs(y_true - y_pred)) / denom * 100.0

def build_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(64, input_shape=input_shape))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

# ----------------- DATA LOADING -----------------
def load_all_series(country_filter="United States"):
    """
    Load the parquet and build a long DataFrame:
    Country, Branches, date, value (Fob value).
    If country_filter is None, uses ALL countries.
    """
    df = pd.read_parquet(PARQUET_FILE)
    print("Parquet columns:", df.columns)

    # Keep Fob value rows
    df = df[df["Unit"] == "Fob value"].copy()
    df["DATA"] = pd.to_numeric(df["DATA"], errors="coerce")
    df = df.dropna(subset=["DATA"])

    # Month "2015M01" -> proper datetime
    df["date"] = pd.to_datetime(
        df["Month"].str.replace("M", "-") + "-01",
        format="%Y-%m-%d",
        errors="coerce",
    )
    df = df.dropna(subset=["date"])

    if country_filter is not None:
        df = df[df["Country"] == country_filter]

    # Aggregate in case of duplicates per (Country, Branches, date)
    df = (
        df.groupby(["Country", "Branches", "date"], as_index=False)["DATA"]
          .sum()
          .rename(columns={"DATA": "value"})
    )

    return df

# ----------------- TRAIN ONE SERIES -----------------
def train_lstm_for_series(sales_series):
    """
    sales_series: 1D array-like of Fob value over time (monthly, sorted).
    Returns dict with r2, wape, model, etc., or None if not enough data.
    """
    sales = np.asarray(sales_series, dtype=float)

    # basic length / variance checks
    if len(sales) < MIN_SERIES_LEN:
        return None
    if np.allclose(sales, sales[0]):
        # constant series => cannot train meaningful model
        return None

    # Features = just the sales series (like your example)
    features = sales.reshape(-1, 1)
    targets  = sales.reshape(-1, 1)

    # Normalization (same logic as your script)
    y_mean = sales.mean()
    y_std  = sales.std()
    if y_std == 0:
        return None

    f_mean = features.mean(axis=0)
    f_std  = features.std(axis=0)
    f_std[f_std == 0] = 1.0

    features_n = (features - f_mean) / f_std
    targets_n  = (targets  - y_mean) / y_std

    # Make sequences
    X, y = make_sequences(features_n, targets_n, lookback=LOOKBACK)
    if len(X) < 10:
        return None

    # Time-based split (no shuffling)
    n_total = len(X)
    n_train = int(np.floor(n_total * (1.0 - VAL_FRACTION)))
    if n_train < 1 or n_train >= n_total:
        return None

    X_train, X_val = X[:n_train], X[n_train:]
    y_train, y_val = y[:n_train], y[n_train:]

    # Build & train model
    model = build_lstm_model(input_shape=(LOOKBACK, X.shape[2]))
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=0  # set to 1 if you want per-epoch logs for each category
    )

    # Denormalize predictions
    y_pred_n = model.predict(X_val, verbose=0)
    y_true = y_val * y_std + y_mean
    y_pred = y_pred_n * y_std + y_mean

    # Metrics
    r2   = r2_score(y_true, y_pred)
    wape = calculate_wape(y_true, y_pred)

    return {
        "model":   model,
        "r2":      float(r2),
        "wape":    float(wape),
        "y_mean":  float(y_mean),
        "y_std":   float(y_std),
        "n_points": len(sales),
    }

# ----------------- RUN ACROSS ALL CATEGORIES -----------------
def train_lstm_across_categories(country_filter="United States"):
    """
    Trains one LSTM per (Country, Branches) time series.
    Returns:
      - results_df: summary DataFrame with metrics
      - best_info: dict containing the best model and meta-data
    """
    df = load_all_series(country_filter=country_filter)
    print(f"\nTotal rows after filtering: {len(df)}")
    print("Example rows:")
    print(df.head())

    results = []
    best_info = None

    grouped = df.sort_values("date").groupby(["Country", "Branches"])

    for (country, branch), sub in grouped:
        sales_series = sub["value"].values

        info = train_lstm_for_series(sales_series)
        if info is None:
            continue

        results.append({
            "country":   country,
            "branch":    branch,
            "n_points":  info["n_points"],
            "r2":        info["r2"],
            "wape":      info["wape"],
        })

        # Track best model by R²
        if (best_info is None) or (info["r2"] > best_info["r2"]):
            best_info = {
                "country": country,
                "branch":  branch,
                **info,
            }

    if not results:
        print("\nNo valid series found (not enough data / constant series).")
        return pd.DataFrame(), None

    results_df = pd.DataFrame(results)

    # Summary: top 20 by R²
    print("\n===== TOP 20 CATEGORIES BY R² (LSTM, Fob value only) =====")
    print(
        results_df.sort_values("r2", ascending=False)
                  .head(20)
                  .to_string(index=False)
    )

    # Show best model info
    print("\n===== BEST CATEGORY (HIGHEST R²) =====")
    print(f"Country: {best_info['country']}")
    print(f"Branch:  {best_info['branch']}")
    print(f"R²:      {best_info['r2']:.4f}")
    print(f"WAPE:    {best_info['wape']:.2f}%")
    print(f"Points:  {best_info['n_points']}")

    return results_df, best_info

# ----------------- RUN IT -----------------
# For USA only:
results_df, best_info = train_lstm_across_categories(country_filter=None)

# If you ever want ALL countries, call:
# results_df, best_info = train_lstm_across_categories(country_filter=None)


Parquet columns: Index(['Branches', 'Country', 'Month', 'Unit', 'DATA'], dtype='object')

Total rows after filtering: 6096
Example rows:
         Country                                 Branches       date    value
0  United States  00 Whole fish, fresh, chilled or on ice 2015-01-01  33.4505
1  United States  00 Whole fish, fresh, chilled or on ice 2015-02-01  43.0561
2  United States  00 Whole fish, fresh, chilled or on ice 2015-03-01  44.1183
3  United States  00 Whole fish, fresh, chilled or on ice 2015-04-01  38.7160
4  United States  00 Whole fish, fresh, chilled or on ice 2015-05-01  39.5563

Total rows after filtering: 6096
Example rows:
         Country                                 Branches       date    value
0  United States  00 Whole fish, fresh, chilled or on ice 2015-01-01  33.4505
1  United States  00 Whole fish, fresh, chilled or on ice 2015-02-01  43.0561
2  United States  00 Whole fish, fresh, chilled or on ice 2015-03-01  44.1183
3  United States  00 Whole fish, fr

  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__init__(**kwargs)
  super().__in


===== TOP 20 CATEGORIES BY R² (LSTM, Fob value only) =====
      country                                  branch  n_points        r2       wape
United States                       69 Other textiles       127  0.459971  31.914287
United States                                 62 Wool       127  0.381982  43.519749
United States                         40 Live animals       127  0.282638  91.592215
United States        44 Meat, salted, dried or smoked       127  0.162553  88.100985
United States                          41 Farmed fish       127  0.107005  51.762494
United States         49 Agricultural products n.e.s.       127  0.061769  91.498939
United States 00 Whole fish, fresh, chilled or on ice       127  0.007946  20.104810
United States     13 Other sea frozen marine products       127  0.000000        NaN
United States           48 Wool, not carded or combed       127  0.000000        NaN
United States                          23 Salted roes       127  0.000000        NaN
Unite