In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Input
from sklearn.metrics import r2_score

# ----------------- CONFIG -----------------
PARQUET_FILE = "Exports-by-branches-of-processing-and-countries-2015-2025.parquet"
CPI_FILE     = "Inflation-Consumer price index.csv"
FX_FILE      = "Exchange-rates_2015-2025.csv"

LOOKBACK       = 12
EPOCHS         = 42
BATCH_SIZE     = 12
VAL_FRACTION   = 0.2
MIN_SERIES_LEN = 40
MIN_NONZERO    = 12

# Only run on these countries
ALLOWED_COUNTRIES = [
    "Sweden",
    "France",
    "United States",    # for "USA"
    "Portugal",
    "Spain",
    "Estonia",
    "Finland",
    "Belarus",
    "United Kingdom",   # for "uk"
    "Austria",
    "Switzerland",
    "Luxembourg",       # spelling in the dataset
    "Slovakia",
]

np.random.seed(42)

# ============================================================
#                       HELPERS
# ============================================================
def make_sequences(X, y, lookback=12):
    Xs, ys = [], []
    for i in range(len(X) - lookback):
        Xs.append(X[i:i+lookback])
        ys.append(y[i+lookback])
    return np.array(Xs), np.array(ys)

def calculate_wape(y_true, y_pred):
    y_true = np.array(y_true, dtype=float)
    y_pred = np.array(y_pred, dtype=float)
    denom = np.sum(np.abs(y_true))
    if denom == 0:
        return np.nan
    return np.sum(np.abs(y_true - y_pred)) / denom * 100.0

def build_lstm_model(input_shape):
    model = Sequential()
    model.add(Input(shape=input_shape))
    model.add(LSTM(64))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer='adam', loss='mse')
    return model

def train_lstm_on_dataframe(df_series, feature_cols, target_col="value"):
    """
    df_series: DataFrame with at least [target_col] and feature_cols, sorted by date.
    feature_cols: list of column names to use as features.
    Returns dict with r2, wape, n_points, model, etc. or None if not enough data.
    """
    if "date" in df_series.columns:
        df_series = df_series.sort_values("date")

    y = df_series[target_col].values.astype(float)
    if len(y) < MIN_SERIES_LEN:
        return None

    nonzero_count = np.count_nonzero(y)
    if nonzero_count < MIN_NONZERO:
        return None

    if np.allclose(y, y[0]):
        return None

    X = df_series[feature_cols].values.astype(float)

    # Normalization
    y_mean = y.mean()
    y_std  = y.std()
    if y_std == 0:
        return None

    f_mean = X.mean(axis=0)
    f_std  = X.std(axis=0)
    f_std[f_std == 0] = 1.0

    X_n = (X - f_mean) / f_std
    y_n = (y.reshape(-1, 1) - y_mean) / y_std

    # Sequences
    X_seq, y_seq = make_sequences(X_n, y_n, lookback=LOOKBACK)
    if len(X_seq) < 10:
        return None

    n_total = len(X_seq)
    n_train = int(np.floor(n_total * (1.0 - VAL_FRACTION)))
    if n_train < 1 or n_train >= n_total:
        return None

    X_train, X_val = X_seq[:n_train], X_seq[n_train:]
    y_train, y_val = y_seq[:n_train], y_seq[n_train:]

    model = build_lstm_model(input_shape=(LOOKBACK, X_seq.shape[2]))
    model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=0
    )

    y_pred_n = model.predict(X_val, verbose=0)
    y_true = y_val * y_std + y_mean
    y_pred = y_pred_n * y_std + y_mean

    r2   = r2_score(y_true, y_pred)
    wape = calculate_wape(y_true, y_pred)

    return {
        "model":    model,
        "r2":       float(r2),
        "wape":     float(wape),
        "y_mean":   float(y_mean),
        "y_std":    float(y_std),
        "n_points": int(len(y)),
    }

# ============================================================
#                       DATA LOADING
# ============================================================
def load_all_series(country_filter="United States"):
    df = pd.read_parquet(PARQUET_FILE)
    print("Parquet columns:", df.columns)

    # --- normalize country names (remove non-breaking spaces etc.) ---
    df["Country"] = (
        df["Country"]
        .astype(str)
        .str.replace("\xa0", " ", regex=False)  # remove NBSP
        .str.strip()                            # trim spaces
    )

    # Keep only Fob value rows
    df = df[df["Unit"] == "Fob value"].copy()
    df["DATA"] = pd.to_numeric(df["DATA"], errors="coerce")
    df = df.dropna(subset=["DATA"])

    # Month "2015M01" -> datetime
    df["date"] = pd.to_datetime(
        df["Month"].astype(str).str.replace("M", "-") + "-01",
        format="%Y-%m-%d",
        errors="coerce",
    )
    df = df.dropna(subset=["date"])

    # --- apply your country filter (now matches clean names) ---
    if country_filter is not None:
        if isinstance(country_filter, (list, tuple, set)):
            df = df[df["Country"].isin(country_filter)]
        else:
            df = df[df["Country"] == country_filter]

    # Aggregate per (Country, Branches, date)
    df = (
        df.groupby(["Country", "Branches", "date"], as_index=False)["DATA"]
          .sum()
          .rename(columns={"DATA": "value"})
    )

    print(f"\nTotal rows after filtering: {len(df)}")
    print("Countries present:", sorted(df["Country"].unique()))
    print("Example rows:")
    print(df.head())

    return df

def load_cpi():
    """
    Load CPI monthly index from CSV.
    Assumes columns: 'Month' and 'Consumer price index Index'.
    """
    cpi = pd.read_csv(CPI_FILE)
    print("Inflation columns:", cpi.columns)

    month_str = cpi["Month"].astype(str)

    cpi["date"] = np.where(
        month_str.str.contains("M"),
        month_str.str.replace("M", "-") + "-01",
        month_str + "-01"
    )
    cpi["date"] = pd.to_datetime(cpi["date"], errors="coerce")

    cpi["cpi_index"] = pd.to_numeric(
        cpi["Consumer price index Index"], errors="coerce"
    )

    cpi = cpi.dropna(subset=["date", "cpi_index"])
    cpi = cpi[["date", "cpi_index"]].sort_values("date").reset_index(drop=True)
    return cpi

def load_fx():
    """
    Load FX file, clean Icelandic decimal-comma format, aggregate to monthly
    average USD mid rate.
    """
    # Using sep=';' and engine='python' to avoid tokenizing errors
    fx = pd.read_csv(FX_FILE, sep=';', engine='python')
    print("FX raw columns:", fx.columns)

    date_col = fx.columns[0]  # usually 'Dagsetning'
    fx[date_col] = pd.to_datetime(fx[date_col], errors="coerce", dayfirst=True)

    usd_cols = [c for c in fx.columns if "Bandaríkjadalur" in c and "miðgengi" in c]
    if not usd_cols:
        raise ValueError("Could not find USD mid column in FX file.")
    usd_col = usd_cols[0]

    fx["usd_fx"] = (
        fx[usd_col]
        .astype(str)
        .str.replace(".", "", regex=False)   # remove thousands sep
        .str.replace(",", ".", regex=False)  # decimal comma -> point
    )
    fx["usd_fx"] = pd.to_numeric(fx["usd_fx"], errors="coerce")

    fx = fx.dropna(subset=[date_col, "usd_fx"])

    fx["year_month"] = fx[date_col].dt.to_period("M")
    fx_monthly = (
        fx.groupby("year_month", as_index=False)["usd_fx"].mean()
    )
    fx_monthly["date"] = fx_monthly["year_month"].dt.to_timestamp()
    fx_monthly = fx_monthly[["date", "usd_fx"]].sort_values("date").reset_index(drop=True)

    return fx_monthly

# ============================================================
#        EXPORT-ONLY LSTM (ACROSS SELECTED COUNTRIES)
# ============================================================
def train_lstm_across_categories(allowed_countries=None):
    """
    Trains one LSTM (exports only) per (Country, Branches) time series,
    restricted to allowed_countries.
    Returns results_df with metrics.
    """
    df = load_all_series(country_filter=None, allowed_countries=allowed_countries)

    results = []
    grouped = df.sort_values("date").groupby(["Country", "Branches"])
    total_groups = len(grouped)
    current_country = None

    print(f"\nNumber of (Country, Branch) series: {total_groups}")

    for idx, ((country, branch), sub) in enumerate(grouped, start=1):
        if current_country != country:
            current_country = country
            print(f"\nProcessing country: {country}")

        info = train_lstm_on_dataframe(sub, feature_cols=["value"], target_col="value")
        if info is None:
            continue

        results.append({
            "country":   country,
            "branch":    branch,
            "n_points":  info["n_points"],
            "r2":        info["r2"],
            "wape":      info["wape"],
        })

        if idx % 100 == 0:
            print(f"Processed {idx}/{total_groups} series...")

    if not results:
        print("\nNo valid series found (not enough data / mostly zero).")
        return pd.DataFrame()

    results_df = pd.DataFrame(results)
    results_df = results_df[results_df["r2"] < 0.9999]

    print("\n===== TOP 20 CATEGORIES BY R² (LSTM, Fob value only, SELECTED COUNTRIES) =====")
    print(
        results_df.sort_values("r2", ascending=False)
                  .head(20)
                  .to_string(index=False)
    )

    best_row = results_df.sort_values("r2", ascending=False).iloc[0]
    print("\n===== BEST CATEGORY (HIGHEST R², EXPORT-ONLY) =====")
    print(f"Country: {best_row['country']}")
    print(f"Branch:  {best_row['branch']}")
    print(f"R²:      {best_row['r2']:.4f}")
    print(f"WAPE:    {best_row['wape']:.2f}%")
    print(f"Points:  {best_row['n_points']}")

    return results_df

# ============================================================
#      MACRO COMBOS (EXPORTS + CPI + FX) ON TOP N
# ============================================================
def evaluate_top20_with_macros(results_df, exports_df, cpi_df, fx_df, top_n=20):
    """
    For the top N (country, branch) from results_df, build
    a merged DataFrame with exports + CPI + FX and train
    4 feature combinations:
      - exports_only
      - exports_plus_fx
      - exports_plus_cpi
      - exports_plus_fx_cpi
    """
    feature_combos = {
        "exports_only":        ["value"],
        "exports_plus_fx":     ["value", "usd_fx"],
        "exports_plus_cpi":    ["value", "cpi_index"],
        "exports_plus_fx_cpi": ["value", "usd_fx", "cpi_index"],
    }

    topN = results_df.sort_values("r2", ascending=False).head(top_n)
    macro_results = []

    cpi_df = cpi_df.copy()
    fx_df  = fx_df.copy()
    cpi_df["date"] = pd.to_datetime(cpi_df["date"])
    fx_df["date"]  = pd.to_datetime(fx_df["date"])

    for _, row in topN.iterrows():
        country = row["country"]
        branch  = row["branch"]

        print(f"\n==== Top category: {country} - {branch} ====")

        series = exports_df[
            (exports_df["Country"] == country) &
            (exports_df["Branches"] == branch)
        ].copy()
        series = series.sort_values("date")

        if series.empty:
            print("  No export data found for this (country, branch); skipping.")
            continue

        merged = series.merge(cpi_df, on="date", how="inner")
        merged = merged.merge(fx_df, on="date", how="inner")

        if len(merged) < MIN_SERIES_LEN:
            print(f"  Skipping (not enough data after joining macros): {len(merged)} points")
            continue

        for combo_name, cols in feature_combos.items():
            info = train_lstm_on_dataframe(merged, feature_cols=cols, target_col="value")
            if info is None:
                print(f"  Combo {combo_name}: skipped (not enough usable data)")
                continue

            print(f"  Combo {combo_name}: R²={info['r2']:.3f}, WAPE={info['wape']:.2f}%, n={info['n_points']}")
            macro_results.append({
                "country":       country,
                "branch":        branch,
                "feature_combo": combo_name,
                "n_points":      info["n_points"],
                "r2":            info["r2"],
                "wape":          info["wape"],
            })

    if not macro_results:
        print("\nNo macro results produced (likely not enough overlapping CPI/FX data).")
        return pd.DataFrame(), pd.DataFrame()

    macro_df = pd.DataFrame(macro_results)
    macro_df = macro_df[macro_df["r2"] < 0.9999]

    print("\n===== TOP 20 CATEGORIES BY R² (ANY MACRO COMBO, SELECTED COUNTRIES) =====")
    print(
        macro_df.sort_values("r2", ascending=False)
                .head(20)
                .to_string(index=False)
    )

    print("\n===== BEST COMBO PER (COUNTRY, BRANCH) =====")
    best_per = (macro_df.sort_values("r2", ascending=False)
                        .groupby(["country", "branch"], as_index=False)
                        .first())
    print(best_per.to_string(index=False))

    return macro_df, best_per

# ============================================================
#                       MAIN
# ============================================================
if __name__ == "__main__":
    # 1) Train export-only LSTM across selected countries/categories
    results_df = train_lstm_across_categories(allowed_countries=ALLOWED_COUNTRIES)
    if results_df.empty:
        raise SystemExit("No valid export-only series; stopping.")

    # 2) Load full exports DF again (same country restriction) for merging with macros
    exports_df = load_all_series(country_filter=None, allowed_countries=ALLOWED_COUNTRIES)

    # 3) Load CPI and FX
    cpi_df = load_cpi()
    fx_df  = load_fx()

    print("\nMacro feature frame head (CPI + FX merged):")
    macro_preview = (
        cpi_df.set_index("date")[["cpi_index"]]
              .join(fx_df.set_index("date")[["usd_fx"]], how="inner")
              .head()
    )
    print(macro_preview)

    # 4) Evaluate macro combos on top 20 from these selected countries
    macro_df, best_per = evaluate_top20_with_macros(
        results_df, exports_df, cpi_df, fx_df, top_n=20
    )


Parquet columns: Index(['Branches', 'Country', 'Month', 'Unit', 'DATA'], dtype='object')

Total rows after filtering: 79248
Countries present: ['Austria', 'Belarus', 'Estonia', 'Finland', 'France', 'Luxembourg', 'Portugal', 'Slovakia', 'Spain', 'Sweden', 'Switzerland', 'United Kingdom', 'United States']
Example rows:
   Country                                 Branches       date  value
0  Austria  00 Whole fish, fresh, chilled or on ice 2015-01-01    0.0
1  Austria  00 Whole fish, fresh, chilled or on ice 2015-02-01    0.0
2  Austria  00 Whole fish, fresh, chilled or on ice 2015-03-01    0.0
3  Austria  00 Whole fish, fresh, chilled or on ice 2015-04-01    0.0
4  Austria  00 Whole fish, fresh, chilled or on ice 2015-05-01    0.0

Number of (Country, Branch) series: 624

Processing country: Austria

Processing country: Belarus

Processing country: Estonia

Processing country: Finland

Processing country: France
Processed 200/624 series...

Processing country: Luxembourg

Processing coun

KeyboardInterrupt: 