# Stock Value Prediction

In [129]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd

from flax import linen
from stockdex import Ticker
from turbanet import TurbaTrainState, mse

In [130]:
TIME_WINDOW = 20
HIDDEN_SIZE = 32
EPOCHS = 10_000
LR = 1e-5

In [131]:
SWARM_SIZE = 100

## Data Collection

In [132]:
# Load in stock data
df = pd.read_csv("../../data/input/stock_info.csv")

In [133]:
# Create sequences of price data of length TIME_WINDOW
def create_sequences(data, time_window) -> tuple[np.ndarray, np.ndarray]:
    sequences = []
    results = []
    for i in range(len(data) - time_window):
        sequence = data[i : i + time_window]
        sequences.append(sequence)
        results.append(data[i + time_window])
    return np.array(sequences).reshape(-1, time_window, 1), np.array(results).reshape(-1, 1)

In [134]:
def get_ticker_data(ticker, date_range="1y", data_granularity="1d"):
    data = Ticker(ticker=ticker).yahoo_api_price(
        range=date_range, dataGranularity=data_granularity
    )
    close = data["close"]
    if close.empty:
        return None
    close /= np.max(close)
    X_data, y_data = create_sequences(close, TIME_WINDOW)
    return X_data, y_data

In [135]:
def get_random_ticker_data(stock_df, num_samples, date_range="1y", data_granularity="1d"):
    tickers = []
    X_data = []
    y_data = []
    max_seq_len = 0
    while len(tickers) < num_samples:
        ticker = stock_df.sample(1).Ticker.values[0]
        try:
            ticker_data = get_ticker_data(ticker, date_range, data_granularity)
        except Exception:
            continue

        # If the data returned is None, try again
        if ticker_data is None:
            continue

        if ticker_data[0].shape[0] < max_seq_len:
            continue

        if ticker_data[0].shape[0] > max_seq_len:
            max_seq_len = ticker_data[0].shape[0]
            invalid_tickers = [x.shape[0] < max_seq_len for x in X_data]
            if len(invalid_tickers) > 0:
                del tickers[invalid_tickers]
                del X_data[invalid_tickers]
                del y_data[invalid_tickers]

        tickers.append(ticker)
        X_data.append(ticker_data[0])
        y_data.append(ticker_data[1])

        print(f"Tickers Found: {len(tickers)}/{num_samples}")

    return tickers, np.array(X_data), np.array(y_data)

In [None]:
TICKERS, X_data, y_data = get_random_ticker_data(
    df, SWARM_SIZE, date_range="1y", data_granularity="1d"
)


In [None]:
TICKERS

In [None]:
X_data.shape, y_data.shape

## Turba

In [139]:
class TurbaLSTM(linen.Module):
    features: int

    @linen.compact
    def __call__(self, x):
        ScanLSTM = linen.scan(
            linen.OptimizedLSTMCell,
            variable_broadcast="params",
            split_rngs={"params": False},
            in_axes=1,
            out_axes=1,
        )

        lstm = ScanLSTM(self.features)
        input_shape = x[:, 0].shape
        carry = lstm.initialize_carry(jax.random.PRNGKey(0), input_shape)
        carry, x = lstm(carry, x)
        final = x[:, -1]
        output = linen.Dense(1)(final)
        return output

In [140]:
optimizer = optax.adam(learning_rate=LR)

In [22]:
turba_model = TurbaTrainState.swarm(TurbaLSTM(features=HIDDEN_SIZE), optimizer, SWARM_SIZE, X_data)

In [None]:
# Train the Turba model
start = time.time()
turba_losses = []
for epoch in range(EPOCHS):
    # Train
    turba_model, loss, pred = turba_model.train(X_data, y_data, mse)

    # Logging
    turba_losses.append(loss)
    if epoch % 100 == 0:
        print(f"Epoch {epoch} Losses: {jnp.array(turba_losses[-100:]).mean(axis=0)}")

print(f"Turba time: {time.time() - start}")

In [24]:
turba_losses = jnp.array(turba_losses)

In [25]:
turba_predictions = turba_model.predict(X_data)

In [None]:
# Plot of losses over training
plt.figure(figsize=(32, 16))
plt.plot(turba_losses, label=TICKERS)
plt.title("Loss", fontsize=32)
plt.xlabel("Epochs", fontsize=24)
plt.ylabel("Loss", fontsize=24)
plt.legend(fontsize=24)
plt.show()

In [None]:
# Subplot of predictions vs ground truth (x by x)
fig = plt.figure(figsize=(32, 16))

# Shared axes
for i in range(SWARM_SIZE):
    ax = fig.add_subplot(
        int(np.ceil(np.sqrt(SWARM_SIZE))), int(np.ceil(np.sqrt(SWARM_SIZE))), i + 1
    )

    # Title
    ax.set_title(TICKERS[i])

    # Axes
    ax.set_ylabel("Price")

    # Data
    ax.plot(y_data[i], label="Ground Truth")
    ax.plot(turba_predictions[i], label="Turba Prediction")

    # Legend
    ax.legend(loc="upper left")

plt.show()


## Torch

In [28]:
import torch
import torch.nn as nn


class TorchLSTM(nn.Module):
    def __init__(self, hidden_size=128):
        super(TorchLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(1, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Initialize hidden and cell state (h_0, c_0) with zeros
        batch_size = x.size(0)
        h_0 = torch.zeros(1, batch_size, self.hidden_size).to(x.device)
        c_0 = torch.zeros(1, batch_size, self.hidden_size).to(x.device)

        # Forward propagate LSTM
        out, (h_n, c_n) = self.lstm(x, (h_0, c_0))  # Shape: (batch, seq_len, hidden_size)

        # Take the last time step output
        final_output = out[:, -1, :]  # Shape: (batch, hidden_size)

        # Fully connected layer to map hidden state to final output
        return self.fc(final_output)  # Shape: (batch, output_size)

In [29]:
torch_models = [TorchLSTM(hidden_size=HIDDEN_SIZE) for _ in range(SWARM_SIZE)]

In [None]:
# Train the PyTorch model
start = time.time()

torch_loss = torch.nn.MSELoss()
torch_losses = []
for idx, torch_model in enumerate(torch_models):
    torch_model.train()
    torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=LR)

    torch_losses.append([])
    for epoch in range(EPOCHS):
        # Train
        torch_optimizer.zero_grad()
        y_pred = torch_model(torch.Tensor(X_data[idx]))
        loss = torch_loss(y_pred, torch.Tensor(y_data[idx]))
        loss.backward()
        torch_optimizer.step()

        # Logging
        torch_losses[idx].append(loss.item())
        if epoch % 100 == 0:
            print(f"Model {idx} - Epoch {epoch} Loss: {np.mean(torch_losses[idx][-100:])}")

print(f"torch time: {time.time() - start}")

In [31]:
torch_predictions = []
for idx, torch_model in enumerate(torch_models):
    torch_model.eval()
    torch_predictions.append(torch_model(torch.Tensor(X_data[idx])).detach().numpy())

In [32]:
torch_losses = np.array(torch_losses).T
torch_predictions = np.array(torch_predictions)

In [None]:
# Plot of losses over training
plt.figure(figsize=(32, 16))
plt.plot(torch_losses, label=TICKERS)
plt.title("Loss", fontsize=32)
plt.xlabel("Epochs", fontsize=24)
plt.ylabel("Loss", fontsize=24)
plt.legend(fontsize=24)
plt.show()

In [None]:
# Subplot of predictions vs ground truth (x by x)
fig = plt.figure(figsize=(32, 16))

# Shared axes
for i in range(SWARM_SIZE):
    ax = fig.add_subplot(
        int(np.ceil(np.sqrt(SWARM_SIZE))), int(np.ceil(np.sqrt(SWARM_SIZE))), i + 1
    )

    # Title
    ax.set_title(TICKERS[i])

    # Axes
    ax.set_ylabel("Price")

    # Data
    ax.plot(y_data[i], label="Ground Truth")
    ax.plot(torch_predictions[i], label="Torch Prediction")

    # Legend
    ax.legend(loc="upper left")

plt.show()

## Final Data

In [None]:
# Subplot of predictions vs ground truth (x by x)
fig = plt.figure(figsize=(32, 20))

# Shared axes
for i in range(SWARM_SIZE):
    ax = fig.add_subplot(
        int(np.ceil(np.sqrt(SWARM_SIZE))), int(np.ceil(np.sqrt(SWARM_SIZE))), i + 1
    )

    # Title
    ax.set_title(TICKERS[i])

    # Axes
    ax.set_ylabel("Price")

    # Data
    ax.plot(y_data[i], label="Ground Truth")
    ax.plot(torch_predictions[i], label="Torch Prediction")
    ax.plot(turba_predictions[i], label="Turba Prediction")

    # Legend
    ax.legend(loc="upper left")

plt.show()