<a href="https://colab.research.google.com/github/aderdouri/ql_web_app/blob/master/ql_notebooks/extendedtrees.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import QuantLib as ql
import unittest
import math

# --- Helper Utilities (assuming they are defined as in previous examples) ---
def flat_rate_py(evaluation_date_or_quote, forward_rate_or_dc, day_counter_or_none=None):
    if isinstance(evaluation_date_or_quote, ql.Date):
        evaluation_date = evaluation_date_or_quote
        forward_rate_obj = forward_rate_or_dc
        day_counter = day_counter_or_none
    else:
        evaluation_date = ql.Settings.instance().evaluationDate
        forward_rate_obj = evaluation_date_or_quote
        day_counter = forward_rate_or_dc
    if isinstance(forward_rate_obj, ql.Quote):
        quote_handle = ql.QuoteHandle(forward_rate_obj)
    elif isinstance(forward_rate_obj, float):
        quote_handle = ql.QuoteHandle(ql.SimpleQuote(forward_rate_obj))
    else:
        quote_handle = forward_rate_obj
    return ql.FlatForward(evaluation_date, quote_handle, day_counter)

def flat_vol_py(evaluation_date_or_quote, vol_level_or_dc, day_counter_or_none=None):
    if isinstance(evaluation_date_or_quote, ql.Date):
        evaluation_date = evaluation_date_or_quote
        vol_level_obj = vol_level_or_dc
        day_counter = day_counter_or_none
    else:
        evaluation_date = ql.Settings.instance().evaluationDate
        vol_level_obj = evaluation_date_or_quote
        day_counter = vol_level_or_dc
    if isinstance(vol_level_obj, ql.Quote):
        vol_quote_handle = ql.QuoteHandle(vol_level_obj)
    elif isinstance(vol_level_obj, float):
        vol_quote_handle = ql.QuoteHandle(ql.SimpleQuote(vol_level_obj))
    else:
        vol_quote_handle = vol_level_obj
    return ql.BlackConstantVol(evaluation_date, ql.NullCalendar(), vol_quote_handle, day_counter)

def time_to_days_py(t, basis=360):
    return int(t * basis + 0.5)

def exercise_type_to_string_py(exercise):
    if isinstance(exercise, ql.EuropeanExercise): return "European"
    return "UnknownExercise"

def payoff_type_to_string_py(payoff):
    if isinstance(payoff, ql.PlainVanillaPayoff): return "PlainVanilla"
    return "UnknownPayoff"

def relative_error_py(x1, x2, reference):
    if abs(reference) > 1.0e-12:
        return abs(x1 - x2) / abs(reference)
    else:
        return abs(x1 - x2)

# --- EngineType Enum Python equivalent for Extended Trees ---
# Using strings with "Extended" prefix as per QL Python conventions
ENGINE_EXT_ANALYTIC = "Analytic" # Analytic is the reference
ENGINE_EXT_JR = "ExtendedJarrowRudd"
ENGINE_EXT_CRR = "ExtendedCoxRossRubinstein"
ENGINE_EXT_EQP = "ExtendedAdditiveEQPBinomialTree"
ENGINE_EXT_TGEO = "ExtendedTrigeorgis"
ENGINE_EXT_TIAN = "ExtendedTian"
ENGINE_EXT_LR = "ExtendedLeisenReimer"
ENGINE_EXT_JOSHI = "ExtendedJoshi4"

class ExtendedTreesTests(unittest.TestCase):

    def setUp(self):
        self.saved_eval_date = ql.Settings.instance().evaluationDate
        self.today = ql.Date(15, ql.May, 2007) # Arbitrary fixed date
        ql.Settings.instance().evaluationDate = self.today

        self.dc = ql.Actual360()
        self.spot_q = ql.SimpleQuote(0.0)
        self.q_rate_q = ql.SimpleQuote(0.0)
        self.r_rate_q = ql.SimpleQuote(0.0)
        self.vol_q = ql.SimpleQuote(0.0)

        self.spot_h = ql.QuoteHandle(self.spot_q)
        self.qTS_h = ql.YieldTermStructureHandle(flat_rate_py(self.today, self.q_rate_q, self.dc))
        self.rTS_h = ql.YieldTermStructureHandle(flat_rate_py(self.today, self.r_rate_q, self.dc))
        self.volTS_h = ql.BlackVolTermStructureHandle(flat_vol_py(self.today, self.vol_q, self.dc))

    def tearDown(self):
        ql.Settings.instance().evaluationDate = self.saved_eval_date

    def _make_process(self, u_q_h, q_ts_h, r_ts_h, vol_ts_h):
        return ql.BlackScholesMertonProcess(u_q_h, q_ts_h, r_ts_h, vol_ts_h)

    def _make_option(self, payoff, exercise, u_q_h, q_ts_h, r_ts_h, vol_ts_h,
                     engine_type_str, binomial_steps=None): # binomial_steps not Nullable in Py

        process = self._make_process(u_q_h, q_ts_h, r_ts_h, vol_ts_h)
        engine = None

        if engine_type_str == ENGINE_EXT_ANALYTIC:
            engine = ql.AnalyticEuropeanEngine(process)
        elif engine_type_str in [ENGINE_EXT_JR, ENGINE_EXT_CRR, ENGINE_EXT_EQP,
                                 ENGINE_EXT_TGEO, ENGINE_EXT_TIAN, ENGINE_EXT_LR, ENGINE_EXT_JOSHI]:
            if binomial_steps is None:
                raise ValueError("binomial_steps must be provided for binomial engines")
            engine = ql.BinomialVanillaEngine(process, engine_type_str, binomial_steps)
        else:
            raise ValueError(f"Unknown engine type for extended trees: {engine_type_str}")

        option = ql.EuropeanOption(payoff, exercise)
        option.setPricingEngine(engine)
        return option

    def _update_market_data(self, s, q, r, v):
        self.spot_q.setValue(s)
        self.q_rate_q.setValue(q)
        self.r_rate_q.setValue(r)
        self.vol_q.setValue(v)

    def _report_failure(self, greek_name, payoff, exercise, s, q, r, today,
                        v, expected, calculated, error, tolerance):
        q_float = q.value() if isinstance(q, ql.Quote) else q
        r_float = r.value() if isinstance(r, ql.Quote) else r
        v_float = v.value() if isinstance(v, ql.Quote) else v
        msg = (f"{exercise_type_to_string_py(exercise)} "
               f"{payoff.optionType()} option with {payoff_type_to_string_py(payoff)} payoff:\n"
               f"    spot value:       {s}\n    strike:           {payoff.strike()}\n"
               f"    dividend yield:   {q_float:.6f}\n    risk-free rate:   {r_float:.6f}\n"
               f"    reference date:   {today}\n    maturity:         {exercise.lastDate()}\n"
               f"    volatility:       {v_float:.6f}\n\n"
               f"    expected   {greek_name}: {expected}\n    calculated {greek_name}: {calculated}\n"
               f"    error:            {error}\n    tolerance:        {tolerance}")
        self.fail(msg)

    def _test_engine_consistency_extended(self, engine_type_str, binomial_steps, tolerance_map):
        print(f"Testing time-dependent {engine_type_str} binomial European engines...")

        types = [ql.Option.Call, ql.Option.Put]
        strikes = [75.0, 100.0, 125.0]
        lengths_years = [1] # Years
        underlyings = [100.0]
        q_rates = [0.00, 0.05]
        r_rates = [0.01, 0.05, 0.15]
        vols = [0.11, 0.50, 1.20]

        current_eval_date = ql.Date(10, ql.January, 2020) # Fixed eval date
        ql.Settings.instance().evaluationDate = current_eval_date
        # Update TS handles with this eval date
        self.qTS_h.linkTo(flat_rate_py(current_eval_date, self.q_rate_q, self.dc))
        self.rTS_h.linkTo(flat_rate_py(current_eval_date, self.r_rate_q, self.dc))
        self.volTS_h.linkTo(flat_vol_py(current_eval_date, self.vol_q, self.dc))


        for opt_type in types:
            for strike in strikes:
                for length_y in lengths_years:
                    # C++ test used `today + length * 360` which implies Act/360 for timeToDays
                    ex_date = current_eval_date + ql.Period(length_y * 360, ql.Days)
                    exercise = ql.EuropeanExercise(ex_date)
                    payoff = ql.PlainVanillaPayoff(opt_type, strike)

                    ref_option = self._make_option(payoff, exercise, self.spot_h, self.qTS_h, self.rTS_h, self.volTS_h,
                                                   ENGINE_EXT_ANALYTIC) # No steps for analytic
                    option_to_test = self._make_option(payoff, exercise, self.spot_h, self.qTS_h, self.rTS_h, self.volTS_h,
                                                       engine_type_str, binomial_steps)

                    for u_val in underlyings:
                        for q_val in q_rates:
                            for r_val in r_rates:
                                for v_val in vols:
                                    self._update_market_data(u_val, q_val, r_val, v_val)

                                    expected = {"value": ref_option.NPV()}
                                    calculated = {"value": option_to_test.NPV()}

                                    if option_to_test.NPV() > self.spot_q.value() * 1.0e-5: # If value is significant
                                        expected["delta"] = ref_option.delta()
                                        expected["gamma"] = ref_option.gamma()
                                        expected["theta"] = ref_option.theta()
                                        calculated["delta"] = option_to_test.delta()
                                        calculated["gamma"] = option_to_test.gamma()
                                        calculated["theta"] = option_to_test.theta()

                                    for greek_name, calc_val in calculated.items():
                                        exp_val = expected[greek_name]
                                        tol = tolerance_map[greek_name]
                                        error = relative_error_py(exp_val, calc_val, u_val)
                                        if error > tol:
                                            self._report_failure(f"{greek_name} ({engine_type_str})", payoff, exercise,
                                                                 u_val, q_val, r_val, current_eval_date, v_val,
                                                                 exp_val, calc_val, error, tol)

    def test_jr_binomial_engines_extended(self):
        tol_map = {"value": 0.002, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_JR, 251, tol_map)

    def test_crr_binomial_engines_extended(self):
        tol_map = {"value": 0.02, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_CRR, 501, tol_map)

    def test_eqp_binomial_engines_extended(self):
        tol_map = {"value": 0.02, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_EQP, 501, tol_map)

    def test_tgeo_binomial_engines_extended(self):
        tol_map = {"value": 0.002, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_TGEO, 251, tol_map)

    def test_tian_binomial_engines_extended(self):
        tol_map = {"value": 0.002, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_TIAN, 251, tol_map)

    def test_lr_binomial_engines_extended(self):
        tol_map = {"value": 1.0e-6, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_LR, 251, tol_map)

    def test_joshi_binomial_engines_extended(self):
        tol_map = {"value": 1.0e-7, "delta": 1.0e-3, "gamma": 1.0e-4, "theta": 0.03}
        self._test_engine_consistency_extended(ENGINE_EXT_JOSHI, 251, tol_map)


if __name__ == '__main__':
    print("Testing QuantLib " + ql.__version__)
    unittest.main(argv=['first-arg-is-ignored'], exit=False)