<a href="https://colab.research.google.com/github/arthurpham/google_colab/blob/main/TARF_MC_Performance_TQF2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install --upgrade tensorflow --user
!pip install tf-quant-finance==0.0.1.dev32
!pip install QuantLib-Python

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tf-quant-finance==0.0.1.dev32
  Downloading tf_quant_finance-0.0.1.dev32-py2.py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 9.2 MB/s 
Installing collected packages: tf-quant-finance
Successfully installed tf-quant-finance-0.0.1.dev32
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting QuantLib-Python
  Downloading QuantLib_Python-1.18-py2.py3-none-any.whl (1.4 kB)
Collecting QuantLib
  Downloading QuantLib-1.26-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (17.8 MB)
[K     |████████████████████████████████| 17.8 MB 129 kB/s 
[?25hInstalling collected packages: QuantLib, QuantLib-Python
Successfully installed QuantLib-1.26 QuantLib-Python-1.18


In [2]:
import os
# reduce number of threads
os.environ['TF_NUM_INTEROP_THREADS'] = '1'
os.environ['TF_NUM_INTRAOP_THREADS'] = '1'

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tf_quant_finance as tff 
import tensorflow as tf
import functools
import pandas as pd
import time
import QuantLib as ql

In [3]:
!nvidia-smi

Tue May 31 15:26:40 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
spot = 18.0
strike = 20.0
K_lower = 15.0
K_upper = 20.0
K_knockout = 30.0
tarf_target = 5.0
step_up_ratio = 2.0

r = 0.0
volatility = 0.5

In [5]:
#@title Set up parameters

dtype = tf.float64 #@param
num_samples = 200000 #@param
num_timesteps = 53 #@param

# expiries =tf.constant( [0.0, 0.5, 1.0], dtype=dtype) # This can be a rank 1 Tensor
dt = 1. / num_timesteps
# times = [1.0]
times = tf.linspace(tf.constant(0.0, dtype=dtype), tf.constant(1.0, dtype=dtype), num_timesteps)
rate = tf.constant(r, dtype=dtype)
dividend = tf.constant(0.0, dtype=dtype)
sigma = tf.constant(volatility, dtype=dtype)
spot = tf.constant(spot, dtype=dtype)
strikes = tf.constant(strike, dtype=dtype)

def set_up_pricer(times, watch_params=False):
    """Set up European option pricing function under Black-Scholes model.
    
    Args:
        expiries: List of expiries at which to to sample the trajectories.
        watch_params: A Python bool. When `True`, gradients of the price function wrt the inputs
          are computed more efficiently. 
    Returns:
     A callable that accepts a rank 1 tensor of strikes, and scalar values for 
     the spots and  volatility values. The callable outputs prices of
     the European call options on the grid `expiries x strikes`.
    """
    def price_eu_options(strikes, spot, sigma, rate, dividend):
        # Define drift and volatility functions. 
        def drift_fn(t, x):
          del t, x
          return rate - 0.5 * sigma**2
        def vol_fn(t, x):
          del t, x
          return tf.reshape(sigma, [1, 1])
        # Use GenericItoProcess class to set up the Ito process
        process = tff.models.GenericItoProcess(
            dim=1,
            drift_fn=drift_fn,
            volatility_fn=vol_fn,
            dtype=dtype)
        log_spot = tf.math.log(tf.reduce_mean(spot))
        if watch_params:
            watch_params_list = [sigma, rate, dividend]
        else:
            watch_params_list = None
        paths = process.sample_paths(
            times=times, num_samples=num_samples,
            initial_state=log_spot, 
            watch_params=watch_params_list,
            # Select a random number generator
            random_type=tff.math.random.RandomType.SOBOL, #PSEUDO_ANTITHETIC
            time_step=dt)
        
        @tf.function
        def tarf_payoff(element):
#             tf.print(element, summarize=-1)
            total = tf.constant(0.0, dtype=tf.float64)
            discounted_payoff = tf.constant(0.0, dtype=tf.float64)
            df = tf.constant(1.0, dtype=tf.float64)
            is_active = True
            for cur_spot in element:
                if is_active:
                    cashflow = tf.constant(0.0, dtype=tf.float64)
                    add_cashflow = False
                    if K_knockout <= cur_spot:
                        # early termination
                        is_active = False
                    if K_upper <= cur_spot: # cur_spot < K_knockout
                        cashflow = cur_spot - strike
                        add_cashflow = True
                    if cur_spot < K_lower:
                        cashflow = step_up_ratio*(cur_spot - strike)
                        add_cashflow = True

                    if add_cashflow:
                        if total + cashflow >= tarf_target:
                            cashflow = tarf_target - total
                            total += cashflow
                            discounted_payoff += df*cashflow
                            is_active = False
                        else:
                            total += cashflow
                            discounted_payoff += df*cashflow
                     
            return discounted_payoff

        reshaped_paths = tf.reshape(tf.math.exp(paths), [num_samples, num_timesteps])
        payoffs = tf.vectorized_map(tarf_payoff, reshaped_paths)
        prices = tf.reduce_mean(payoffs)

        return prices
    return price_eu_options    

price_eu_options = tf.function(set_up_pricer(times, watch_params=False),
                               input_signature=[
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64)
                               ])


In [6]:
def set_up_pricer_xla(times, watch_params=False):
    """Set up European option pricing function under Black-Scholes model.
    
    Args:
        expiries: List of expiries at which to to sample the trajectories.
        watch_params: A Python bool. When `True`, gradients of the price function wrt the inputs
          are computed more efficiently. 
    Returns:
     A callable that accepts a rank 1 tensor of strikes, and scalar values for 
     the spots and  volatility values. The callable outputs prices of
     the European call options on the grid `expiries x strikes`.
    """
    def price_eu_options(strikes, spot, sigma, rate, dividend):
        # Define drift and volatility functions. 
        def drift_fn(t, x):
          del t, x
          return rate - dividend - 0.5 * sigma**2
        def vol_fn(t, x):
          del t, x
          return tf.reshape(sigma, [1, 1])
        # Use GenericItoProcess class to set up the Ito process
        process = tff.models.GenericItoProcess(
            dim=1,
            drift_fn=drift_fn,
            volatility_fn=vol_fn,
            dtype=dtype)
        log_spot = tf.math.log(tf.reduce_mean(spot))
        if watch_params:
            watch_params_list = [sigma, rate, dividend]
        else:
            watch_params_list = None
        paths = process.sample_paths(
            times=times, num_samples=num_samples,
            initial_state=log_spot, 
            watch_params=watch_params_list,
            # Select a random number generator
            random_type=tff.math.random.RandomType.SOBOL, #PSEUDO_ANTITHETIC
            time_step=dt)
        
        @tf.function
        def tarf_payoff(paths):
            # Shape [num_timesteps, num_samples]
            paths = tf.transpose(paths)
            cur_spot = paths[0]
            total = tf.zeros_like(cur_spot)
            discounted_payoff = tf.zeros_like(cur_spot)
            df = tf.constant(1.0, dtype=tf.float64)
            is_active = tf.ones([num_samples], dtype=tf.bool)
            i = tf.constant(0, dtype=tf.int32)
            # Explicitly define the while_loop 
            def cond(i, is_active, total, discounted_payoff):
                return i < num_timesteps

            def body(i, is_active, total, discounted_payoff):
                # Here Tensors are of shape `[num_samples]`
                cur_spot = paths[i]

                cashflow = tf.zeros_like(cur_spot)
                new_is_active = K_knockout > cur_spot
                add_cashflow = tf.where(tf.logical_or(K_upper <= cur_spot, cur_spot < K_lower),
                    True, 
                    False)

                new_cashflow = tf.where(K_upper <= cur_spot,
                    cur_spot - strike,
                    cashflow
                )
                new_cashflow = tf.where(cur_spot < K_lower,
                                    step_up_ratio*(cur_spot - strike),
                                    new_cashflow)
                new_is_active = tf.where(add_cashflow,
                    tf.where(total + new_cashflow >= tarf_target,
                            False, new_is_active),
                    new_is_active)

                new_cashflow = tf.where(add_cashflow,
                    tf.where(total + new_cashflow >= tarf_target,
                            tarf_target - total, new_cashflow),
                    new_cashflow
                    )

                new_total = tf.where(add_cashflow,
                    total + new_cashflow,
                    total
                    )
                new_discounted_payoff = tf.where(add_cashflow,
                    discounted_payoff + df * new_cashflow,
                    discounted_payoff)
                # Update values only if active
                new_cashflow = tf.where(is_active, 
                                        new_cashflow,
                                        cashflow)
                new_total = tf.where(is_active, new_total, total)
                new_discounted_payoff = tf.where(is_active, 
                                        new_discounted_payoff,
                                        discounted_payoff)
                new_is_active = tf.where(is_active, new_is_active, is_active)

                return (i + 1, new_is_active, new_total, new_discounted_payoff)
                
            _, is_active, total, discounted_payoff = tf.while_loop(
                cond, body, (i, is_active, total, discounted_payoff),
                maximum_iterations=num_timesteps,
            )
            return discounted_payoff

        reshaped_paths = tf.reshape(tf.math.exp(paths), [num_samples, num_timesteps])
        payoffs = tarf_payoff(reshaped_paths)
        prices = tf.reduce_mean(payoffs)

        return prices
    return price_eu_options

price_eu_options_xla = tf.function(set_up_pricer_xla(times, watch_params=False),
                               input_signature=[
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64),
                                                tf.TensorSpec([], dtype=tf.float64)
                               ], jit_compile=True)    

price_eu_options_xla2 = set_up_pricer_xla(times, watch_params=True)

In [7]:
# device = "/gpu:0"
# with tf.device(device):
#     tarf_price = price_eu_options(strikes, spot, sigma)
#     print('price', tarf_price)

In [8]:
# device = "/gpu:0"
# with tf.device(device):
#     tarf_price = price_eu_options_xla(strikes, spot, sigma)
#     print('price', tarf_price)

In [9]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_price = price_eu_options(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_price = price_eu_options(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('price', tarf_price)

------------------------
TQF gpu TARF
wall time + tracing:  23.417309045791626
options per second + tracing:  0.04270345486941046
wall time:  2.5204977989196777
options per second:  0.3967470237143689
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
------------------------
TQF cpu TARF
wall time + tracing:  8.546136856079102
options per second + tracing:  0.11701193379423507
wall time:  5.958648681640625
options per second:  0.1678232856857513
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)


In [10]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_price = price_eu_options_xla(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_price = price_eu_options_xla(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF XLA'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('price', tarf_price)

------------------------
TQF gpu TARF XLA
wall time + tracing:  2.962099313735962
options per second + tracing:  0.33759840372764044
wall time:  0.022627830505371094
options per second:  44.193366207274416
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
------------------------
TQF cpu TARF XLA
wall time + tracing:  2.3186111450195312
options per second + tracing:  0.4312926736973725
wall time:  1.8334612846374512
options per second:  0.5454164799546013
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)


In [11]:
@tf.function(jit_compile=False,
             input_signature=[tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64)
                               ])
def greeks_fn(strikes, spot, sigma, rate, dividend):
    with tf.GradientTape() as tape:
      tape.watch([spot, sigma, rate, dividend])
      prices = price_eu_options(strikes, spot, sigma, rate, dividend)
    return prices, tape.gradient(prices, [spot, sigma, rate, dividend])

@tf.function(jit_compile=True,
             input_signature=[tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64)
                               ])
def greeks_fn_xla(strikes, spot, sigma, rate, dividend):
    with tf.GradientTape() as tape:
      tape.watch([spot, sigma, rate, dividend])
      prices = price_eu_options_xla2(strikes, spot, sigma, rate, dividend)
    return prices, tape.gradient(prices, [spot, sigma, rate, dividend])

In [12]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_price, tarf_greeks = greeks_fn(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_price, tarf_greeks = greeks_fn(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF price+delta+vega'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('price', tarf_price)
        print('greeks', tarf_greeks)

  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)


------------------------
TQF gpu TARF price+delta+vega
wall time + tracing:  8.679392337799072
options per second + tracing:  0.11521543917826635
wall time:  5.569403648376465
options per second:  0.17955243741248844
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=21.37942877385511>, <tf.Tensor: shape=(), dtype=float64, numpy=-331.9822946422317>, <tf.Tensor: shape=(), dtype=float64, numpy=217.8267631339107>, None]
------------------------
TQF cpu TARF price+delta+vega
wall time + tracing:  11.1425940990448
options per second + tracing:  0.08974570832529251
wall time:  8.111238718032837
options per second:  0.12328573165733718
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=21.37942877385511>, <tf.Tensor: shape=(), dtype=float64, numpy=-331.98229464223164>, <tf.Tensor: shape=(), dtype=float64, numpy=217.8267631339

In [13]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_price, tarf_greeks = greeks_fn_xla(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_price, tarf_greeks = greeks_fn_xla(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF XLA price+delta+vega'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('price', tarf_price)
        print('greeks', tarf_greeks)

------------------------
TQF gpu TARF XLA price+delta+vega
wall time + tracing:  3.4292311668395996
options per second + tracing:  0.2916105538961394
wall time:  0.06168770790100098
options per second:  16.210684981274422
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=21.37942877385511>, <tf.Tensor: shape=(), dtype=float64, numpy=-331.9822946422317>, <tf.Tensor: shape=(), dtype=float64, numpy=217.8267631339107>, <tf.Tensor: shape=(), dtype=float64, numpy=-217.8267631339107>]
------------------------
TQF cpu TARF XLA price+delta+vega
wall time + tracing:  7.095331430435181
options per second + tracing:  0.1409377433322613
wall time:  5.588995933532715
options per second:  0.1789230144184263
------------------------
price tf.Tensor(-246.927357070608, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=21.37942877385511>, <tf.Tensor: shape=(), dtype=float64, numpy=-331.98229464

In [14]:
# devices = ['gpu'] # , 'cpu'

# for device in devices:
#     with tf.device('/{}:0'.format(device)):
#         for spot in [15.0, 18.2, 20.0, 25.0]:
#             # spot = tf.constant(spot, dtype=dtype)
#             t = time.time()
#             tarf_price, tarf_greeks = greeks_fn_xla(strikes, tf.convert_to_tensor(spot, dtype=dtype), sigma, rate, dividend)
#             time_tqf0 = time.time() - t

#             t = time.time()
#             tarf_price, tarf_greeks = greeks_fn_xla(strikes, tf.convert_to_tensor(spot, dtype=dtype), sigma, rate, dividend)
#             time_tqf = time.time() - t

#             print('------------------------')
#             print('TQF {} TARF XLA price+delta+vega spot: {}'.format(device, spot))
#             print('wall time + tracing: ', time_tqf0)
#             print('options per second + tracing: ', 1.0/time_tqf0)
#             print('wall time: ', time_tqf)
#             print('options per second: ', 1.0/time_tqf)
#             print('------------------------')
#             print('price', tarf_price)
#             print('greeks', tarf_greeks)
#             # print(greeks_fn_xla.pretty_printed_concrete_signatures())            

In [15]:
devices = ['gpu'] # , 'cpu'

for device in devices:
    with tf.device('/{}:0'.format(device)):
        for spot in [tf.convert_to_tensor(x, tf.float64) for x in [15.0, 18.2, 20.0, 25.0]]:
            t = time.time()
            tarf_price, tarf_greeks = greeks_fn_xla(strikes, spot, sigma, rate, dividend)
            time_tqf0 = time.time() - t

            t = time.time()
            tarf_price, tarf_greeks = greeks_fn_xla(strikes, spot, sigma, rate, dividend)
            time_tqf = time.time() - t

            print('------------------------')
            print('TQF {} TARF XLA price+delta+vega spot: {}'.format(device, spot))
            print('wall time + tracing: ', time_tqf0)
            print('options per second + tracing: ', 1.0/time_tqf0)
            print('wall time: ', time_tqf)
            print('options per second: ', 1.0/time_tqf)
            print('------------------------')
            print('price', tarf_price)
            print('greeks', tarf_greeks)
            # print(greeks_fn_xla.pretty_printed_concrete_signatures())            

------------------------
TQF gpu TARF XLA price+delta+vega spot: 15.0
wall time + tracing:  1.2935822010040283
options per second + tracing:  0.7730471238888714
wall time:  0.05926394462585449
options per second:  16.873665874136563
------------------------
price tf.Tensor(-489.18585120429816, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=47.21793991971346>, <tf.Tensor: shape=(), dtype=float64, numpy=-375.1297930074455>, <tf.Tensor: shape=(), dtype=float64, numpy=358.638751028938>, <tf.Tensor: shape=(), dtype=float64, numpy=-358.638751028938>]
------------------------
TQF gpu TARF XLA price+delta+vega spot: 18.2
wall time + tracing:  1.2760756015777588
options per second + tracing:  0.7836526290163257
wall time:  0.05937552452087402
options per second:  16.841956480711858
------------------------
price tf.Tensor(-234.51510001171167, shape=(), dtype=float64)
greeks [<tf.Tensor: shape=(), dtype=float64, numpy=20.189230768587276>, <tf.Tensor: shape=(), dtype=

In [16]:
print(greeks_fn_xla.pretty_printed_concrete_signatures())  

greeks_fn_xla(strikes, spot, sigma, rate, dividend)
  Args:
    strikes: float64 Tensor, shape=()
    spot: float64 Tensor, shape=()
    sigma: float64 Tensor, shape=()
    rate: float64 Tensor, shape=()
    dividend: float64 Tensor, shape=()
  Returns:
    (<1>, [<2>, <3>, <4>, <5>])
      <1>: float64 Tensor, shape=()
      <2>: float64 Tensor, shape=()
      <3>: float64 Tensor, shape=()
      <4>: float64 Tensor, shape=()
      <5>: float64 Tensor, shape=()


In [17]:
@tf.function(jit_compile=False,
             input_signature=[tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64)
                               ])
def delta_fn(strikes, spot, sigma, rate, dividend):
    fn = lambda spot: price_eu_options(strikes, spot, sigma, rate, dividend)
    return tff.math.fwd_gradient(fn, spot, use_gradient_tape=True)

@tf.function(jit_compile=True,
             input_signature=[tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64)
                               ])
def delta_fn_xla(strikes, spot, sigma, rate, dividend):
    fn = lambda spot: price_eu_options_xla(strikes, spot, sigma, rate, dividend)
    return tff.math.fwd_gradient(fn, spot, use_gradient_tape=True)

In [18]:
@tf.function(jit_compile=False,
             input_signature=[tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64),
                            tf.TensorSpec([], dtype=tf.float64)
                               ])
def vega_fn(strikes, spot, sigma, rate, dividend):
    fn = lambda sigma: price_eu_options(strikes, spot, sigma, rate, dividend)
    return tff.math.fwd_gradient(fn, sigma, use_gradient_tape=True)

In [19]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_price = price_eu_options(strikes, spot, sigma, rate, dividend)
        tarf_delta = delta_fn(strikes, spot, sigma, rate, dividend)
        tarf_vega = vega_fn(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_price = price_eu_options(strikes, spot, sigma, rate, dividend)
        tarf_delta = delta_fn(strikes, spot, sigma, rate, dividend)
        tarf_vega = vega_fn(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF price+delta+vega'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('price', tarf_price)
        print('delta', tarf_delta)
        print('vega', tarf_vega)

  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)
  "shape. This may consume a large amount of memory." % value)


------------------------
TQF gpu TARF price+delta+vega
wall time + tracing:  18.301156520843506
options per second + tracing:  0.05464135552641619
wall time:  6.987584590911865
options per second:  0.14311096874599727
------------------------
price tf.Tensor(4.951110958164982, shape=(), dtype=float64)
delta tf.Tensor(0.00300543832659926, shape=(), dtype=float64)
vega tf.Tensor(-0.11210128640365602, shape=(), dtype=float64)
------------------------
TQF cpu TARF price+delta+vega
wall time + tracing:  50.33869290351868
options per second + tracing:  0.019865434367091005
wall time:  53.090919971466064
options per second:  0.018835612578148093
------------------------
price tf.Tensor(4.951110958164982, shape=(), dtype=float64)
delta tf.Tensor(0.003005438326599259, shape=(), dtype=float64)
vega tf.Tensor(-0.11210128640365602, shape=(), dtype=float64)


In [20]:
devices = ['gpu', 'cpu']

for device in devices:
    with tf.device('/{}:0'.format(device)):
        t = time.time()
        tarf_delta = delta_fn_xla(strikes, spot, sigma, rate, dividend)
        time_tqf0 = time.time() - t

        t = time.time()
        tarf_delta = delta_fn_xla(strikes, spot, sigma, rate, dividend)
        time_tqf = time.time() - t

        print('------------------------')
        print('TQF {} TARF'.format(device))
        print('wall time + tracing: ', time_tqf0)
        print('options per second + tracing: ', 1.0/time_tqf0)
        print('wall time: ', time_tqf)
        print('options per second: ', 1.0/time_tqf)
        print('------------------------')
        print('delta', tarf_delta)

InternalError: ignored