In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import pickle

In [2]:
jnp.asarray([1, 2, 3]).shape

(3,)

In [6]:
jnp.concatenate((jnp.asarray([1, 2]), jnp.asarray([3, 4])))

Array([1, 2, 3, 4], dtype=int32)

In [26]:
def _jax_optimize(r:jnp.ndarray, s:jnp.ndarray, rho:jnp.ndarray, target_return=None, risk_free_rate=0.0, max_epochs:int = 1000, lr:float = 0.01, device = "cpu") -> np.ndarray:

    print("Using JAX for optimization...")

    n_assets = len(r)
    
    if device == "cpu":
        jax.config.update("jax_platform_name", "cpu")
    elif device == "gpu" or device == "cuda":
        jax.config.update("jax_platform_name", "gpu")
    else:
        raise ValueError("Device must be either 'cpu' or 'gpu' or 'cuda'.")

    # mistakes in implementation?
    @jax.jit
    def neg_sharp_ratio(weights:jnp.ndarray):

        weights = jnp.clip(weights, min=0.0)
        weights = weights / jnp.sum(weights)

        return - (weights.T @ r - risk_free_rate)\
              / jnp.sqrt(
                  weights.T @ (rho * (s[:, None] @ s[None, :])) @ weights
              )

    grad_fn = jax.grad(neg_sharp_ratio)

    weights = jnp.array(np.random.uniform(0.0, 1.0, n_assets))

    for epoch in range(max_epochs):

        loss = neg_sharp_ratio(weights)
        
        grad = grad_fn(weights)
        weights = weights - lr * grad

        if (epoch+1) % (max_epochs // 10) == 0:
            print(f"Epoch {epoch+1}/{max_epochs}, Loss: {loss:.4f}")

    weights = jnp.clip(weights, min=0.0)
    weights = weights / jnp.sum(weights)
    return np.array(weights)

In [27]:
_, r, s = pickle.load(open("../results/stock/statistics.pkl", "rb"))
rho = pickle.load(open("../results/stock/correlation_matrix.pkl", "rb"))

opt_weights = _jax_optimize(r, s, rho, risk_free_rate=0, max_epochs=1000, lr=0.1, device="gpu")

Using JAX for optimization...
Epoch 100/1000, Loss: -0.4717
Epoch 200/1000, Loss: -0.4762
Epoch 300/1000, Loss: -0.4807
Epoch 400/1000, Loss: -0.4851
Epoch 500/1000, Loss: -0.4895
Epoch 600/1000, Loss: -0.4940
Epoch 700/1000, Loss: -0.4983
Epoch 800/1000, Loss: -0.5027
Epoch 900/1000, Loss: -0.5071
Epoch 1000/1000, Loss: -0.5115


In [22]:
weights = np.random.uniform(0.0, 1.0, len(r))
risk_free_rate = 0.04

(weights.T @ r - risk_free_rate) , np.sqrt(weights.T @ (rho * (s[:, None] @ s[None, :]))) @ weights

  (weights.T @ r - risk_free_rate) , np.sqrt(weights.T @ (rho * (s[:, None] @ s[None, :]))) @ weights


(np.float64(204.84745448262635), np.float64(nan))

In [25]:
weights.T @ (rho * (s[:, None] @ s[None, :])) @ weights

np.float64(202294.4203934044)