<a href="https://colab.research.google.com/github/aderdouri/ql_web_app/blob/master/ql_notebooks/basismodels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install QuantLib-Python

Data: Class-level lists are used to store the term structure and volatility data.
Helper Methods: _get_yts, _get_optionlet_ts, _get_swaption_vts encapsulate the creation of the term structures and volatility structures, making the test methods cleaner. Note the use of ql.Relinkable...Handle to mirror the C++ pattern, although simple ql....Handle might suffice if the underlying objects aren't changed within a test.
_test_swaptioncfs_impl: The common logic for testing SwaptionCashFlows is put into a helper method called by the two specific test cases (cont_comp_spread and simple_comp_spread).
Interpolation: ql.Cubic() and ql.Linear() are used for curve and correlation interpolation, respectively.
Indices: Indices (ql.Euribor3M, ql.Euribor6M) are created using the appropriate forecasting curve handles.
Experimental Classes: ql.SwaptionCashFlows, ql.TenorOptionletVTS, ql.TenorSwaptionVTS, ql.TwoParameterCorrelation are instantiated using their Python constructors. Check the QL Python documentation for exact parameter order if unsure.
Comparison Logic: The assertion logic (assertAlmostEqual, assertLessEqual) mirrors the checks performed in the C++ tests, including the different tolerances (tol, tol2).
Settings: setUp and tearDown manage the evaluation date and the usingAtParCoupons setting.

In [None]:
import QuantLib as ql
import unittest
import math

class BasisModelsTests(unittest.TestCase):

    # Data equivalent to C++ global data
    termsData = [ql.Period(0, ql.Days), ql.Period(1, ql.Years), ql.Period(2, ql.Years), ql.Period(3, ql.Years),
                 ql.Period(5, ql.Years), ql.Period(7, ql.Years), ql.Period(10, ql.Years), ql.Period(15, ql.Years),
                 ql.Period(20, ql.Years), ql.Period(61, ql.Years)] # 61Y to avoid extrapolation

    discRatesData = [-0.00147407, -0.001761684, -0.001736745, -0.00119244, 0.000896055,
                     0.003537077, 0.007213824, 0.011391278, 0.013334611, 0.013982809]

    proj3mRatesData = [-0.000483439, -0.000578569, -0.000383832, 0.000272656, 0.002478699,
                       0.005100113, 0.008750643, 0.012788095, 0.014534052, 0.014942896]

    proj6mRatesData = [0.000233608, 0.000218862, 0.000504018, 0.001240556, 0.003554415,
                       0.006153921, 0.009688264, 0.013521628, 0.015136391, 0.015377704]

    capletTermsData = [ql.Period(1, ql.Years), ql.Period(2, ql.Years), ql.Period(3, ql.Years),
                       ql.Period(5, ql.Years), ql.Period(7, ql.Years), ql.Period(10, ql.Years),
                       ql.Period(15, ql.Years), ql.Period(20, ql.Years), ql.Period(25, ql.Years),
                       ql.Period(30, ql.Years)]

    capletStrikesData = [-0.0050, 0.0000, 0.0050, 0.0100, 0.0150, 0.0200, 0.0300, 0.0500]

    capletVolsData = [
        [0.003010094, 0.002628065, 0.00456118, 0.006731268, 0.008678572, 0.010570881, 0.014149552, 0.021000638],
        [0.004173715, 0.003727039, 0.004180263, 0.005726083, 0.006905876, 0.008263514, 0.010555395, 0.014976523],
        [0.005870143, 0.005334526, 0.005599775, 0.006633987, 0.007773317, 0.009036581, 0.011474391, 0.016277549],
        [0.007458597, 0.007207522, 0.007263995, 0.007308727, 0.007813586, 0.008274858, 0.009743988, 0.012555171],
        [0.007711531, 0.007608826, 0.007572816, 0.007684107, 0.007971932, 0.008283118, 0.009268828, 0.011574083],
        [0.007619605, 0.007639059, 0.007719825, 0.007823373, 0.00800813, 0.008113384, 0.008616374, 0.009785436],
        [0.007312199, 0.007352993, 0.007369116, 0.007468333, 0.007515657, 0.00767695, 0.008020447, 0.009072769],
        [0.006905851, 0.006966315, 0.007056413, 0.007116494, 0.007259661, 0.00733308, 0.007667563, 0.008419696],
        [0.006529553, 0.006630731, 0.006749022, 0.006858027, 0.007001959, 0.007139097, 0.007390404, 0.008036255],
        [0.006225482, 0.006404012, 0.00651594, 0.006642273, 0.006640887, 0.006885713, 0.007093024, 0.00767373]
    ]

    swaptionVTSTermsData = [ql.Period(1, ql.Years), ql.Period(5, ql.Years), ql.Period(10, ql.Years),
                            ql.Period(20, ql.Years), ql.Period(30, ql.Years)]

    swaptionVolsData = [
        [0.002616, 0.00468, 0.0056, 0.005852, 0.005823],
        [0.006213, 0.00643, 0.006622, 0.006124, 0.005958],
        [0.006658, 0.006723, 0.006602, 0.005802, 0.005464],
        [0.005728, 0.005814, 0.005663, 0.004689, 0.004276],
        [0.005041, 0.005059, 0.004746, 0.003927, 0.003608]
    ]


    def setUp(self):
        # Set evaluation date for tests
        self.today = ql.Date(11, ql.April, 2018)
        ql.Settings.instance().evaluationDate = self.today
        self.calendar = ql.TARGET()
        self.dayCounter = ql.Actual365Fixed()
        self._original_using_at_par_coupons = ql.IborCoupon.Settings.instance().usingAtParCoupons()


    def tearDown(self):
        # Reset settings
        ql.Settings.instance().evaluationDate = ql.Date()
        ql.IborCoupon.Settings.instance().setUseAtParCoupons(self._original_using_at_par_coupons)


    def _get_yts(self, terms, rates, spread=0.0):
        """Helper to create a YieldTermStructure."""
        dates = [ql.NullCalendar().advance(self.today, term, ql.Unadjusted) for term in terms]
        ratesPlusSpread = [r + spread for r in rates]
        ts = ql.InterpolatedZeroCurve(dates, ratesPlusSpread, self.dayCounter, ql.NullCalendar(), ql.Cubic())
        # Return handle as in C++
        return ql.RelinkableYieldTermStructureHandle(ts)

    def _get_optionlet_ts(self):
        """Helper to create an OptionletVolatilityStructure."""
        dates = [self.calendar.advance(self.today, term, ql.Following) for term in self.capletTermsData]

        capletVolQuotes = []
        for row_vols in self.capletVolsData:
            row_quotes = [ql.QuoteHandle(ql.SimpleQuote(v)) for v in row_vols]
            # Use RelinkableQuoteHandle if vols might change during tests, otherwise QuoteHandle is fine
            # row_quotes = [ql.RelinkableQuoteHandle(ql.SimpleQuote(v)) for v in row_vols]
            capletVolQuotes.append(row_quotes)

        curve3m = self._get_yts(self.termsData, self.proj3mRatesData)
        index = ql.Euribor3M(curve3m) # C++ used Euribor6M here, maybe typo? Based on name curve3m, using Euribor3M

        # settlementDays, calendar, businessDayConvention, iborIndex, optionletDates, strikes, volatilities, dayCounter, volatilityType=ShiftedLognormal, displacement=0.0
        # C++ uses settlement=2, calendar=TARGET, convention=Following
        strippedOptionlet = ql.StrippedOptionlet(2, self.calendar, ql.Following, index, dates,
                                                 self.capletStrikesData, capletVolQuotes,
                                                 self.dayCounter, ql.Normal) # Normal vol from C++

        # Wrap in adapter
        adapter = ql.StrippedOptionletAdapter(strippedOptionlet)
        return ql.RelinkableOptionletVolatilityStructureHandle(adapter)

    def _get_swaption_vts(self):
        """Helper to create a SwaptionVolatilityStructure."""
        swaptionVolQuotes = []
        for row_vols in self.swaptionVolsData:
            row_quotes = [ql.QuoteHandle(ql.SimpleQuote(v)) for v in row_vols]
            # row_quotes = [ql.RelinkableQuoteHandle(ql.SimpleQuote(v)) for v in row_vols]
            swaptionVolQuotes.append(row_quotes)

        # calendar, bdc, optionTenors, swapTenors, vols, dayCounter, flatExtrapolation=true, volType=Normal, shift=0.0
        # C++ uses TARGET, Following, Actual365Fixed, true, Normal
        matrix = ql.SwaptionVolatilityMatrix(self.calendar, ql.Following,
                                             self.swaptionVTSTermsData, self.swaptionVTSTermsData,
                                             swaptionVolQuotes, self.dayCounter, True, ql.Normal)

        return ql.RelinkableSwaptionVolatilityStructureHandle(matrix)


    def _test_swaptioncfs_impl(self, contTenorSpread):
        """Implementation common to both SwaptionCFs tests."""
        usingAtParCoupons = ql.IborCoupon.Settings.instance().usingAtParCoupons()

        # Market data
        discYTS = self._get_yts(self.termsData, self.discRatesData)
        proj6mYTS = self._get_yts(self.termsData, self.proj6mRatesData)
        euribor6m = ql.Euribor6M(proj6mYTS)

        # Swap details
        swapStart = self.calendar.advance(self.today, ql.Period(5, ql.Years), ql.Following)
        swapEnd = self.calendar.advance(swapStart, ql.Period(10, ql.Years), ql.Following)
        exerciseDate = self.calendar.advance(swapStart, ql.Period(-2, ql.Days), ql.Preceding)

        fixedSchedule = ql.Schedule(swapStart, swapEnd, ql.Period(1, ql.Years), self.calendar,
                                    ql.ModifiedFollowing, ql.ModifiedFollowing,
                                    ql.DateGeneration.Backward, False)
        floatSchedule = ql.Schedule(swapStart, swapEnd, ql.Period(6, ql.Months), self.calendar,
                                    ql.ModifiedFollowing, ql.ModifiedFollowing,
                                    ql.DateGeneration.Backward, False)

        # Use Thirty360.BondBasis for fixed leg day counter as in C++
        fixed_leg_dc = ql.Thirty360(ql.Thirty360.BondBasis)

        swap = ql.VanillaSwap(ql.Swap.Payer, 10000.0, fixedSchedule, 0.03, fixed_leg_dc,
                              floatSchedule, euribor6m, 0.0, euribor6m.dayCounter())

        swap.setPricingEngine(ql.DiscountingSwapEngine(discYTS))

        # Swaption
        europeanExercise = ql.EuropeanExercise(exerciseDate)
        # Use Physical settlement as in C++
        swaption = ql.Swaption(swap, europeanExercise, ql.Settlement.Physical)

        # SwaptionCashFlows
        cashFlows = ql.SwaptionCashFlows(swaption, discYTS, contTenorSpread)

        # Checks
        exerciseTime_manual = ql.Actual365Fixed().yearFraction(discYTS.referenceDate(),
                                                               swaption.exercise().dates()[0])
        self.assertAlmostEqual(exerciseTime_manual, cashFlows.exerciseTimes()[0], delta=1e-12,
                               msg="Swaption cash flow exercise time differs from manual calculation")

        tol = 1.0e-8

        # Fixed leg check
        fixedLegNPV_cf = 0.0
        for k in range(len(cashFlows.fixedTimes())):
            fixedLegNPV_cf += cashFlows.fixedWeights()[k] * discYTS.discount(cashFlows.fixedTimes()[k])

        self.assertAlmostEqual(fixedLegNPV_cf, -swap.fixedLegNPV(), delta=tol,
                               msg=f"SwaptionCF fixed leg NPV ({fixedLegNPV_cf:.8f}) "
                                   f"!= swap fixed leg NPV ({-swap.fixedLegNPV():.8f})")

        # Float leg check
        floatLegNPV_cf = 0.0
        for k in range(len(cashFlows.floatTimes())):
            floatLegNPV_cf += cashFlows.floatWeights()[k] * discYTS.discount(cashFlows.floatTimes()[k])

        self.assertAlmostEqual(floatLegNPV_cf, swap.floatingLegNPV(), delta=tol,
                               msg=f"SwaptionCF float leg NPV ({floatLegNPV_cf:.8f}) "
                                   f"!= swap float leg NPV ({swap.floatingLegNPV():.8f})")

        # Single curve check (expecting near-zero intermediate float weights)
        tol2 = tol if usingAtParCoupons else 0.02
        singleCurveCashFlows = ql.SwaptionCashFlows(swaption, proj6mYTS, contTenorSpread)
        # Check intermediate weights (skip first and last which handle notional exchanges)
        for k in range(1, len(singleCurveCashFlows.floatWeights()) - 1):
             self.assertAlmostEqual(singleCurveCashFlows.floatWeights()[k], 0.0, delta=tol2,
                                    msg=f"SwaptionCF float weight [{k}] not zero ({singleCurveCashFlows.floatWeights()[k]:.4e}) "
                                        f"in single-curve setting (tol={tol2:.1e})")


    def test_swaptioncfs_cont_comp_spread(self):
        print("Testing deterministic tenor basis model with continuous compounded spreads...")
        self._test_swaptioncfs_impl(True)

    def test_swaptioncfs_simple_comp_spread(self):
        print("Testing deterministic tenor basis model with simple compounded spreads...")
        self._test_swaptioncfs_impl(False)


    def test_tenoroptionletvts(self):
        print("Testing volatility transformation for caplets/floorlets...")

        spread = 0.01
        # YTS Handles
        discYTS = self._get_yts(self.termsData, self.discRatesData) # Not directly used by VTS, but needed for context/indices
        proj3mYTS = self._get_yts(self.termsData, self.proj3mRatesData)
        proj6mYTS = self._get_yts(self.termsData, self.proj3mRatesData, spread) # Use 3m rates + spread for 6m curve

        # Indices
        euribor3m = ql.Euribor3M(proj3mYTS) # Use 3M curve for 3M index
        euribor6m = ql.Euribor6M(proj6mYTS) # Use 6M curve for 6M index

        # Base Volatility (3M)
        optionletVTS3m = self._get_optionlet_ts()

        # Correlation Structure 1 (De-correlation)
        corrTimesRaw1 = [0.0, 50.0]
        rhoInfDataRaw1 = [0.3, 0.3]
        betaDataRaw1 = [0.9, 0.9]
        rhoInterp1 = ql.LinearInterpolation(corrTimesRaw1, rhoInfDataRaw1)
        betaInterp1 = ql.LinearInterpolation(corrTimesRaw1, betaDataRaw1)
        corr1 = ql.TwoParameterCorrelation(rhoInterp1, betaInterp1)

        optionletVTS6m_1 = ql.TenorOptionletVTS(optionletVTS3m, euribor3m, euribor6m, corr1)

        for term in self.capletTermsData:
            for strike in self.capletStrikesData:
                vol3m = optionletVTS3m.volatility(term, strike, True)
                vol6m = optionletVTS6m_1.volatility(term, strike, True)
                vol6mShifted = optionletVTS6m_1.volatility(term, strike + spread, True)

                # Check: Shifted 6m vol should not be significantly larger than 3m vol
                # We leave 1bp tolerance due to simplified spread calculation (as per C++)
                self.assertLessEqual(vol6mShifted - vol3m, 0.0001,
                                     f"Shifted 6m vol ({vol6mShifted:.4%}) > 3m vol ({vol3m:.4%}) "
                                     f"with de-correlation. Term={term}, Strike={strike:.4f}")

        # Correlation Structure 2 (Perfect Correlation)
        corrTimesRaw2 = [0.0, 50.0]
        rhoInfDataRaw2 = [0.0, 0.0] # Perfect correlation
        betaDataRaw2 = [0.0, 0.0]   # Perfect correlation
        rhoInterp2 = ql.LinearInterpolation(corrTimesRaw2, rhoInfDataRaw2)
        betaInterp2 = ql.LinearInterpolation(corrTimesRaw2, betaDataRaw2)
        corr2 = ql.TwoParameterCorrelation(rhoInterp2, betaInterp2)

        optionletVTS6m_2 = ql.TenorOptionletVTS(optionletVTS3m, euribor3m, euribor6m, corr2)

        for i, term in enumerate(self.capletTermsData):
            for strike in self.capletStrikesData:
                vol3m = optionletVTS3m.volatility(term, strike, True)
                vol6mShifted = optionletVTS6m_2.volatility(term, strike + spread, True)

                # Check: Shifted 6m vol should match 3m vol for perfect correlation
                tol = 0.001 if i < 3 else 0.0001 # 10bp tol for short tenors, 1bp otherwise
                self.assertAlmostEqual(vol6mShifted, vol3m, delta=tol,
                                       msg=f"Shifted 6m vol ({vol6mShifted:.4%}) != 3m vol ({vol3m:.4%}) "
                                           f"with perfect correlation. Term={term}, Strike={strike:.4f}, Tol={tol:.4f}")


    def test_tenorswaptionvts(self):
        print("Testing volatility transformation for swaptions...")

        spread = 0.01 # Not directly used in VTS transform, but context for indices
        # YTS Handles & Indices
        discYTS = self._get_yts(self.termsData, self.discRatesData)
        proj3mYTS = self._get_yts(self.termsData, self.proj3mRatesData)
        proj6mYTS = self._get_yts(self.termsData, self.proj3mRatesData, spread) # Use 3m rates + spread for 6m curve
        euribor3m = ql.Euribor3M(proj3mYTS) # Use 3M curve
        euribor6m = ql.Euribor6M(proj6mYTS) # Use 6M curve

        # Base Swaption Volatility (6M)
        euribor6mSwVTS = self._get_swaption_vts()

        # Transformation: 6M -> 3M
        fixedLegTenor = ql.Period(1, ql.Years)
        fixedLegDC = ql.Thirty360(ql.Thirty360.BondBasis) # Match C++
        floatLegDC = ql.Thirty360(ql.Thirty360.BondBasis) # Match C++

        euribor3mSwVTS = ql.TenorSwaptionVTS(euribor6mSwVTS, discYTS,
                                            euribor6m, euribor3m,
                                            fixedLegTenor, fixedLegTenor, # Fixed leg tenor/Float leg tenor (used for swap rate calc?)
                                            fixedLegDC, floatLegDC)

        # Check: 3m vol should be <= 6m vol (basis effect)
        for optTenor in self.swaptionVTSTermsData:
            for swpTenor in self.swaptionVTSTermsData:
                strike = 0.01 # Example strike
                vol6m = euribor6mSwVTS.volatility(optTenor, swpTenor, strike, True)
                vol3m = euribor3mSwVTS.volatility(optTenor, swpTenor, strike, True)
                # Expect vol3m <= vol6m due to basis spread effect making 3m rates generally lower than 6m here
                self.assertLessEqual(vol3m, vol6m,
                                     f"Euribor 3m vol ({vol3m:.4%}) > 6m vol ({vol6m:.4%}). "
                                     f"OptionTenor={optTenor}, SwapTenor={swpTenor}")

        # Transformation: 6M -> 6M (Identity Check)
        euribor6mSwVTS2 = ql.TenorSwaptionVTS(euribor6mSwVTS, discYTS,
                                             euribor6m, euribor6m,
                                             fixedLegTenor, fixedLegTenor,
                                             fixedLegDC, floatLegDC)

        tol_identity = 1.0e-8
        for optTenor in self.swaptionVTSTermsData:
            for swpTenor in self.swaptionVTSTermsData:
                strike = 0.01 # Example strike
                vol6m = euribor6mSwVTS.volatility(optTenor, swpTenor, strike, True)
                vol6m2 = euribor6mSwVTS2.volatility(optTenor, swpTenor, strike, True)
                self.assertAlmostEqual(vol6m2, vol6m, delta=tol_identity,
                                       msg=f"Euribor 6m -> 6m vol transform failed identity check. "
                                           f"OptionTenor={optTenor}, SwapTenor={swpTenor}\n"
                                           f"Orig: {vol6m:.8f}, Transformed: {vol6m2:.8f}")

        # Transformation: 6M -> 3M -> 6M (Round Trip Check)
        # euribor3mSwVTS was created above (6M -> 3M)
        euribor6mSwVTS3 = ql.TenorSwaptionVTS(
            ql.SwaptionVolatilityStructureHandle(euribor3mSwVTS), # Use the transformed 3m vol surface
            discYTS,
            euribor3m, euribor6m, # Transform 3M back to 6M
            fixedLegTenor, fixedLegTenor,
            fixedLegDC, floatLegDC)

        tol_roundtrip = 1.0e-8
        for optTenor in self.swaptionVTSTermsData:
            for swpTenor in self.swaptionVTSTermsData:
                strike = 0.01 # Example strike
                vol6m_orig = euribor6mSwVTS.volatility(optTenor, swpTenor, strike, True)
                vol6m_roundtrip = euribor6mSwVTS3.volatility(optTenor, swpTenor, strike, True)
                self.assertAlmostEqual(vol6m_roundtrip, vol6m_orig, delta=tol_roundtrip,
                                       msg=f"Euribor 6m -> 3m -> 6m vol transform failed round trip check. "
                                           f"OptionTenor={optTenor}, SwapTenor={swpTenor}\n"
                                           f"Orig: {vol6m_orig:.8f}, RoundTrip: {vol6m_roundtrip:.8f}")


if __name__ == '__main__':
    print("Python QuantLib version:", ql.__version__)
    print("Testing Basis Models (Python)...")
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(BasisModelsTests))
    unittest.TextTestRunner(verbosity=2).run(suite)