# vjp jvp

In [18]:
import jax
import jax.numpy as jnp

def output(a, b, c):
    d = b + c
    e = a * c
    f = d + e
    g = e / f
    return g

# Set the initial values for a, b, c
a, b, c = 2.0, 3.0, 4.0  # Example values

# Set the sensitivities (tangents) for a, b, c
a_dot, b_dot, c_dot = 1.0, 1.0, 1.0  # Example sensitivities

# Function to compute forward mode derivative of g with respect to a, b, c
def compute_forward_gradient(a, b, c, a_dot, b_dot, c_dot):
    # Wrap the function to use with jax.jvp
    def func(inputs):
        a, b, c = inputs
        return output(a, b, c)

    # Inputs and their perturbations
    inputs = jnp.array([a, b, c])
    tangents = jnp.array([a_dot, b_dot, c_dot])

    # Compute the Jacobian-vector product
    _, g_dot = jax.jvp(func, (inputs,), (tangents,))
    return g_dot

# Compute g_dot
g_dot = compute_forward_gradient(a, b, c, a_dot, b_dot, c_dot)
print("Forward mode gradient g_dot:", g_dot)

# Function to compute reverse mode gradient of a, b, c given g_bar
def compute_reverse_gradient(a, b, c, g_bar):
    # Wrap the function to use with jax.vjp
    def func(a, b, c):
        return output(a, b, c)
    
    # Get the function output and vjp function
    g, vjp_fun = jax.vjp(func, a, b, c)

    # Compute the vector-Jacobian product
    a_bar, b_bar, c_bar = vjp_fun(g_bar)
    return a_bar, b_bar, c_bar

# Assume g_bar is given (the sensitivity of the loss with respect to g)
g_bar = 1.0  # Example sensitivity of the loss with respect to g

# Compute a_bar, b_bar, c_bar
a_bar, b_bar, c_bar = compute_reverse_gradient(a, b, c, g_bar)
print("Reverse mode gradients a_bar, b_bar, c_bar:", a_bar, b_bar, c_bar)

Thou
# Calculate the dot products
lhs = a_dot * a_bar + b_dot * b_bar + c_dot * c_bar
rhs = g_dot * g_bar

print(f"Left-hand side (Input Sensitivity Product): {lhs}")
print(f"Right-hand side (Output Sensitivity Product): {rhs}")

Forward mode gradient g_dot: 0.115555555
Reverse mode gradients a_bar, b_bar, c_bar: 0.124444455 -0.035555556 0.026666671
Left-hand side (Input Sensitivity Product): 0.11555556952953339
Right-hand side (Output Sensitivity Product): 0.11555555462837219


# checkpointing

## no checkpointing

\begin{equation*}
   S_{t+1} = S_{t} \exp{((r-\frac{1}2{\sigma}^2){\delta}t +{\sigma}{\sqrt{\delta}t}{\phi})}
\end{equation*}

where $\phi \sim N(0,1)$.


In [1]:
import jax
import jax.numpy as jnp
from jax import random
  

def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
    time_steps = jnp.linspace(dt, T, num_steps)
    log_path = (mu - 0.5 * sigma**2) * time_steps + sigma * increments.cumsum(axis=0)
    S_path = S0 * jnp.exp(log_path)
    return S_path[-1]

def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)

LHS: 2.567828732334462
RHS: 2.567828732334464


another approximation
\begin{equation*}
   S_{t+1} = S_{t} (1+r{\delta}t+{\sigma}{\phi}{\sqrt{\delta t}})
\end{equation*}
since we have $dW \sim {\phi}{\sqrt{\delta t}}$

In [71]:
import jax
import jax.numpy as jnp
from jax import random, checkpoint
jax.config.update("jax_enable_x64", True)  # force 64-bit accuracy


def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path[-1]


def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)

LHS: 0.11919670816963901
RHS: 0.11919670816968986


## checkpointing using jax

In [2]:
import jax
import jax.numpy as jnp
from jax import random, checkpoint
jax.config.update("jax_enable_x64", True)  # force 64-bit accuracy


def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
    time_steps = jnp.linspace(dt, T, num_steps)
    log_path = (mu - 0.5 * sigma**2) * time_steps + sigma * increments.cumsum(axis=0)
    S_path = S0 * jnp.exp(log_path)
    return S_path[-1]

ckpt_simulate_gbm = checkpoint(simulate_gbm, static_argnums=(3, 4))

def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: ckpt_simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return ckpt_simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)

LHS: 2.567828732334462
RHS: 2.567828732334464


### using for loop here 

testing with different simulation formula from above

In [68]:
import jax
import jax.numpy as jnp
from jax import random, checkpoint
jax.config.update("jax_enable_x64", True)  # force 64-bit accuracy


def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path[-1]

ckpt_simulate_gbm = checkpoint(simulate_gbm, static_argnums=(3, 4))

def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: ckpt_simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return ckpt_simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)

LHS: 0.11919670816963901
RHS: 0.11919670816966921


## not using jax.checkpoint

### 

below is more like the one-step checkpointing

In [55]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vjp, jvp

jax.config.update("jax_enable_x64", True) 

def compute_increments(dt, num_steps, key):
    return random.normal(key, (num_steps,)) * jnp.sqrt(dt)

def compute_final_price(S0, mu, sigma, increments, time_steps):
    num_steps = int((T)/ dt)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
    return S_path[-1]

def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    time_steps = jnp.linspace(dt, T, num_steps)
    increments = compute_increments(dt, num_steps, key)
    final_price = compute_final_price(S0, mu, sigma, increments, time_steps)
    return final_price

# for manually checkpointing
def forward_and_save_intermediates(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    time_steps = jnp.linspace(dt, T, num_steps)
    increments = compute_increments(dt, num_steps, key)
    # Save increments and time_steps for recomputation
    return S0, mu, sigma, increments, time_steps

def manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key):
    S0, mu, sigma, increments, time_steps = forward_and_save_intermediates(S0, mu, sigma, T, dt, key)
    
    def backward_from_saved(S0, mu, sigma):
        return compute_final_price(S0, mu, sigma, increments, time_steps)
    
    g, vjp_fun = vjp(lambda S0, mu, sigma: backward_from_saved(S0, mu, sigma), S0, mu, sigma)
    g_bar = jnp.array(1.0)
    gradients = vjp_fun(g_bar)
    return gradients

S0 = 100.0  # Initial price
mu = 0.05   # Drift
sigma = 0.2 # Volatility
T = 1.0     # Time horizon
dt = 0.01   # Time step
key = random.PRNGKey(0)

gradients = manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key)

In [56]:
gradients

(Array(0.87897801, dtype=float64, weak_type=True),
 Array(88.03136753, dtype=float64, weak_type=True),
 Array(-88.79114883, dtype=float64, weak_type=True))

this is the LHS: S0_bar$*$S0_dot + sigma_bar$*$sigma_dot + mu_bar$*$mu_dot

the result is the same as the 2.2.1

In [57]:
sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array(gradients))

Array(0.11919671, dtype=float64)

### jax in reverse mode

In [58]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vjp, jvp

jax.config.update("jax_enable_x64", True) 

def compute_increments(dt, num_steps, key):
    return random.normal(key, (num_steps,)) * jnp.sqrt(dt)

def compute_step_price(S0, mu, sigma, increments, t0, T, dt):
    num_steps = round((T-t0)/ dt)
    if num_steps == 0:
        print(T,t0)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path

# for manually checkpointing
def forward_and_save_intermediates(S0, mu, sigma, T, dt, key):
    num_steps = round(T / dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    increments = compute_increments(dt, num_steps, key)
    # Save increments and time_steps for recomputation
    state_price = compute_step_price(S0, mu, sigma, increments, 0, T, dt)
    return S0, mu, sigma, increments, state_price

def manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key):
    S0, mu, sigma, increments, state_price = forward_and_save_intermediates(S0, mu, sigma, T, dt, key)
    num_steps = round(T / dt)
    g_bar = jnp.array(1.0)
    mu_grad = 0.0
    sigma_grad = 0.0
    S0_grad = 0.0
    for i in reversed(range(num_steps)):
        S0 = state_price[i]
        def step_fn(S0, mu, sigma):
            return compute_step_price(S0, mu, sigma, increments[i:i+1], dt*i, dt*(i+1), dt)[-1]   
        _, vjp_fun = vjp(step_fn, S0, mu, sigma)
        g_bar, mu_bar, sigma_bar = vjp_fun(g_bar)

        mu_grad += mu_bar
        sigma_grad += sigma_bar
        if i == 0:
            S0_grad = g_bar

    return S0_grad, mu_grad, sigma_grad

S0 = 100.0  # Initial price
mu = 0.05   # Drift
sigma = 0.2 # Volatility
T = 1.0     # Time horizon
dt = 0.01   # Time step
key = random.PRNGKey(0)

gradients = manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key)

In [59]:
gradients

(Array(0.87897801, dtype=float64, weak_type=True),
 Array(88.03136753, dtype=float64, weak_type=True),
 Array(-88.79114883, dtype=float64, weak_type=True))

In [61]:
sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array(gradients))

Array(0.11919671, dtype=float64)

### totally manual calculation in reverse mode

In [65]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vjp, jvp

jax.config.update("jax_enable_x64", True) 

def compute_increments(dt, num_steps, key):
    return random.normal(key, (num_steps,)) * jnp.sqrt(dt)

def compute_step_price(S0, mu, sigma, increments, t0, T, dt):
    num_steps = round((T-t0)/ dt)
    if num_steps == 0:
        print(T,t0)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path

# for manually checkpointing
def forward_and_save_intermediates(S0, mu, sigma, T, dt, key):
    num_steps = round(T / dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    increments = compute_increments(dt, num_steps, key)
    # Save increments and time_steps for recomputation
    state_price = compute_step_price(S0, mu, sigma, increments, 0, T, dt)
    return S0, mu, sigma, increments, state_price

def manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key):
    S0, mu, sigma, increments, state_price = forward_and_save_intermediates(S0, mu, sigma, T, dt, key)
    num_steps = round(T / dt)
    g_bar = jnp.array(1.0)
    mu_grad = 0.0
    sigma_grad = 0.0
    S0_grad = 0.0
    for i in reversed(range(num_steps)):
        S0 = state_price[i]
        mu_bar = S0*dt*g_bar
        sigma_bar = S0*increments[i]*g_bar
        g_bar = (1+mu*dt+sigma*increments[i])*g_bar

        mu_grad += mu_bar
        sigma_grad += sigma_bar
        if i == 0:
            S0_grad = g_bar

    return S0_grad, mu_grad, sigma_grad

S0 = 100.0  # Initial price
mu = 0.05   # Drift
sigma = 0.2 # Volatility
T = 1.0     # Time horizon
dt = 0.01   # Time step
key = random.PRNGKey(0)

gradients = manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key)

In [66]:
gradients

(Array(0.87897801, dtype=float64),
 Array(88.03136753, dtype=float64),
 Array(-88.79114883, dtype=float64))

In [67]:
sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array(gradients))

Array(0.11919671, dtype=float64)

# control flow

In [46]:
def control_flow(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=1):
    
    def output(a, b, c):
        d = b+c
        e = a*c
        l = e > d
        if l:
            f = d - e
        else:
            f = d + e
        g = e/f
        return d,e,f,g,l
    
    def forward(a,b,c,d,e,f,g,l,a_dot, b_dot, c_dot):
        d_dot = b_dot + c_dot
        e_dot = a_dot*c + a*c_dot
        if l:
            f_dot = d_dot - e_dot
        else:
            f_dot = d_dot + e_dot
        g_dot = 1.0/(f*f)*(e_dot*f - e*f_dot)
        return g_dot
    
    def backward(a,b,c,d,e,f,g,l,g_bar):
        f_bar = -e/(f*f)*g_bar
        e_bar = 1.0/f*g_bar
        d_bar = f_bar
        if l:
            e_bar -= f_bar
        else:
            e_bar += f_bar
        a_bar = e_bar*c
        c_bar = e_bar*a
        b_bar = d_bar
        c_bar += d_bar
        return a_bar, b_bar, c_bar
    
    def validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar):
        LHS = a_dot*a_bar + b_dot*b_bar + c_dot*c_bar
        RHS = g_dot*g_bar
        print(LHS,RHS,f'error:{abs(LHS-RHS)}')
    
    d,e,f,g,l = output(a,b,c)
    g_dot = forward(a,b,c,d,e,f,g,l,a_dot, b_dot, c_dot)
    a_bar, b_bar, c_bar = backward(a,b,c,d,e,f,g,l,g_bar)
    validation(a_dot, b_dot, c_dot, g_dot, a_bar, b_bar, c_bar, g_bar)
    

In [47]:
control_flow(a=2,b=3,c=4, a_dot=1, b_dot=1, c_dot=1, g_bar=1)

26.0 26.0 error:0.0


In [5]:
import jax
import jax.numpy as jnp

def output(a, b, c):
    d = b+c
    e = a*c
    l = e > d
    if l:
        f = d - e
    else:
        f = d + e
    g = e/f
    return g

# Set the initial values for a, b, c
a, b, c = 2.0, 3.0, 4.0  # Example values

# Set the sensitivities (tangents) for a, b, c
a_dot, b_dot, c_dot = 1.0, 1.0, 1.0  # Example sensitivities

# Function to compute forward mode derivative of g with respect to a, b, c
def compute_forward_gradient(a, b, c, a_dot, b_dot, c_dot):
    # Wrap the function to use with jax.jvp
    def func(inputs):
        a, b, c = inputs
        return output(a, b, c)

    # Inputs and their perturbations
    inputs = jnp.array([a, b, c])
    tangents = jnp.array([a_dot, b_dot, c_dot])

    # Compute the Jacobian-vector product
    _, g_dot = jax.jvp(func, (inputs,), (tangents,))
    return g_dot

# Compute g_dot
g_dot = compute_forward_gradient(a, b, c, a_dot, b_dot, c_dot)
print("Forward mode gradient g_dot:", g_dot)

# Function to compute reverse mode gradient of a, b, c given g_bar
def compute_reverse_gradient(a, b, c, g_bar):
    # Wrap the function to use with jax.vjp
    def func(a, b, c):
        return output(a, b, c)
    
    # Get the function output and vjp function
    g, vjp_fun = jax.vjp(func, a, b, c)

    # Compute the vector-Jacobian product
    a_bar, b_bar, c_bar = vjp_fun(g_bar)
    return a_bar, b_bar, c_bar

# Assume g_bar is given (the sensitivity of the loss with respect to g)
g_bar = 1.0  # Example sensitivity of the loss with respect to g

# Compute a_bar, b_bar, c_bar
a_bar, b_bar, c_bar = compute_reverse_gradient(a, b, c, g_bar)
print("Reverse mode gradients a_bar, b_bar, c_bar:", a_bar, b_bar, c_bar)


# Calculate the dot products
lhs = a_dot * a_bar + b_dot * b_bar + c_dot * c_bar
rhs = g_dot * g_bar

print(f"Left-hand side (Input Sensitivity Product): {lhs}")
print(f"Right-hand side (Output Sensitivity Product): {rhs}")

Forward mode gradient g_dot: 26.0
Reverse mode gradients a_bar, b_bar, c_bar: 28.0 -8.0 6.0
Left-hand side (Input Sensitivity Product): 26.0
Right-hand side (Output Sensitivity Product): 26.0


# performance comparison

Set T = 100 and dt = 0.01, test the execution time

In [3]:
import time

## manual checkpointing

In [9]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vjp, jvp

jax.config.update("jax_enable_x64", True) 

def compute_increments(dt, num_steps, key):
    return random.normal(key, (num_steps,)) * jnp.sqrt(dt)

def compute_step_price(S0, mu, sigma, increments, t0, T, dt):
    num_steps = round((T-t0)/ dt)
    if num_steps == 0:
        print(T,t0)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path

# for manually checkpointing
def forward_and_save_intermediates(S0, mu, sigma, T, dt, key):
    num_steps = round(T / dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    increments = compute_increments(dt, num_steps, key)
    # Save increments and time_steps for recomputation
    state_price = compute_step_price(S0, mu, sigma, increments, 0, T, dt)
    return S0, mu, sigma, increments, state_price

def manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key):
    S0, mu, sigma, increments, state_price = forward_and_save_intermediates(S0, mu, sigma, T, dt, key)
    num_steps = round(T / dt)
    g_bar = jnp.array(1.0)
    mu_grad = 0.0
    sigma_grad = 0.0
    S0_grad = 0.0
    for i in reversed(range(num_steps)):
        S0 = state_price[i]
        mu_bar = S0*dt*g_bar
        sigma_bar = S0*increments[i]*g_bar
        g_bar = (1+mu*dt+sigma*increments[i])*g_bar

        mu_grad += mu_bar
        sigma_grad += sigma_bar
        if i == 0:
            S0_grad = g_bar

    return S0_grad, mu_grad, sigma_grad

S0 = 100.0  # Initial price
mu = 0.05   # Drift
sigma = 0.2 # Volatility
T = 1000.0     # Time horizon
dt = 0.01   # Time step
key = random.PRNGKey(0)

start_time = time.time()
gradients = manual_checkpoint_gbm_grad(S0, mu, sigma, T, dt, key)
end_time = time.time()
execution_time = end_time - start_time

lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array(gradients))
print('LHS:',lhs)
print(f"Execution time: {execution_time} seconds")

LHS: 4.318122574486458e+16
Execution time: 12.584457874298096 seconds


## jax.checkpointing

In [10]:
import jax
import jax.numpy as jnp
from jax import random, checkpoint
jax.config.update("jax_enable_x64", True)  # force 64-bit accuracy


def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path[-1]

ckpt_simulate_gbm = checkpoint(simulate_gbm, static_argnums=(3, 4))

def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: ckpt_simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return ckpt_simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1000.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
start_time = time.time()
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)
end_time = time.time()
execution_time = end_time - start_time

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)

print(f"Execution time: {execution_time} seconds")

LHS: 4.318122574486458e+16
RHS: 4.318122574486346e+16
Execution time: 616.5304710865021 seconds


## no checkpointing

In [11]:
import jax
import jax.numpy as jnp
from jax import random, checkpoint
jax.config.update("jax_enable_x64", True)  # force 64-bit accuracy


def simulate_gbm(S0, mu, sigma, T, dt, key):
    num_steps = int(T / dt)
    increments = random.normal(key, (num_steps,)) * jnp.sqrt(dt)
#     time_steps = jnp.linspace(dt, T, num_steps)
    S_path = [S0]
    for i in range(num_steps):
        S_path.append(S_path[-1]*(1+mu*dt+sigma*increments[i]))
#         S_path.append(S_path[-1]*jnp.exp((mu - 0.5 * sigma**2)*dt+sigma*increments[i]))    
    return S_path[-1]


def forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot):
    gbm_func = lambda x: simulate_gbm(x[0], x[1], x[2], T, dt, key)
    inputs = jnp.array([S0, mu, sigma])
    tangents = jnp.array([a_dot, b_dot, c_dot])  # Change to calculate sensitivity w.r.t. each input
    _, g_dot = jax.jvp(gbm_func, (inputs,), (tangents,))
    return g_dot

def reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar):
    def gbm_func(S0, mu, sigma):
        return simulate_gbm(S0, mu, sigma, T, dt, key)
    _, vjp_fun = jax.vjp(gbm_func, S0, mu, sigma)
    g_bar = 1.0  # Sensitivity of the loss w.r.t. output
    return vjp_fun(g_bar)

S0, mu, sigma = 100.0, 0.05, 0.2
T, dt = 1000.0, 0.01
key = random.PRNGKey(0)
g_bar = 1.0
a_dot, b_dot, c_dot = 1.0,1.0,1.0

# Calculate forward and reverse mode derivatives
start_time = time.time()
g_dot = forward_mode_gbm(S0, mu, sigma, T, dt, key, a_dot, b_dot, c_dot)
a_bar, b_bar, c_bar = reverse_mode_gbm(S0, mu, sigma, T, dt, key, g_bar)
end_time = time.time()
execution_time = end_time - start_time

# Assume the forward and reverse are the same for demonstration
lhs = sum(jnp.array([a_dot, b_dot, c_dot]) * jnp.array([a_bar, b_bar, c_bar]))
rhs = g_dot * g_bar  # g_bar assumed to be same as g_dot for simplicity

print("LHS:", lhs)
print("RHS:", rhs)
print(f"Execution time: {execution_time} seconds")

LHS: 4.318122574486458e+16
RHS: 4.318122574486579e+16
Execution time: 239.1335198879242 seconds
