## Solve a quadratic equation, $x^2 + b x + c = 0$.


Write a function which receives $b$ and $c$, the coefficients of a monic quadratic polynomial, $x^2 + b x + c$, and returns the pair of its roots. Your function should always return two values, even if quadratic has a double root.

For example, given a quadratic $x^2 - 2x + 1$, your function should return a pair of $(1, 1)$. Of course, in floating point, your answers may differ from an exact unity.

Your function also must correctly handle complex roots (to this end, you might need the `cmath` module from the standard library).

Test the your function on several examples against a calculation by hand. Once you're sure that your function works, try these five test cases below. 

Note that the last two test cases are special: they test whether your function handles extreme cases where a too simple approach is prone to a catastrophic cancellation. Make sure your function passes all five tests.

This exercise is graded, each test case contributes a 20% of the grade. 

In [171]:
import math as math
import cmath as cmath
import numpy as np

In [188]:
def solve_quad(b, c):
    """Solve a quadratic equation, x**2 + bx + c = 0.
    
    Parameters
    ----------
    b, c : float
       Coefficients
       
    Returns
    -------
    x1, x2 : float or complex
       Roots.
    """
    if (b == 0):
        return (cmath.sqrt(c), -cmath.sqrt(c))
    elif (np.isclose(c/b, 0, atol = 1e-6)): 
        return (-b, -c/b) # Using Vieta's formula in case of much bigger b
    elif (b != 0):
        if (c != 0):
            discriminant = count_discriminant(b, c)
            if (discriminant > 0):
                discriminant_sqrt = math.sqrt(discriminant)
                return (-b + discriminant_sqrt)/2.0, (-b - discriminant_sqrt)/2.0
            elif (discriminant < 0):
                complex_discriminant_sqrt = cmath.sqrt(discriminant)
                return (-b + complex_discriminant_sqrt)/2.0, (-b - complex_discriminant_sqrt)/2.0
            else:
                return (-b/2.0, -b/2.0)
        else:
            return 0, -b

In [189]:
def count_discriminant(b, c):
    return b*b - 4*c

In [190]:
from numpy import allclose

In [191]:
variants = [{'b': 4.0, 'c': 3.0},
            {'b': 2.0, 'c': 1.0},
            {'b': 0.5, 'c': 4.0},
            {'b': 1e10, 'c': 3.0},
            {'b': -1e10, 'c': 4.0},]

In [192]:
for var in variants:
    x1, x2 = solve_quad(**var)
    print(x1, x2)
    print(allclose(x1*x2, var['c']))

-1.0 -3.0
True
-1.0 -1.0
True
(-0.25+1.984313483298443j) (-0.25-1.984313483298443j)
True
-10000000000.0 -3e-10
True
10000000000.0 4e-10
True
