# SABR Volatility Model

This notebook gives an overview about the SABR volatillity model. For an overview about various volatility models, please refer to [this](../marketdata/equity_volatilities.ipynb) notebook.

## Import Libaries

In [None]:
# Import libraries
import numpy as np
import datetime as dt
from dateutil.relativedelta import relativedelta
import pandas as pd
import plotly.graph_objects as go
import ipywidgets as widgets
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import axes3d
matplotlib.use('nbagg')
# %matplotlib inline
import random
import scipy as sp
import scipy.interpolate
from scipy.optimize import least_squares
from jupyter_dash import JupyterDash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import rivapy
from rivapy import marketdata as mkt_data
from rivapy import enums as enums
#reload modules
%load_ext autoreload
%autoreload 2

## Define SABR Function

The SABR model assumes that the forward rate and the instantaneous volatility are driven by two correlated Brownian motions:

$$df_t = \alpha_t f_t^\beta d W_t^1$$

$$d\alpha_t = \nu\alpha_t d W_t^2$$

$$E\bigl[d W_t^1 d W_T^2\bigr] = \rho d t$$

The expression that the implied volatility must satisfy is

$$\sigma_B(K,f) = \frac{\alpha\biggl\{1+\biggl[\frac{(1-\beta)^2}{24}\frac{\alpha^2}{(fK)^{1-\beta}}+\frac{1}{4}\frac{\rho\beta\nu\alpha}{(FK)^{(1-\beta)/2}}+\frac{2-3\rho^2}{24}\nu^2\biggr]T\biggr\}}{(fK)^{(1-\beta)/2}\biggl[1+\frac{(1-\beta)^2}{24}{ln}^2\frac{f}{K}+\frac{(1-\beta)^4}{1920}{ln}^4\frac{f}{K}\biggr]}\frac{z}{\chi(z)}$$

$$z=\frac{\nu}{\alpha}(fK)^{(1-\beta)/2}ln\frac{f}{K}$$

$$\chi(z) = ln\Biggl[\frac{\sqrt{1-2\rho z+z^2}+z-\rho}{1-\rho}\Biggr]$$

When $f = K $ (for ATM options), the above formula for implied volatility simplifies to:

$$\sigma_{ATM} = \sigma_B(f,f)=\frac{\alpha\biggl\{1+\biggl[\frac{(1-\beta)^2}{24}\frac{\alpha^2}{f^{2-2\beta}}+\frac{1}{4}\frac{\rho\beta\nu\alpha}{f^{1-\beta}}\frac{2-3\rho^2}{24}\nu^2\biggr]T\biggr\}}{f^{1-\beta}}$$

where

> $\alpha$ is the instantaneous vol;

> $\nu$ is the vol of vol;

> $\rho$ is the correlation between the Brownian motions driving the forward rate and the instantaneous vol;

> $\beta$ is the CEV component for forward rate (determines shape of forward rates, leverage effect and backbone of ATM vol).

Source: https://bsic.it/sabr-stochastic-volatility-model-volatility-smile/

In [None]:
# Define SABR function
# https://bsic.it/sabr-stochastic-volatility-model-volatility-smile/
def SABR(f,K,T,alpha,nu,beta, rho):
    
    zeta = nu/alpha*(f*K)**((1-beta)/2)*np.log(f/K)
    chi_zeta = np.log((np.sqrt(1-2*rho*zeta+zeta**2)+zeta-rho)/(1-rho))
    
    if f == K:
        sigma = alpha*(1+((1-beta)**2/24*alpha**2/f**(2-2*beta)+1/4*rho*beta*nu*alpha/f**(1-beta)+(2-3*rho**2)/24*nu**2)*T)/f**(1-beta)
        
    else:
        sigma = alpha*(1+((1-beta)**2/24*alpha**2/(f*K)**(1-beta)+1/4*rho*beta*nu*alpha/(f*K)**((1-beta)/2)+(2-3*rho**2)/24*nu**2)*T)/(f*K)**((1-beta)/2)*(1+(1-beta)**2/24*np.log(f/K)**2+(1-beta)**4/1920*np.log(f/K)**4)*zeta/chi_zeta
    
    return sigma

## Volatility Smile

In [None]:
# Plotting SABR Volatility Surface

# create widgets
style = {'description_width': 'initial'}

FloatTextAlpha = widgets.FloatText(value = 0.3, step = 0.01, description = 'Alpha')

FloatSliderNu = widgets.FloatSlider(value = 0.1, min = 0.0001, max = 5, step = 0.01, description = 'Nu',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatSliderBeta = widgets.FloatSlider(value = 0.1, min = 0, max = 1, step = 0.01, description = 'Beta',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatSliderRho = widgets.FloatSlider(value = 0.1, min = -0.999999, max = 0.999999, step = 0.01, description = 'Rho',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatRangeSliderStrikes = widgets.FloatRangeSlider(value=[.4, 1.6], min=0, max=3.0, step=0.05, description='Strike Range:', 
                                                   disabled=False, continuous_update=False,orientation='horizontal',
                                                   readout=True, readout_format='.1f',style=style)

FloatSliderExpiries = widgets.FloatSlider(value=0, min=0, max=30.0, step=.25, description='Expiry:', 
                                                   disabled=False, continuous_update=False,orientation='horizontal',
                                                   readout=True, readout_format='.0f',style=style)

ButtonNewPlot = widgets.Button(description="New Plot")

ButtonAddTrace = widgets.Button(description="Add Trace")

global OutputWidget
OutputWidget = widgets.Output()

def create_vol_grid(alpha, nu, beta, rho,strike_range, expiry):
    F_0 = 1
    strikes = np.linspace(strike_range[0], strike_range[1], num=100)
    vols = [SABR(F_0,x,expiry,alpha,nu,beta,rho) for x in strikes]

    return strikes, vols

def create_plot(strikes, vols,expiry,alpha, nu, beta, rho):
    
    fig.add_trace(go.Scatter(x= strikes,y= vols 
                            ,mode = 'lines+markers'
                          ,hovertemplate = 
                            'Moneyness:  %{x: .1%}' #+\
                            +'<br>Volatility: %{y: .1%}'
                            +'<br>Expiry: {:,.0f} Yrs'.format(expiry)
                            +'<br>Alpha: {:,.1%}'.format(alpha)
                            +'<br>Nu: {:,.1f}'.format(nu)
                            +'<br>Beta: {:,.1f}'.format(beta)
                            +'<br>Rho: {:,.1f}'.format(rho)
                            +'<extra></extra>',
                            showlegend=False)
                         )


    fig.update_layout(title={
                          'text': "<b>Volatility Smile</b>",
                          'y':0.95,
                          'x':0.5,
                          'xanchor': 'center',
                          'yanchor': 'top'
                            }
                    ,width=1000
                    ,height=500
                    ,xaxis_title='Moneyness'
                    ,xaxis_tickformat = '.1%'
                    ,xaxis_range=[strikes.min(),strikes.max()]
                    ,yaxis_title='Volatility'
                    ,yaxis_tickformat = '.1%'
                    ,yaxis_range=[0,1]
                    ,font=dict(
                      family="Courier New, monospace"
                      ,size=10
                      )
                    ,margin=dict(l=65, r=50, b=65, t=90)
    )
    fig.show()

def plot(alpha, nu, beta, rho,strike_range, expiry):
#     function is called by eventhandler, i.e. if input parameter changes

#     clear output
    OutputWidget.clear_output()
    
#     1. create vol grid
    strikes, vols = create_vol_grid(alpha, nu, beta, rho,strike_range, expiry)

#     2. plot surface
    create_plot(strikes, vols,expiry,alpha, nu, beta, rho)
    

def eventhandler(change):
          
    alpha = FloatTextAlpha.value
    nu = FloatSliderNu.value
    beta = FloatSliderBeta.value
    rho = FloatSliderRho.value
    strike_range = FloatRangeSliderStrikes.value
    expiry = FloatSliderExpiries.value
    
#     call plot function
    with OutputWidget:
        plot(alpha, nu, beta, rho,strike_range, expiry)
        
def eventhandler2(change):
    global fig
    
    fig = go.Figure()
    
    alpha = FloatTextAlpha.value
    nu = FloatSliderNu.value
    beta = FloatSliderBeta.value
    rho = FloatSliderRho.value
    strike_range = FloatRangeSliderStrikes.value
    expiry = FloatSliderExpiries.value
    
    with OutputWidget:
        plot(alpha, nu, beta, rho,strike_range, expiry)

# bind eventhandler to widgets
ButtonAddTrace.on_click(eventhandler)
ButtonNewPlot.on_click(eventhandler2)

# widgets groups
WidgetsGrpH1 = widgets.HBox(children=[widgets.Label('Set Chart Area:')])
WidgetsGrpH2 = widgets.HBox(children=[FloatRangeSliderStrikes])
WidgetsGrpH3 = widgets.HBox(children=[widgets.Label('Set Parameters:')])
WidgetsGrpH4 = widgets.HBox(children=[FloatTextAlpha,FloatSliderExpiries])
WidgetsGrpH5 = widgets.HBox(children=[FloatSliderNu,FloatSliderBeta,FloatSliderRho])
WidgetsGrpH6 = widgets.HBox(children=[ButtonNewPlot,ButtonAddTrace])
WidgetsGrpV1 = widgets.VBox(children=[WidgetsGrpH1,WidgetsGrpH2,WidgetsGrpH3,WidgetsGrpH4,WidgetsGrpH5,WidgetsGrpH6])

display(WidgetsGrpV1)
display(OutputWidget)

## Volatility Surface

### Single Parameter Set

In [None]:
# Plotting SABR Volatility Surface

# strikes = np.linspace(0.4, 1.6, num=100)
# expiries = np.linspace(0.0, 5.0, (5*4+1), endpoint=True)

# Create Widgets

style = {'description_width': 'initial'}

FloatTextAlpha = widgets.FloatText(value = 0.3, step = 0.01, description = 'Alpha')

FloatSliderNu = widgets.FloatSlider(value = 0.1, min = 0.0001, max = 5, step = 0.01, description = 'Nu',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatSliderBeta = widgets.FloatSlider(value = 0.1, min = 0, max = 1, step = 0.01, description = 'Beta',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatSliderRho = widgets.FloatSlider(value = 0.1, min = -0.999999, max = 0.999999, step = 0.01, description = 'Rho',
                                       continuous_update=False, orientation='horizontal', readout=True, readout_format='.1f')

FloatRangeSliderStrikes = widgets.FloatRangeSlider(value=[.4, 1.6], min=0, max=3.0, step=0.05, description='Strike Range:', 
                                                   disabled=False, continuous_update=False,orientation='horizontal',
                                                   readout=True, readout_format='.1f',style=style)

FloatRangeSliderExpiries = widgets.FloatRangeSlider(value=[0, 3], min=0, max=30.0, step=1, description='Expiries Range:', 
                                                   disabled=False, continuous_update=False,orientation='horizontal',
                                                   readout=True, readout_format='.0f',style=style)

ButtonCreatePlot = widgets.Button(description="Create Plot")

global OutputWidget2
OutputWidget2 = widgets.Output()

def create_vol_grid2(alpha, nu, beta, rho,strike_range, expiry_range):
    F_0 = 1
    strikes = np.linspace(strike_range[0], strike_range[1], num=100)
    expiries = np.linspace(expiry_range[0], expiry_range[1], int(expiry_range[1]*4+1), endpoint=True)
#     print(strikes)
    vols = np.empty(shape=(strikes.shape[0], expiries.shape[0]))
    for i in range(strikes.shape[0]):
        for j in range(expiries.shape[0]):
            vols[i,j] = SABR(F_0,strikes[i],expiries[j],alpha,nu,beta,rho)
#     print(vols)
    return strikes, expiries, vols

def create_plot2(strikes, expiries, vols):
    fig1 = go.Figure(data=[go.Surface(x= expiries,y= strikes,z= vols   
                    ,contours = {"x": {"show": True,"size": 0.1, "color":"red"},
                                "y": {"show": True,"size": 0.1, "color":"red"},
                                }
                          ,hovertemplate = 
                            'Moneyness:  %{y: .1%}' +\
                            '<br>Maturity (yrs): %{x: .2f}' +\
                            '<br>Volatility: %{z: .1%}<extra></extra>'
                         ,colorscale = 'temps')
                         ])



    fig1.update_layout(title={
                          'text': "<b>Volatility Surface</b>",
                          'y':0.95,
                          'x':0.5,
                          'xanchor': 'center',
                          'yanchor': 'top'
                            }
#                     ,autosize=True
                    ,width=1000
                    ,height=500
                    ,scene = dict(
                      xaxis_title='Maturity (yrs)'
                      ,xaxis_tickformat = '.2f'
                      ,xaxis_autorange = 'reversed'
                      ,yaxis_title='Moneyness'
                      ,yaxis_tickformat = '.1%'
                      ,zaxis_title='Volatility'
                      ,zaxis_tickformat = '.1%'
                      )
                    ,font=dict(
                      family="Courier New, monospace"
                      ,size=10
                      )
                    ,margin=dict(l=65, r=50, b=65, t=90)
    )
    fig1.show()

def plot2(alpha, nu, beta, rho,strike_range, expiry_range):
#     function is valled by eventhandler, i.e. if input parameter changes
#     clear output
    OutputWidget2.clear_output()
    
#     1. Create vol grid
    strikes, expiries, vols = create_vol_grid2(alpha, nu, beta, rho,strike_range, expiry_range)

#     2. plot surface
    create_plot2(strikes, expiries, vols)
    

def eventhandler3(change):
    
    alpha = FloatTextAlpha.value
    nu = FloatSliderNu.value
    beta = FloatSliderBeta.value
    rho = FloatSliderRho.value
    strike_range = FloatRangeSliderStrikes.value
    expiry_range = FloatRangeSliderExpiries.value
    
#     call plot function
    with OutputWidget2:
        plot2(alpha, nu, beta, rho,strike_range, expiry_range)
        
# observe parameter changes
FloatTextAlpha.observe(eventhandler3, names='value')
FloatSliderNu.observe(eventhandler3, names='value')
FloatSliderBeta.observe(eventhandler3, names='value')
FloatSliderRho.observe(eventhandler3, names='value')
FloatRangeSliderStrikes.observe(eventhandler3, names='value')
FloatRangeSliderExpiries.observe(eventhandler3, names='value')

# bind eventhandler to widgets
ButtonCreatePlot.on_click(eventhandler3)

# Widgets groups
WidgetsGrpH1 = widgets.HBox(children=[widgets.Label('Set Chart Area:')])
WidgetsGrpH2 = widgets.HBox(children=[FloatRangeSliderStrikes,FloatRangeSliderExpiries])
WidgetsGrpH3 = widgets.HBox(children=[widgets.Label('Set Parameters:')])
WidgetsGrpH4 = widgets.HBox(children=[FloatTextAlpha])
WidgetsGrpH5 = widgets.HBox(children=[FloatSliderNu,FloatSliderBeta,FloatSliderRho])
WidgetsGrpH6 = widgets.HBox(children=[ButtonCreatePlot])
WidgetsGrpV1 = widgets.VBox(children=[WidgetsGrpH1,WidgetsGrpH2,WidgetsGrpH3,WidgetsGrpH4,WidgetsGrpH5,WidgetsGrpH6])


display(WidgetsGrpV1)
display(OutputWidget2)

### Expiry Dependent Parameters 

In [None]:
# Define moneyness – time-to-maturity grid
strikes = np.linspace(0.4, 1.6, num=100)
expiries = np.linspace(0.0, 3.0, (3*4+1), endpoint=True)

In [None]:
# Define parameters
F_0 = 1
array_alpha = np.random.uniform(low=.1, high=.2,size=(expiries.size,))
array_nu = np.random.uniform(low=.0001, high=1,size=(expiries.size,))
array_beta = np.random.uniform(low=0, high=1,size=(expiries.size,))
array_rho = np.random.uniform(low=-.999999, high=.999999,size=(expiries.size,))

In [None]:
# Create vol grid
vols = np.empty(shape=(strikes.size, expiries.size))
for i in range(strikes.size):
    for j in range(expiries.size):
        vols[i,j] = SABR(F_0,strikes[i],expiries[j],array_alpha[j],array_nu[j],array_beta[j],array_rho[j])

In [None]:
# plot
fig2 = go.Figure(data=[go.Surface(x= expiries,y= strikes,z= vols                   
                    ,hidesurface =True
                    ,contours = {"x": {"show": True,"size": 0.1, "color":"red"},
#                                 "y": {"show": True,"size": 0.1, "color":"red"},
                                }
                    ,hovertemplate = 
                        'Moneyness:  %{y: .2%}' +\
                        '<br>Maturity (yrs): %{x: .1f}' +\
                        '<br>Volatility: %{z: .2f}<extra></extra>'
                     ,colorscale = 'temps')

                     ])



fig2.update_layout(title={
                      'text': "<b>Volatility Surface</b>",
                      'y':0.95,
                      'x':0.5,
                      'xanchor': 'center',
                      'yanchor': 'top'
                        }
                # ,autosize=True
                ,width=1000
                ,height=500
                ,scene = dict(
                  xaxis_title='Maturity (yrs)'
                  ,xaxis_tickformat = '.1f'
                  ,xaxis_autorange = 'reversed'
                  ,yaxis_title='Moneyness'
                  ,yaxis_tickformat = '.2%'
                  ,zaxis_title='Volatility'
                  ,zaxis_tickformat = '.2%'
                  )
                ,font=dict(
                  family="Courier New, monospace"
                  ,size=10
                  )
                ,margin=dict(l=65, r=50, b=65, t=90)
)
fig2.show()

### Rivapy

#### Creating Forward Curve
We create a dummy forward curve as shown in the  [forward_curve](equity_forwardcurve.ipynb) notebook which will be used in all subsequent volatility surface constructions.

In [None]:
refdate = dt.datetime(2017,1,1)

#dividend table neede fo forward curve
object_id = "TEST_DIV" 
ex_dates = [dt.datetime(2018,3,29), dt.datetime(2019,3,29), dt.datetime(2020,3,29), dt.datetime(2021,3,29)]
pay_dates = [dt.datetime(2018,4,1), dt.datetime(2019,4,1), dt.datetime(2020,4,1), dt.datetime(2021,4,1)]
tax_factors = [1.0, 1.0, 1.0, 1.0]
div_yield = [0, 0.005, 0.01, 0.01]
div_cash = [3.0, 2.0, 1.0, 0.0]
div_table=rivapy.marketdata.DividendTable(object_id, refdate, ex_dates, pay_dates, div_yield, div_cash, tax_factors)

#discount- and borrowing curve needed for forward curve
dates = [refdate + dt.timedelta(days=x) for x in [0,10]]
df = [1.0,1.0]
dc = mkt_data.DiscountCurve(object_id, refdate, dates, df, 
                             enums.InterpolationType.HAGAN_DF, enums.ExtrapolationType.NONE, enums.DayCounterType.Act365Fixed)
bc = mkt_data.DiscountCurve(object_id, refdate, dates, df, 
                             enums.InterpolationType.HAGAN_DF, enums.ExtrapolationType.NONE, enums.DayCounterType.Act365Fixed)
spot = 100.0

#forward curve
forward_curve = mkt_data.EquityForwardCurve(spot, dc, bc, div_table)

#### SABR Parametrization

In [None]:
ttm = [1.0/12.0, 1.0, 2.0, 3.0]
sabr_params = np.array([[.2, 0.1, .9,-.8],
                        [.23, 0.1, .1, .1],
                        [.28, .3, .9, -.75,],
                        [.30, .3, .9, -.85,]])

In [None]:
sabr_param = mkt_data.VolatilityParametrizationSABR(ttm, sabr_params)
sabr_param.calc_implied_vol(1,1)

#### Volatility Surface

In [None]:
obj_id = 'TEST_SURFACE'
refdate = dt.datetime(2017,1,1)
vol_surf = mkt_data.VolatilitySurface(obj_id, refdate, forward_curve, enums.DayCounterType.Act365Fixed, sabr_param)
# vol_surface = mkt_data.VolatilitySurface(obj_id, refdate, vol_surf.getForwardCurve(), enums.DayCounterType.Act365Fixed, sabr_param)

In [None]:
vol = vol_surf.calc_implied_vol(refdate + dt.timedelta(days=365),100,refdate)
print(vol)

In [None]:
# 
refdate = dt.datetime(2017,1,1,0,0,0)
expiries = [dt.datetime(2017,2,1,0,0,0), dt.datetime(2018,1,1,0,0,0), dt.datetime(2019,1,1,0,0,0), dt.datetime(2020,1,1,0,0,0)]

# strikes = list(s_range(80, 120, 100))
moneyness = np.linspace(0.5, 1.5,100)

y = moneyness
x = ttm

term_structure = []
for i in moneyness: 
    temp = []
    for j in expiries:
      strike = i*forward_curve.value(refdate,j)
      temp.append(vol_surf.calc_implied_vol(j, strike,refdate))
    term_structure.append(temp)

fig3 = go.Figure(data=[go.Surface(x=x, y=y,z=term_structure
                      ,contours = {"x": {"show": True,"size": 0.1, "color":"red"},
                                   "y": {"show": True,"size": 0.1, "color":"red"},}
                      ,hovertemplate = 
                        'Moneyness:  %{y: .2%}' +\
                        '<br>Maturity (yrs): %{x: .1f}' +\
                        '<br>Volatility: %{z: .2f}<extra></extra>'
                     ,colorscale = 'temps')
                     ])

fig3.update_layout(title={
                      'text': "<b>Volatility Surface</b>",
                      'y':0.95,
                      'x':0.5,
                      'xanchor': 'center',
                      'yanchor': 'top'
                        }
                # ,autosize=True
                ,width=1000
                ,height=500
                ,scene = dict(
                  xaxis_title='Maturity (yrs)'
                  ,xaxis_tickformat = '.1f'
                  ,xaxis_autorange = 'reversed'
                  ,yaxis_title='Moneyness'
                  ,yaxis_tickformat = '.2%'
                  ,zaxis_title='Volatility'
                  ,zaxis_tickformat = '.2%'
                  )
                ,font=dict(
                  family="Courier New, monospace"
                  ,size=10
                  )
                ,margin=dict(l=65, r=50, b=65, t=90)
)

fig3.show()

---