In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss

In [None]:
S0 = 100
K = 100
T = 1
r = 0.05
sig = 0.2
lam = 0.3
mu_J = -0.1
del_J = 0.2
n_max = 20

In [None]:
from dataclasses import dataclass

@dataclass
class Params:
  S0: float
  K: float
  T: int
  r: float
  sig: float
  lam: float
  mu_J: float
  del_J: float
  n_max: int|None = None

  def estimate_k(self):
    self.k = np.exp(self.mu_J + 0.5*self.del_J**2)-1

In [None]:
params = Params(S0, K, T, r, sig, lam, mu_J, del_J, n_max)
params.estimate_k()
params.k

In [None]:
class BSModel(Params):
  def __init__(self, S0, K, T, r, sig, lam, mu_J, del_J, n_max):
    super().__init__(S0, K, T, r, sig, lam, mu_J, del_J, n_max)
    self.estimate_k()

  def estimate_d1(self, r:float, sig:float):
    sig = self.sig if sig is None else sig
    r = self.r if r is None else r

    self.d1 = None
    self.d1 = (np.log(self.S0/self.K) + (self.r + 0.5*self.sig**2)*self.T) \
              / sig*np.sqrt(self.T)

  def estimate_d2(self, r:float, sig:float=None):
    assert self.d1 is not None, "d1 must be computed first"
    sig = self.sig if sig is None else sig

    self.d2 = None
    self.d2 = self.d1 - sig*np.sqrt(self.T)

  def estimate_bs(self, r:float=None, sig:float=None, contract:str="call"):
    assert contract in ["call", "put"], "contract must be either call or put"
    coef = 1 if contract=="call" else -1
    r = self.r if r is None else r

    self.estimate_d1(r=r, sig=sig)
    self.estimate_d2(r=r, sig=sig)

    return coef*(self.S0*ss.norm.cdf(coef*self.d1)) - coef*self.K*np.exp(r*self.T*(-1))*ss.norm.cdf(coef*self.d2)

  

In [None]:
mod = BSModel(S0, K, T, r, sig, lam, mu_J, del_J, n_max)
mod.estimate_bs()

In [None]:
import math

class JDModel(BSModel):
  def __init__(self, S0, K, T, r, sig, lam, mu_J, del_J, n_max):
    super().__init__(S0, K, T, r, sig, lam, mu_J, del_J, n_max)

  def estimate_P_n(self, n:int):
    return (np.exp(-self.lam*self.T)*(self.lam*self.T)**n)/math.factorial(n)

  def estimate_r_n(self, n):
    return self.r - self.lam*self.k+n/self.T*np.log(1+self.k)

  def estimate_sig_n(self, n):
    return np.sqrt(self.sig**2 + n/self.T*self.del_J**2)

  def estimate_jd(self, contract:str="call", debug:bool=False):
    price_sum = 0
    for n in range(self.n_max+1):
      P_n = self.estimate_P_n(n)
      r_n = self.estimate_r_n(n)
      sig_n = self.estimate_sig_n(n)
      C_BS = self.estimate_bs(contract=contract, r=r_n, sig=sig_n)
      price_sum += P_n*C_BS
      if debug:
        print(f"P_n: {P_n:.3f}; r_n: {r_n:.3f}; C_BS: {C_BS:.3f}; sigma: {sig_n:.3f}  for n = {n}")

    return price_sum

In [None]:
mod = JDModel(S0, K, T, r, sig, lam, mu_J, del_J, n_max=3)
mod.estimate_jd(debug=True)