In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import yfinance as yf
import polars as pl
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import flax.nnx as nnx
from huggingface_hub import HfApi
from model import Model
from dotenv import load_dotenv

load_dotenv()

True

In [3]:
api = HfApi()

In [4]:
import jax


model = Model.load_from_hf(api, "Kicel/economics3a", "itransformer")

In [5]:
mesh = jax.make_mesh((1,), ("a",))
shard = jax.NamedSharding(mesh, P(None))
test_set = pl.read_parquet("test-stocks.parquet")

In [6]:
from consts import LAG

@nnx.vmap(in_axes=(None, None, 0))
@nnx.jit
def predict(model, arr, i):
  data = jax.lax.dynamic_slice_in_dim(arr, i, LAG)
  return model.forecast(data)

In [7]:
model.eval()

In [38]:
stock_name = "SHOP"

In [39]:
from data_loading import preprocess_history

history = test_set.filter(symbol=stock_name)
arr = preprocess_history(history).to_jax().to_device(shard)

In [40]:
starts = jnp.arange(arr.shape[0] - LAG + 1)
preds = predict(model, arr, starts)

In [41]:
import numpy as np

In [42]:
df = history.with_columns(
    pred=np.array(jnp.pad(preds, [(LAG - 1 + 20, 0), (0, 0)], mode='constant', constant_values=jnp.nan))  # type: ignore
)
df

symbol,date,open,close,high,low,volume,pred
str,date,f64,f64,f64,f64,i64,"array[f32, 8]"
"""SHOP""",2024-01-02,76.440002,73.830002,76.629997,72.910004,13134800,"[NaN, NaN, … NaN]"
"""SHOP""",2024-01-03,72.080002,71.82,72.989998,71.18,9649900,"[NaN, NaN, … NaN]"
"""SHOP""",2024-01-04,71.629997,73.419998,74.019997,70.614998,11927400,"[NaN, NaN, … NaN]"
"""SHOP""",2024-01-05,73.32,74.510002,75.946999,73.0,9830200,"[NaN, NaN, … NaN]"
"""SHOP""",2024-01-08,74.779999,77.690002,77.900002,74.720001,8232000,"[NaN, NaN, … NaN]"
…,…,…,…,…,…,…,…
"""SHOP""",2024-10-31,79.75,78.209999,80.389999,77.120003,4637000,"[77.618973, 77.536812, … 78.218941]"
"""SHOP""",2024-11-01,79.050003,78.989998,79.889999,77.940002,6189900,"[79.124924, 79.18158, … 79.714035]"
"""SHOP""",2024-11-04,78.709999,78.440002,79.059998,77.699997,3584500,"[78.617661, 78.729416, … 79.368546]"
"""SHOP""",2024-11-05,78.550003,79.57,80.040001,78.230003,4652300,"[80.10891, 80.053238, … 80.364639]"


In [43]:
df.write_parquet(f"analysis/i/{stock_name.lower()}.parquet")