In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pykalman import KalmanFilter

x = np.sin(np.linspace(5, 20, 500))
w = np.random.normal(0, 0.4, len(x))
x_w = x + w

transition_mtx=[1]
observation_mtx=[1]
initial_state_cov=1
transition_cov=0.01
observation_covariance=1
def init_kfc(transition_mtx=[1], observation_mtx=[1],
            initial_state_cov=None, transition_cov=None,
            observation_covariance=None
    ):
  kf = KalmanFilter(
    transition_matrices=transition_mtx,
    observation_matrices=observation_mtx,
    transition_covariance=transition_cov,
    observation_covariance=observation_covariance,
    initial_state_mean=x_w[0],
    initial_state_covariance=initial_state_cov
  )
  return kf


kf = init_kfc(
            initial_state_cov=1,transition_cov=0.01,
            observation_covariance=1
        )
state_means, _ = kf.filter(x_w)


kf = init_kfc()
kf = kf.em(x_w, n_iter=10)
state_means_em, _ = kf.filter(x_w)

plt.plot(x, color="red")
plt.plot(state_means, color="orange")
plt.plot(state_means_em, color="blue")

In [None]:
import yfinance as yf
import pandas as pd
from scipy.linalg import lstsq

In [None]:
cnf = {
    "start": "2020-01-01",
    "end": "2025-01-01",
    "ticker": "NVDA",
    "top_q": 0.9,
    "bottom_q": 0.1,
    "window": 20,
    "center_time": True,
    "slope_at": "end"
}

In [None]:
data = yf.Ticker(cnf["ticker"]).history(start=cnf["start"], end=cnf["end"])

In [None]:
def init_kf(init_mean, transition_mtx:list[float]=[1],
             observation_mtx:list[float]=[1],initial_state_cov:float=None,
             transition_cov:float=None,observation_covariance:float=None
    ):
  kf =  KalmanFilter(
    transition_matrices=transition_mtx,
    observation_matrices=observation_mtx,
    transition_covariance=transition_cov,
    observation_covariance=observation_covariance,
    initial_state_mean=init_mean,
    initial_state_covariance=initial_state_cov
  )
  return kf

def apply_filter(
    data: pd.DataFrame,
    transition_mtx:list[float]=[1],
    observation_mtx:list[float]=[1],
    initial_state_cov:float=1,
    transition_cov:float=0.01,
    observation_covariance:float=1,
    estimate:bool=False
  ) -> pd.Series:
  log_close = data["LogClose"]
  kf = init_kf(log_close.iloc[0], transition_mtx=transition_mtx,
             observation_mtx=observation_mtx,initial_state_cov=initial_state_cov,
             transition_cov=transition_cov,observation_covariance=observation_covariance
    )
   
  if estimate:
    kf = kf.em(log_close, n_iter=10)

  state_means, state_covs = kf.filter(log_close)
  return state_means, state_covs


def get_beta(data: pd.DataFrame, window:int=20, center_time:bool=True, slope_at:str="center"):
  y = data["StateMeans"]
  t = np.arange(window, dtype=float)
  y_arr = y.values.astype(float)
  n = len(y)
  idx = y.index

  if center_time:
    t -= np.mean(t)

  x = np.column_stack([np.ones(window), t, t**2])

  out_a, out_b, out_c = [np.full(n, np.nan) for _ in range(3)]

  for end in range(window, n+1):
    start = end - window
    y_win = y_arr[start:end]
    coef, _, _, _ = lstsq(x, y_win)

    out_a[end-1], out_b[end-1], out_c[end-1] = coef

  coeffs = pd.DataFrame({
      "a": out_a, "b": out_b, "c": out_c
    },
    index=idx
  )

  if slope_at == "center":
    slope = coeffs["b"]

  elif slope_at == "end":
    if center_time:
      t0 = t[-1] 
    else:
      t0 = window - 1
    slope = coeffs["b"] + 2 * coeffs["c"] * t0

  else:
    raise ValueError("slope_at must be 'center' or 'end'")

  coeffs["Slope"] = slope
  return coeffs



In [None]:
data["LogClose"] = np.log(data["Close"])
state_means, _ = apply_filter(data)

In [None]:
data["StateMeans"] = state_means
plt.plot(data["LogClose"])
plt.plot(data["StateMeans"])
plt.show()

In [None]:
quad = get_beta(data, window=cnf["window"], center_time=cnf["center_time"], slope_at=cnf["slope_at"])
data["QuadraticSlope"] = quad["Slope"]
data["QuadraticSlopeAnnual"] = data["QuadraticSlope"] * 252
plt.plot(data["QuadraticSlope"])
plt.show()

In [None]:
data["q"] = data["QuadraticSlope"].rank(pct=True)
data["Signal"] = 0
data.loc[data["q"] >= cnf["top_q"], "Signal"] = 1
data.loc[data["q"] <= cnf["bottom_q"], "Signal"] = -1

plt.plot(data["Signal"])
plt.show()