In [None]:
# advanced_ts_attention_forecasting.py
# Single-file implementation:
# - loads AirPassengers (statsmodels) or generates synthetic fallback
# - preprocessing (scaling, windowing)
# - custom additive temporal attention layer
# - Attention-LSTM and baseline LSTM models (Keras)
# - rolling-origin evaluation and SARIMA baseline
# - plotting attention heatmap and saving a small text report

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K, Input, Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
import statsmodels.api as sm
from statsmodels.tsa.statespace.sarimax import SARIMAX
import math
import time

# 1) Load AirPassengers
try:
    dataset = sm.datasets.get_rdataset("AirPassengers").data
    if 'value' in dataset.columns:
        ts = dataset['value'].astype(float)
    else:
        ap = sm.datasets.airpassengers.load_pandas().data
        if 'AirPassengers' in ap.columns:
            ts = ap['AirPassengers'].astype(float)
        else:
            ts = ap.iloc[:,0].astype(float)
except Exception as e:
    print("Could not load AirPassengers; generating synthetic demo series. Error:", e)
    rng = pd.date_range("1949-01-01", periods=144, freq='M')
    ts = pd.Series(100 + 0.5*np.arange(144) + 10*np.sin(np.arange(144)/12*2*np.pi) + np.random.randn(144)*5, index=rng)

# ensure monthly datetime index if not present
if not isinstance(ts, pd.Series):
    ts = pd.Series(ts)
if ts.index.dtype == 'int64' or not hasattr(ts.index, 'freq'):
    ts.index = pd.date_range("1949-01-01", periods=len(ts), freq='M')

print(f"Loaded series length: {len(ts)}; start: {ts.index[0].date()}")

# 2) Preprocessing helper
def create_windows(series, input_len, output_len, stride=1):
    X, y = [], []
    L = len(series)
    for start in range(0, L - input_len - output_len + 1, stride):
        end = start + input_len
        X.append(series[start:end])
        y.append(series[end:end+output_len])
    X = np.array(X)
    y = np.array(y)
    return X[..., np.newaxis], y  # add feature dim

# 3) Custom additive temporal attention layer (Keras)
class TemporalSelfAttention(Layer):
    def __init__(self, return_attention=False, **kwargs):
        super().__init__(**kwargs)
        self.return_attention = return_attention

    def build(self, input_shape):
        self.timesteps = input_shape[1]
        self.features = input_shape[2]
        self.W = self.add_weight(shape=(self.features, self.features),
                                 initializer='glorot_uniform', name='W_att')
        self.b = self.add_weight(shape=(self.features,), initializer='zeros', name='b_att')
        self.v = self.add_weight(shape=(self.features,), initializer='glorot_uniform', name='v_att')
        super().build(input_shape)

    def call(self, inputs):
        # inputs: (batch, T, F)
        score = K.tanh(K.dot(inputs, self.W) + self.b)          # (batch, T, F)
        score = K.dot(score, K.expand_dims(self.v))            # (batch, T, 1)
        score = K.squeeze(score, axis=-1)                      # (batch, T)
        attention_weights = K.softmax(score, axis=-1)          # (batch, T)
        attention_weights_expanded = K.expand_dims(attention_weights, axis=-1)
        weighted_seq = inputs * attention_weights_expanded     # (batch, T, F)
        if self.return_attention:
            return weighted_seq, attention_weights
        return weighted_seq

    def get_config(self):
        config = super().get_config().copy()
        config.update({"return_attention": self.return_attention})
        return config

# 4) Model builders
def build_attention_lstm(input_len, output_len, n_features=1, lstm_units=64, dense_units=32, learning_rate=1e-3):
    inputs = Input(shape=(input_len, n_features))
    x = layers.Bidirectional(layers.LSTM(lstm_units, return_sequences=True))(inputs)
    att_layer = TemporalSelfAttention(return_attention=False)
    weighted_seq = att_layer(x)
    context = layers.GlobalAveragePooling1D()(weighted_seq)
    x = layers.Dense(dense_units, activation='relu')(context)
    outputs = layers.Dense(output_len)(x)
    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mse')
    model._att_layer = att_layer
    return model

def build_baseline_lstm(input_len, output_len, n_features=1, lstm_units=64, dense_units=32, learning_rate=1e-3):
    inputs = Input(shape=(input_len, n_features))
    x = layers.LSTM(lstm_units)(inputs)
    x = layers.Dense(dense_units, activation='relu')(x)
    outputs = layers.Dense(output_len)(x)
    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mse')
    return model

# 5) Rolling-origin evaluation routine (fits model on expanding window and forecasts horizon)
def rolling_origin_forecast(series, input_len, output_len, model_builder, scaler=None, initial_train_size=100, max_splits=5, fit_kwargs=None, **model_kwargs):
    n = len(series)
    metrics = []
    models = []
    idx = initial_train_size
    splits = 0
    while idx + output_len <= n and splits < max_splits:
        train = series[:idx]
        test = series[idx: idx+output_len]
        local_scaler = MinMaxScaler(feature_range=(0,1)) if scaler is None else scaler
        train_scaled = local_scaler.fit_transform(train.reshape(-1,1)).flatten()
        X_train, y_train = create_windows(train_scaled, input_len, output_len)
        model = model_builder(input_len, output_len, **model_kwargs)
        es = EarlyStopping(patience=10, restore_best_weights=True, monitor='loss', verbose=0)
        model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[es], verbose=0, **(fit_kwargs or {}))
        full_scaled = local_scaler.transform(series.reshape(-1,1)).flatten()
        last_input = full_scaled[idx-input_len:idx].reshape(1, input_len, 1)
        y_pred_scaled = model.predict(last_input).flatten()
        y_pred = local_scaler.inverse_transform(y_pred_scaled.reshape(-1,1)).flatten()
        rmse = math.sqrt(mean_squared_error(test, y_pred))
        mae = mean_absolute_error(test, y_pred)
        mape = np.mean(np.abs((test - y_pred) / test)) * 100
        metrics.append({'split_start': idx-input_len, 'train_end': idx-1, 'rmse': rmse, 'mae': mae, 'mape': mape})
        models.append((model, local_scaler))
        idx += output_len
        splits += 1
    return metrics, models

# 6) SARIMA rolling benchmark
def sarima_rolling(series, order=(1,1,1), seasonal_order=(1,1,1,12), initial_train=84, output_len=12, max_splits=3):
    n = len(series)
    idx = initial_train
    metrics = []
    models = []
    splits = 0
    while idx + output_len <= n and splits < max_splits:
        train = series[:idx]
        test = series[idx: idx+output_len]
        try:
            mod = SARIMAX(train, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
            res = mod.fit(disp=False)
            pred = res.get_forecast(steps=output_len).predicted_mean
            rmse = math.sqrt(mean_squared_error(test, pred))
            mae = mean_absolute_error(test, pred)
            mape = np.mean(np.abs((test - pred) / test)) * 100
        except Exception as e:
            print("SARIMA fit failed at idx", idx, "error:", e)
            rmse = mae = mape = np.nan
            res = None
        metrics.append({'split_start': idx-output_len, 'train_end': idx-1, 'rmse': rmse, 'mae': mae, 'mape': mape})
        models.append(res)
        idx += output_len
        splits += 1
    return metrics, models

# 7) Experiment config (change these for longer/stronger runs)
series_values = ts.values.astype(float)
INPUT_LEN = 24
OUTPUT_LEN = 12
INITIAL_TRAIN = 84
MAX_SPLITS = 3

# Demo hyperparams (small grid)
chosen_hparams = {
    'lstm_units': 32,
    'dense_units': 16,
    'learning_rate': 1e-3
}

print("Starting rolling-origin evaluations (demo config). This may take several minutes depending on your machine.")

start = time.time()
att_metrics, att_models = rolling_origin_forecast(series_values, INPUT_LEN, OUTPUT_LEN, build_attention_lstm,
                                                  initial_train_size=INITIAL_TRAIN, max_splits=MAX_SPLITS,
                                                  fit_kwargs={'verbose':0}, **chosen_hparams)
base_metrics, base_models = rolling_origin_forecast(series_values, INPUT_LEN, OUTPUT_LEN, build_baseline_lstm,
                                                    initial_train_size=INITIAL_TRAIN, max_splits=MAX_SPLITS,
                                                    fit_kwargs={'verbose':0}, **chosen_hparams)
sarima_metrics, sarima_models = sarima_rolling(series_values, order=(2,1,2), seasonal_order=(1,1,1,12),
                                              initial_train=INITIAL_TRAIN, output_len=OUTPUT_LEN, max_splits=MAX_SPLITS)

elapsed = time.time() - start
print(f"Done. Time elapsed: {elapsed:.1f}s")
print("Attention-LSTM metrics per split:")
for m in att_metrics:
    print(m)
print("Baseline LSTM metrics per split:")
for m in base_metrics:
    print(m)
print("SARIMA metrics per split:")
for m in sarima_metrics:
    print(m)

# 8) Aggregate metrics
def aggregate_metrics(metrics_list):
    df = pd.DataFrame(metrics_list)
    return {'rmse_mean': df['rmse'].mean(), 'mae_mean': df['mae'].mean(), 'mape_mean': df['mape'].mean()}

agg_att = aggregate_metrics(att_metrics)
agg_base = aggregate_metrics(base_metrics)
agg_sarima = aggregate_metrics(sarima_metrics)
results_table = pd.DataFrame([{'model':'Attention-LSTM', **agg_att},
                              {'model':'Baseline LSTM', **agg_base},
                              {'model':'SARIMA', **agg_sarima}]).set_index('model')
print("\nAggregated metrics (mean across splits):")
print(results_table.round(3))

# 9) Attention extraction & visualization (for last fitted Attention-LSTM)
if len(att_models) > 0:
    sample_model, sample_scaler = att_models[-1]
    # find encoder (Bidirectional LSTM) layer
    encoder_layer = None
    for layer in sample_model.layers:
        if isinstance(layer, layers.Bidirectional):
            encoder_layer = layer
            break
    if encoder_layer is not None and hasattr(sample_model, '_att_layer'):
        sub = Model(sample_model.input, encoder_layer.output)
        # prepare last input
        train_end_idx = att_metrics[-1]['train_end'] + 1
        scaled_full = sample_scaler.transform(series_values.reshape(-1,1)).flatten()
        last_input = scaled_full[train_end_idx-INPUT_LEN:train_end_idx].reshape(1, INPUT_LEN, 1)
        enc_out = sub.predict(last_input)  # (1, T, F)
        # recreate attention layer with return_attention=True and copy weights
        att_recreate = TemporalSelfAttention(return_attention=True)
        _ = att_recreate(enc_out)
        att_recreate.set_weights(sample_model._att_layer.get_weights())
        weighted_seq, att_w = att_recreate(enc_out)
        att_w = att_w.numpy()  # (1, T)
        # plot
        plt.figure(figsize=(9,2))
        plt.imshow(att_w, aspect='auto')
        plt.title('Attention weights (last split sample)')
        plt.xlabel('time step (older -> newer)')
        plt.colorbar(label='weight')
        plt.show()
    else:
        print("Could not find encoder/attention layer for visualization.")

# 10) Save a small textual report
report_lines = []
report_lines.append("Advanced Time Series Forecasting â€” Experiment Report (demo)\n")
report_lines.append("Dataset: AirPassengers (or synthetic fallback)\n\n")
report_lines.append("Aggregated metrics (mean across splitting):\n")
report_lines.append(results_table.to_csv() + "\n")
report_lines.append("Chosen hyperparameters:\n")
for k,v in chosen_hparams.items():
    report_lines.append(f"{k}: {v}\n")
report_lines.append("\nNotes on interpretability: TemporalSelfAttention is additive attention across encoder timesteps; the heatmap above highlights which historical months the model emphasized.\n")
with open("ts_forecasting_report.txt","w") as f:
    f.writelines(report_lines)
print("Report saved to ts_forecasting_report.txt")
