In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import r2_score

# ----------------- CONFIG -----------------
PARQUET_FILE = "Exports-by-branches-of-processing-and-countries-2015-2025.parquet"

LOOKBACK       = 12    # months
EPOCHS         = 42
BATCH_SIZE     = 12
VAL_FRACTION   = 0.2
MIN_SERIES_LEN = 40    # minimum months of data per series

np.random.seed(42)

# ----------------- HELPERS -----------------
def make_sequences(X, y, lookback=12):
    """
    Turn a 2D feature array and 1D/2D target into
    (num_samples-lookback, lookback, num_features).
    """
    Xs, ys = [], []
    for i in range(len(X) - lookback):
        Xs.append(X[i:i+lookback])
        ys.append(y[i+lookback])
    return np.array(Xs), np.array(ys)

def calculate_wape(y_true, y_pred):
    y_true = np.array(y_true, dtype=float)
    y_pred = np.array(y_pred, dtype=float)
    denom = np.sum(np.abs(y_true))
    if denom == 0:
        return np.nan
    return np.sum(np.abs(y_true - y_pred)) / denom * 100.0

def build_lstm_model(input_shape):
    model = Sequential()
    model.add(LSTM(64, input_shape=input_shape))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

# ----------------- DATA LOADING -----------------
def load_all_series(country_filter="United States"):
    """
    Load the parquet and build a long DataFrame:
    Country, Branches, date, value (Fob value).
    If country_filter is None, uses ALL countries.
    """
    df = pd.read_parquet(PARQUET_FILE)
    print("Parquet columns:", df.columns)

    # Keep Fob value rows
    df = df[df["Unit"] == "Fob value"].copy()
    df["DATA"] = pd.to_numeric(df["DATA"], errors="coerce")
    df = df.dropna(subset=["DATA"])

    # Month "2015M01" -> proper datetime
    df["date"] = pd.to_datetime(
        df["Month"].str.replace("M", "-") + "-01",
        format="%Y-%m-%d",
        errors="coerce",
    )
    df = df.dropna(subset=["date"])

    if country_filter is not None:
        df = df[df["Country"] == country_filter]

    # Aggregate in case of duplicates per (Country, Branches, date)
    df = (
        df.groupby(["Country", "Branches", "date"], as_index=False)["DATA"]
          .sum()
          .rename(columns={"DATA": "value"})
    )

    return df

# ----------------- TRAIN ONE SERIES -----------------
def train_lstm_for_series(sales_series):
    """
    sales_series: 1D array-like of Fob value over time (monthly, sorted).
    Returns dict with r2, wape, model, etc., or None if not enough data.
    """
    sales = np.asarray(sales_series, dtype=float)

    # basic length / variance checks
    if len(sales) < MIN_SERIES_LEN:
        return None
    if np.allclose(sales, sales[0]):
        # constant series => cannot train meaningful model
        return None

    # Features = just the sales series (like your example)
    features = sales.reshape(-1, 1)
    targets  = sales.reshape(-1, 1)

    # Normalization (same logic as your script)
    y_mean = sales.mean()
    y_std  = sales.std()
    if y_std == 0:
        return None

    f_mean = features.mean(axis=0)
    f_std  = features.std(axis=0)
    f_std[f_std == 0] = 1.0

    features_n = (features - f_mean) / f_std
    targets_n  = (targets  - y_mean) / y_std

    # Make sequences
    X, y = make_sequences(features_n, targets_n, lookback=LOOKBACK)
    if len(X) < 10:
        return None

    # Time-based split (no shuffling)
    n_total = len(X)
    n_train = int(np.floor(n_total * (1.0 - VAL_FRACTION)))
    if n_train < 1 or n_train >= n_total:
        return None

    X_train, X_val = X[:n_train], X[n_train:]
    y_train, y_val = y[:n_train], y[n_train:]

    # Build & train model
    model = build_lstm_model(input_shape=(LOOKBACK, X.shape[2]))
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=0  # set to 1 if you want per-epoch logs for each category
    )

    # Denormalize predictions
    y_pred_n = model.predict(X_val, verbose=0)
    y_true = y_val * y_std + y_mean
    y_pred = y_pred_n * y_std + y_mean

    # Metrics
    r2   = r2_score(y_true, y_pred)
    wape = calculate_wape(y_true, y_pred)

    return {
        "model":   model,
        "r2":      float(r2),
        "wape":    float(wape),
        "y_mean":  float(y_mean),
        "y_std":   float(y_std),
        "n_points": len(sales),
    }

# ----------------- RUN ACROSS ALL CATEGORIES -----------------
def train_lstm_across_categories(country_filter="United States"):
    """
    Trains one LSTM per (Country, Branches) time series.
    Returns:
      - results_df: summary DataFrame with metrics
      - best_info: dict containing the best model and meta-data
    """
    df = load_all_series(country_filter=country_filter)
    print(f"\nTotal rows after filtering: {len(df)}")
    print("Example rows:")
    print(df.head())

    results = []
    best_info = None

    grouped = df.sort_values("date").groupby(["Country", "Branches"])


    act_country = None
    for (country, branch), sub in grouped:
        #Track which country where training on.
        if(country!=act_country):
            print("STARTING TO TRAIN COUNTRY:", country)
            act_country = country
        print(country, branch)

        sales_series = sub["value"].values

        info = train_lstm_for_series(sales_series)
        if info is None:
            continue

        results.append({
            "country":   country,
            "branch":    branch,
            "n_points":  info["n_points"],
            "r2":        info["r2"],
            "wape":      info["wape"],
        })

        # Track best model by R²
        if (best_info is None) or (info["r2"] > best_info["r2"]):
            best_info = {
                "country": country,
                "branch":  branch,
                **info,
            }

    if not results:
        print("\nNo valid series found (not enough data / constant series).")
        return pd.DataFrame(), None

    results_df = pd.DataFrame(results)

    # Summary: top 20 by R²
    print("\n===== TOP 20 CATEGORIES BY R² (LSTM, Fob value only) =====")
    print(
        results_df.sort_values("r2", ascending=False)
                  .head(20)
                  .to_string(index=False)
    )

    # Show best model info
    print("\n===== BEST CATEGORY (HIGHEST R²) =====")
    print(f"Country: {best_info['country']}")
    print(f"Branch:  {best_info['branch']}")
    print(f"R²:      {best_info['r2']:.4f}")
    print(f"WAPE:    {best_info['wape']:.2f}%")
    print(f"Points:  {best_info['n_points']}")

    return results_df, best_info

# ----------------- RUN IT -----------------
# For USA only:
results_df, best_info = train_lstm_across_categories(country_filter=None)

# If you ever want ALL countries, call:
# results_df, best_info = train_lstm_across_categories(country_filter=None)


Parquet columns: Index(['Branches', 'Country', 'Month', 'Unit', 'DATA'], dtype='object')

Total rows after filtering: 1542288
Example rows:
        Country                                 Branches       date  value
0  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-01-01    0.0
1  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-02-01    0.0
2  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-03-01    0.0
3  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-04-01    0.0
4  Afghanistan   00 Whole fish, fresh, chilled or on ice 2015-05-01    0.0
STARTING TO TRAIN COUNTRY: Afghanistan 
Afghanistan  00 Whole fish, fresh, chilled or on ice
Afghanistan  01 Fish fillets, fresh, chilled or on ice
Afghanistan  02 Other marine products, fresh or chilled
Afghanistan  10 Sea frozen fish, whole
Afghanistan  11 Sea frozen fish fillets, in blocks
Afghanistan  12 Sea frozen fish fillets, n.e.s.
Afghanistan  13 Other sea frozen marine products
Afghanistan  14 Who

  super().__init__(**kwargs)


Albania  50 Preserved food product
Albania  51 Food products in other containers


  super().__init__(**kwargs)


Albania  52 Non-alcoholic beverages
Albania  53 Alcoholic beverages


  super().__init__(**kwargs)


Albania  60 Leather and tanned or dressed skins
Albania  62 Wool


  super().__init__(**kwargs)


Albania  69 Other textiles


  super().__init__(**kwargs)


Albania  71 Products of power intensive plants
Albania  72 Prod. of other chemical industries


  super().__init__(**kwargs)


Albania  82 Fishing equipment


  super().__init__(**kwargs)


Albania  83 Machinery and equipment


  super().__init__(**kwargs)


Albania  89 Other manufacturing, n.e.s.


  super().__init__(**kwargs)


Albania  90 Products of mining
Albania  91 Recovered articles for recycling
Albania  92 Transport equipment
Albania  99 Other products


  super().__init__(**kwargs)


STARTING TO TRAIN COUNTRY: Algeria 
Algeria  00 Whole fish, fresh, chilled or on ice
Algeria  01 Fish fillets, fresh, chilled or on ice
Algeria  02 Other marine products, fresh or chilled
Algeria  10 Sea frozen fish, whole
Algeria  11 Sea frozen fish fillets, in blocks
Algeria  12 Sea frozen fish fillets, n.e.s.
Algeria  13 Other sea frozen marine products
Algeria  14 Whole frozen fish, n.e.s.
Algeria  15 Frozen fish fillets, in blocks
Algeria  16 Frozen fish fillets n.e.s.
Algeria  17 Minced or strained fish, frozen
Algeria  18 Frozen roes
Algeria  19 Other frozen marine products
Algeria  20 Dried-salted fish
Algeria  21 Uncured salted fish
Algeria  22 Salted fish fillets, bits etc.
Algeria  23 Salted roes
Algeria  27 Stock fish
Algeria  28 Dried fish heads
Algeria  29 Other dried, salted fish
Algeria  31 Fish meal
Algeria  32 Fish oil


  super().__init__(**kwargs)


Algeria  39 Fish processing, n.e.s.
Algeria  40 Live animals
Algeria  41 Farmed fish
Algeria  42 Freshwater fish n.e.s.
Algeria  43 Meat, fresh, chilled or frozen
Algeria  44 Meat, salted, dried or smoked
Algeria  45 Raw hides and skins
Algeria  46 Tanned or crust hides and skins
Algeria  48 Wool, not carded or combed
Algeria  49 Agricultural products n.e.s.
Algeria  50 Preserved food product
Algeria  51 Food products in other containers
Algeria  52 Non-alcoholic beverages
Algeria  53 Alcoholic beverages
Algeria  60 Leather and tanned or dressed skins
Algeria  62 Wool


  super().__init__(**kwargs)


Algeria  69 Other textiles


  super().__init__(**kwargs)


Algeria  71 Products of power intensive plants
Algeria  72 Prod. of other chemical industries


  super().__init__(**kwargs)


Algeria  82 Fishing equipment
Algeria  83 Machinery and equipment
Algeria  89 Other manufacturing, n.e.s.


  super().__init__(**kwargs)


Algeria  90 Products of mining
Algeria  91 Recovered articles for recycling
Algeria  92 Transport equipment
Algeria  99 Other products


  super().__init__(**kwargs)


STARTING TO TRAIN COUNTRY: American Samoa
American Samoa 00 Whole fish, fresh, chilled or on ice
American Samoa 01 Fish fillets, fresh, chilled or on ice
American Samoa 02 Other marine products, fresh or chilled
American Samoa 10 Sea frozen fish, whole
American Samoa 11 Sea frozen fish fillets, in blocks
American Samoa 12 Sea frozen fish fillets, n.e.s.
American Samoa 13 Other sea frozen marine products
American Samoa 14 Whole frozen fish, n.e.s.
American Samoa 15 Frozen fish fillets, in blocks
American Samoa 16 Frozen fish fillets n.e.s.
American Samoa 17 Minced or strained fish, frozen
American Samoa 18 Frozen roes
American Samoa 19 Other frozen marine products
American Samoa 20 Dried-salted fish
American Samoa 21 Uncured salted fish
American Samoa 22 Salted fish fillets, bits etc.
American Samoa 23 Salted roes
American Samoa 27 Stock fish
American Samoa 28 Dried fish heads
American Samoa 29 Other dried, salted fish
American Samoa 31 Fish meal
American Samoa 32 Fish oil
American Samo

  super().__init__(**kwargs)


Andorra  01 Fish fillets, fresh, chilled or on ice


  super().__init__(**kwargs)


Andorra  02 Other marine products, fresh or chilled
Andorra  10 Sea frozen fish, whole
Andorra  11 Sea frozen fish fillets, in blocks
Andorra  12 Sea frozen fish fillets, n.e.s.
Andorra  13 Other sea frozen marine products
Andorra  14 Whole frozen fish, n.e.s.
Andorra  15 Frozen fish fillets, in blocks
Andorra  16 Frozen fish fillets n.e.s.


  super().__init__(**kwargs)


Andorra  17 Minced or strained fish, frozen
Andorra  18 Frozen roes
Andorra  19 Other frozen marine products
Andorra  20 Dried-salted fish
Andorra  21 Uncured salted fish
Andorra  22 Salted fish fillets, bits etc.
Andorra  23 Salted roes
Andorra  27 Stock fish
Andorra  28 Dried fish heads
Andorra  29 Other dried, salted fish
Andorra  31 Fish meal
Andorra  32 Fish oil
Andorra  39 Fish processing, n.e.s.
Andorra  40 Live animals


  super().__init__(**kwargs)


Andorra  41 Farmed fish
Andorra  42 Freshwater fish n.e.s.
Andorra  43 Meat, fresh, chilled or frozen
Andorra  44 Meat, salted, dried or smoked
Andorra  45 Raw hides and skins
Andorra  46 Tanned or crust hides and skins
Andorra  48 Wool, not carded or combed
Andorra  49 Agricultural products n.e.s.
Andorra  50 Preserved food product
Andorra  51 Food products in other containers
Andorra  52 Non-alcoholic beverages
Andorra  53 Alcoholic beverages
Andorra  60 Leather and tanned or dressed skins
Andorra  62 Wool
Andorra  69 Other textiles
Andorra  71 Products of power intensive plants
Andorra  72 Prod. of other chemical industries
Andorra  82 Fishing equipment
Andorra  83 Machinery and equipment
Andorra  89 Other manufacturing, n.e.s.


  super().__init__(**kwargs)


Andorra  90 Products of mining
Andorra  91 Recovered articles for recycling
Andorra  92 Transport equipment
Andorra  99 Other products


  super().__init__(**kwargs)


STARTING TO TRAIN COUNTRY: Angola 
Angola  00 Whole fish, fresh, chilled or on ice
Angola  01 Fish fillets, fresh, chilled or on ice
Angola  02 Other marine products, fresh or chilled
Angola  10 Sea frozen fish, whole
Angola  11 Sea frozen fish fillets, in blocks
Angola  12 Sea frozen fish fillets, n.e.s.
Angola  13 Other sea frozen marine products
Angola  14 Whole frozen fish, n.e.s.
Angola  15 Frozen fish fillets, in blocks
Angola  16 Frozen fish fillets n.e.s.
Angola  17 Minced or strained fish, frozen
Angola  18 Frozen roes
Angola  19 Other frozen marine products
Angola  20 Dried-salted fish
Angola  21 Uncured salted fish
Angola  22 Salted fish fillets, bits etc.
Angola  23 Salted roes
Angola  27 Stock fish
Angola  28 Dried fish heads
Angola  29 Other dried, salted fish
Angola  31 Fish meal
Angola  32 Fish oil
Angola  39 Fish processing, n.e.s.
Angola  40 Live animals
Angola  41 Farmed fish
Angola  42 Freshwater fish n.e.s.
Angola  43 Meat, fresh, chilled or frozen
Angola  44 Meat,

  super().__init__(**kwargs)


Angola  52 Non-alcoholic beverages
Angola  53 Alcoholic beverages
Angola  60 Leather and tanned or dressed skins
Angola  62 Wool
Angola  69 Other textiles
Angola  71 Products of power intensive plants
Angola  72 Prod. of other chemical industries
Angola  82 Fishing equipment


  super().__init__(**kwargs)


Angola  83 Machinery and equipment
Angola  89 Other manufacturing, n.e.s.
Angola  90 Products of mining
Angola  91 Recovered articles for recycling
Angola  92 Transport equipment
Angola  99 Other products


  super().__init__(**kwargs)


STARTING TO TRAIN COUNTRY: Anguilla
Anguilla 00 Whole fish, fresh, chilled or on ice
Anguilla 01 Fish fillets, fresh, chilled or on ice
Anguilla 02 Other marine products, fresh or chilled
Anguilla 10 Sea frozen fish, whole
Anguilla 11 Sea frozen fish fillets, in blocks
Anguilla 12 Sea frozen fish fillets, n.e.s.
Anguilla 13 Other sea frozen marine products
Anguilla 14 Whole frozen fish, n.e.s.
Anguilla 15 Frozen fish fillets, in blocks
Anguilla 16 Frozen fish fillets n.e.s.
Anguilla 17 Minced or strained fish, frozen
Anguilla 18 Frozen roes
Anguilla 19 Other frozen marine products
Anguilla 20 Dried-salted fish
Anguilla 21 Uncured salted fish
Anguilla 22 Salted fish fillets, bits etc.
Anguilla 23 Salted roes
Anguilla 27 Stock fish
Anguilla 28 Dried fish heads
Anguilla 29 Other dried, salted fish
Anguilla 31 Fish meal
Anguilla 32 Fish oil
Anguilla 39 Fish processing, n.e.s.
Anguilla 40 Live animals
Anguilla 41 Farmed fish
Anguilla 42 Freshwater fish n.e.s.
Anguilla 43 Meat, fresh, chille

  super().__init__(**kwargs)


Antigua and Barbuda 71 Products of power intensive plants


  super().__init__(**kwargs)


Antigua and Barbuda 72 Prod. of other chemical industries


  super().__init__(**kwargs)


Antigua and Barbuda 82 Fishing equipment


  super().__init__(**kwargs)


Antigua and Barbuda 83 Machinery and equipment
Antigua and Barbuda 89 Other manufacturing, n.e.s.


  super().__init__(**kwargs)


Antigua and Barbuda 90 Products of mining
Antigua and Barbuda 91 Recovered articles for recycling
Antigua and Barbuda 92 Transport equipment
Antigua and Barbuda 99 Other products
STARTING TO TRAIN COUNTRY: Argentina 
Argentina  00 Whole fish, fresh, chilled or on ice
Argentina  01 Fish fillets, fresh, chilled or on ice
Argentina  02 Other marine products, fresh or chilled
Argentina  10 Sea frozen fish, whole
Argentina  11 Sea frozen fish fillets, in blocks
Argentina  12 Sea frozen fish fillets, n.e.s.
Argentina  13 Other sea frozen marine products
Argentina  14 Whole frozen fish, n.e.s.


  super().__init__(**kwargs)


Argentina  15 Frozen fish fillets, in blocks
Argentina  16 Frozen fish fillets n.e.s.
Argentina  17 Minced or strained fish, frozen
Argentina  18 Frozen roes
Argentina  19 Other frozen marine products
Argentina  20 Dried-salted fish
Argentina  21 Uncured salted fish
Argentina  22 Salted fish fillets, bits etc.
Argentina  23 Salted roes
Argentina  27 Stock fish
Argentina  28 Dried fish heads
Argentina  29 Other dried, salted fish
Argentina  31 Fish meal
Argentina  32 Fish oil


  super().__init__(**kwargs)


Argentina  39 Fish processing, n.e.s.
Argentina  40 Live animals
Argentina  41 Farmed fish
Argentina  42 Freshwater fish n.e.s.
Argentina  43 Meat, fresh, chilled or frozen
Argentina  44 Meat, salted, dried or smoked
Argentina  45 Raw hides and skins
Argentina  46 Tanned or crust hides and skins
Argentina  48 Wool, not carded or combed
Argentina  49 Agricultural products n.e.s.
Argentina  50 Preserved food product
Argentina  51 Food products in other containers


  super().__init__(**kwargs)


Argentina  52 Non-alcoholic beverages


  super().__init__(**kwargs)


Argentina  53 Alcoholic beverages
Argentina  60 Leather and tanned or dressed skins
Argentina  62 Wool


  super().__init__(**kwargs)


KeyboardInterrupt: 