# WaveResNetRegressor with Context
This notebook demonstrates two ways to provide context to the `WaveResNetRegressor` while keeping the input pipeline in `float32`.
1. Passing context arrays directly during `fit`/`predict` (ensure arrays are `np.float32`).
2. Using the built-in cosine context builder to derive contextual features automatically; calling `set_params` on context options resets the cached builder.


## Imports
We import NumPy for data generation, Matplotlib for quick visualisation, scikit-learn utilities for train/test splits and metrics, and finally `WaveResNetRegressor` from `psann.sklearn`.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

from psann import WaveResNetRegressor


## Synthetic dataset
We generate a sinusoidal signal where the phase offset is treated as contextual information. The model receives four primary features (sine and cosine terms at two frequencies) and two context features (sine/cosine of the phase).


In [None]:
rng = np.random.default_rng(7)
n_samples = 2000
t = rng.uniform(0, 6 * np.pi, size=(n_samples, 1)).astype(np.float32)
phase = rng.uniform(-np.pi, np.pi, size=(n_samples, 1)).astype(np.float32)

X = np.concatenate([
    np.sin(t),
    np.cos(t),
    np.sin(0.5 * t),
    np.cos(0.5 * t),
], axis=1).astype(np.float32)

context = np.concatenate([np.sin(phase), np.cos(phase)], axis=1).astype(np.float32)
signal = np.sin(t + phase) + 0.3 * np.sin(0.5 * t - 0.5 * phase)
y = signal.reshape(-1).astype(np.float32)
y += 0.05 * rng.standard_normal(size=y.shape).astype(np.float32)

X_train, X_test, ctx_train, ctx_test, y_train, y_test = train_test_split(
    X, context, y, test_size=0.25, random_state=42
)


## Training with explicit context
We supply the context arrays directly to `fit` and `predict`.


In [None]:
explicit_est = WaveResNetRegressor(
    hidden_layers=4,
    hidden_width=64,
    epochs=150,
    batch_size=128,
    lr=3e-3,
    context_dim=2,
    random_state=0,
    early_stopping=True,
    patience=20,
)
explicit_est.fit(X_train, y_train, context=ctx_train)

pred_exp = explicit_est.predict(X_test, context=ctx_test)
mse_exp = mean_squared_error(y_test, pred_exp)
print(f"Explicit context MSE: {mse_exp:.4f}")


### Quick visual check
The scatter plot compares predictions against targets for a sample subset.


In [None]:
subset = slice(0, 200)
fig, ax = plt.subplots(figsize=(5, 5))
ax.scatter(y_test[subset], pred_exp[subset], alpha=0.6, label="Explicit context")
ax.plot([-2, 2], [-2, 2], color="tab:red", linestyle="--", linewidth=1)
ax.set_xlabel("Target")
ax.set_ylabel("Prediction")
ax.set_title("WaveResNetRegressor with explicit context")
ax.legend()
ax.grid(True, linestyle=":")
plt.show()


## Training with automatic cosine context
Here we let the estimator derive contextual features from the primary inputs using the cosine context builder.


In [None]:
auto_est = WaveResNetRegressor(
    hidden_layers=4,
    hidden_width=64,
    epochs=150,
    batch_size=128,
    lr=3e-3,
    context_builder="cosine",
    context_builder_params={"frequencies": [1.0, 2.0], "include_sin": True, "include_cos": True},
    random_state=0,
    early_stopping=True,
    patience=20,
)
auto_est.fit(X_train, y_train)

pred_auto = auto_est.predict(X_test)
mse_auto = mean_squared_error(y_test, pred_auto)
print(f"Cosine builder MSE: {mse_auto:.4f}")
print(f"Learned context_dim: {auto_est.context_dim}")


## Comparison
Both approaches leverage context, either provided manually or generated automatically. The cosine builder offers a quick way to add frequency-encoded context without crafting arrays by hand.
