Create a Python class named `FourierSeriesExpansion` that performs the following tasks:

* `__init__` method should take a mathematical `sympy` function $f(x)$ and an interval $[a, b]$ as input parameters.

* Implement a method called `calculate_fourier_series` that uses `sympy`'s `fourier_series` function to compute the Fourier series expansion of the function on the specified interval. The method should allow the user to specify the number of terms to which the series should be truncated.

* Implement another method called `plot_function_and_series` that uses the `matplotlib` library to plot both the original function $f(x)$ and its Fourier series expansion.
* Use $f(x) = x^2\sin(x)$ and $[a, b] = [-1, 1]$ as a test function and interval respectively.

In [None]:
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt


class FourierSeriesExpansion:
    def __init__(self, f, a, b):
        self.f = f
        self.a = a
        self.b = b
        self.x = sp.symbols("x")

    def calculate_fourier_series(self, n_terms):
        fourier_series = sp.fourier_series(self.f, (self.x, self.a, self.b)).truncate(
            n=n_terms
        )
        return fourier_series

    def plot_function_and_series(self, n_terms):
        fourier_series = self.calculate_fourier_series(n_terms)

        x_vals = np.linspace(self.a, self.b, 400)
        f_lambda = sp.lambdify(self.x, self.f, ["numpy", "sympy"])
        series_lambda = sp.lambdify(self.x, fourier_series, ["numpy", "sympy"])

        fig = plt.figure(figsize=(10, 6))
        plt.plot(x_vals, f_lambda(x_vals), label="Original Function", color="blue")
        plt.plot(
            x_vals,
            series_lambda(x_vals),
            label=f"Fourier Series (n={n_terms})",
            color="red",
            linestyle="--",
        )
        plt.title("Function and its Fourier Series Expansion")
        plt.xlabel("x")
        plt.ylabel("f(x)")
        plt.legend()
        plt.grid()
        plt.show()
        plt.close(fig)


x = sp.symbols("x")
f = x**2 * sp.sin(x)
a = -1
b = 1

fourier_expansion = FourierSeriesExpansion(f, a, b)
fourier_expansion.plot_function_and_series(n_terms=22)
