# **<u>Part 1</u>:** Idiomatic Python: Writing Pythonic Code

### 
<img src="img/idiom.png" alt="Drawing" width="1200"/>

### 

### <u>**Idiomatic Python**</u> refers to constructs and coding practices that are unique
### to the Python programming language and often lack direct equivalents in other languages.

### 

### <u>**In other words**</u>: idiomatic Python embraces a "pythonic" style that demonstrates a deep
### understanding of the language's syntax, built-in modules, and conventions.

### 

### <u>Idiomatic Python emphasizes</u>:
### - **Readability**: Code that is easy to understand and maintain.
### - **Simplicity**: Clear and straightforward solutions.
### - **Efficiency**: Writing concise code without sacrificing clarity.

### 

### <u>Here an inspirational quote</u>:

### _"Code is read much more often than it is written."_
###   $\qquad$ – Kenneth Reitz

### 

### <u>Why Learn Idiomatic Python?</u> 
### Writing Pythonic code not only makes your work more elegant and maintainable but also
### helps it integrate seamlessly with other Python projects and libraries. 
### In this lecture, we'll explore how to write code that is not only functional but also 
### embodies the best practices of the Python community.
### 
<hr style="border:1px solid blue">

### 

## <u>Some general remarks about the lessons</u>:
### Try not to focus on the details too much. Try to see the bigger picture.
### If you see something that you think will help you, you can always
### come back to the notebook and try to understand it better !
### 
<hr style="border:1px solid blue">

### 

## <u>How to Navigate These Notebooks</u>:
### The notebooks are comprised of code cells that teach us Python programming concepts.
### Sometimes these cells can be a bit long.
### To help us navigate them, code snippets will typically be provided with tags.
### 
### The four tags are:
### - `[IMPORT]`: Imports relevant modules / packages such as `Numpy`. <u>Straightforward</u>.
### - `[HELPER]`: Defining utility functions. <u>Details can be safely ignored</u>.
### - `[SETUP]`: Defining variables. <u>Usually straightforward. If unclear, don't focus</u>.
### - `[FOCUS]`: As the name suggests. <u> Focussing on this part is essential</u>. 

### 

### I did my best not to repeat any of these tags but they may appear more than once.

### 

### <u>A basic example:</u>
### 

In [None]:
""" [IMPORT] """
import numpy as np
from matplotlib import pyplot as plt


""" [HELPER] """
def plot_xy(x, y):
    fig, ax = plt.subplots()
    ax.set_title("Plot x vs y")
    ax.plot(x, y)
    plt.show()


""" [SETUP] """
x = np.linspace(0, 2 * np.pi, 101)


""" [FOCUS] """
# the wrong way first
y_wrong = np.zeros(x.shape)
for i in range(len(x)):
    y_wrong[i] = np.sin(x[i])

# now the right way
y_right = np.sin(x)


""" [HELPER] """
plot_xy(x, y_right)

### 
### Sometimes the tags are accompanied by a short explanation. 
### This is done when the explanation is easier to understand than reading the code.
### 

In [None]:
""" [IMPORT] """
import numpy as np
from matplotlib import pyplot as plt


""" [HELPER] - func to plot `x` vs `y` """
def plot_xy(x, y):
    fig, ax = plt.subplots()
    ax.set_title("Plot x vs y")
    ax.plot(x, y)
    plt.show()


""" [SETUP] - create a linspace `x` from 0 to 2π """
x = np.linspace(0, 2 * np.pi, 101)


""" [FOCUS] """
# the wrong way first
y_wrong = np.zeros(x.shape)
for i in range(len(x)):
    y_wrong[i] = np.sin(x[i])

# now the right way
y_right = np.sin(x)


""" [HELPER] - plot the result """
plot_xy(x, y_right)

### 
### Besides helping us navigate the notebooks, the tags also ensure that yours truly, the lecturer,
### does not forget to mention important stuff ;-)
### 
### 
# <u>Lesson 1</u>: what makes code readable ?

### 

### <u>Task</u>: take the union of two meshes.
### Given two sets of elements and their corresponding points,
### create a union mesh from the two input meshes without duplicate elements and points.

In [None]:
# PREPARATIONAL CELL FOR CONCRETE IMPLEMENTATIONS BELOW


""" [IMPORT] """
import numpy as np
from matplotlib import pyplot as plt


""" [HELPER] - function to plot one or several meshes """
def plot_meshes(list_of_elements, list_of_points):

    fig, ax = plt.subplots()
    for elems, points in zip(list_of_elements, list_of_points):
        ax.triplot(*points.T, elems, alpha=0.5)

    plt.show()


""" [SETUP] - create mesh defined by element indices `elems0` and point coordinates `points0` """
elems0 = np.array([ [0, 1, 4],
                    [4, 1, 5],
                    [1, 2, 5],
                    [5, 2, 6],
                    [2, 3, 6],
                    [6, 3, 7],
                    [4, 5, 8],
                    [8, 5, 9],
                    [5, 6, 9],
                    [9, 6, 10],
                    [6, 7, 10],
                    [10, 7, 11] ])

points0 = np.stack(list(map(np.ravel, np.meshgrid( np.linspace(0, 3, 4),
                                                   np.linspace(0, 2, 3) ))), axis=1)


""" [SETUP] - create second mesh `(elems1, points1)` by shifting `points0` +2 in the x-direction """
elems1 = elems0
points1 = points0 + np.array([[2, 0]])


""" [HELPER] - plot both meshes """
plot_meshes([elems0, elems1], [points0, points1])

### 
<hr style="border:1px solid blue">

### In the following a few implementations that represent different levels of Python expertise.

### 

### 1. The absolute bloody beginner solution

#### (seeing it breaks my heart)

In [None]:
""" [SETUP] - define utility containers / variables """
elements = [elems0, elems1]
points = [points0, points1]

map_point_index = {}
index = 0
new_elements = []
new_points = []
seen = set()


""" [FOCUS] - on the structure, not the details """
for i in range(len(elements)):
    myelems = elements[i]
    mypoints = points[i]
    my_new_elems = []
    for j in range(len(myelems)):
        my_new_elem = []
        myelement = myelems[j]
        for k in range(len(myelement)):
            myindex = myelement[k]
            mypoint = tuple(mypoints[myindex])
            if (mypoint in map_point_index) == False:
                map_point_index[mypoint] = index
                index = index + 1
                new_points.append(mypoint)
            my_new_elem.append(map_point_index[mypoint])
        my_identifier = tuple(sorted(my_new_elem))
        if (my_identifier in seen) == False:
            my_new_elems.append(my_new_elem)
            seen.add(my_identifier)
    new_elements.append(my_new_elems)


""" [SETUP] - convert to `numpy` arrays for plotting """
new_elements = np.concatenate(new_elements)
new_points = np.array(new_points)


""" [HELPER] - plot the mesh union, print number of cells and points """
plot_meshes([new_elements], [new_points])


print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elements))

### 

### 2. The somewhat more idiomatic solution

#### (still not a good code)

In [None]:
""" [SETUP] - define utility containers / variables """
elements = [elems0, elems1]
points = [points0, points1]

map_point_index = {}
index = 0
new_elements = []
new_points = []
seen = set()


""" [FOCUS] """
for myelems, mypoints in zip(elements, points):
    my_new_elems = []
    
    for myelement in myelems:
        my_new_elem = []
        
        for myindex in myelement:
            mypoint = tuple(mypoints[myindex])
            
            if mypoint not in map_point_index:
                map_point_index[mypoint] = index
                index += 1
                new_points.append(mypoint)
                
            my_new_elem.append(map_point_index[mypoint])
            
        my_identifier = tuple(sorted(my_new_elem))
        
        if my_identifier not in seen:  # use plain english
            my_new_elems.append(my_new_elem)
            seen.add(my_identifier)
            
    new_elements.append(my_new_elems)


""" [SETUP] - convert to `numpy` arrays for plotting """
new_elements = np.concatenate(new_elements)
new_points = np.array(new_points)


""" [HELPER] - plot and print """
plot_meshes([new_elements], [new_points])


print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elements))

### 
<hr style="border:1px solid blue">

## <u>Exercise 1.1</u>:
### What does impl. 2. do better than 1. ? What do you still find difficult to read ?

<hr style="border:1px solid blue">

### 

### 3. A good solution that does not use `Numpy` in `[FOCUS]`

In [None]:
""" [IMPORT] - import additional functionality from Python built-in modules """
from itertools import count
from collections import defaultdict


""" [SETUP] - define utility containers """
elements = [elems0, elems1]
points = [points0, points1]

map_point_index = defaultdict(count().__next__)  # This needs further explanation, I guess ...
seen = set()
new_elems = []


""" [FOCUS] """
for myelems, mypoints in zip(elements, points):
    for elem in myelems:
        new_elem = [map_point_index[point] for point in map(tuple, mypoints[elem])]
        # [0, 2, 1] and [2, 1, 0] are the same element, sort to avoid counting twice.
        if ( sorted_elem := tuple(sorted(new_elem)) ) not in seen:
            new_elems.append(new_elem)
            seen.add(sorted_elem)


""" [SETUP] - convert to `numpy` arrays for plotting """
new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))


""" [HELPER] - plot and print """
plot_meshes([new_elems], [new_points])


print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 
<hr style="border:1px solid blue">

## <u>Exercise 1.2</u>:
### Besides fewer lines of code and using built-in module functionality, what is better about impl. 3's **structure** compared to 1 + 2 ?

<hr style="border:1px solid blue">

### 

### 4. A fancy but somewhat overengineered solution 
### (use it to impress your friends)

In [None]:
""" [IMPORT] - import additional functionality from Python built-in modules """
from itertools import count
from collections import defaultdict


""" [SETUP] - define utility containers """
elements = [elems0, elems1]
points = [points0, points1]

map_point_index = defaultdict(count().__next__)
seen = set()
new_elems = []


""" [FOCUS] - shorter doesn't automatically mean better """
for myelems, mypoints in zip(elements, points):
    for elem in myelems:
        new_elem = [map_point_index[point] for point in map(tuple, mypoints[elem])]
        (sorted_elem := tuple(sorted(new_elem))) in seen or new_elems.append(new_elem) or seen.add(sorted_elem)


""" [SETUP] - convert to `numpy` arrays for plotting """
new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))


""" [HELPER] - plot and print """
plot_meshes([new_elems], [new_points])


print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 
### 5. The `Numpy` solution 
#### (adding comments is super helpful and does not hinder readability, a little less memory efficient, completely avoids indentation) 

In [None]:
""" [IMPORT] """
from itertools import count


""" [FOCUS] - the body of the function uses Numpy functionality """
# get all unique points of the two sets of points
new_points = np.unique(np.concatenate([points0, points1]), axis=0)

# map each unique point to an index
map_point_index = dict(zip(map(tuple, new_points), count()))

# map both meshes' elements' points to the new index
mapped_elems = np.apply_along_axis(lambda x: map_point_index[tuple(x)],
                                   axis=-1,
                                   arr=np.concatenate([points0[elems0], points1[elems1]]))

# find the indices of the first occurence of the transformed elements that are unique
_, unique_indices = np.unique(np.sort(mapped_elems, axis=1), return_index=True, axis=0)

# keep only the unique occurences
new_elems = mapped_elems[unique_indices]


""" [HELPER] - plot and print """
plot_meshes([new_elems], [new_points])


print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 
<hr style="border:1px solid blue">

## <u>Exercise 1.3</u>:
### Which of these five is the "best" implementation ?
### **HINT No. 1**: It's neither 1. nor 2. ...
### **HINT No. 2**: The answer to this question, to some extend, boils down to your programming "philosophy"

<hr style="border:1px solid blue">

### 
## What makes python code readable ?

### What we know so far:

* Using idioms (zip, enumerate, itertools, collections, comprehensions, := ) to reduce boilerplate.
* Fewer (but not too few) lines of code.
* Using => => => indentation instead of => <= => <=.
* Descriptive variable names.
* A program flow that almost reads like english.
* Comments and using library functionality (for instance numpy) where possible.

### 
<hr style="border:1px solid blue">

### 
## <u>Some open questions</u>

### 
### <u>Open question</u>: which one is more pythonic ?

In [1]:
""" [HELPER] - Fake first-order derivative function """
def _first_order_derivative(func):
    # return derivative
    pass
    

""" [FOCUS] - version 0 or 1 ? """
def nth_derivative_v0(func, n=1):
    n = int(n)
    assert n >= 0
    if n == 0:
        return func
    else:
        return nth_derivative_v0(_first_order_derivative(func), n=n-1)
    

def nth_derivative_v1(func, n=1):
    assert (n := int(n)) >= 0
    if n == 0:
        return func
    return nth_derivative_v1(_first_order_derivative(func), n=n-1)

### 

### <u>Open question</u>: which meshgrid call signature do you find more readable ?

In [2]:
""" [IMPORT] """
import numpy as np


""" [HELPER] - Two methods that make a meshgrid """
def make_meshgrid0(dims):
    return np.stack(np.meshgrid(*map(np.arange, dims)), axis=-1).reshape(-1, len(dims))

def make_meshgrid1(*dims):
    return np.stack(np.meshgrid(*map(np.arange, dims)), axis=-1).reshape(-1, len(dims))


""" [FOCUS] - a or b ? """
a = make_meshgrid0([3, 2, 2])
b = make_meshgrid1(3, 2, 2)


""" [HELPER] - ensure a and b are really the same """
print('a equals b: ', (a == b).all())


# both generalise to arbitrary dimensions
a = make_meshgrid0([3, 2])
b = make_meshgrid1(3, 2, 2, 4)

a equals b:  True


### 
### <u>Open question</u>: which one do you think is more pythonic ?

In [3]:
# Find the greatest divisor of an integer (excluding self) and return it.
# Also, print all other (smaller) divisors, if any.


""" [FOCUS] - two versions for greatest divisor excluding `val` itself """
def greatest_divisor_v0(val):
    *other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]
    if other_divisors:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
    return greatest_divisor


def greatest_divisor_v1(val):
    divisors = [i for i in range(1, val) if val % i == 0]
    other_divisors, greatest_divisor = divisors[:-1], divisors[-1]
    if len(other_divisors) > 0:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
    return greatest_divisor


""" [HELPER] - ensure same output """
print('Greatest divisor of 10: ', greatest_divisor_v0(10), '\n')
print('Greatest divisor of 10: ', greatest_divisor_v1(10))

For 10, I also found the divisors [1, 2].
Greatest divisor of 10:  5 

For 10, I also found the divisors [1, 2].
Greatest divisor of 10:  5


### 
## What makes a code more pythonic ?
##  (usually, that means more readable)

### <u>Some additional insights</u>:

* Early handling of special cases, avoiding if-else indentation when redundant (example 1).
* Avoiding nested parentheses: `f(a, b, c)` instead of `f([a, b, c])` (example 2).
* using star syntax `*args` where possible (examples 2 + 3).
* `if object` rather than `if (object has property)`. For instance `if other_divisors` instead of `if len(other_divisors) > 0` (example 3).
### 

<hr style="border:1px solid blue">

### 
### Now that we have an idea of what makes a code pythonic, let's move on to the next

# <u>Lesson 2</u>: 
## `*args, **kwargs`, star syntax variable unpacking

In [None]:
%reset -f

### 
### An often made mistake (even among experienced programmers!) is incorrect argument forwarding.
### Here an example:
### 

In [None]:
""" [IMPORT] """
import numpy as np
from scipy.optimize import root_scalar


""" [FOCUS] """
def find_root_of_quadratic_func(a, b, c, args=(),
                                         method=None,
                                         bracket=None,
                                         fprime=None,
                                         fprime2=None,
                                         x0=0,
                                         x1=None,
                                         xtol=None,
                                         rtol=None,
                                         maxiter=None,
                                         options=None):
    
    func = lambda x: a * x**2 + b * x + c
    
    return root_scalar(func, args=args,
                             method=method,
                             bracket=bracket,
                             fprime=fprime,
                             fprime2=fprime2,
                             x0=x0,
                             x1=x1,
                             xtol=xtol,
                             rtol=rtol,
                             maxiter=maxiter,
                             options=options)


""" [HELPER] - find root of f(x) = x^2 - 2x + 1 """
print("The root of f(x) = x^2 - 2x + 1 equals x = {0:.6f}".format(find_root_of_quadratic_func(1, -2, 1, maxiter=20).root))

### 
### This lesson will teach us how to avoid making mistakes like the above and utilize idiomatic
### Python **star syntax** for writing shorter and more maintainable code.
### 
### What does this dummy function do ?

In [4]:
""" [HELPER] - f(*args, **kwargs), print whatever `args` and `kwargs` are """

def dummy_function(*args, **kwargs):
    print('Received the following args and kwargs: \n')
    print('args: ', args, '\n')
    print('kwargs: ', kwargs, '\n')

### 
### 1. We pass only positional arguments

In [5]:
""" [FOCUS] - pass pos. only """

dummy_function(1, 2, 'a', [1, 2, 3])

Received the following args and kwargs: 

args:  (1, 2, 'a', [1, 2, 3]) 

kwargs:  {} 



### 
### 2. Only keyword arguments

In [6]:
""" [FOCUS] - keyword only """

dummy_function(a=1, b=2, c='a', d=[1, 2, 3])

Received the following args and kwargs: 

args:  () 

kwargs:  {'a': 1, 'b': 2, 'c': 'a', 'd': [1, 2, 3]} 



### 
### 3. Mixed

In [7]:
""" [FOCUS] - mixed pos. and keyword """

dummy_function(1, 'a', (5, 6), a=5, b='a', c=[5, 6])

Received the following args and kwargs: 

args:  (1, 'a', (5, 6)) 

kwargs:  {'a': 5, 'b': 'a', 'c': [5, 6]} 



### 
### <u> Observation</u>:
### for a function of the form `f(*args, **kwargs)`:
### 1. all positional arguments passed are inside of the function available as a tuple called `args`
### 2. all keyword arguments passed are available as a dictionary `kwargs` of the form `{str(keyword): passed_value}`

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 2.1</u>: 
### write a function that takes positional arguments (numbers) and adds them.
### For example:  `f(1, 2) = 3`; `f(4, 5, 6) = 15`; `f(1, 1, 2, 3) = 7`
### By convention: `f() = 0`

In [None]:
""" [FOCUS] - your code here """
# def f( ? ? ? ):
    # your code here
    # return  ? ? ?


""" [HELPER] - verify result """
print(f())
print( f(1, 2) )
print( f(4, 5, 6) )
print( f(1, 1, 2, 3) )

### solution:

In [None]:
""" [FOCUS] """
def f(*args):
    return sum(args)


""" [HELPER] """
print(f())
print( f(1, 2) )
print( f(4, 5, 6) )
print( f(1, 1, 2, 3) )

### 
<hr style="border:1px solid blue">

### 
## <u> Exercise 2.2 </u>: 
### write a function that takes only keyword arguments:
### `university_name = (grade0, grade1, grade2, ...)`
### and prints the average grade for each group in alphabetical order.
### Swiss grading scheme, 1 (worst) to 6 (best).

### for instance: 
### `average_grade(Delft=(3, 5, 3, 5), ETHZ=(5, 5, 2, 2, 4), EPFL=(4, 5, 3), Pavia=(6, 5, 6, 5, 5))`

In [None]:
""" [FOCUS] """
# def average_grade(? ? ?):
#     your code here
#     return ???


""" [HELPER] - plot average """
average_grade(Delft=(3, 5, 3, 5),
              ETHZ=(5, 5, 2, 2, 4),
              EPFL=(4, 5, 3),
              Pavia=(6, 5, 6, 5, 5))

### 
### Solution:

In [None]:
""" [FOCUS] """
def average_grade(**kwargs):
    map_univ_average = {univ: sum(grades) / len(grades) for univ, grades in kwargs.items()}
    for univ, average in sorted(map_univ_average.items(), key=lambda x: x[0]):  # sort by university name
        print(f"University name: {univ}, average grade: {average}")

""" [HELPER] - plot average """
average_grade(Delft=(3, 5, 3, 5),
              ETHZ=(5, 5, 2, 2, 4),
              EPFL=(4, 5, 3),
              Pavia=(6, 5, 6, 5, 5))

### 
<hr style="border:1px solid blue">

### 
### The `*args, **kwargs` syntax can also be used in the opposite way.
### Observe the following:

In [None]:
""" [HELPER] """
def compute_area(length, width, unit='meters'):
    area = length * width
    return f"The area is {area} square {unit}."


""" [SETUP] """
kwargs = {'unit': 'kilometers'}


""" [FOCUS] """
compute_area(2, 5, **kwargs)

### 
### We are allowed to drop a `{str(key): value}`  dictionary into a function that accepts keyword arguments via double dereferencing.
#### If the dict's `.keys()` are the same or a subset of the function's (keyword) arguments, it will work.

#### `test = f(a=5, b=2)` is equivalent to `kwargs = {'a': 5, 'b': 2}` and then `test = f(**kwargs)`

### 
<hr style="border:1px solid blue">

### 
## <u> Application: keyword argument forwarding</u>

![scipykwargs](img/scipykwargs.png)

### 
## <u>Exercise 2.3</u>:
### write a function that uses `scipy.optimize.minimize` to find the argmin of $a x^2 + b x + c$.
### The function must be able to forward relevant keyword arguments to `scipy.optimize.minimize`.
### Find the argmin of the quadratic function using `method='SLSQP'`.

### 
### <u>The wrong way</u>:
#### (doing this will result in capital punishment)

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize


""" [FOCUS] - uuuuuuugh, lots of boilerplate >,< """
def minimize_quadratic_function(a, b, c, args=(), 
                                         method=None, 
                                         jac=None,
                                         hess=None,
                                         hessp=None,
                                         bounds=None,
                                         constraints=(),
                                         tol=None,
                                         callback=None,
                                         options=None):
    fun = lambda x: a * x ** 2 + b * x + c
    x0 = 0  # initial guess
    return minimize(fun, x0, args=args,
                             method=method,
                             jac=jac,
                             hess=hess,
                             hessp=hessp,
                             bounds=bounds,
                             constraints=constraints,
                             tol=tol,
                             callback=callback,
                             options=options).x[0]


""" [HELPER] verify that the func finds the correct minimizer """
# min of f(x) = 0.5 x^2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(.5, 1, 1, method='SLSQP'))

### 
### Implement it correctly:

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize


""" [FOCUS] """
# def minimize_quadratic_function(a=1, b=2, c=3, ????):
#     fun = lambda x: a * x ** 2 + b * x + c
#     x0 = 0
#     return ???


""" [HELPER] verify that the func finds the correct minimizer """
# min of f(x) = 0.5 x^2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(a=.5, b=1, c=1, method='SLSQP'))

### 
### solution:

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize


""" [FOCUS] """
def minimize_quadratic_function(a=1, b=2, c=3, **scipykwargs):
    fun = lambda x: a * x ** 2 + b * x + c
    x0 = 0  # initial guess
    return minimize(fun, x0, **scipykwargs).x[0]


""" [HELPER] verify that the func. finds the correct minimizer """
# min of f(x) = 0.5 x^2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(a=.5, b=1, c=1, method='SLSQP'))    

### 
<hr style="border:1px solid blue">

### 
### Observe the following:

In [8]:
""" [HELPER] - Roots of f(x) = ax^2 + bx + c """
def abc_formula(a, b, c):
    discriminant = (b ** 2 - 4 * a * c)**.5
    return (-b - discriminant) / (2 * a), (-b + discriminant) / (2 * a)


""" [FOCUS] - same function but with a = 1 """
def abc_monic_wrong(b, c):
    return abc_formula(1, b, c)


def abc_monic_right(*args):
    # abc_monic_right(2, 3) => args = (2, 3)
    
    # we drop args back into `abc_formula`
    # abc_formula(1, *args) is then the same as abc_formula(1, 2, 3)
    
    return abc_formula(1, *args)


""" [HELPER] - verify """
print('The complex roots of f = x^2 + 3 * x + 1 read: {:.7g} and {:.7g}'.format(*abc_monic_right(3, 1)))

The complex roots of f = x^2 + 3 * x + 1 read: -2.618034 and -0.381966


### 
#### `test = f(1, 2, 3)` $\Longleftrightarrow$ `args = (1, 2, 3)` and then `test = f(*args)`

#### intuition: the `*` dereferencing ~ *removes the outer parentheses* ~
#### For tuple inputs: `f(*(1, 2, 3))` $\Longleftrightarrow$ `f(1, 2, 3)`
#### same for lists `f(*[1, 2, 3])` $\Longleftrightarrow$ `f(1, 2, 3)`
#### (and the same for any object that can be **iterated** over)

### 
<hr style="border:1px solid blue">

### 
## <u> Exercise 2.4 </u> :
### Given a function `f(*args)` that sums its pos. arguments.
### Write a function `g(head, tail)` that accepts two tuples / lists of numbers and sums all of them.
### `g([1, 2, 3], (4, 5, 6)) => f(1, 2, 3, 4, 5, 6) = 21` (note that one is a list and one is a tuple)

In [None]:
""" [HELPER] - f(i0, i1, ...) returns i0 + i1 + ... """
def f(*args):
    return sum(args)


""" [FOCUS] - your code here """
#def g(???):
#    ???


""" [HELPER] - verify result """
print( g([1, 2, 3], (4, 5, 6)) )

### solution:

In [None]:
""" [HELPER] """
def f(*args):
    return sum(args)
    

""" [FOCUS] """
def g(head, tail):
    # head = [1, 2, 3]
    # tail = (4, 5, 6)
    # => f(*head, *tail) = f(*[1, 2, 3], *(4, 5, 6)) = f(1, 2, 3, 4, 5, 6)
    return f(*head, *tail)


""" [HELPER] - verify """
print(g([1, 2, 3], (4, 5, 6)))

### 
<hr style="border:1px solid blue">

### 
### Having gained an intuition for `*args, **kwargs`, we are now in the position to learn **star syntax variable unpacking**

### what does this do ?

In [None]:
""" [SETUP] """
tail = (4, 5, 6)


""" [FOCUS] """
joined_tuple = (1, 2, 3, *tail)


""" [HELPER] - print """
print(joined_tuple)

### 
### $\implies$ We can drop tuples' / lists' contents into tuples / lists via dereferencing
### 
### We may even completely avoid the parentheses. 
### If so, `Python` will automatically infer them and create a `tuple`.
### 

In [None]:
""" [SETUP] """
tail = 4, 5, 6    # round brackets inferred


""" [HELPER] - print tail """
print("tail: ", tail)


""" [FOCUS] """
joined_tuple = 1, 2, 3, *tail    # round brackets inferred


""" [HELPER] """
print("joined tuple: ", joined_tuple)

###
### It is important to know when you can omit the parentheses.
### If you can omit, typically you should (it makes your code more readable).
### 
<hr style="border:1px solid blue">

### 
### <u>Further examples</u>:

In [None]:
""" [FOCUS] """


# multiple list / tuple unpacking
print( [1, 2, *[3, 4, 5], 6, 7, *(8, 9, 10)] )


# we can also dereference ranges into tuples / lists
print( (*[1, 2, 3], *range(4, 11)) )

### 
# <u>Exercise 2.5</u>:
### Looking at the above, how do you think you can merge two disjoint dictionaries ?
### <u>Template</u>:

In [None]:
""" [SETUP] - create two disjoint dicts """
dict0 = {'a': 0, 'b': 1}
dict1 = {'c': 2, 'd': 3}


""" [FOCUS] """
merged_dict = ### Your code here


""" [HELPER] - validate """
print(f"The merged dict is given by {merged_dict}.\n")

### <u>Solution</u>:

In [None]:
""" [STEUP] - create two disjoint dicts """
dict0 = {'a': 0, 'b': 1}
dict1 = {'c': 2, 'd': 3}


""" [FOCUS] """
merged_dict = {**dict0, **dict1}


""" [HELPER] - validate """
print(f"The merged dict is given by {merged_dict}.\n")

### 
<hr style="border:1px solid blue">

### 
### What do you guys think this does ?

### `(a, *tail) = [1, 2, 3, 4]`
### What is `a` ? what is `tail` ?
### 
### <u>the answer</u>:

In [None]:
""" [SETUP] """
(a, *tail) = [1, 2, 3, 4]


""" [HELPER] - validate """
print(a)
print(tail)

### 
### As before, we can forego the outer parentheses (both left and / or right) when no confusion is possible

In [None]:
""" [SETUP] """
a, *tail = 1, 2, 3, 4


""" [HELPER] - validate """
print(a)
print(tail)

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 2.6</u>
### Given a `tuple` of tuples of the form
### `tot = ((a, b, c), (d, e, f))`
### or more compactly:
### `tot = (a, b, c), (d, e, f)`, 
### unpack `a, d, e, f` into variables of the same name and create a list `bc` containing `b` and `c`
### **IN ONE LINE OF CODE**
### 
### The wrong solution (don't do it like this)

In [None]:
""" [SETUP] - define tuple of tuples """
tot = (1, 2, 3), (4, 5, 6)


""" [FOCUS] - the wrong way """
a, d, e, f = tot[0][0], tot[1][0], tot[1][1], tot[1][2]
bc = [tot[0][1], tot[0][2]]


""" [HELPER] - validate """
print(a, bc, d, e, f)

### 
### Implement it correctly:

In [None]:
""" [SETUP] """
tot = (1, 2, 3), (4, 5, 6)


""" [FOCUS] - do it right """
### your code here


""" [HELPER] - validate """
print(a, bc, d, e, f)

### 
### solution:

In [None]:
""" [SETUP] """
tot = (1, 2, 3), (4, 5, 6)


""" [FOCUS] - the right way """
(a, *bc), (d, e, f) = tot


""" [HELPER] - validate """
print(a, bc, d, e, f)

### 
<hr style="border:1px solid blue">

### 
### The same kind of unpacking works for numpy arrays

In [None]:
""" [IMPORT] """
import numpy as np


""" [SETUP] - create array of shape A.shape == (2, 3) """
A = np.arange(6).reshape(2, 3)


""" [HELPER] - print array """
print('A: \n\n', A, '\n\n')


""" [FOCUS] - unpack A """

# round brackets, square brackets, both are fine
[a0, a1, a2], (b0, b1, b2) = A 

print(f"(a0, a1, a2, b0, b1, b2): ({a0}, {a1}, {a2}, {b0}, {b1}, {b2}). \n")

### 
<hr style="border:1px solid blue">

### 

## <u>Application</u>
### Given a ($2$D) triangular element in $\mathbb{R}^3$ characterised by its vertices `A = [a, b, c]` (a, b, c row vectors, counterclockwise)
### compute the surface area of the triangle.

### 

### <u>the wrong solution</u>:

In [None]:
""" [IMPORT] """
import numpy as np


""" [SETUP] - create three points as matrix rows """
A = np.array([
              [0, 0, 0],    # a
              [1, 0, 1],    # b
              [0, 1, 0],    # c
             ])


""" [FOCUS] """
def triangle_surface_area(A: np.ndarray) -> float:
    assert A.shape == (3, 3)
    
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    a = A[0, :]
    b = A[1, :]
    c = A[2, :]
    
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J @ J.T
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant(G))
    a00 = G[0, 0]
    a01 = G[0, 1]
    a10 = G[1, 0]
    a11 = G[1, 1]
    
    return .5 * ((a00 * a11 - a01 * a10)**.5)


""" [HELPER] - validate """
print('The surface area reads: ', triangle_surface_area(A))

### 
### <u>The correct solution</u>:

In [None]:
""" [IMPORT] """
import numpy as np


""" [SETUP] - create three points as matrix rows """
A = np.array([
              [0, 0, 0],    # a
              [1, 0, 1],    # b
              [0, 1, 0],    # c
             ])


""" [FOCUS] """
def triangle_surface_area(A: np.ndarray) -> float:
    assert A.shape == (3, 3)
    
    # unpack the rows
    a, b, c = A
    
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J @ J.T
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant(G))
    (a00, a01), (a10, a11) = G
    return .5 * ((a00 * a11 - a01 * a10)**.5)


""" [HELPER] - validate """
print('The surface area reads: ', triangle_surface_area(A))

### 
### 
### 
### Why did it not work ???

In [None]:
""" [IMPORT] """
import numpy as np


""" [FOCUS] """

# J inside of `surface_area` is of shape J.shape == (3, 2)
J = np.random.randn(3, 2)

print("J @ J.T has shape:", (J @ J.T).shape)

### 
### `G = J @ J.T` has shape (3, 3) which is wrong. It should have been `G = J.T @ J` of shape `(2, 2)`.
### These errors happen easily and go unnoticed for a long long time ...
### Our idiomatic version caught this error because the unpacking implicitly
### assumed `G` to be of shape `G.shape == (2, 2)`.
### 
### <u>The correct solution</u> (this time for real):

In [None]:
""" [IMPORT] """
import numpy as np


""" [SETUP] - create three points as matrix rows """
A = np.array([
              [0, 0, 0],    # a
              [1, 0, 1],    # b
              [0, 1, 0],    # c
             ])


""" [FOCUS] """
def triangle_surface_area(A: np.ndarray) -> float:
    assert A.shape == (3, 3)
    
    # unpack the rows
    a, b, c = A
    
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J.T @ J
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant)
    (a00, a01), (a10, a11) = G
    return .5 * ((a00 * a11 - a01 * a10)**.5)


""" [HELPER] - validate """
print('The surface area reads: ', triangle_surface_area(A))

### 
### Note that the non-pythonic implementation did not catch the error and gave
### the wrong (but reasonable-looking) result `area = 0.5`.
### 
<hr style="border:1px solid blue">

### 
### We can use star syntax in combination with a `zip(...)` statement.
### 
## <u>Task</u>:
### You are given an array of elements and a number of local system matrix iterables.
### Your task is to write a FEM assembly routine.
### Write a routine that iterates simulatenously over the elements and system matrix iterables
### and places their sum at the correct position in the matrix

In [None]:
""" [IMPORT] """
from scipy.sparse import lil_matrix
import numpy as np


""" [HELPER] - fake local mass and stiffness matrix iterators """
def mass_matrix_iter():
    while True:
        yield 1 + np.abs(np.random.randn(1)) * np.array([ [2, 1, 1],
                                                          [1, 2, 1],
                                                          [1, 1, 2] ])
        
def stiffness_matrix_iter():
    while True:
        yield 1 + np.abs(np.random.randn(1)) * np.array([ [2, -1, -1],
                                                          [-1, 2, -1],
                                                          [-1, -1, 2] ])


""" [SETUP] - create mesh element array and empty sparse matrix of appropriate size """

# only elements, points are irrelevant for this exercise
elements = np.array([ [0, 1, 4],
                      [4, 1, 5],
                      [1, 2, 5],
                      [5, 2, 6],
                      [2, 3, 6],
                      [6, 3, 7],
                      [4, 5, 8],
                      [8, 5, 9],
                      [5, 6, 9],
                      [9, 6, 10],
                      [6, 7, 10],
                      [10, 7, 11] ])

A = lil_matrix((elements.max() + 1,)*2)


""" [FOCUS] - iterate simultaneously over elements and two (or more) local system matrix iterators """

for tri, *system_matrices in zip(elements, mass_matrix_iter(),
                                           stiffness_matrix_iter()):
    
    # add the sum to the right position in the matrix
    A[np.ix_(tri, tri)] += np.add.reduce(system_matrices)
    
    
print(A.todense())

### 
<hr style="border:1px solid blue">

### 
### `*args, **kwargs` and star syntax unpacking are among the most powerful tools 
### for a more readable and maintainable code.
### 
# <u> What we have learned: </u>
### 1. Functions `f(*args)`, `f(**kwargs)` allow for a variable number of arguments.
### 2. Star syntax tends to be more readable (if familiar) by avoiding boilerplate.
### 3. Star syntax can help you write safer code.
### 

### Star syntax can be used for:
* argument forwarding;
* list / tuple packing / unpacking;
* and many more ...

### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 3</u>: Advanced uses of Python dictionaries

In [None]:
%reset -f

### 
### We have seen how dictionaries play an important role in handling keyword arguments.
### In what follows, we discuss some advanced uses of python dictionaries.

### 
### Python dictionaries can be utilised to avoid if-else clauses via tokenization.
### 
### <u>Task</u>:
### Write a function `solve(A, b, method='direct', **solverkwargs)` that solves $A x = b$ for $x$ using
* `method='direct'`: `sparse.linalg.spsolve`
* `method='cg'`: `sparse.linalg.cg`
* `method='gmres':` `sparse.linalg.gmres`
* `method='bicgstab':` `sparse.linalg.bicgstab`

### 
### A straightforward but cumbersome solution:

In [None]:
""" [IMPORT] """
from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np


""" [SETUP] - define A and b """
diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)


""" [FOCUS] - I see a big big big boilerplate ... """
def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    if method == 'direct':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.spsolve(A, b, **solverkwargs)
    elif method == 'cg':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.cg(A, b, **solverkwargs)
    elif method == 'gmres':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.gmres(A, b, **solverkwargs)
    elif method == 'bicgstab':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.bicgstab(A, b, **solverkwargs)
    else:
        raise ValueError(f'Unknown method name {method}.')


""" [HELPER] - solve with bicgstab and solver tolerance tol=1e-7 """
print("solution of Ax = b:", solve(A, b, method='bicgstab', rtol=1e-7))

### 
### A pythonic solution:

In [None]:
""" [IMPORT] """
from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np

        
""" [SETUP] - define A and b """
diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)


""" [FOCUS] - no boilerplate """
def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    
    solver = { 'bicgstab': splinalg.bicgstab,
               'direct'  : splinalg.spsolve,
               'gmres'   : splinalg.gmres, 
               'cg'      : splinalg.cg        }.get(method, None)  # solver = None if token not found
    
    if solver is None:  # token not found, raise error.
        raise ValueError(f'Unknown method name {method}.')
        
    print(f'Solving sparse linear system with solver method: `{method}`.')

    return solver(A, b, **solverkwargs)


""" [HELPER] - validate """
print("solution of Ax = b:", solve(A, b, method='bicgstab', rtol=1e-7))

### 
### The `dict.get(key, default_value)` method can be used to return a default value in case a key has not been found.
### 
<hr style="border:1px solid blue">

### 
### In the following, we discuss dict.setdefault.
### 
### Let us see what this does:

In [None]:
""" [SETUP] - define dict and print it """

test_dict = {'a': 5, 'b': 10}
print('test: ', test_dict, '\n')


""" [FOCUS] """

test_dict.setdefault('a', 20)
print('test after the first setdefault operation: ', test_dict, '\n')

test_dict.setdefault('c', 15)
print('test after the second setdefault operation: ', test_dict, '\n')

### 
### The first `.setdefault` didn't do anything because the key was already contained.
### `dict.setdefault(key, value)` only changes the `dict` if `key not in dict`
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 3.1</u>:
### Using the `solve(A, b, ...)` function, write a function `solve_SPD(A, b, **kwargs)`
### which calls the `solve` function but uses the `'cg'` method unless another method token is passed.

In [None]:
# again, the solve(...) method from above, to be used in your implementation

""" [IMPORT] """
from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np


""" [HELPER] - solve method from above """
def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    
    solver = { 'bicgstab': splinalg.bicgstab,
               'direct'  : splinalg.spsolve,
               'gmres'   : splinalg.gmres, 
               'cg'      : splinalg.cg        }.get(method, None)  # solver = None if token not found
    
    if solver is None:  # token not found, raise error.
        raise ValueError(f'Unknown method name {method}.')
        
    print(f'Solving sparse linear system with solver method: `{method}`.')

    return solver(A, b, **solverkwargs)

In [None]:
""" [FOCUS] - Incorrect solution ! """

# We have to forward method=method in solve(...) explicitly.
# Easy to forget ... especially if you add more deviating default values.

def solve_SPD(A, b, method='cg', **kwargs):
    return solve(A, b, method=method, **kwargs)

### 
### Your solution here:

In [None]:
""" [SETUP] - make A and b """
diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)


""" [FOCUS] - your solution """
def solve_SPD(A, b, **kwargs):
    # your code here
    pass


""" [HELPER] - validate """

# don't specify the solver => should solve with `cg`
print(solve_SPD(A, b))

# use `bicgstab`
print(solve_SPD(A, b, method='bicgstab'))

### 
### Solution:

In [None]:
""" [SETUP] - make A and b """
diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)


""" [FOCUS] - your solution """
def solve_SPD(A, b, **kwargs):
    kwargs.setdefault('method', 'cg')
    return solve(A, b, **kwargs)


""" [HELPER] - validate """

# don't specify the solver => should solve with `cg`
print(solve_SPD(A, b))

# use `bicgstab`
print(solve_SPD(A, b, method='bicgstab'))

### 
<hr style="border:1px solid blue">

### 
### By invoking `var = dict.setdefault(key, value)`, we can optionally
### capture in `var` whatever the dictionary contains at `key` after the `.setdefault` operation.

In [None]:
""" [SETUP] """
test = {'a': 5, 'b': 10}


""" [FOCUS] """

var = test.setdefault('a', 20)
print(var)

var = test.setdefault('c', 30)
print(var)

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 3.2</u>:
### you are given a tuple of edges `edges = ((i0, i1), (i2, i3), ...)` representing the edges of a directed graph.
### An edge `edge = (i0, i1)` points from node `i0` to node `i1`. 
### The nodes are not necessarily numbered from `0` to `N` but can be anything.
### Write a code that counts the number of unique edges incident to each node. Some edges may be duplicated in `edges`.
### **Remember**: you have no idea what indices the nodes have.

In [None]:
""" [SETUP] - make a bunch of edges """
edges = (1, 12), (4, 12), (4, 1), (13, 6), \
        (12, 8), (8, 4), (6, 8), (1, 12), \
        (6, 4), (12, 8), (13, 6), (3, 13), (3, 4), (3, 1)


""" [HELPER] - plot the (directed) graph """

# you need networkx: `pip install networkx`
import networkx as nx
from matplotlib import pyplot as plt

G = nx.DiGraph()
G.add_edges_from(edges)

pos = nx.planar_layout(G)

nx.draw_networkx(G, arrows=True, pos=pos)
plt.show()

### 
### A crude (non-pythonic) solution:

In [None]:
""" [SETUP] """
map_node_root_vertex = {}


""" [FOCUS] - make sure your implementation does NOT do this """
for edge in edges:
    v0 = edge[0]  # get the root node
    v1 = edge[1]  # get the incident node
    if v1 not in map_node_root_vertex:
        map_node_root_vertex[v1] = set()
    map_node_root_vertex[v1].add(v0)  # add the root vertex to the set of nodes incident to i1


""" [HELPER] - number of incident edges to `node` = number of unique roots pointing to `node` """

map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
### Your solution using dict.setdefault

In [None]:
""" [SETUP] """
map_node_root_vertex = {}


""" [FOCUS] """
# your code here


""" [HELPER] - count number of incident edges and print result """
map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
### solution:

In [None]:
""" [SETUP] """
map_node_root_vertex = {}


""" [FOCUS] """
for v0, v1 in edges:
    map_node_root_vertex.setdefault(v1, set()).add(v0)  # if no set at v1 yet, create one, return set and add to it


""" [HELPER] - count number of incident edges and print result """
map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. In `Python`, dictionaries play a central role. Advanced use of them is considered **good practice**.
### 2. Making it a habit to use methods like `dict.setdefault` and 
### `dict.get(item, default_value)` will make your code safer and more readable.
### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 4</u>: Python _"truthiness"_ 

In [None]:
%reset -f

### 
### **ALL** Python types (`dict`, `set`, `list`, ..., custom types) can be converted into a boolean.
### This can be done by invoking `bool(var)`.
### 
### Unless otherwise specified, custom objects (that do not derive from built-ins) convert to `True`.
### 
### The Python developers have chosen intuitive rules as to whether a built-in type should convert
### into `True` or `False`.
### 
### Let's see if our intiution is correct.

### 
## <u>Exercise 4.1</u>: Does `bool(var)` convert to `True` or `False` ?

In [None]:
# what do you think ? (you'll get most of them right)


""" [FOCUS] """
variables = [ 
              True,          # I guess we can all agree what this one is converted into ;)
              False,         # idem
              None,          # untyped null pointer
              {},            # empty dict
              {1: 2},        # nonempty dict
              [],            # empty list
              [1, 2],        # nonempty list
              tuple(),       # empty tuple
              (1, 2),        # nonempty tuple
              0,             # zero integer
              0.0,           # zero float
              1.0,           # positive float
              -1.,           # negative integer
              "",            # empty string
              "Connie",      # nonempty string
            ]


""" [HELPER] - validate """
for var in variables:
    print('var: {}, bool(var): {}'.format(var, bool(var)), '\n')

### 
<hr style="border:1px solid blue">


### 
### Now that we have an idea of how various variable types transform into a boolean,
### we are in the position to understand the `if other_divisors:` statement from **Lesson 1**.

In [None]:
""" [FOCUS] - greatest divisor (excl. self), print all other divisors, if any """
def greatest_divisor_excluding_self(val):
    *other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]
    
    if other_divisors:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
        
    return greatest_divisor


""" [HELPER] - validate """
print('The greatest divisor (excl self) of 10 equals:', greatest_divisor_excluding_self(10))

### 
### First, we break down the line:
### 
### `*other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]`
### 
* the rhs creates a monotone increasing list with at least one element `1`.
* the left hand side peels off the last (and hence largest) element in that list, while collecting all remaining divisors in the list `other_divisors`
* if the only divisor is `[1]`, `other_divisors == []`, else `other_divisors != []`.

### 
### Then, the line `if other_divisors:`
### 
* in the next line `if other_divisors:`, Python, under the hood, converts this line to `if bool(other_divisors):`
* from before, we know that if `other_divisors == []`, `if other_divisors:` becomes `if False:` and the if clause is ignored.
* On the other hand, if `other_divisors != []` (nonempty), then `if other_divisors:` converts to `if True:` and the if clause's code is executed.
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 4.2</u>:
### write a function `average(*numbers)` that computes the average of the numbers
### in `numbers` and returns `0.0` if no numbers have been passed.
### Use Python truthiness !

In [None]:
""" [FOCUS] - your code here """
def average(*numbers):
    # your code here
    pass


""" [HELPER] - validate """
print(average(1, 2, 3))
print(average())

### 
### 2 equivalent solutions:

In [None]:
""" [FOCUS] """
def average_v0(*numbers):
    if numbers:
        # note that if list_of_numbers == [], we divide by 0.
        return sum(numbers) / len(numbers)
    return 0.0


""" [HELPER] - validate """    
print(average_v0(1, 2, 3))
print(average_v0())





""" [FOCUS] - one liner """
def average_v1(*numbers):
    return sum(numbers) / len(numbers) if numbers else 0.0
    

""" [HELPER] - validate """
print(average_v1(1, 2, 3))
print(average_v1())

### 
<hr style="border:1px solid blue">

### 
### There are some interesting use cases of Python '_truthiness_' that most people are unaware of.

### 
### Let us look at the following truth table:
### `bool0 or bool1`
### `True or True` => `True`
### `True or False` => `True`
### `False or True` => `True`
### `False or False` => `False`
### Observation: if `bool0 is True` we return `bool0`. Else, we return `bool1`.
### 
***
### 
### Python takes it a step further. Consider `object0 or object1`
### python's behaviour: if `bool(object0) is True: return object0`,
###                     if `bool(object0) is False: return object1`.
### 
### Convince yourself that this also reproduces the expected behaviour when both `object0` and `object1`
### are booleans (obviously, `bool(False) = False` and `bool(True) = True`).
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 4.3</u>:
### Write a function `average(*numbers)` that uses an `or` statement to avoid division by zero if `numbers == ()`.

In [None]:
""" [FOCUS] - your one liner here """
def average(*numbers):
    # ONE LINE of code here.
    pass


""" [HELPER] - validate """
print(average(1, 2, 3))
print(average())

### 
### solution:

In [None]:
""" [FOCUS] - one liner solution """
def average(*numbers):
    return sum(numbers) / (len(numbers) or 1)


""" [HELPER] - validate """
print(average(1, 2, 3))
print(average())

### 
<hr style="border:1px solid blue">

### 
### Python truthiness comes in handy in the following setting:
### Suppose you are trying to find the zero of a function using `scipy.optimize.minimize`.
### 
### The function itself is the integral over $(0, 1)$ of an integrand 
### and for integrating you use `scipy.integrate.fixed_quad`.
### Find $\min_{a_0, \ldots, a_N} \left(\int_{(0, 1)} f(a_0, \ldots, a_N, x) \, \mathrm{d}x \right)^2$
### 
### Now you have two functions that take optional keyword arguments. How do you handle that ?

### Let's look at a crude solution (don't focus on the details, just boilerplate):

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt
from typing import Callable


""" [SETUP] """

integrand = lambda a, x: a * x**2 - 3 * x + 1   # integrand(a, x) = a x^2 - 3x + 1



""" [FOCUS] - why is saying `minimize_kwargs={}` a shit idea ? """

def find_root_of_integral(integrand: Callable, x0: float,  minimize_kwargs=None,
                                                           integrate_kwargs=None) -> np.ndarray:
    
    ### boilerplate
    
    if minimize_kwargs is None:
        minimize_kwargs = {}

    if integrate_kwargs is None:
        integrate_kwargs = {}
        
    ###

    # for given `a`, intergrate over `x = [0, 1]` and square
    f = lambda a: fixed_quad(lambda x: integrand(a, x), 0, 1, **integrate_kwargs)[0]**2

    # minimize over `a`
    return minimize(f, x0=x0, **minimize_kwargs).x[0]



""" [HELPER] - find root of integral and plot """

root = find_root_of_integral(integrand, 0.0, minimize_kwargs={'method': 'SLSQP'},
                                             integrate_kwargs={'n': 2})  # quadratic func, n=2 suffices


print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))


xi = np.linspace(0, 1, 101)
f = integrand(root, xi)

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

### 
## <u>Exercise 4.4</u>:
### Find a solution that substantially reduces boilerplate in the `find_root_of_integral` function.

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt


""" [SETUP] """

integrand = lambda a, x: a * x**2 - 3 * x + 1   # integrand(a, x) = a x^2 - 3x + 1



""" [FOCUS] - your code here """
def find_root_of_integral(integrand: Callable, x0: float,  minimize_kwargs=None,
                                                           integrate_kwargs=None) -> np.ndarray:
    
    ### Your code here
    
    # f = ???
    
    return # ???



""" [HELPER] - find root of integral and plot """

root = find_root_of_integral(integrand, 0.0, minimize_kwargs={'method': 'SLSQP'},
                                             integrate_kwargs={'n': 2})  # quadratic func, n=2 suffices


print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))


xi = np.linspace(0, 1, 101)
f = integrand(root, xi)

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

### 
### solution:

In [None]:
""" [IMPORT] """
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt


""" [SETUP] """

integrand = lambda a, x: a * x**2 - 3 * x + 1   # integrand(a, x) = a x^2 - 3x + 1



""" [FOCUS] - the entire thing can be done in 2 lines of code """
def find_root_of_integral(integrand: Callable, x0: float,  minimize_kwargs=None,
                                                           integrate_kwargs=None) -> np.ndarray:

    # `or` is evaluated before **
    # `**kwargs or {}` is the same as `**(kwargs or {})`
    
    f = lambda a: fixed_quad(lambda x: integrand(a, x), 0, 1, **integrate_kwargs or {})[0]**2
    
    return minimize(f, x0=x0, **minimize_kwargs or {}).x[0]



""" [HELPER] - find root of integral and plot """

root = find_root_of_integral(integrand, 0.0, minimize_kwargs={'method': 'SLSQP'},
                                             integrate_kwargs={'n': 2})  # quadratic func, n=2 suffices


print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))


xi = np.linspace(0, 1, 101)
f = integrand(root, xi)

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

<hr style="border:1px solid blue">

### 
### <u>Final note</u> (for those who are interested):
### There is an equivalent `object0 and object1` which behaves exactly opposite to `object0 or object1`. 
### It is a bit less useful in practice.

### 
### <u>Final brain teaser</u>:
### You remember this (overengineered) code snipped from **Lesson 1** ? Can you explain it now ? 
#### hint: a function that does not return anything, returns None instead. Don't feel bad if it's still confusing ;)

```python
""" [IMPORT] - import additional functionality from Python built-in modules """
from itertools import count
from collections import defaultdict


""" [SETUP] - define utility containers """
elements = [elems0, elems1]
points = [points0, points1]
map_point_index = defaultdict(count().__next__)
seen = set()
new_elems = []


for myelems, mypoints in zip(elements, points):
    for elem in myelems:
        new_elem = [map_point_index[point] for point in map(tuple, mypoints[elem])]

        """ [FOCUS] """
        (sorted_elem := tuple(sorted(new_elem))) in seen or new_elems.append(new_elem) or seen.add(sorted_elem)


""" [SETUP] - convert to `numpy` arrays """
new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))


""" [HELPER] - plot the mesh union, print number of cells and points """
plot_meshes([new_elems], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))
```

### 
### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. All `Python` built-in types can be converted to a `bool`.
### 2. It avoids boilerplate.
### 3. With a little practice, it helps you avoid handling special cases separately
### (one function body handles all inputs equally gracefully).
### 
<hr style="border:1px solid blue">

### 
### We come to this session's most important
# <u>Lesson 5</u>: Advanced iteration and the use of itertools.

In [None]:
%reset -f

### 
### Python's most important iteration feature is the `zip` statement
### (Most of you have seen it).
### 
### What happens if the iterables we iterate over have unequal length ?

In [None]:
""" [IMPORT] """
from itertools import repeat


""" [SETUP] - important """
iter0 = [0, 1, 2, 3, 4, 5]  # length 6
iter1 = range(20, 40)       # consumed after a max of 20 iterations
iter2 = repeat("Connie")    # this guy is repeated an infinite number of times


""" [FOCUS] """
for elem0, elem1, elem2 in zip(iter0, iter1, iter2):
    print(elem0, elem1, elem2)

### 
### <u>We conclude</u>: the `zip(...)` loop terminates after the shortest iterable has been consumed.
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 5.1</u>:
### Given two polynomials represented by their weights
### `pol0 = [a0, a1, ..., aN]`
### `pol1 = [b0, b1, ..., bM]`
### potentially `N != M`.
### Write a function that returns their sum.
### Note that you have to be able to handle the case `len(pol0) != len(pol1)`.
### HINT (in case you didn't know):
### $(a_0 + a_1 x + \ldots + a_N x^N) + (b_0 + b_1 x + \ldots + b_M x^M) = (a_0 + b_0) + (a_1 + b_1) x + \ldots$

In [None]:
""" [SETUP] - define two pols; `pol0` is a list, `pol1` a tuple """

pol0 = [1, 2, 3]     # 1 + 2x + 3x^2
pol1 = (0, 3, 2, 1)  #     3x + 2x^2 + x^3


""" [FOCUS] - your code here """
def add_two_polynomials(pol0, pol1):
    # your code here
    pass


""" [HELPER] - validate - should give [1, 5, 5, 1] """
print(add_two_polynomials(pol0, pol1))

### 
### Two solutions WITHOUT the use of itertools, one **bad**, one **better**:

In [None]:
""" [SETUP] """

pol0 = [1, 2, 3]
pol1 = (0, 3, 2, 1)


""" [FOCUS] - a cumbersome solution """
def add_two_polynomials_bad(pol0, pol1) -> list:
    # we convert both to lists, to handle just one type
    pol0 = list(pol0)
    pol1 = list(pol1)
    
    # make sure `len(pol1) >= len(pol0)`
    if len(pol0) > len(pol1):
        pol1, pol0 = pol0, pol1

    diff = len(pol1) - len(pol0)  # length difference >= 0
    
    # add a tail of zeros. Note that pol0 is a list, we made sure of that
    pol0 = pol0 + [0] * diff
    
    return [a + b for a, b in zip(pol0, pol1)]


""" [FOCUS] - a somewhat more elegant solution (still a bit cumbersome) """
def add_two_polynomials_better(pol0, pol1) -> list:
    # convert to lists, we will learn about `map`, don't worry
    pol0, pol1 = map(list, (pol0, pol1))
    
    diff = len(pol1) - len(pol0)  # take difference in length, can be < 0

    # [0] * n == [] if n <= 0
    pol0.extend([0] * diff)  # using extend is more efficient
    pol1.extend([0] * -diff)

    return [a + b for a, b in zip(pol0, pol1)]



""" [HELPER] - validate both implementations """
print(add_two_polynomials_bad(pol0, pol1))
print(add_two_polynomials_better(pol0, pol1))

### 
<hr style="border:1px solid blue">

### 
### Let us see what this does:

In [None]:
""" [IMPORT] """
from itertools import zip_longest


""" [SETUP] """
pol0 = [1, 2, 3]
pol1 = (0, 3, 2, 1)


""" [FOCUS] """
for elem0, elem1 in zip_longest(pol0, pol1, fillvalue=0):
    print(elem0, elem1)

### 
### I probably won't have to explain why this is handy now ;)
### 
### In the following an implementation of `add_two_polynomials`:

In [None]:
""" [IMPORT] """
from itertools import zip_longest


""" [SETUP] """
pol0 = [1, 2, 3]     # 1 + 2x + 3x^2
pol1 = (0, 3, 2, 1)  #     3x + 2x^2 + x^3


""" [FOCUS] - one liner """
def add_two_polynomials(pol0, pol1):
    return [a + b for a, b in zip_longest(pol0, pol1, fillvalue=0)]


""" [HELPER] - validate """
print(add_two_polynomials(pol0, pol1))

###
### Note that `zip_longest = zip_longest(*iterables, fillvalue=None)`.
### Accepts as many iterables as we want, same as `zip(...)`.
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 5.2</u>:
### Write a function `add_polynomials(*polynomials)` 
### that accepts an arbitrary number of polynomials for addition.

In [None]:
""" [SETUP] - bunch of pols. Some tuple, some list """
pol0 = (1, 2, 3)
pol1 = [0, 3, 2, 1]
pol2 = [-1, 2]
pol3 = [-1, -2, 5, 7, 10]


""" [FOCUS] - your code here """
def add_polynomials(*polynomials):
    # your code here
    pass


""" [HELPER] - validate - should give [-1, 5, 10, 8, 10] """
print(add_polynomials(pol0, pol1, pol2, pol3))

### 
### solution(s):

In [None]:
""" [SETUP] - bunch of pols. Some tuple, some list """
pol0 = (1, 2, 3)
pol1 = [0, 3, 2, 1]
pol2 = [-1, 2]
pol3 = [-1, -2, 5, 7, 10]


""" [FOCUS] - one liner """
def add_polynomials(*polynomials):
    return [sum(weights) for weights in zip_longest(*polynomials, fillvalue=0)]
    

""" [HELPER] - validate """
print(add_polynomials(pol0, pol1, pol2, pol3))




""" [FOCUS] - another one liner, for good measure """
def add_polynomials_map(*polynomials):
    return list(map(sum, zip_longest(*polynomials, fillvalue=0)))


""" [HELPER] - validate """
print(add_polynomials_map(pol0, pol1, pol2, pol3))

### 
<hr style="border:1px solid blue">

### 
### The difference between an ugly implementation and one that is orders of magnitude simpler
### often just amounts to using the right function from the `itertools` module.
### 
<hr style="border:1px solid blue">

### 
### A very useful iterator from the `itertools` module is `itertools.count`.
### 

In [None]:
""" [IMPORT] """
from itertools import count


""" [FOCUS] - the obvious way to do it """
i = 0
while i < 10:
    print(i)
    i += 1


print('\n')
    

""" [FOCUS] - a more pythonic solution """
for i in count():
    print(i)
    if i == 9: break

### 
### <u>Task</u>:
### given a set of unique edges `(i0, i1), (i1, i2), ...`, 
### create a dictionary that assigns an index to each edge.

In [None]:
# we make a few edges first

""" [IMPORT] - pairwise for making a bunch of edges """
from itertools import pairwise


""" [SETUP] - make a few edges and print them """
edges = tuple(map(tuple, pairwise(range(11))))  # the map statement will come in handy later
print("The edges are:", edges)

### 
### The naive solution

In [None]:
""" [SETUP] """
index = 0
map_edge_index = {}


""" [FOCUS] """
for edge in edges:
    map_edge_index[edge] = index
    index += 1


""" [HELPER] - validate """
print(map_edge_index)

### 
### A more pythonic solution

In [None]:
""" [IMPORT] """
from itertools import count


""" [SETUP] """
map_edge_index = {}
counter = count()


""" [FOCUS] """
for edge in edges:
    map_edge_index[edge] = next(counter)  # return current index and increment


""" [HELPER] - validate """
print(map_edge_index)

### 
### The optimal solution

In [None]:
""" [FOCUS] - one line of code ! """
map_edge_index = dict(zip(edges, count()))


""" [HELPER] - validate """
print(map_edge_index)

### 
### (Another option is a dict comprehension)
### 
### Let's break down what happened in the last example:
* We can instantiate a `dict` via `(key, value)` pairs. For instance `dict([(1, 2), (2, 3)]) == {1: 2, 2: 3}`.
* The `zip` statement can be used to produce these pairs.
* We remember from before that `zip` terminates once the shortest iterator has been consumed. 
* Here: `edges` has finite length (unlike count) and is therefore consumed first.
* The `zip(edges, count())` statement produces the pairs `(edge0, 0), (edge1, 1), ...`
### 

<hr style="border:1px solid blue">

### 
### We have also seen the `map(...)` statement several times now, let's see what it does.

In [None]:
# the map statement can be used to apply a function to all inputs


""" [FOCUS] - print the squares of 0...9 using `map(...)` """
for val in map(lambda x: x**2, range(10)):
    print(val)

### 
### The `map(func, arguments)` statement creates an iterator that applies `func` to each `arg in arguments`.
### 
### We take the example with the edges one step further.
### 
## <u>Exercise 5.3</u>:
### You are given the edges as before but this time, they are a **list of lists**.
### Your task is to assign to each edge an index, as before. 
### However, since lists are **mutable** (they can be changed), you are not allowed to use them as keys in a `dict`. 
### 
### Therefore, write a one-liner that uses `dict`, `zip`, `map` and `count` 
### to assign an index to each edge converted to a `tuple`.

In [None]:
""" [IMPORT] """
from itertools import pairwise


""" [SETUP] - make a bunch of list edges """
edges = list(map(list, pairwise(range(11))))


""" [HELPER] - print edges """
print(edges)

### 
### The wrong solution first:

In [None]:
""" [IMPORT] """
from itertools import count


""" [SETUP] """
map_edge_index = {}
counter = count()


""" [FOCUS] """
for edge in edges:
    map_edge_index[tuple(edge)] = next(counter)


""" [HELPER] - validate """
print(map_edge_index)

### 
### do it better:

In [None]:
""" [IMPORT] """
from itertools import count


""" [FOCUS] - your one liner here """
# map_edge_index = ???


""" [HELPER] - validate """
print(map_edge_index)

### solution:

In [None]:
""" [IMPORT] """
from itertools import count


""" [FOCUS] - two equivalent solutions """

map_edge_index = dict(zip(map(tuple, edges), count()))

# or 
map_edge_index = {tuple(edge): i for i, edge in enumerate(edges)}


""" [HELPER] - validate """
print(map_edge_index)

### 
<hr style="border:1px solid blue">

### 
### The `map(...)` statement also accepts more than one input.
### For instance, `map(func, iter0, iter1, ...)` calls `func(object0, object1, ...)` where `object_i` is generated by `iter_i`.
### Example:

In [None]:
""" [IMPORT] """
import numpy as np


""" [HELPER] - roots of quadratic function ax^2 + bx + c """
def abc_formula(a, b, c):
    discriminant = (b ** 2 - 4 * a * c)**.5
    return (-b - discriminant) / (2 * a), (-b + discriminant) / (2 * a)


""" [SETUP] - three arrays `a, b, c` containing `N == 10` entries each """
N = 10
a, b, c = np.abs(np.random.randn(3, N)).astype(np.complex128)


""" [HELPER] - print quadratic functions we're taking the root of """
print("Finding the complex roots of:\n \n{}\n\n".
      format('\n'.join([f'f(x) = {a0:.5g} x^2 + {b0:.5g} x + {c0:.5g}'
                        for a0, b0, c0 in zip(*map(np.real, (a, b, c)))])))


print('The complex roots are: \n')



""" [FOCUS] - pythonic solution for finding all complex roots """

for root0, root1 in map(abc_formula, a, b, c):
    
    # equivalent to: 
    
    # root0, root1 = abc_formula(a[0], b[0], c[0])
    # root0, root1 = abc_formula(a[1], b[1], c[1])
    # root0, root1 = abc_formula(a[2], b[2], c[2])
    # ...
    
    print(f'({root0.real:.5g} + {root0.imag:.5g}j) and ({root1.real:.5g} + {root1.imag:.5g}j)')

### 
<hr style="border:1px solid blue">

### 
### 
### <u>A little intermezzo</u>:
### 
### A highly useful but quite unknown Python feature is the `for-else` clause.
### 
### <u>Task</u>:
### Given a list of words `list_of_words` write a function that finds the first word
### that starts on an `'a'` (case-insensitive) and prints it.
### If no such word is found, raise an error.

In [None]:
""" [HELPER] - custom exception that indicates no word was found """
class NoWordFoundException(Exception):
    pass


""" [SETUP] """
words_with_a = ['Connie', 'is', 'a', 'bear']
words_without_a = ['This', 'is', 'no', 'suitcase', 'curseword']


""" [FOCUS] - bad implementation first """
def find_word_that_starts_with_a_non_pythonic(list_of_words):
    found = False
    
    for word in list_of_words:
        if word != '':  # Can't index into empty string, i.e., ''[0] gives an error.
            if word[0] == 'a' or word[0] == 'A':
                print('Found the word `{}`.'.format(word))
                found = True
                break
                
    if found is False:  # use truthiness ;)
        raise NoWordFoundException("Couldn't find a word that starts with `a` or `A`.")


""" [HELPER] - validate, first one finds a word, second raises an error """
for words in (words_with_a, words_without_a):
    try:
        find_word_that_starts_with_a_non_pythonic(words)
    except NoWordFoundException as ex:
        raise NoWordFoundException from ex

### 
### Now the pythonic implementation:

In [None]:
""" [HELPER] - create custom exception that indicates no word was found """
class NoWordFoundException(Exception):
    pass


""" [SETUP] """
words_with_a = ['Connie', 'is', 'a', 'bear']
words_without_a = ['This', 'is', 'no', 'suitcase', 'curseword']


""" [FOCUS] - the pythonic implementation """
def find_word_that_starts_with_a_pythonic(list_of_words):
    for word in list_of_words:
        if word[:1].lower().startswith('a'):
            print('Found the word `{}`.'.format(word))
            break
    else:  # no break occured
        raise NoWordFoundException("Couldn't find a word that starts with `a` or `A`.")


""" [HELPER] - first one finds a word, second raises an error """
for words in (words_with_a, words_without_a):
    try:
        find_word_that_starts_with_a_pythonic(words)
    except NoWordFoundException as ex:
        raise NoWordFoundException from ex

### 
### The `else` part of the `for-else` clause is only entered if the for loop
### is completed normally (not terminated early).
### **Ways to terminate early**: the use of `break` (break out of the loop) or `return`.
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 5.4</u>:
### Use a `for-else` clause to write a method 
### 
### `quasi_newton(func, x0, maxiter=10, eps=1e-6, tol=1e-5)`
### 
### that uses a quasi-newton method with finite-difference size `eps > 0` 
### to find a function's root and return it.
### 
### If `maxiter` is exceeded, we throw an error.
### 
### **FYI**: $\quad x_{n+1} = x_n - \tfrac{f(x_n)}{f^\prime(x_n)}$, $\quad$ with $\quad f^\prime(x_n) \approx \frac{f(x_n + \varepsilon) - f(x_n)}{\varepsilon}$.

In [None]:
""" [IMPORT] """
from typing import Callable


""" [HELPER] - custom convergence not reached error """
class FailedToConvergeError(Exception):
    pass


""" [SETUP] """
maxiter = 20
func = lambda x: 10 * x**3 - 5 * x**2 + 6 * x  # one root is x == 0


""" [FOCUS] - your code here """
def quasi_newton(func: Callable, x0: Number, maxiter: int = 10, eps: float = 1e-6, tol: float = 1e-5):
    assert (eps := float(eps)) > 0
    assert (tol := float(tol)) > 0
    assert (maxiter := int(maxiter)) > 0
    
    # your code here
    

""" [HELPER] - validate, one should converge, the other one not """
for x0 in (5, 10000):
    try:
        print('Found the root: ', quasi_newton(func, x0, maxiter=maxiter))
    except FailedToConvergeError:
        print('The quasi-Newton method failed to converge after {} iterations.'.format(maxiter))

### 
### solution:

In [None]:
""" [IMPORT] """
from typing import Callable


""" [HELPER] - custom convergence not reached error """
class FailedToConvergeError(Exception):
    pass


""" [SETUP] """
maxiter = 20
func = lambda x: 10 * x**3 - 5 * x**2 + 6 * x  # one root is x == 0


""" [FOCUS] """
def quasi_newton(func: Callable, x0, maxiter: int = 10, eps: float = 1e-6, tol: float = 1e-5):
    assert (eps := float(eps)) > 0
    assert (tol := float(tol)) > 0
    assert (maxiter := int(maxiter)) > 0
    
    x, y = x0, func(x0)
    
    for _ in range(maxiter):
        fprime = (func(x + eps) - y) / eps  # approximate derivative
        x -= y / fprime
        
        if abs(( y := func(x) )) < tol:  # define and use y := func(x) in the same line
            return x
    else:
        raise FailedToConvergeError(f"Failed to converge after {maxiter} iterations.")
        

""" [HELPER] - validate, one should converge, the other one not """
for x0 in (5, 10000):
    try:
        print('Found the root: ', quasi_newton(func, x0, maxiter=maxiter))
    except FailedToConvergeError:
        print('The quasi-Newton method failed to converge after {} iterations.'.format(maxiter))
    except Exception as ex:
        raise Exception from ex

### 
<hr style="border:1px solid blue">

### 
### We have seen how the `itertools` module provides tools 
### for substantially reducing the required number of lines of code.
### Besides that, it also provides the infrastructure for writing code 
### that is more scalable in the number of inputs.
### 
### We will now discuss the `itertools.product` iterator which can make it
### laughably easy to make a function scale to arbitrary dimensionality.

### 
### <u>Task</u>:
### Given a tuple of arbitrary length `N` where each element is itself a tuple of arbitrary length
### containing numbers. 
### Write a function that yields all length $N$ combinations of elements from the tuples.
### For instance, `N == 2`: `(1, 2), (3, 4) -> (1, 3), (1, 4), (2, 3), (2, 4)`
### 

### If we knew the number $N$ in advance, we could approach this problem like so:

In [None]:
""" [SETUP] """
elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4), (7, 8)


""" [FOCUS] - iterators returning the combinations for various `N` """

# N = 0
def generate_combinations0(*elements):
    assert not elements
    yield ()

# N = 1
def generate_combinations1(*elements):
    elems0, = elements
    for elem in elems0:
        yield elem,

# N = 2
def generate_combinations2(*elements):
    assert len(elements) == 2
    for elem0 in elements[0]:
        for elem1 in elements[1]:
            yield (elem0, elem1)

# N = 3
def generate_combinations3(*elements):
    assert len(elements) == 3
    for elem0 in elements[0]:
        for elem1 in elements[1]:
            for elem2 in elements[2]:
                yield (elem0, elem1, elem2)
                
# and so on ...


""" [HELPER] - validate, you can dereference iterators. They'll be consumed. """
print(f"All combinations of {elements0}: \n\n", *generate_combinations0(*elements0), '\n\n')
print(f"All combinations of {elements1}: \n\n", *generate_combinations1(*elements1), '\n\n')
print(f"All combinations of {elements2}: \n\n", *generate_combinations2(*elements2), '\n\n')
print(f"All combinations of {elements3}: \n\n", *generate_combinations3(*elements3), '\n\n')

### 
### We can see how it would be kinda cumbersome to write a separate function
### for handling each number $N$ of input tuples.
### 
### Let us solve this using a (pythonic) C-style implementation that calls itself.

In [None]:
""" [SETUP] """
elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4), (7, 8)


""" [FOCUS] - iterator returning the combinations for any `N` """

# cumbersome but at least pythonic ;)
# you don't need to understand this function (unless you're interested)

def generate_combinations(*elements):
    if len(elements) == 0:
        yield ()
        return
    head, *tail = elements
    for elem0 in head:
        for rest in generate_combinations(*tail):
            yield (elem0,) + rest


""" [HELPER] - validate, you can dereference iterators. They'll be consumed. """
print(f"All combinations of {elements0}: \n\n", *generate_combinations0(*elements0), '\n\n')
print(f"All combinations of {elements1}: \n\n", *generate_combinations1(*elements1), '\n\n')
print(f"All combinations of {elements2}: \n\n", *generate_combinations2(*elements2), '\n\n')
print(f"All combinations of {elements3}: \n\n", *generate_combinations3(*elements3), '\n\n')

### 
### This one is indeed capable of handling all input lengths. However, is it very readable ?
### Maybe this example is still manageable but as soon as you go to more complex stuff, you'll easily lose track.
### 
### Instead, we may generate all combinations using `itertools.product(*elements)`:

In [None]:
""" [IMPORT] """
from itertools import product


""" [SETUP] """
elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4), (7, 8)


""" [FOCUS] + validate """
# we can simply replace generate_combinations -> product
print(f"All combinations of {elements0}: \n\n", *product(*elements0), '\n\n')
print(f"All combinations of {elements1}: \n\n", *product(*elements1), '\n\n')
print(f"All combinations of {elements2}: \n\n", *product(*elements2), '\n\n')
print(f"All combinations of {elements3}: \n\n", *product(*elements3), '\n\n')

### Laughably easy ;-)
### `itertools.product` generates the elements in the order they would be generated by the use
### of nested for loops.
### 
<hr style="border:1px solid blue">


### 
## <u>Exercise 5.5</u>:
### You are given $N$ lists containing strings. Each string in the list represents a basis function (in mathematical typesetting).
### Create a list containing the basis functions resulting from a tensor product of all basis functions.
### 
### Note that if you have `str0` and `str1`, the product is `'{} * {}'.format(str0, str1)` (don't forget the ` * `).

In [None]:
""" [IMPORT] """
import numpy as np


""" [SETUP] - first 5 Fourier basis functions over x0 = [0, 2] """

# 1
# exp(1 * pi * 1j * x0)
# exp(2 * pi * 1j * x0)
# exp(3 * pi * 1j * x0)
# exp(4 * pi * 1j * x0)

fourier_basis0 = ['exp({} * pi * 1j * x0)'.format(n) if n else '1' for n in range(5)]

# ------


""" [SETUP] - First 3 canonical polynomial basis functions over x1 """

# 1
# x1 ** 1
# x1 ** 2

pol_basis1 = ['x1**{}'.format(n) if n else '1' for n in range(3)]

# ------


""" [SETUP] - Gaussian basis functions exp(-(x2 - a)^2) for various a over x2 """

# exp(-(x2 - 0.0)^2)
# exp(-(x2 - 0.5)^2)
# exp(-(x2 - 1.0)^2)

gauss_basis2 = ['exp(-(x2 - {:.5g})**2)'.format(a) for a in np.linspace(0, 1, 3) ]

# ------


""" [HELPER] - print them """
univariate_bases = [fourier_basis0, pol_basis1, gauss_basis2]
for i, basis in enumerate(univariate_bases):
    print("univariate basis number {}: \n\n{}\n\n".format(i, '\n \n'.join(basis)))

### 
### template:

In [None]:
""" [IMPORT] """
from itertools import product


""" [FOCUS] - your one liner here """
# trivariate_basis = ???


print('Trivariate basis function \n')


""" [HELPER] - validate by printing all funcs """
for i, func in enumerate(trivariate_basis):
    print(f'v{i}: {func}', '\n')

### 
### solution:

In [None]:
""" [IMPORT] """
from itertools import product


""" [FOCUS] """
trivariate_basis = list(map(' * '.join, product(*univariate_bases)))


print('Trivariate basis function \n')


""" [HELPER] - validate by printing all funcs """
for i, func in enumerate(trivariate_basis):
    print(f'v{i}: {func}', '\n')

### This is actually useful in the symbolic math library `sympy`:

In [None]:
""" [IMPORT] - sympy stuff """
from sympy import lambdify, symbols


""" [SETUP] - define symbols and pick one trivariate basis function """
x0, x1, x2 = symbols('x0 x1 x2')
func = trivariate_basis[35]


""" [FOCUS] """

# convert string to an actual function
f = lambdify([x0, x1, x2], func)

# we can actually evaluate that function.
print('v35 in (x0, x1, x2) = (1/7, .5, 0): ', f(1/7, .5, 0))

### 
<hr style="border:1px solid blue">

### 
### <u>One last thing</u>:
### Suppose you have a function of the form `f(*args)` and you would like to pass
### dereferenced lists of varying length to `f` in a for loop using the `map(...)` function.

In [None]:
# example

""" [HELPER] - square arbitrary no. of pos. arguments """
def square(*args) -> list:
    return [i**2 for i in args]


""" [SETUP] """
inputs = [1, 2], [1, 2, 3], [0], [7, 8, 9, 10]


""" [FOCUS] """
for squares in map(lambda x: square(*x), inputs):
    print(squares)

### 
### We can use the `itertools.starmap` function which dereferences automatically:

In [None]:
""" [IMPORT] """
from itertools import starmap


""" [FOCUS] """
for squares in starmap(square, inputs):
    print(squares)

### 
<hr style="border:1px solid blue">

### 

# What we have learned:
### 1. `for` / `while` loops have an optional `else` clause ;) (you can use it to impress your friends).
### 2. `itertools` helps you make loops more concise and readable.
### It can make generalisations to arbitrary input numbers laughably easy and is worth checking out.

### 
<hr style="border:1px solid blue">

### 

# <u>Lesson 6</u>: The use of `functools` and decorators


### In this lession, we learn how to modify the behavior of functions without using too much boilerplate.
### 
### We have seen how we can write functions that accept `list`s, `tuple`s maybe even `np.ndarray`s and
### handle them all equally well. 
### 
### This is called **type agnosticism**, i.e., one function with the same name
### that handles all kinds of different input types
### (returning differing output types from that function is considered bad design though).
### It can be very useful.
### 
### A good example is the `add_polynomials` function.

In [None]:
""" [IMPORT] """
from itertools import zip_longest
import numpy as np


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)  # object so they remain python floats, not numpy
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [FOCUS] - np.array, list, tuple are all fine """
def add_polynomials(*polynomials):
    return tuple(map(sum, zip_longest(*polynomials, fillvalue=0)))


""" [HELPER] - validate """
print(add_polynomials(pol0, pol1, pol2))

### 
### In more general cases, it may be useful to handle only one type of input inside of the function.
### 
### Type agnosticism is then easily achieved by converting the inputs to the desired type.
### 
### If an input can't be converted, too bad, we get a runtime error early on. 
### But that's **desired behaviour**.
### 
<hr style="border:1px solid blue">

### 
### For reasons that will become apparent shortly, in many cases it is desirable to convert the input(s)
### of a function to `tuple`s **before** they are passed to the function.
###
### To run the cell below, make sure you have `more_itertools` installed,
### `pip install more_itertools`

In [None]:
""" [IMPORT] """
from more_itertools import convolve  # multiplying two polynomials is the same as convolving their weights
import numpy as np


""" [HELPER] - convert arbitrary number of polynomials to tuples """
def convert_to_tuples(*polynomials):
    """
      [1, 2, 3], (2, 3, 4) -> (1, 2, 3), (2, 3, 4)
      [1, 2, 3], [2, 3, 4] -> (1, 2, 3), (2, 3, 4)
      (1, 2, 3), (2, 3, 4) -> (1, 2, 3), (2, 3, 4)
    """
    return tuple(map(tuple, polynomials))


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [FOCUS] - multiply any number of polynomials, cumbersome implementation """
def multiply_polynomials(*polynomials):
    
    ### special cases first
    
    if not polynomials:  # len == 0
        return ()
    if len(polynomials) == 1:
        return polynomials[0]
        
    ###

    first_two = tuple(convolve(*polynomials[:2]))  # multiply first two pols
    return multiply_polynomials(first_two, *polynomials[2:])  # multiply by the rest


""" [FOCUS] - convert to tuples and multiply """
polynomials = convert_to_tuples(pol0, pol1, pol2)  # we have to convert first (annyoing)
print("The product of all polynomials is given by: \n\n", multiply_polynomials(*polynomials), '\n\n')

### 
<hr style="border:1px solid blue">

### 
### The `multiply_polynomials` function can be significantly simplified using the `functools.reduce` function:
### `reduce(lambda x, y: x + y, [1, 2, 3, 4, 5])` $\implies$ `(((1 + 2) + 3) + 4) + 5`
### 

In [None]:
""" [IMPORT] """
from more_itertools import convolve  # multiplying two polynomials is the same as convolving their weights
from functools import reduce
import numpy as np


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] """
def convert_to_tuples(*polynomials):
    "([1, 2, 3], (2, 3, 4)) -> ((1, 2, 3), (2, 3, 4))"
    return tuple(map(tuple, polynomials))


""" [HELPER] - multiply two polynomials """
def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


""" [FOCUS] - multiply any number of polynomials using reduce """
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)  # apply `_multiply_pols` recursively to entire tuple


""" [FOCUS] - convert to tuples and multiply """
polynomials = convert_to_tuples(pol0, pol1, pol2)
print("The product of all polynomials is given by: \n\n", multiply_polynomials(*polynomials), '\n\n')

### 
<hr style="border:1px solid blue">

### 

### That's better !
### However, it's annoying to always have to pass the tuple conversion into the function.
### Can we do it smarter ?
### Maybe like this ?

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np


""" [SETUP] - define polynomials """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] """
def convert_to_tuples(*polynomials):
    "([1, 2, 3], (2, 3, 4)) -> ((1, 2, 3), (2, 3, 4))"
    return tuple(map(tuple, polynomials))


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))

def _multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


""" [FOCUS] - compose `_multiply_polynomials` with tuple conversion """
multiply_polynomials = lambda *polynomials: _multiply_polynomials(*convert_to_tuples(*polynomials))


""" [HELPER] - we no longer have to take the composition """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### Now we have have created a function that automatically converts the input to tuples
### which is then forwarded to the actual function which then does the work.
### It's better but I still see a lot of boilerplate.
### 
<hr style="border:1px solid blue">

### 
### Let's try something else. How about we define a **function that takes a function**
### and returns a **new function** ?
### Check this out:

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np
from typing import Callable


""" [SETUP] - define polynomials """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] - multiply polynomials """
def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


""" [FOCUS] - take function and create adjusted function """
def convert_input_to_tuples(f: Callable) -> Callable:
    
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))  # convert into tuple and forward to `f`
        
    return converted_function


""" [FOCUS] - overwrite function """
multiply_polynomials = convert_input_to_tuples( multiply_polynomials )


""" [HELPER] - validate """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### 
### The line
### `multiply_polynomials = convert_input_to_tuples(multiply_polynomials)`
### is equivalent to the following **decorator**:
### 

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable


### This is what I call clean code


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] - multiply two polynomials """
def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


""" [HELPER] - take function and create adjusted function """
def convert_input_to_tuples(f: Callable) -> Callable:
    
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))
        
    return converted_function


""" [FOCUS] - 'decorate' functions to achieve same effect """
@convert_input_to_tuples
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)  # can you remove two lines of code using an `and` statement ? ;-)


@convert_input_to_tuples
def other_heavy_computation(*polynomials):
    pass


@convert_input_to_tuples
def yet_another_heavy_computation(*polynomials):
    pass


""" [HELPER] - validate """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### <u> We conclude </u>:
### 
### The code snipped:
```python
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


multiply_polynomials = convert_input_to_tuples( multiply_polynomials )
```
### 
### is equivalent to
```python
@convert_input_to_tuples
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)
```
###
### You tell me which one is cleaner ;-)
### 
<hr style="border:1px solid blue">

### 

## <u>Exercise 6.1</u>:
### Extend the above to not only convert the polynomials to tuples but also sort the tuples
### to make the function agnostic to the order in which arguments are passed.
### This will come in handy very soon.
### 
### Do not write one decorator but two separate ones.
### 
### HINT: The `sorted` function can sort a list / tuple of tuples containing numbers
### `sorted( [(1, 2, 3), (0, 1, 2)] ) = [(0, 1, 2), (1, 2, 3)]`
###
### Template:

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] """
def convert_input_to_tuples(f: Callable) -> Callable:
    
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))
        
    return converted_function


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


""" [FOCUS] - your decorator here """
def sort_input(f: Callable) -> Callable:
    # ???
    pass


""" [FOCUS] - which decorators do we add ? """

# add decorators here
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


""" [HELPER] - validate """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### solution:

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable
from functools import wraps


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] """
def convert_input_to_tuples(f: Callable) -> Callable:
    
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))
        
    return converted_function


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


""" [FOCUS] - decorator that sorts input """
def sort_input(f: Callable) -> Callable:
    
    @wraps(f)
    def wrapper(*polynomials):
        return f(*sorted(polynomials))
    
    return wrapper


""" [FOCUS] """
@convert_input_to_tuples  # first convert to tuples,
@sort_input               # then sort and forward to function
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


""" [HELPER] - validate """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### 
<hr style="border:1px solid blue">

### 
### A popular application of decorators is function caching.
### Suppose you have a function that performs a heavy computation.
### 
### It is possible that the same computation has to be performed more than once.
### It would be nice if the function "remembered" the output given its inputs.
###
### Now we also understand why we converted the inputs to `tuple`. 
### The underlying data structure for caching is a `dict`. A `dict` cannot take `list`s as keys !
### 
### <u>Task</u>: Write a decorator that does exactly that.

In [None]:
""" [IMPORT] """
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable
from functools import wraps


""" [SETUP] """
pol0 = np.array([1.0, 2.0, 4.0, -1.0], dtype=object)
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


""" [HELPER] - convert to tuple decorator + sort input decorator """
def convert_input_to_tuples(f: Callable) -> Callable:

    @wraps(f)  # make sure the new function has the same name, docstring, ... as `f`
    def wrapper(*polynomials):
        return f(*map(tuple, polynomials))

    return wrapper


def sort_input(f: Callable) -> Callable:

    @wraps(f)  # retain `f`'s docstring
    def wrapper(*polynomials):
        return f(*sorted(polynomials))

    return wrapper


""" [HELPER] - multiply two pols """
def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))



""" [FOCUS] - add a 'cache' to a function decorator """
def cache_input(f: Callable) -> Callable:
    f._cache = {}  # add empty dict to f

    @wraps(f)
    def wrapper(*polynomials):
        try:
            return f._cache[polynomials]
            
        except KeyError:  # key not found => compute, remember, return
            return f._cache.setdefault(polynomials, f(*polynomials))

    return wrapper


""" [FOCUS] - combine decorators """

@convert_input_to_tuples  # convert input to tuples
@sort_input               # sort input for commutative operations
@cache_input              # remember input output pairs
def multiply_polynomials(*polynomials):
    
    sleep(3)  # fake heavy computation
    
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)



""" [HELPER] - validate """
print("The product of all polynomials is given by: \n\n", multiply_polynomials(pol0, pol1, pol2), '\n\n')

### 
### This took a few seconds to compute.
### 
### To prove to you that the function remembers its input, we call it again.
### **The output should be immediate**.
### 
### Since we are sorting the input, the order in which we pass the polynomials should not matter.

In [None]:
""" [FOCUS] - call the function with `pol1, pol2, pol0` rather than `pol0, pol1, pol2` """
print( multiply_polynomials(pol1, pol2, pol0) )

### 
<hr style="border:1px solid blue">

### 

### The `cache_input` decorator we wrote is a simplistic representation of
### the `functools.lru_cache` decorator.
### The `functools.lru_cache` decorator optionally takes a `maxsize` keyword argument
### which specifies the number of input / output pairs to memoize.
### 
### As in our self-written decorator, the input arguments need to be **hashable**.
### Example:

In [None]:
""" [IMPORT] - lru_cache """
from functools import lru_cache



""" [FOCUS] - Compute Fibonacci[n] without and with caching """

# Fibonacci: [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, ...]


def fibonacci(n):
    print('Computing the Fibonacci term with n={}'.format(n))
    
    if n < 2:
        return n
        
    return fibonacci(n-2) + fibonacci(n-1)


@lru_cache(maxsize=2)
def fibonacci_cached(n):
    print('Computing the Fibonacci term with n={}'.format(n))
    
    if n < 2:
        return n
        
    return fibonacci_cached(n-2) + fibonacci_cached(n-1)



""" [HELPER] - validate """

print('Computing the eigth term in the Fibonacci sequence without caching. \n')
print('Eight Fibonacci term:', fibonacci(8))

print('\n')

print('Computing the eigth term in the Fibonacci sequence with caching. \n')
print('Eighth Fibonacci term:', fibonacci_cached(8))

### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. Decoration `@decorator` is syntactic sugar for `decorator(fn)`, i.e., a simple function call.
### 2. `Python` functions are `first class citizens`. You can add attributes to them. They can be passed to functions.
### 3. `functools` provides tools for virtually all functional programming concepts `reduce, partial, wraps, ...`
### 4. `caching` is accomplished by adding a `hashmap` (`dict`) to a function.
### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 7</u>:
## The concept of `hashing`.

### 
### We have seen extensive use of `Python` `dictionaries`, so-called hashmaps.
### 
### The question is: 
### 1. what types can we use as keys and 2. what values can we put into a hashmap ?
### <u>Answer to 2. </u>: everything.
### <u>Answer to 1. </u>: all built-in types that are `immutable` (cannot be changed during their lifetime).
### 
### a. Numeric types: `int`, `float`, `complex`, `bool` (`bool` is a subclass of `int`).
### b. Collections: `tuple` (as long as the items are themselves hashable), `frozenset`.
### c. `Nonetype`.
### ... and a few more.
### 
### Characteristic for `immutable` built-ins is that they have a so-called `hash` value.
### 
### A `hash` is an easy-to-compute `int` value we assign to an immutable type.
### Assigning a `hash` value to a type is a science on its own ...
### 
### The `Python` built-in function `hash` does exactly that:

In [None]:
""" [FOCUS] """
print(f'Hash of `5`: {hash(5)} (not surprising).\n')

print(f'Hash of `3.14`: {hash(3.14)}.\n')

print(f'Hash of `3.15`: {hash(3.15)} (outcome is quite different from `hash(3.14)`).\n')

print(f'The hash of (3.14, 5, 6.1): {hash((3.14, 5, 6.1))}. \n')

### 
### The `Python` `hash` function is designed to avoid so-called **hash collisions**
### of assigning the same `hash` to two different types. 
### However, collisions are not completely avoidable:
### 

In [None]:
""" [FOCUS] """

print(f'The hash of `5.0` is the same as the hash of `5` ? {hash(5.0) == hash(5)}.\n')

###
### When we invoke `dict[key]`, here is what happens under the hood (simplified):
### 
### 1. Python computes the `hash` of `key` by calling `hash(key)`.
### 
### 2. The resulting `hash` value is mapped to an index in the underlying array 
### (sometimes called a "bucket array") using a transformation like modulo operation. 
### This index points to where Python expects to find the key-value pair.
### 
### 3. If the array index is empty (no key-value pair stored), a `KeyError` is raised.
### 
### 4. If the index contains a key-value pair, Python checks if the input `key` is 
### equal to the stored key at that index by probing if `input_key == stored_key` is `True`.
### 
### 5. If the keys are equal, the associated value is returned. Otherwise:
###    - If there are multiple keys at the same index (due to a `hash collision`),
###      Python probes further to find the correct key. 
###      If no match is found, a `KeyError` is raised.
###
### Note: Python's `dict` achieves amortized $\mathcal{O}(1)$ lookup time, but hash collisions
### or resizing of the dictionary can occasionally lead to higher computational costs.
###
### What makes a good hash function?
### 
### 1. **Deterministic**: The same input always produces the same `hash`.
### 2. **Fast**: The `hash` should be cheap to compute.
### 3. **Low Collisions**: It should minimize `hash collisions` (different keys mapping to the same index).
###
### Python's `hash` function is robust and efficient for general-purpose use, but you can
### implement custom `__hash__` and `__eq__` (checking equality of two objects) methods if needed.
###

### To see why Python **discourages** the use of `mutable` types in `hashmap`'s,
### here a minimal example:
### 

In [None]:
""" [HELPER] - behaves like a list but also has a hash value """
class HashableList(list):
    def __hash__(self):
        return hash(tuple(self))
    

""" [SETUP] """
hlist = HashableList([1, 2, 3])
hashmap = {}


""" [FOCUS] """

hashmap[hlist] = 5
print("The hashmap BEFORE mutating `hlist`:", hashmap, '\n')

hlist.append(4)
print("The hashmap AFTER mutating `hlist`:", hashmap)

### 
### The output suggests that invoking `hashmap[hlist]`, 
### with `hlist == [1, 2, 3, 4]`, should return `5`.
### 
### Let's see if that's the case:

In [None]:
""" [FOCUS] - try to retrieve the value of key [1, 2, 3, 4] """
hashmap[hlist]

### 
### Bummer !
### 
### Let's try to use the old key `HashableList([1, 2, 3])` in `hashmap`.
### 

In [None]:
""" [FOCUS] - try to retrieve based on the old key """
hashmap[ HashableList([1, 2, 3]) ]

###
### That's very unfortunate !
###
### Can we understand why we get these errors ?
### 
### <u> Here is what happens </u>: 
### 1. `hashmap` maps `hash(hlist)` (where `hlist = [1, 2, 3]`)
### to a bucket containing the `key == [1, 2, 3]`.
### 
### 2. `key == [1, 2, 3]` is mutated in place to `key == [1, 2, 3, 4]`.
### 
### 3. Now, the old hash value `hash([1, 2, 3])` maps to a bucket containing `key == [1, 2, 3, 4]`.
### The hash value no longer matches the mutated object `[1, 2, 3, 4]`
### because it is based on the content `[1, 2, 3]` at the time it was added to `hashmap`.
### 
### 4. Trying to look up `[1, 2, 3, 4]` in `hashmap` raises a `KeyError`
### because `hash([1, 2, 3, 4])` maps to an empty bucket. That's not so surprising.
### 
### 5. Conversely, trying to look up `[1, 2, 3]` in `hashmap` raises a `KeyError` 
### even though the bucket corresponding to `hash([1, 2, 3])` is nonempty.
### This is because the object stored in this bucket is `[1, 2, 3, 4]`, 
### and `[1, 2, 3] != [1, 2, 3, 4]` (equality check fails).
### 
### Et voilà, you've broken your `hashmap` (beyond repair) by using a mutable object as a key.
### 
<hr style="border:1px solid blue">

### 
### Can we still hash `mutable` types ?
### Indirectly, yes, and we have done it before.
### 
### One way to circumvent this issue is by converting a mutable type to an immutable
### one before using it as a `key` in a `hashmap` (or cached function, for that matter).
### For this we need to make a copy of the data.
### 

In [None]:
""" [IMPORT] """
import numpy as np
from typing import Tuple
from functools import lru_cache, wraps
from time import sleep



""" [HELPER] - convert np.ndarray to hashable type (serialize) and undo conversion (deserialize) """
def serialize_array(arr: np.ndarray) -> Tuple[bytes, Tuple[int, ...]]:
    arr = np.asarray(arr, dtype=float)
    return (arr.tobytes(), arr.shape)  # a tuple containing two hashable types


def deserialize_array(serialized_array: Tuple[bytes, Tuple[int, ...]]):
    byte_array, shape = serialized_array
    return np.frombuffer(byte_array).reshape(shape)  # you can uniquely recreate the array from the input


""" [HELPER] - decorators for serializing and deserializing array input(s) """
def serialize_inputs(fn):
    
    @wraps(fn)
    def wrapper(*args):
        return fn(*map(serialize_array, args))
        
    return wrapper


def deserialize_inputs(fn):
    
    @wraps(fn)
    def wrapper(*args):
        return fn(*map(deserialize_array, args))
        
    return wrapper



""" [SETUP] - two arrays """
arr0 = np.arange(3, dtype=float)
arr1 = np.arange(3, 6, dtype=float)



""" [FOCUS] """
@serialize_inputs    # convert to hashable type
@lru_cache           # invoke cashing on the hashable type
@deserialize_inputs  # convert back to numpy array and forward to function
def heavy_computation(arr0: np.ndarray, arr1: np.ndarray) -> np.ndarray:
    print(f'Performing a heavy computation with inputs {arr0} and {arr1} ....\n')
    sleep(1)
    return arr0 ** 3 + arr1 ** 4



""" [HELPER] - validate """
print(f'result: {heavy_computation(arr0, arr1)}.\n\n')

arr0[-1] = 10
print(f'Changed `arr0` inplace, which is now {arr0}.\n')

print(f'result: {heavy_computation(arr0, arr1)}.\n\n')

arr0[-1] = 2
print(f'Changed arr0 back to the original array: {arr0}.\n')

print(f"result: {heavy_computation(arr0, arr1)} (Used cached result, the function didn't print).\n\n")

### Convince yourself that this implementation is `safe`.
### 
### Note that there is a slight overhead resulting from converting
### the arrays back and forth.
### Make sure that your computation is sufficiently heavy for this to be worth it.
### 
<hr style="border:1px solid blue">

### 
### Topics we couldn't cover due to time constraints:
### 1. Use of collections: namedtuple, ChainMap, defaultdict (useful alternative to `dict.setdefault`) <- this is definitely worth checking out.
### 2. dataclass - a nice gimmick. Helpful in some cases.
### 3. Iterating over `numpy.ndarray`'s - not super important because you should avoid loops in numpy ;-)
### 4. Writing functions that can handle several input types (achieving `multiple dispatch`).
### <- the `type agnosticism` we built into some of our functions is similar to traditional `multiple dispatch`. Have a look if you're interested.
### 
<hr style="border:1px solid blue">

### 
### To conclude, a pythonic and a non-pythonic code.
### <u>Pythonic</u>:

In [None]:
# %load mess/mesh_neat.py
import numpy as np
from numpy import newaxis as _
import meshio

from functools import cached_property, wraps


def frozen_cached_property(fn):
  @cached_property
  @wraps(fn)
  def wrapper(self, *args, **kwargs):
    ret = fn(self, *args, **kwargs)
    ret.flags.writeable = False
    return ret
  return wrapper


class Triangulation:

  def __init__(self, mesh):
    assert isinstance(mesh, meshio._mesh.Mesh)
    self.mesh = mesh

  @frozen_cached_property
  def triangles(self):
    return self.mesh.cells_dict['triangle']

  @frozen_cached_property
  def lines(self):
    return self.mesh.cells_dict['line']

  @frozen_cached_property
  def normals(self):
    ts = (self.points[self.lines] * np.array([-1, 1])[_, :, _]).sum(1)
    ns = ts[:, ::-1] * np.array([[1, -1]])
    return ns / np.linalg.norm(ns, ord=2, axis=1)[:, _]

  @frozen_cached_property
  def points(self):
    return self.mesh.points[:, :2]

  def points_iter(self):
    for tri in self.triangles:
      yield self.points[tri]

  @frozen_cached_property
  def BK(self):
    a, b, c = self.points[self.triangles.T]
    return np.stack([b - a, c - a], axis=2)

  @frozen_cached_property
  def detBK(self):
    return np.abs(np.linalg.det(self.BK))

  @frozen_cached_property
  def detBK_boundary(self):
    a, b = self.points[self.lines.T]
    return np.linalg.norm(b - a, ord=2, axis=1)

  @frozen_cached_property
  def BKinv(self):
    (a, b), (c, d) = self.BK.T
    return np.rollaxis(np.stack([[d, -b], [-c, a]], axis=1), -1) / self.detBK[:, _, _]

  @frozen_cached_property
  def boundary_indices(self):
    return np.sort(np.unique(self.lines.ravel()))

### non-pythonic:

In [None]:
# %load mess/mesh_mess.py
import numpy as np
import meshio


class Triangulation:

  def __init__(self, mesh):
    assert isinstance(mesh, meshio._mesh.Mesh)
    self.mesh = mesh

  @property
  def triangles(self):
    if not hasattr(self, '_triangles'):
      self._triangles = self.mesh.cells_dict['triangle']
      self._triangles.flags.writeable = False
      return self._triangles
    else:
      return self._triangles

  @property
  def lines(self):
    if not hasattr(self, '_lines'):
      self._lines = self.mesh.cells_dict['line']
      self._lines.flags.writeable = False
      return self._lines
    else:
      return self._lines

  @property
  def normals(self):
    if not hasattr(self, '_normals'):
      ns = []
      for i in range(len(self.lines)):
        line = self.lines[i]
        t = self.points[line[1]] - self.points[line[0]]
        norm = np.linalg.norm(t)
        ns.append( np.array([t[0], - t[1]]) / norm )
      self._normals = np.array(ns)
      self._normals.flags.writeable = False
      return self._normals
    else:
      return self._normals

  @property
  def points(self):
    if not hasattr(self, '_points'):
      points = self.mesh.points[:, :2]
      points.flags.writeable = False
      self._points = points
      return self._points
    else:
      return self._points

  def points_iter(self):
    for i in range(len(self.triangles)):
      tri = self.triangles[i]
      ret = np.empty((3, 2), dtype=float)
      for j in range(3):
        ret[j, :] = self.points[tri[j], :]
      yield ret

  @property
  def BK(self):
    if not hasattr(self, '_BK'):
      BK = np.empty((len(self.elements, 2, 2)), dtype=float)
      for i in range(len(BK)):
        a = self.points[self.triangles[i, 0]]
        b = self.points[self.triangles[i, 1]]
        c = self.points[self.triangles[i, 2]]
        BK[i, :, 0] = b - a
        BK[i, :, 1] = c - a
      BK.flags.writeable = False
      self._BK = BK
      return self._BK
    else:
      return self._BK

  @property
  def detBK(self):
    if not hasattr(self, '_detBK'):
      BK = self.BK
      self._detBK = np.empty((len(BK),), dtype=float)
      for i in range(len(self._detBK)):
        self._detBK[i] = abs(np.linalg.det(BK[i, :, :]))
      self._detBK.flags.writeable = False
      return self._detBK
    else:
      return self._detBK

  @property
  def BKinv(self):
    if not hasattr(self, '_BKinv'):
      self._BKinv = np.empty((len(self.triangles), 2, 2), dtype=float)
      for i in range(len(self._BKinv)):
        self._BKinv[i, 0, 0] = self.BK[i, 1, 1] / self.detBK[i]
        self._BKinv[i, 0, 1] = -self.BK[i, 1, 0] / self.detBK[i]
        self._BKinv[i, 1, 0] = -self.BK[i, 0, 1] / self.detBK[i]
        self._BKinv[i, 1, 1] = -self.BK[i, 0, 0] / self.detBK[i]
      self._BKinv.flags.writeable = False
      return self._BKinv
    else:
      return self._BKinv

  @property
  def boundary_indices(self):
    if not hasattr(self, '_boundary_indices'):
      self._boundary_indices = np.sort(np.unique(self.lines.reshape((self.lines.shape[0] * self.lines.shape[1],))))
      self._boundary_indices.flags.writeable = False
      return self._boundary_indices
    else:
      return self._boundary_indices

### 
### quite a difference ;-)