In [1]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, os
import jax.numpy as jnp, jax
from jax import grad, jit, vmap
import tqdm

from utils.newton import *

***Prepare data***


In [2]:
data = pd.read_csv("data.csv")
data = data.rename({"Unnamed: 0":"date"}, axis=1)
data.date = pd.to_datetime(data.date)
data.head()

Unnamed: 0,date,open,high,low,close,adjclose,volume,ticker
0,1999-11-18,32.546494,35.765381,28.612303,31.473534,26.889658,62546380,A
1,1999-11-19,30.713518,30.758226,28.478184,28.880545,24.674322,15234146,A
2,1999-11-22,29.551144,31.473534,28.657009,31.473534,26.889658,6577870,A
3,1999-11-23,30.400572,31.205294,28.612303,28.612303,24.445148,5975611,A
4,1999-11-24,28.701717,29.998213,28.612303,29.372318,25.094481,4843231,A


In [3]:
ts_data = data.set_index(["ticker", "date"])
ts_data.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,open,high,low,close,adjclose,volume
ticker,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
A,1999-11-18,32.546494,35.765381,28.612303,31.473534,26.889658,62546380
A,1999-11-19,30.713518,30.758226,28.478184,28.880545,24.674322,15234146
A,1999-11-22,29.551144,31.473534,28.657009,31.473534,26.889658,6577870
A,1999-11-23,30.400572,31.205294,28.612303,28.612303,24.445148,5975611
A,1999-11-24,28.701717,29.998213,28.612303,29.372318,25.094481,4843231


In [4]:
target_data = ts_data.adjclose.unstack(0).shift(-1).stack()
target_data = pd.DataFrame(target_data.rename("target"))
target_data = target_data.swaplevel(0,1).sort_index()
target_data.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,target
ticker,date,Unnamed: 2_level_1
A,1999-11-17,26.889658
A,1999-11-18,24.674322
A,1999-11-19,26.889658
A,1999-11-22,24.445148
A,1999-11-23,25.094481


**Implementation**

In [5]:
tickers = data.ticker.unique()
tckr = tickers[0]
tckr

Z = target_data.loc[tckr]
X = ts_data.loc[tckr]
days = Z.index
Z, X = jnp.array(Z), jnp.array(X)

N, Nz = Z.shape
_, Nx = X.shape

In [13]:
rng_key = jax.random.PRNGKey(88)

P0 = jax.random.normal(rng_key, (Nz, Nz))
P0 = jnp.matmul(P0, P0.T)

K = 100
sigma = 1e-3

R = sigma * jnp.eye(Nx)
Q = sigma * jnp.eye(Nz)
Qi = jnp.linalg.inv(Q)
Ri = jnp.linalg.inv(R)

z0 = Z[0]

D0 = jax.random.normal(rng_key, (Nz, Nz))
H0 = jax.random.normal(rng_key, (Nx, Nz))

In [20]:
z = jax.random.normal(rng_key, (K+1, Nz))
P = jax.random.normal(rng_key, (K+1, Nz, Nz))

z_s = jax.random.normal(rng_key, (K+1, Nz))
P_s = jax.random.normal(rng_key, (K+1, Nz, Nz))
G_s = jax.random.normal(rng_key, (K+1, Nz, Nz))


for id_day, day in tqdm.tqdm(enumerate(days)):
    if id_day >= K+1 and id_day+K <= N:
        
        X_obs = X[id_day:id_day+K, :]
        Z_obs = X[id_day:id_day+K, :]
        
        for k in range(K):
            z_m = D0 @ z[k-1, :]
            P_m = D0 @ P[k-1, :, :] @ D0.T + Q

            S = H0 @ P_m @ H0.T + R
            G = P_m @ H0.T @ jnp.linalg.inv(S)
            z.at[k].set(z_m + G @ (X_obs[k] - H0 @ z_m))
            P.at[k].set(P_m - G @ S @ G.T)
        
        z_s = jnp.copy(z)
        P_s = jnp.copy(P)
        
        
        for kk in range(0, K):
            k = K-2 - kk
            
            z_m = D0 @ z[k]
            P_m = D0 @ P[k] @ D0.T + Q
            
            G_s.at[k].set(P[k] @ D0.T @ jnp.linalg.inv(P_m))
            z_s.at[k].set(z[k, :] + G_s[k] @ (z_s[k+1] - z_m))
            P_s.at[k].set(P[k] + G_s[k] @ (P_s[k+1] - P_m) @ G_s[k].T)
        
        SS = jnp.zeros((Nz, Nz))
        FF = jnp.zeros((Nz, Nz))
        B = jnp.zeros((Nx, Nz))
        C = jnp.zeros((Nz, Nz))
        DD = jnp.zeros((Nx, Nx))
        for k in range(K):
            SS += P_s[k] + z_s[k, :, None] * z_s[k, None, :]
            FF += P_s[k-1] + z_s[k-1, :, None] * z_s[k-1, None, :]
            B += X_obs[k, :, None] * z_s[k, None, :]
            C += P_s[k] @ G_s[k-1].T + z_s[k, :, None] * z_s[k-1, None, :]
            DD += X_obs[k, :, None] * X_obs[k, None, :]
        SS *= 1/K
        FF *= 1/K
        B *= 1/K
        C *= 1/K
        DD *= 1/K
        
        
        f_D = lambda D: jnp.trace(Qi @ (SS - C @ D.T - D @ C.T + D @ FF @ D.T))
        f_H = lambda H: jnp.trace(Ri @ (DD - B @ H.T - H @ B.T + H @ SS @ H.T))
        
        f_D = jit(f_D)
        df_D = grad(f_D)
        ddf_D = jax.jacfwd(df_D)
        D0 = solve_newton(D0, f_D, df_D, ddf_D, eps=1e-2, verbose=True)
        print(D0)
        f_H = jit(f_H)
        df_H = grad(f_H)
        ddf_H = jax.jacfwd(df_H)
        H0 = solve_newton(H0, f_H, df_H, ddf_H, eps=1e-2, verbose=True)
        
        break
        #if id_day > 200: break


0it [00:00, ?it/s]

newton score 6.865185e-26
[[0.13313237]]
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803


101it [00:04, 25.15it/s]

newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
newton score 0.0011115803
SNAKE EATER



