<a href="https://colab.research.google.com/github/aderdouri/ql_web_app/blob/master/ql_notebooks/curvestates.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import QuantLib as ql
import unittest
import math # For fabs, though assertAlmostEqual is preferred

# Helper to create QuantLib vectors/matrix from Python lists
def double_vector_from_list(lst):
    dv = ql.DoubleVector(len(lst))
    for i, val in enumerate(lst):
        dv[i] = val
    return dv

def size_vector_from_list(lst):
    sv = ql.SizeVector(len(lst))
    for i, val in enumerate(lst):
        sv[i] = val
    return sv

def matrix_from_list_of_lists(lol):
    if not lol:
        return ql.Matrix(0,0)
    rows = len(lol)
    cols = len(lol[0])
    m = ql.Matrix(rows, cols)
    for i in range(rows):
        for j in range(cols):
            m[i,j] = lol[i][j]
    return m

class CommonVars:
    def __init__(self):
        self.tol = 1.0e-4
        self.spanning_fwds = 1 # const Size spanningFwds = 1;

        # Expected values from C++
        self.expected_drifts_lmm_cms = [
            -0.0825792, -0.0787625, -0.0748546, -0.0708555, -0.0667655,
            -0.0625846, -0.0583128, -0.0539504, -0.0494972, -0.0449536,
            -0.0403194, -0.0355949, -0.0307801, -0.025875, -0.0208799,
            -0.0157948, -0.0106197, -0.00535471, 0.0
        ]
        self.expected_discount_ratios_lmm_cms = [
            1.58379, 1.55274, 1.52154, 1.49025, 1.45888, 1.42748, 1.39607,
            1.36468, 1.33335, 1.3021, 1.27096, 1.23996, 1.20913, 1.17848,
            1.14806, 1.11788, 1.08796, 1.05833, 1.029
        ]
        self.expected_forward_rates_lmm_cms_cot = [
            0.04, 0.041, 0.042, 0.043, 0.044, 0.045, 0.046, 0.047, 0.048,
            0.049, 0.05, 0.051, 0.052, 0.053, 0.054, 0.055, 0.056, 0.057, 0.058
        ]
        self.expected_swap_annuity_cms = [
            0.776368, 0.760772, 0.745125, 0.729442, 0.713739, 0.698034,
            0.68234, 0.666673, 0.651048, 0.635479, 0.619979, 0.604563,
            0.589242, 0.574031, 0.558939, 0.54398, 0.529163, 0.5145, 0.5
        ]
        self.expected_cot_drifts = [
            -0.0472372, -0.0447452, -0.042233, -0.0397016, -0.0371516,
            -0.034584, -0.0319995, -0.0293991, -0.0267836, -0.0241539,
            -0.0215109, -0.0188555, -0.0161887, -0.0135113, -0.0108244,
            -0.00812878, -0.00542554, -0.00271562, 0.0
        ]
        self.expected_cot_discount_ratios = self.expected_discount_ratios_lmm_cms # Same
        self.expected_cot_swap_annuity = [
            12.0934, 11.317, 10.5563, 9.81115, 9.08171, 8.36797, 7.66994,
            6.9876, 6.32092, 5.66988, 5.0344, 4.41442, 3.80986, 3.22061,
            2.64658, 2.08764, 1.54366, 1.0145, 0.5
        ]

        self.calendar = ql.NullCalendar()
        self.todays_date = ql.Settings.instance().evaluationDate
        self.end_date = self.todays_date + ql.Period(10, ql.Years)

        dates_schedule = ql.Schedule(self.todays_date, self.end_date, ql.Period(ql.Semiannual),
                                     self.calendar, ql.Following, ql.Following,
                                     ql.DateGeneration.Backward, False)

        # rate_times in C++ is std::vector<Time>(dates.size()-1)
        # It stores year fractions from todays_date to each schedule date *except the first one*.
        rate_times_list = []
        self.day_counter = ql.SimpleDayCounter()
        for i in range(1, dates_schedule.size()): # Start from 1 to exclude todays_date for rateTimes
            rate_times_list.append(self.day_counter.yearFraction(self.todays_date, dates_schedule[i]))
        self.rate_times_qv = double_vector_from_list(rate_times_list)

        # N is number of forward rates, which is len(rate_times) - 1 if rate_times includes T_0
        # or len(rate_times) if rate_times starts from T_1.
        # C++: N = rateTimes.size() - 1; Here rate_times_list matches C++ rateTimes content.
        # So N is the number of forward rate periods.
        self.N = len(rate_times_list) -1

        self.payment_times_qv = double_vector_from_list(rate_times_list[1:]) # rate_times[1:] in Python

        accruals_list = []
        for i in range(1, len(rate_times_list)):
            accruals_list.append(rate_times_list[i] - rate_times_list[i-1])
        self.accruals_qv = double_vector_from_list(accruals_list)

        self.numeraire = self.N # Last forward rate index (0 to N-1) means N elements, numeraire is N for T_{N+1}
                                # C++ numeraire is N, meaning P(t, T_{N+1})
                                # If N is number of forwards, then T_{N+1} is the (N+1)-th time point in rate_times.
                                # So, index N in rate_times_list (0-indexed)

        self.pseudo_qm = ql.Matrix(self.N, self.N, 0.1)

        todays_forwards_list = [(0.04 + 0.0010 * i) for i in range(self.N)]
        self.todays_forwards_qv = double_vector_from_list(todays_forwards_list)

        self.displacements_qv = double_vector_from_list([0.0] * self.N)

        todays_discounts_list = [0.0] * len(rate_times_list)
        todays_discounts_list[0] = 0.95
        for i in range(1, len(rate_times_list)):
            todays_discounts_list[i] = todays_discounts_list[i-1] / \
                (1.0 + todays_forwards_list[i-1] * accruals_list[i-1])
        self.todays_discounts_qv = double_vector_from_list(todays_discounts_list)

        evolution_times_list = rate_times_list[:-1] # Exclude last time T_{N+1}
        evolution = ql.EvolutionDescription(self.rate_times_qv, double_vector_from_list(evolution_times_list))
        self.taus_qv = evolution.rateTaus() # ql.DoubleVector
        self.first_alive_rates_sv = evolution.firstAliveRate() # ql.SizeVector


class CurveStatesTests(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Set evaluation date once for all tests in this class if needed for CommonVars
        # ql.Settings.instance().evaluationDate = ql.Date(15, ql.May, 2003) # Example
        # However, CommonVars uses Settings.instance().evaluationDate() directly.
        # If TopLevelFixture sets a specific date, we should replicate it here.
        # For these tests, assuming default eval date or that it's set by fixture correctly.
        # The C++ test uses `todaysDate = Settings::instance().evaluationDate();`
        # so we rely on the default or fixture-set date.
        pass

    def setUp(self):
        # This creates a new CommonVars for each test, ensuring test isolation
        self.vars = CommonVars()

    def test_lmm_curve_state(self):
        print("Testing Libor-market-model curve state...")

        lmm_drift_calculator = ql.LMMDriftCalculator(
            self.vars.pseudo_qm, self.vars.displacements_qv, self.vars.taus_qv,
            self.vars.numeraire, self.vars.first_alive_rates_sv[0] # first_alive_rates is a vector
        )
        lmm_cs = ql.LMMCurveState(self.vars.rate_times_qv)
        lmm_cs.setOnForwardRates(self.vars.todays_forwards_qv)

        lmm_drifts_calculated = ql.DoubleVector(self.vars.N) # Initialize vector for results
        lmm_drift_calculator.compute(lmm_cs, lmm_drifts_calculated)

        for i in range(self.vars.N):
            self.assertAlmostEqual(lmm_drifts_calculated[i], self.vars.expected_drifts_lmm_cms[i],
                                   delta=self.vars.tol,
                                   msg=f"LMM drift mismatch at index {i}: "
                                       f"got {lmm_drifts_calculated[i]}, "
                                       f"expected {self.vars.expected_drifts_lmm_cms[i]}")

            # numeraire index for discountRatio in LMMCurveState corresponds to the P(t,T_k) where T_k is rateTimes[k]
            # The C++ test uses N as numeraire index for P(t,T_{N+1}).
            # If N is number of forwards, rateTimes has N+1 elements (T_0 ... T_N), so rateTimes[N] = T_N.
            # The C++ test uses numeraire = N, which should map to the last element in rate_times_qv.
            # LMMCurveState.discountRatio(i, j) is P(t_i, t_j) / P(t_i, t_{numeraire_index_for_P_num})
            # where P_num is P(t, T_numeraire).
            # In the C++ test, vars.N (numeraire) seems to refer to T_{N+1}, i.e. rateTimes[N]
            # LMMCurveState::discountRatio(firstValidIndex, numeraireIndex)
            # Here, numeraire is self.vars.N, which is rateTimes.size() - 1.
            # So, numeraireIndex for P(t,T_{N+1}) would be self.vars.N.
            # The C++ method is LMMCurveState::discountRatio(i, N)
            # This needs to map to the P(t_i,T_{N+1}) / P(t_N, T_{N+1}) or P(t_i,T_i) / P(t_N,T_N) etc.
            # LMMCurveState::discountRatio(rateIndex, numeraire) is P(T_rateIndex, T_{rateIndex+1}) / P(T_numeraire, T_{numeraire+1})
            # This interpretation seems off.
            # More likely: discountRatio(i, j) = P(t, T_i) / P(t, T_j)
            # In C++: discountRatio(i, vars.N)
            # If vars.N is the index for the numeraire time T_{N+1},
            # then it's P(t, T_i) / P(t, T_{N+1})
            # In Python, LMMCurveState.discountRatio(i, j)
            calculated_discount_ratio = lmm_cs.discountRatio(i, self.vars.numeraire)
            self.assertAlmostEqual(calculated_discount_ratio, self.vars.expected_discount_ratios_lmm_cms[i],
                                   delta=self.vars.tol,
                                   msg=f"LMM discount ratio mismatch at index {i}: "
                                       f"got {calculated_discount_ratio}, "
                                       f"expected {self.vars.expected_discount_ratios_lmm_cms[i]}")

            self.assertAlmostEqual(lmm_cs.forwardRate(i), self.vars.expected_forward_rates_lmm_cms_cot[i],
                                   delta=self.vars.tol,
                                   msg=f"LMM forward rate mismatch at index {i}: "
                                       f"got {lmm_cs.forwardRate(i)}, "
                                       f"expected {self.vars.expected_forward_rates_lmm_cms_cot[i]}")

    def test_coterminal_swap_curve_state(self):
        print("Testing coterminal-swap-market-model curve state...")

        # Calculate todaysCoterminalSwapRates and coterminalAnnuity as in C++
        todays_coterminal_swap_rates_list = [0.0] * self.vars.N
        coterminal_annuity_list = [0.0] * self.vars.N

        for i in range(1, self.vars.N + 1): # Loop from 1 to N (inclusive)
            idx_n_minus_i = self.vars.N - i # Current index to fill (from end to start)
            if i == 1: # Last element (N-1 in 0-indexed)
                coterminal_annuity_list[idx_n_minus_i] = \
                    self.vars.accruals_qv[self.vars.N - 1] * self.vars.todays_discounts_qv[self.vars.N]
            else:
                coterminal_annuity_list[idx_n_minus_i] = \
                    coterminal_annuity_list[idx_n_minus_i + 1] + \
                    self.vars.accruals_qv[idx_n_minus_i] * self.vars.todays_discounts_qv[idx_n_minus_i + 1]

            floating_leg = self.vars.todays_discounts_qv[idx_n_minus_i] - self.vars.todays_discounts_qv[self.vars.N]
            todays_coterminal_swap_rates_list[idx_n_minus_i] = floating_leg / coterminal_annuity_list[idx_n_minus_i]

        todays_coterminal_swap_rates_qv = double_vector_from_list(todays_coterminal_swap_rates_list)

        smm_drift_calculator = ql.SMMDriftCalculator(
            self.vars.pseudo_qm, self.vars.displacements_qv, self.vars.taus_qv,
            self.vars.numeraire, self.vars.first_alive_rates_sv[0]
        )
        cot_cs = ql.CoterminalSwapCurveState(self.vars.rate_times_qv)
        cot_cs.setOnCoterminalSwapRates(todays_coterminal_swap_rates_qv)

        cot_drifts_calculated = ql.DoubleVector(self.vars.N)
        smm_drift_calculator.compute(cot_cs, cot_drifts_calculated)

        for i in range(self.vars.N):
            self.assertAlmostEqual(cot_drifts_calculated[i], self.vars.expected_cot_drifts[i],
                                   delta=self.vars.tol,
                                   msg=f"COT drift mismatch at index {i}")

            calculated_discount_ratio_cot = cot_cs.discountRatio(i, self.vars.numeraire)
            self.assertAlmostEqual(calculated_discount_ratio_cot, self.vars.expected_cot_discount_ratios[i],
                                   delta=self.vars.tol,
                                   msg=f"COT discount ratio mismatch at index {i}")

            self.assertAlmostEqual(cot_cs.forwardRate(i), self.vars.expected_forward_rates_lmm_cms_cot[i],
                                   delta=self.vars.tol,
                                   msg=f"COT forward rate mismatch at index {i}")

            self.assertAlmostEqual(cot_cs.coterminalSwapRate(i), todays_coterminal_swap_rates_list[i],
                                   delta=self.vars.tol,
                                   msg=f"COT swap rate mismatch at index {i}")

            # coterminalSwapAnnuity(numeraire index for P_num, index of swap rate)
            calculated_cot_annuity = cot_cs.coterminalSwapAnnuity(self.vars.numeraire, i)
            self.assertAlmostEqual(calculated_cot_annuity, self.vars.expected_cot_swap_annuity[i],
                                   delta=self.vars.tol,
                                   msg=f"COT swap annuity mismatch at index {i}")


    def test_cm_swap_curve_state(self):
        print("Testing constant-maturity-swap-market-model curve state...")

        cms_drift_calculator = ql.CMSMMDriftCalculator(
            self.vars.pseudo_qm, self.vars.displacements_qv, self.vars.taus_qv,
            self.vars.numeraire, self.vars.first_alive_rates_sv[0], self.vars.spanning_fwds
        )
        cms_cs = ql.CMSwapCurveState(self.vars.rate_times_qv, self.vars.spanning_fwds)
        # In CMS, setOnCMSwapRates expects rates for swaps of a constant maturity (spanningFwds)
        # The C++ test uses todaysForwards as input, which implies CMS rates are equal to forward rates
        # for a 1-period swap if spanningFwds = 1.
        cms_cs.setOnCMSwapRates(self.vars.todays_forwards_qv)

        cms_drifts_calculated = ql.DoubleVector(self.vars.N)
        cms_drift_calculator.compute(cms_cs, cms_drifts_calculated)

        for i in range(self.vars.N):
            self.assertAlmostEqual(cms_drifts_calculated[i], self.vars.expected_drifts_lmm_cms[i],
                                   delta=self.vars.tol,
                                   msg=f"CMS drift mismatch at index {i}")

            calculated_discount_ratio_cms = cms_cs.discountRatio(i, self.vars.numeraire)
            self.assertAlmostEqual(calculated_discount_ratio_cms, self.vars.expected_discount_ratios_lmm_cms[i],
                                   delta=self.vars.tol,
                                   msg=f"CMS discount ratio mismatch at index {i}")

            self.assertAlmostEqual(cms_cs.forwardRate(i), self.vars.expected_forward_rates_lmm_cms_cot[i],
                                   delta=self.vars.tol,
                                   msg=f"CMS forward rate mismatch at index {i}")

            # cmSwapRate(index of swap start, number of spanning forwards)
            calculated_cms_rate = cms_cs.cmSwapRate(i, self.vars.spanning_fwds)
            self.assertAlmostEqual(calculated_cms_rate, self.vars.expected_forward_rates_lmm_cms_cot[i],
                                   delta=self.vars.tol,
                                   msg=f"CMS swap rate mismatch at index {i}")

            # cmSwapAnnuity(numeraire index, index of swap start, number of spanning forwards)
            calculated_cms_annuity = cms_cs.cmSwapAnnuity(self.vars.numeraire, i, self.vars.spanning_fwds)
            self.assertAlmostEqual(calculated_cms_annuity, self.vars.expected_swap_annuity_cms[i],
                                   delta=self.vars.tol,
                                   msg=f"CMS swap annuity mismatch at index {i}")


if __name__ == '__main__':
    print("Testing QuantLib " + ql.__version__)
    # Set a default evaluation date for consistency if TopLevelFixture would do so.
    # ql.Settings.instance().evaluationDate = ql.Date(15, ql.May, 2003) # Example from a similar test
    # For this test, CommonVars reads the current eval date. Let's assume it's set.
    unittest.main(argv=['first-arg-is-ignored'], exit=False)