# **<u>Part 1</u>:** Idiomatic Python

### 
<img src="img/idiom.png" alt="Drawing" width="1200"/>

### 

### **Idiomatic Python** refers to the use of constructs
### that are characteristic for the Python programming language
### and typically do not exist in the exact same form in other languages.
### 
### **In other words**: idiomatic Python refers to a "pythonic" coding
### style that requires in-depth knowledge of the language's syntax,
### built-in modules and coding conventions.
### 
### Idiomatic Python code follows the conventions and best practices of the Python language. 
### It emphasises **readability**, **simplicity** and achieving **efficient, clear, and concise code.**
### 
### Here an inspirational quote:
### 
### _"Code is read much more often than it is written."_
### 
### - Kenneth Reitz
### 
### 
<hr style="border:1px solid blue">

### 

## <u>Some general remarks about the lessons</u>:
## Try not to focus on the details too much. Try to see the bigger picture.
## If you see something that you think will help you, you can always
## come back to the notebook and try to understand it better !
### 
<hr style="border:1px solid blue">

### 

# <u>Lesson 1</u>: what makes code readable ?

### 

### <u>Task</u>: take the union of two meshes.
### Given two sets of elements and their corresponding points,
### create a union mesh from the two input meshes without duplicate elements and points.

In [None]:
# make two triangular meshes characterised by their element indices and points

import numpy as np

elems0 = np.array([ [0, 1, 4],
                    [4, 1, 5],
                    [1, 2, 5],
                    [5, 2, 6],
                    [2, 3, 6],
                    [6, 3, 7],
                    [4, 5, 8],
                    [8, 5, 9],
                    [5, 6, 9],
                    [9, 6, 10],
                    [6, 7, 10],
                    [10, 7, 11] ])

points0 = np.stack(list(map(np.ravel, np.meshgrid( np.linspace(0, 3, 4),
                                                   np.linspace(0, 2, 3) ))), axis=1)


# same mesh but shifted by +2 in the x direction
elems1 = elems0
points1 = points0 + np.array([[2, 0]])

In [None]:
from matplotlib import pyplot as plt

def plot_meshes(list_of_elements, list_of_points):

    fig, ax = plt.subplots()
    for elems, points in zip(list_of_elements, list_of_points):
        ax.triplot(*points.T, elems, alpha=0.5)

    plt.show()

### 
### we plot the two meshes we just created

In [None]:
plot_meshes([elems0, elems1], [points0, points1])

### 
<hr style="border:1px solid blue">

### 

### The absolute bloody beginner solution

#### (seeing it breaks my heart)

In [None]:
elements = [elems0, elems1]
points = [points0, points1]

map_point_index = {}
index = 0
new_elements = []
new_points = []
seen = set()

for i in range(2):
    myelems = elements[i]
    mypoints = points[i]
    my_new_elems = []
    for j in range(len(myelems)):
        my_new_elem = []
        myelement = myelems[j]
        for k in range(len(myelement)):
            myindex = myelement[k]
            mypoint = tuple(mypoints[myindex])
            if mypoint not in map_point_index:
                map_point_index[mypoint] = index
                index += 1
                new_points.append(mypoint)
            my_new_elem.append(map_point_index[mypoint])
        my_identifier = tuple(sorted(my_new_elem))
        if my_identifier not in seen:
            my_new_elems.append(my_new_elem)
            seen.add(my_identifier)
    new_elements.append(my_new_elems)

new_elements = np.concatenate(new_elements)
new_points = np.array(new_points)

plot_meshes([new_elements], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elements))

### 

### The somewhat more idiomatic solution

#### (still not a good code)

In [None]:
map_point_index = {}
index = 0
new_elements = []
new_points = []
seen = set()

for myelems, mypoints in zip([elems0, elems1], [points0, points1]):
    my_new_elems = []
    for myelement in myelems:
        my_new_elem = []
        for myindex in myelement:
            mypoint = tuple(mypoints[myindex])
            if mypoint not in map_point_index:
                map_point_index[mypoint] = index
                index += 1
                new_points.append(mypoint)
            my_new_elem.append(map_point_index[mypoint])
        my_identifier = tuple(sorted(my_new_elem))
        if my_identifier not in seen:
            my_new_elems.append(my_new_elem)
            seen.add(my_identifier)
    new_elements.append(my_new_elems)

new_elements = np.concatenate(new_elements)
new_points = np.array(new_points)

plot_meshes([new_elements], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elements))

### 

### A good non-numpy solution:

In [None]:
from itertools import count
from collections import defaultdict

### This needs further explanation, I guess ...
map_point_index = defaultdict(count().__next__)
###

seen = set()

new_elems = []
for elems, points in zip([elems0, elems1], [points0, points1]):
    for elem in elems:
        new_elem = [map_point_index[point] for point in map(tuple, points[elem])]
        # Two elements [0, 2, 1] and [2, 1, 0] represent the same element.
        # Sort to avoid adding them twice.
        if ( elem_identifier := tuple(sorted(new_elem)) ) not in seen:
            new_elems.append(new_elem)
            seen.add(elem_identifier)

new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))

plot_meshes([new_elems], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 

### A fancy but somewhat overengineered solution (use it to impress your friends)

In [None]:
from itertools import count
from collections import defaultdict

map_point_index = defaultdict(count().__next__)
seen = set()

new_elems = []
for elems, points in zip([elems0, elems1], [points0, points1]):
    for elem in elems:
        new_elem = [map_point_index[point] for point in map(tuple, points[elem])]
        (identifier := tuple(sorted(new_elem))) in seen or new_elems.append(new_elem) or seen.add(identifier)

new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))

plot_meshes([new_elems], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 
### The numpy solution 
#### (achieves the best readability provided comments are added, a little less memory efficient, completely avoids indentation) 

In [None]:
from itertools import count

# get all unique points of the two sets of points
new_points = np.unique(np.concatenate([points0, points1]), axis=0)

# map each unique point to an index
map_point_index = dict(zip(map(tuple, new_points), count()))

# map both meshes' elements' points to the new index
mapped_elems = np.apply_along_axis(lambda x: map_point_index[tuple(x)],
                                   axis=-1,
                                   arr=np.concatenate([points0[elems0], points1[elems1]]))

# find the indices of the first occurence of the transformed elements that are unique
_, unique_indices = np.unique(np.sort(mapped_elems, axis=1), return_index=True, axis=0)

# keep only the unique occurences
new_elems = mapped_elems[unique_indices]

plot_meshes([new_elems], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))

### 
## What makes python code readable ?

### What we know so far:

* Using idioms (zip, enumerate, itertools, collections, comprehensions, := ) to reduce boilerplate.
* Fewer (but not too few) lines of code.
* Using => => => indentation instead of => <= => <=.
* Descriptive variable names.
* A program flow that almost reads like english.
* Comments and using library functionality (for instance numpy) where possible.

### 
<hr style="border:1px solid blue">

### 
## <u>Some open questions</u>

### 
### <u>Open question</u>: which one is more pythonic ?

In [None]:
def _first_order_derivative(func):
    # return derivative
    pass

# version 0 or version 1 ?

def nth_derivative_v0(func, n=1):
    n = int(n)
    assert n >= 0
    if n == 0:
        return func
    else:
        return nth_derivative_v0(_first_order_derivative(func), n=n-1)
    

def nth_derivative_v1(func, n=1):
    assert (n := int(n)) >= 0
    if n == 0:
        return func
    return nth_derivative_v1(_first_order_derivative(func), n=n-1)

### 

### <u>Open question</u>: which meshgrid call signature do you find more readable ?

In [None]:
import numpy as np

### forget about this part
def make_meshgrid0(dims):
    return np.stack(np.meshgrid(*map(np.arange, dims)), axis=-1).reshape(-1, len(dims))

def make_meshgrid1(*dims):
    return np.stack(np.meshgrid(*map(np.arange, dims)), axis=-1).reshape(-1, len(dims))
###


# THIS PART
# option a or b ?
a = make_meshgrid0([3, 2, 2])

b = make_meshgrid1(3, 2, 2)

# check if they're really the same
print('a equals b: ', (a == b).all())

# both generalise to arbitrary dimensions
a = make_meshgrid0([3, 2])

b = make_meshgrid1(3, 2, 2, 4)

### 
### <u>Open question</u>: which one do you think is more pythonic ?

In [None]:
# Find the greatest divisor of an integer (excluding self) and return it.
# Also, print all other (smaller) divisors, if any.

# greatest divisor excluding `val` itself
def greatest_divisor_v0(val):
    *other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]
    if other_divisors:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
    return greatest_divisor


def greatest_divisor_v1(val):
    divisors = [i for i in range(1, val) if val % i == 0]
    other_divisors, greatest_divisor = divisors[:-1], divisors[-1]
    if len(other_divisors) > 0:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
    return greatest_divisor


print('Greatest divisor of 10: ', greatest_divisor_v0(10), '\n')
print('Greatest divisor of 10: ', greatest_divisor_v1(10))

### 
## What makes a code more pythonic ?
##  (usually, that means more readable)

### <u>Some additional insights</u>:

* Early handling of special cases, avoiding if-else indentation when redundant (example 1).
* Avoiding nested parentheses: `f(a, b, c)` instead of `f([a, b, c])` (example 2).
* using star syntax `*args` where possible (examples 2 + 3).
* `if object` rather than `if (object has property)`. For instance `if other_divisors` instead of `if len(other_divisors) > 0` (example 3).
### 

<hr style="border:1px solid blue">

### 
### Now that we have an idea of what makes a code pythonic, let's move on to the next

# <u>Lesson 2</u>: 
## `*args, **kwargs`, star syntax variable unpacking

In [None]:
%reset -f

### 
### What does this dummy function do ?

In [None]:
# pass whatever arguments into this function and print the input it receives
def dummy_function(*args, **kwargs):
    print('Received the following args and kwargs: \n')
    print('args: ', args, '\n')
    print('kwargs: ', kwargs, '\n \n')
    

# give the dummy function random positional-only, keyword-only and mixed arguments.

### We pass only positional arguments

In [None]:
print('Only positional arguments: \n')
dummy_function(1, 2, 'a', [1, 2, 3])

### Only keyword arguments

In [None]:
print('Only keyword arguments: \n')
dummy_function(a=1, b=2, c='a', d=[1, 2, 3])

### Mixed

In [None]:
print('Both positional and keyword arguments: \n')
dummy_function(1, 'a', (5, 6), a=5, b='a', c=[5, 6])

### 
### <u> Observation</u>:
### for a function of the form `f(*args, **kwargs)`:
### 1.all positional arguments passed are inside of the function available as a tuple called `args`
### 2. all keyword arguments passed are available as a dictionary `kwargs` of the form `{str(keyword): passed_value}`

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 2.1</u>: 
### write a function that takes positional arguments (numbers) and adds them.
### For example:  `f(1, 2) = 3`; `f(4, 5, 6) = 15`; `f(1, 1, 2, 3) = 7`
### By convention: `f() = 0`

In [None]:
#def f( ? ? ? ):
    # your code here
    #return  ? ? ?
    
print(f())
print( f(1, 2) )
print( f(4, 5, 6) )
print( f(1, 1, 2, 3) )

### solution:

In [None]:
def f(*args):
    return sum(args)

print(f())
print( f(1, 2) )
print( f(4, 5, 6) )
print( f(1, 1, 2, 3) )

### 
<hr style="border:1px solid blue">

### 
## <u> Exercise 2.2 </u>: 
### write a function that takes only keyword arguments `subject = (value, weight)`
### and computes the weighted average of all passed subjects
### $\text{weighted-average} = \frac{\sum_i \text{value}_i \, \times \, \text{weight}_i}{\sum_i \text{weight}_i}$

#### `weighted_average(math=(80, 4), english=(60, 2), geography=(50, 1)) = (320 + 120 + 50) / (4 + 2 + 1) = 70`

In [None]:
# def weighted_average(? ? ?):
#     your code here
#     return ???



print(weighted_average(math=(80, 4), english=(60, 2), geography=(50, 1)))

### 
### solution (not very pythonic !!):

In [None]:
# we'll learn how to write this function better, don't worry.
def weighted_average(**kwargs):
    weighted_values = 0
    weights = 0
    # subject: english
    # values_weights: (60, 2)
    for subject, values_weights in kwargs.items():
        weighted_values += values_weights[0] * values_weights[1]
        weights += values_weights[1]
    return weighted_values / weights

print(weighted_average(math=(80, 4), english=(60, 2), geography=(50, 1)))


# more pythonic solution, maybe too difficult right now
def weighted_average_py(**kwargs):
    return sum(a * b for a, b in kwargs.values()) / sum(b for a, b in kwargs.values())

### 
<hr style="border:1px solid blue">

### 
### The `*args, **kwargs` syntax can also be used in the opposite way.
### Observe the following:

In [None]:
def compute_area(length, width, unit='meters'):
    area = length * width
    return f"The area is {area} square {unit}."

kwargs = {'unit': 'kilometers'}

compute_area(2, 5, **kwargs)

### 
### We are allowed to drop a `{str(key): value}`  dictionary into a function that accepts keyword arguments via double dereferencing.
#### If the dict's `.keys()` are the same or a subset of the function's (keyword) arguments, it will work.

#### `test = f(a=5, b=2)` is equivalent to `kwargs = {'a': 5, 'b': 2}` and then `test = f(**kwargs)`

### 
<hr style="border:1px solid blue">

### 
## <u> Application: keyword argument forwarding</u>

![scipykwargs](img/scipykwargs.png)

### 
## <u>Exercise 2.3</u>:
### write a function that uses `scipy.optimize.minimize` to find the argmin of $a x^2 + b x + c$.
### The function must be able to forward relevant keyword arguments to `scipy.optimize.minimize`.
### Find the argmin of the quadratic function using `method='SLSQP'`.

### 
### <u>The wrong way</u>:
#### (doing this will result in capital punishment)

In [None]:
from scipy.optimize import minimize


# uuuuugh, lots of boilerplate >.<
def minimize_quadratic_function(a=1, b=2, c=3, args=(), 
                                               method=None, 
                                               jac=None,
                                               hess=None,
                                               hessp=None,
                                               bounds=None,
                                               constraints=(),
                                               tol=None,
                                               callback=None,
                                               options=None):
    fun = lambda x: a * x ** 2 + b * x + c
    x0 = 0
    return minimize(fun, x0, args=args,
                             method=method,
                             jac=jac,
                             hess=hess,
                             hessp=hessp,
                             bounds=bounds,
                             constraints=constraints,
                             tol=tol,
                             callback=callback,
                             options=options).x[0]

# min 0.5 x**2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(a=.5, b=1, c=1, method='SLSQP'))

### 
### Implement it correctly:

In [None]:
from scipy.optimize import minimize

# def minimize_quadratic_function(a=1, b=2, c=3, ????):
#     fun = lambda x: a * x ** 2 + b * x + c
#     x0 = 0
#     return ???

# min 0.5 x**2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(a=.5, b=1, c=1, method='SLSQP'))

### 
### solution:

In [None]:
from scipy.optimize import minimize

def minimize_quadratic_function(a=1, b=2, c=3, **scipykwargs):
    fun = lambda x: a * x ** 2 + b * x + c
    x0 = 0
    return minimize(fun, x0, **scipykwargs).x[0]

# min 0.5 x**2 + x + 1 is assumed at x = -1
print('The minimum of 0.5 x^2 + x + 1 is assumed at x =',
       minimize_quadratic_function(a=.5, b=1, c=1, method='SLSQP'))    

### 
<hr style="border:1px solid blue">

### 
### Observe the following:

In [None]:
def abc_formula(a, b, c):
    "Return the roots of a quadratic function f(x) = ax^2 + bx + c"
    discriminant = (b ** 2 - 4 * a * c)**.5
    return (-b - discriminant) / (2 * a), (-b + discriminant) / (2 * a)


def abc_monic(*args):
    "Same as abc_formula but a = 1."
    # abc_monic(2, 3) => args = (2, 3)
    # we drop args back into `abc_formula`
    # abc_formula(1, *args) is then the same as abc_formula(1, 2, 3)
    return abc_formula(1, *args)


print('The complex roots of f = x^2 + 3 * x + 1 read are {:.7f} and {:.7f}'.format(*abc_monomic(3, 1)))

### 
#### `test = f(1, 2, 3)` $\Longleftrightarrow$ `args = (1, 2, 3)` and then `test = f(*args)`

#### intuition: the `*` dereferencing ~ *removes the outer parentheses* ~
#### For tuple inputs: `f(*(1, 2, 3))` $\Longleftrightarrow$ `f(1, 2, 3)`
#### same for lists `f(*[1, 2, 3])` $\Longleftrightarrow$ `f(1, 2, 3)`
#### (and the same for any object that can be **iterated** over)

### 
<hr style="border:1px solid blue">

### 
## <u> Exercise 2.4 </u> :
### Given a function `f(*args)` that sums its pos. arguments.
### Write a function `g(head, tail)` that accepts two tuples / lists of numbers and sums all of them.
### `g([1, 2, 3], (4, 5, 6)) => f(1, 2, 3, 4, 5, 6) = 21` (note that one is a list and one is a tuple)

In [None]:
def f(*args):
    return sum(args)

#def g(???):
#    ???

print( g([1, 2, 3], (4, 5, 6)) )

### solution:

In [None]:
def f(*args):
    return sum(args)

#we can also just dereference both head and tail
def g(head, tail):
    # head = [1, 2, 3]
    # tail = (4, 5, 6)
    # => f(*head, *tail) = f(*[1, 2, 3], *(4, 5, 6)) = f(1, 2, 3, 4, 5, 6)
    return f(*head, *tail)

print(g([1, 2, 3], (4, 5, 6)))

### 
<hr style="border:1px solid blue">

### 
### Having gained an intuition for `*args, **kwargs`, we are now in the position to learn **star syntax variable unpacking**

### what does this do ?

In [None]:
tail = (4, 5, 6)

joined_tuple = (1, 2, 3, *tail)

print(joined_tuple)

### 
### $\implies$ We can drop tuples' / lists' contents into tuples / lists via dereferencing
### 
<hr style="border:1px solid blue">

### 
### <u>Further examples</u>:

In [None]:
# multiple list / tuple unpacking
print( [1, 2, *[3, 4, 5], 6, 7, *(8, 9, 10)] )

# we can also dereference ranges into tuples / lists
print( (*[1, 2, 3], *range(4, 11)) )

### 
# <u>Exercise 2.5</u>:
### Looking at the above, how do you think you can merge to disjoint dictionaries ?
### <u>Template</u>:

In [None]:
dict0 = {'a': 0, 'b': 1}
dict1 = {'c': 2, 'd': 3}

merged_dict = ### Your code here
print(f"The merged dict is given by {merged_dict}.\n")

### <u>Solution</u>:

In [None]:
dict0 = {'a': 0, 'b': 1}
dict1 = {'c': 2, 'd': 3}

merged_dict = {**dict0, **dict1}
print(f"The merged dict is given by {merged_dict}.\n")

### 
<hr style="border:1px solid blue">

### 
### What do you guys think this does ?

### `(a, *tail) = [1, 2, 3, 4]`
### What is `a` ? what is `tail` ?
### 
### <u>the answer</u>:

In [None]:
(a, *tail) = [1, 2, 3, 4]

print(a)
print(tail)

### 
### we can forego the outer parentheses (both left and / or right) when no confusion is possible

In [None]:
a, *tail = 1, 2, 3, 4

print(a)
print(tail)

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 2.6</u>
### Given a `tuple` of tuples of the form
### `tot = (a, b, c), (d, e, f)`, 
### unpack `a, d, e, f` into variables of the same name and create a list `bc` containing `b` and `c`
### **IN ONE LINE OF CODE**
### 
### The wrong solution (don't do it like this)

In [None]:
tot = (1, 2, 3), (4, 5, 6)

a, d, e, f = tot[0][0], tot[1][0], tot[1][1], tot[1][2]
bc = [tot[0][1], tot[0][2]]

print(a, bc, d, e, f)

### 
### Implement it correctly:

In [None]:
tot = (1, 2, 3), (4, 5, 6)

### your code here

print(a, bc, d, e, f)

### 
### solution:

In [None]:
tot = (1, 2, 3), (4, 5, 6)

(a, *bc), (d, e, f) = tot

print(a, bc, d, e, f)

### 
<hr style="border:1px solid blue">

### 
### The same kind of unpacking works for numpy arrays

In [None]:
import numpy as np

# array of shape A.shape == (2, 3)
A = np.arange(6).reshape(2, 3)

print('A: \n\n', A, '\n\n')


# unpacking a and b into their consituents again creates numbers (integers in this case)
(a0, a1, a2), (b0, b1, b2) = A
print('(a0, a1, a2, b0, b1, b2): ','({}, {}, {}, {}, {}, {})'.format(a0, a1, a2, b0, b1, b2), '\n')

### 
<hr style="border:1px solid blue">

### 

## <u>Application</u>
### Given a ($2$D) triangular element in $\mathbb{R}^3$ characterised by its vertices `A = [a, b, c]` (a, b, c row vectors, counterclockwise)
### compute the surface area of the triangle.

### 

### <u>the wrong solution</u>:

In [None]:
import numpy as np

def triangle_surface_area(A: np.ndarray) -> float:
    # make sure A has the correct shape
    assert A.shape == (3, 3)
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    
    a, b, c = A[0, :], A[1, :], A[2, :]
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J @ J.T
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant)
    a00, a01, a10, a11 = G[0, 0], G[0, 1], G[1, 0], G[1, 1]
    return .5 * ((a00 * a11 - a01 * a10)**.5)


# a = [0, 0, 0]
# b = [1, 0, 1]
# c = [0, 1, 0]
A = np.array([
              [0, 0, 0],
              [1, 0, 1],
              [0, 1, 0],
             ])

print('The surface area reads: ', triangle_surface_area(A))

### 
### <u>The correct solution</u>:

In [None]:
import numpy as np

def triangle_surface_area(A: np.ndarray) -> float:
    # make sure A has the correct shape
    assert A.shape == (3, 3)
    
    # unpack the rows
    a, b, c = A
    
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J @ J.T
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant)
    (a00, a01), (a10, a11) = G
    return .5 * ((a00 * a11 - a01 * a10)**.5)


# a = [0, 0, 0]
# b = [1, 0, 1]
# c = [0, 1, 0]
A = np.array([
              [0, 0, 0],
              [1, 0, 1],
              [0, 1, 0],
             ])

print('The surface area reads: ', triangle_surface_area(A))

### 
### 
### 
### Why did it not work ???

In [None]:
import numpy as np

# J inside of `surface_area` is of shape J.shape == (3, 2)
J = np.random.randn(3, 2)

print((J @ J.T).shape)

### 
### `G = J @ J.T` has shape (3, 3) which is wrong. It should have been `G = J.T @ J` of shape `(2, 2)`.
### These errors happen easily and go unnoticed for a long long time ...
### Our idiomatic version caught this error because the unpacking implicitly assumed `G` to be of shape `G.shape == (2, 2)`.
### 
### <u>The correct solution</u> (this time for real):

In [None]:
import numpy as np

def triangle_surface_area(A: np.ndarray) -> float:
    # make sure A has the correct shape
    assert A.shape == (3, 3)
    
    # unpack the rows
    a, b, c = A
    
    # create triangle's jacobian J = [b - a; c - a] (column vectors)
    J = np.stack([b - a, c - a], axis=1)
    
    # create the local metric tensor by taking the outer product
    G = J.T @ J
    
    # unpack G's entries to compute the surface area as 1/2 sqrt(determinant)
    (a00, a01), (a10, a11) = G
    return .5 * ((a00 * a11 - a01 * a10)**.5)


A = np.array([
              [0, 0, 0],
              [1, 0, 1],
              [0, 1, 0],
             ])

print('The surface area reads: ', triangle_surface_area(A))

### 
### Note that the non-pythonic implementation did not catch the error and gave the wrong (but reasonable-looking) result `area = 0.5`.
### 
<hr style="border:1px solid blue">

### 
### We can use star syntax in combination with a `zip(...)` statement.
### 
## <u>Task</u>:
### You are given an array of elements and a number of local system matrix iterables.
### Your task is to write a FEM assembly routine.
### Write a routine that iterates simulatenously over the elements and system matrix iterables
### and places their sum at the correct position in the matrix

In [None]:
from scipy.sparse import lil_matrix
import numpy as np

# the mesh defined by its elements (the corresponding points are irrelevant)
elements = np.array([ [0, 1, 4],
                      [4, 1, 5],
                      [1, 2, 5],
                      [5, 2, 6],
                      [2, 3, 6],
                      [6, 3, 7],
                      [4, 5, 8],
                      [8, 5, 9],
                      [5, 6, 9],
                      [9, 6, 10],
                      [6, 7, 10],
                      [10, 7, 11] ])


# a fake local mass matrix iterator
def mass_matrix_iter():
    while True:
        yield 1 + np.abs(np.random.randn(1)) * np.array([ [2, 1, 1],
                                                          [1, 2, 1],
                                                          [1, 1, 2] ])
# a fake stiffness matrix iterator
def stiffness_matrix_iter():
    while True:
        yield 1 + np.abs(np.random.randn(1)) * np.array([ [2, -1, -1],
                                                          [-1, 2, -1],
                                                          [-1, -1, 2] ])
        
A = lil_matrix((elements.max() + 1,)*2)


# iterate simultaneously over the current element and the
# two (or any number for that matter) local system matrix iterators
for tri, *system_matrices in zip(elements, mass_matrix_iter(), stiffness_matrix_iter()):
    
    # add the sum to the right position in the matrix
    A[np.ix_(tri, tri)] = np.add.reduce(system_matrices)
    
    
print(A.todense())

### 
<hr style="border:1px solid blue">

### 
### `*args, **kwargs` and star syntax unpacking are among the most powerful tools for a more readable and maintainable code.
### 
# <u> What we have learned: </u>
### 1. Functions `f(*args)`, `f(**kwargs)` allow for a variable number of arguments.
### 2. Star syntax tends to be more readable (if familiar) by avoiding boilerplate.
### 3. Star syntax can help you write safer code.
### 

### Star syntax can be used for:
* argument forwarding;
* list / tuple packing / unpacking;
* and many more ...

### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 3</u>: Advanced uses of Python dictionaries

In [None]:
%reset -f

### 
### We have seen how dictionaries play an important role in handling keyword arguments.
### In what follows, we discuss some advanced uses of python dictionaries.

### 
### Python dictionaries can be utilised to avoid if-else clauses via tokenization.
### 
### <u>Task</u>:
### Write a function `solve(A, b, method='direct', **solverkwargs)` that solves $A x = b$ for $x$ using
* `method='direct'`: `sparse.linalg.spsolve`
* `method='cg'`: `sparse.linalg.cg`
* `method='gmres':` `sparse.linalg.gmres`
* `method='bicgstab':` `sparse.linalg.bicgstab`

### 
### A straightforward but cumbersome solution:

In [None]:
from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np

# I see a big big big big boilerplate ...
def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    if method == 'direct':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.spsolve(A, b, **solverkwargs)
    elif method == 'cg':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.cg(A, b, **solverkwargs)
    elif method == 'gmres':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.gmres(A, b, **solverkwargs)
    elif method == 'bicgstab':
        print(f'Solving sparse linear system with solver method: `{method}`.')
        return splinalg.bicgstab(A, b, **solverkwargs)
    else:
        raise ValueError(f'Unknown method name {method}.')
        

diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)

# solve with bicgstab and solver tolerance tol=1e-7
print(solve(A, b, method='bicgstab', tol=1e-7))

### 
### A pythonic solution:

In [None]:
from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np

# I see only clean code =)

def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    # Get solver from token using a dict. Return None if token is not found.
    solver = { 'bicgstab': splinalg.bicgstab,
               'direct'  : splinalg.spsolve,
               'gmres'   : splinalg.gmres, 
               'cg'      : splinalg.cg        }.get(method, None)
    
    if solver is None:  # token not found, raise error.
        raise ValueError(f'Unknown method name {method}.')
        
    print(f'Solving sparse linear system with solver method: `{method}`.')

    return solver(A, b, **solverkwargs)
        

diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)

# solve with bicgstab and solver tolerance tol=1e-7
print(solve(A, b, method='bicgstab', tol=1e-7))

### 
### The `dict.get(key, default_value)` method can be used to return a default value in case a key has not been found.
### 
<hr style="border:1px solid blue">

### 
### In the following, we discuss dict.setdefault.
### 
### Let us see what this does:

In [None]:
test = {'a': 5, 'b': 10}

print('test: ', test, '\n')

test.setdefault('a', 20)
print('test after the first setdefault operation: ', test, '\n')

test.setdefault('c', 15)
print('test after the second setdefault operation: ', test, '\n')

### 
### The first `.setdefault` didn't do anything because the key was already contained.
### `dict.setdefault(key, value)` only changes the `dict` if `key not in dict`
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 3.1</u>:
### Using the `solve(A, b, ...)` function, write a function `solve_SPD(A, b, **kwargs)`
### which calls the `solve` function but uses the `'cg'` method unless another method token is passed.

In [None]:
# again, the solve(...) method from above, to be used in your implementation

from scipy.sparse import linalg as splinalg, spmatrix, diags
import numpy as np


def solve(A: spmatrix, b: np.ndarray, method: str = 'direct', **solverkwargs):
    # Get solver from token using a dict. Return None if token is not found.
    solver = { 'bicgstab': splinalg.bicgstab,
               'direct'  : splinalg.spsolve,
               'gmres'   : splinalg.gmres, 
               'cg'      : splinalg.cg        }.get(method, None)
    
    if solver is None:  # token not found, raise error.
        raise ValueError(f'Unknown method name {method}.')
        
    print(f'Solving sparse linear system with solver method: `{method}`.')
    
    return solver(A, b, **solverkwargs)

In [None]:
# Incorrect solution ! this function does not have the correct signature
# And for good reason ! We have to forward method=method in solve(...) explicitly.
# Easy to forget ... especially if you add more deviating default values.
def solve_SPD(A, b, method='cg', **kwargs):
    return solve(A, b, method=method, **kwargs)

### 
### Your solution here:

In [None]:
def solve_SPD(A, b, **kwargs):
    # your code here
    pass


diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)

# solve the system but don't specify the solver, should solve with `cg`
print(solve_SPD(A, b))

# solve with `bicgstab`
print(solve_SPD(A, b, method='bicgstab'))

### 
### Solution:

In [None]:
def solve_SPD(A, b, **kwargs):
    kwargs.setdefault('method', 'cg')
    return solve(A, b, **kwargs)


diagonals = 2 * np.ones(10), -np.ones(9), -np.ones(9)
A = diags(diagonals, [0, -1, 1])
b = np.ones(10)

# solve the system but don't specify the solver
print(solve_SPD(A, b), '\n')

# solve with bicgstab
print(solve_SPD(A, b, method='bicgstab'))

### 
<hr style="border:1px solid blue">

### 
### By invoking `var = dict.setdefault(key, value)`, we can optionally
### capture in `var` whatever the dictionary contains at `key` after the `.setdefault` operation.

In [None]:
test = {'a': 5, 'b': 10}

var = test.setdefault('a', 20)

print(var)

var = test.setdefault('c', 30)

print(var)

### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 3.2</u>:
### you are given a tuple of edges `edges = ((i0, i1), (i2, i3), ...)` representing the edges of a directed graph.
### An edge `edge = (i0, i1)` points from node `i0` to node `i1`. The nodes are not necessarily numbered from `0` to `N` but can be anything.
### Write a code that counts the number of unique edges incident to each node. Some edges may be duplicated in `edges`.
### **Remember**: you have no idea what indices the nodes have.

In [None]:
edges = (1, 12), (4, 12), (4, 1), (13, 6), \
        (12, 8), (8, 4), (6, 8), (1, 12), \
        (6, 4), (12, 8), (13, 6), (3, 13), (3, 4), (3, 1)

In [None]:
# let's draw the graph first.
# you need networkx installed `pip install networkx`
import networkx as nx
from matplotlib import pyplot as plt

G = nx.DiGraph()
G.add_edges_from(edges)

pos = nx.planar_layout(G)

nx.draw_networkx(G, arrows=True, pos=pos)
plt.show()

### 
### A crude (non-pythonic) solution:

In [None]:
map_node_root_vertex = {}

for edge in edges:
    # Make sure your implementation does NOT do this.
    v0 = edge[0]  # get the root node
    v1 = edge[1]  # get the incident node
    if v1 not in map_node_root_vertex:
        map_node_root_vertex[v1] = set()
    map_node_root_vertex[v1].add(v0)  # add the root vertex to the set of nodes incident to i1

# the number of incident edges to `node` is the number of unique roots pointing to `node`
map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
### Your solution using dict.setdefault

In [None]:
map_node_root_vertex = {}

# your code here

map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
### solution:

In [None]:
map_node_root_vertex = {}

# .setdefault(key, default_value) returns whatever is at dict[key] after the .setdefault operation.
# => .setdefault(i1, set()) returns either an existing or a new set that we immediately add the root vertex to.
for v0, v1 in edges:
    map_node_root_vertex.setdefault(v1, set()).add(v0)
    
map_node_nedges = {node: len(root_verts) for node, root_verts in map_node_root_vertex.items()}

print(map_node_nedges)

### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. In `Python`, dictionaries play a central role. Advanced use of them is considered **good practice**.
### 2. Making it a habit to use methods like `dict.setdefault` and 
### `dict.get(item, default_value)` will make your code safer and more readable.
### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 4</u>: Python _"truthiness"_ 

In [None]:
%reset -f

### 
### ALL Python types (`dict`, `set`, `list`, ..., custom types) can be converted into a boolean.
### The Python developers have chosen intuitive rules as to whether a built-in type should convert
### into `True` or `False`.
### Let's see if our intiution is correct.

### 
## <u>Exercise 4.1</u>: Does `bool(var)` convert to `True` or `False` ?

In [None]:
# what do you think ? (you'll get most of them right)

variables = [ 
              True,          # I guess we can all agree what this one is converted into ;)
              False,         # idem
              None,          # untyped null pointer
              {},            # empty dict
              {1: 2},        # nonempty dict
              [],            # empty list
              [1, 2],        # nonempty list
              tuple(),       # empty tuple
              (1, 2),        # nonempty tuple
              0,             # zero integer
              0.0,           # zero float
              1.0,           # positive float
              -1.,           # negative integer
              "",            # empty string
              "Connie",      # nonempty string
            ]


for var in variables:
    print('var: {}, bool(var): {}'.format(var, bool(var)), '\n')

### 
<hr style="border:1px solid blue">


### 
### Now that we have an idea of how various variable types transform into a boolean,
### we are in the position to understand the `if other_divisors:` statement from **Lesson 1**.

In [None]:
# Find the greatest divisor of an integer (excluding self) and return it.
# Also, print all other (smaller) divisors, if any.

def greatest_divisor_excluding_self(val):
    """ excluding `val` itself """
    *other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]
    if other_divisors:
        print("For {}, I also found the divisors {}.".format(val, other_divisors))
    return greatest_divisor

print('The greatest divisor (excl self) of 10 equals:', greatest_divisor_excluding_self(10))

### 
### First, we break down the line `*other_divisors, greatest_divisor = [i for i in range(1, val) if val % i == 0]`
* the rhs creates a monotone increasing list with at least one element `1`.
* the left hand side peels off the last (and hence largest) element in that list, while collecting all remaining divisors in the list `other_divisors`
* if the only divisor is `[1]`, `other_divisors == []`, else `other_divisors != []`.

### Then, the line `if other_divisors:`
* in the next line `if other_divisors:`, Python, under the hood, converts this line to `if bool(other_divisors):`
* from before, we know that if `other_divisors == []`, `if other_divisors:` becomes `if False:` and the if clause is ignored.
* On the other hand, if `other_divisors != []` (nonempty), then `if other_divisors:` converts to `if True:` and the if clause's code is executed.
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 4.2</u>:
### write a function `average(*numbers)` that computes the average of the numbers in `numbers` and returns `0.0` if no numbers have been passed.
### Use Python truthiness !

In [None]:
def average(*numbers):
    # your code here
    pass


print(average(1, 2, 3))
print(average())

### 
### 2 equivalent solutions:

In [None]:
def average_v0(*numbers):
    if numbers:
        # note that if list_of_numbers == [], we divide by 0.
        return sum(numbers) / len(numbers)
    return 0.0
    
    
print(average_v0(1, 2, 3))
print(average_v0())

# one liner
def average_v1(*numbers):
    return sum(numbers) / len(numbers) if numbers else 0.0
    
    
print(average_v1(1, 2, 3))
print(average_v1())

### 
<hr style="border:1px solid blue">

### 
### There are some interesting use cases of Python '_truthiness_' that most people are unaware of.

### 
### Let us look at the following truth table:
### `bool0 or bool1`
### `True or True` => `True`
### `True or False` => `True`
### `False or True` => `True`
### `False or False` => `False`
### Observation: if `bool0 is True` we return `bool0`. Else, we return `bool1`.
***
### Python takes it a step further. Consider `object0 or object1`
### python's behaviour: if `bool(object0) is True: return object0`,
###                     if `bool(object0) is False: return object1`.
#### Convince yourself that this also reproduces the expected behaviour when both `object0` and `object1` are booleans (obviously, `bool(False) = False` and `bool(True) = True`).
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 4.3</u>:
### Write a function `average(*numbers)` that uses an `or` statement to avoid division by zero if `numbers == []`.

In [None]:
def average(*numbers):
    # ONE LINE of code here.
    pass


print(average(1, 2, 3))
print(average())

### 
### solution:

In [None]:
def average(*numbers):
    return sum(numbers) / (len(numbers) or 1)


print(average(1, 2, 3))
print(average())

### 
<hr style="border:1px solid blue">

### 
### Python truthiness comes in handy in the following setting:
### Suppose you are trying to find the zero of a function using `scipy.optimize.minimize`.
### 
### The function itself is the integral over $(0, 1)$ of an integrand 
### and for integrating you use `scipy.integrate.fixed_quad`.
### Find $\min_{a_0, \ldots, a_N} \left(\int_{(0, 1)} f(a_0, \ldots, a_N, x) \, \mathrm{d}x \right)^2$
### 
### Now you have two functions that take optional keyword arguments. How do you handle that ?

### Let's look at a crude solution (don't focus on the details, just boilerplate):

In [None]:
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt
from typing import Callable


# why is saying minimize_kwargs={} a shit idea ?
def find_root_of_integral(integrand: Callable, x0: np.ndarray,  minimize_kwargs=None,
                                                               integrate_kwargs=None) -> np.ndarray:
    
    ### big boilerplate
    if minimize_kwargs is None:  # root_kwargs not passed => set to empty dict
        minimize_kwargs = {}

    if integrate_kwargs is None:  # idem
        integrate_kwargs = {}
    ###
    
    f = lambda a: fixed_quad(lambda x: integrand(a, x), 0, 1, **integrate_kwargs)[0]**2
    
    return minimize(f, x0=x0, **minimize_kwargs).x[0]


# find the parameter `a` such that a x^2 -3 x + 1 integrates to 0 over [0, 1]
integrand = lambda a, x: a * x**2 - 3 * x + 1

# the function is quadratic so gaussian integration of n=2 suffices.
root = find_root_of_integral(integrand, 0.0, minimize_kwargs={'method': 'SLSQP'},
                                             integrate_kwargs={'n': 2})

# print the root 
print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))

# plot
xi = np.linspace(0, 1, 101)
f = [integrand(root, x) for x in xi]

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

### 
## <u>Exercise 4.4</u>:
### Find a solution that substantially reduces boilerplate in the `find_root_of_integral` function.

In [None]:
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt


def find_root_of_integral(integrand: Callable, x0: np.ndarray,  minimize_kwargs=None,
                                                               integrate_kwargs=None) -> np.ndarray:
    
    ### Your code here
    
    # f = ???
    
    return # ???


# find the parameter `a` such that a x^2 -3 x + 1 integrates to 0 over [0, 1]
integrand = lambda a, x: a * x**2 - 3 * x + 1

# the function is quadratic so gaussian integration of n=2 suffices.
root = find_root_of_integral(integrand, 0, minimize_kwargs={'method': 'SLSQP'}, integrate_kwargs={'n': 2})

# print the root 
print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))

# plot
xi = np.linspace(0, 1, 101)
f = [integrand(root, x) for x in xi]

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

### 
### solution:

In [None]:
from scipy.optimize import minimize
from scipy.integrate import fixed_quad
import numpy as np
from matplotlib import pyplot as plt


# The entire thing can be done in 2 lines of code.
def find_root_of_integral(integrand: Callable, x0: np.ndarray,  minimize_kwargs=None,
                                                               integrate_kwargs=None) -> np.ndarray:
    
    f = lambda a: fixed_quad(lambda x: integrand(a, x), 0, 1, **integrate_kwargs or {})[0]**2
    
    # `or` is evaluated before ** so we can do **kwargs or {} instead of **(kwargs or {})
    return minimize(f, x0=x0, **minimize_kwargs or {}).x[0]


# find the parameter `a` such that a x^2 -3 x + 1 integrates to 0 over [0, 1]
integrand = lambda a, x: a * x**2 - 3 * x + 1

# the function is quadratic so gaussian integration of n=2 suffices.
root = find_root_of_integral(integrand, 0, minimize_kwargs={'method': 'SLSQP'}, integrate_kwargs={'n': 2})

# print the root 
print("The function's integral with a={} evaluates to {}.".format(root, fixed_quad(lambda x: integrand(root, x), *[0, 1], n=2)[0]))

# plot
xi = np.linspace(0, 1, 101)
f = [integrand(root, x) for x in xi]

plt.plot(xi, f)
plt.plot(xi, np.zeros_like(xi), '--', c='b')
plt.show()

<hr style="border:1px solid blue">

### 
### <u>Final note</u> (for those who are interested):
### There is an equivalent `object0 and object1` which behaves exactly opposite to `object0 or object1`. 
### It is a bit less useful in practice.

### 
### <u>Final brain teaser</u>:
### You remember this code snipped from **Lesson 1** ? Can you explain it now ? 
#### hint: a function that does not return anything, returns None instead. Don't feel bad if it's still confusing ;)

```python
from itertools import count
from collections import defaultdict

map_point_index = defaultdict(count().__next__)
seen = set()

new_elems = []
for elems, points in zip([elems0, elems1], [points0, points1]):
    for elem in elems:
        new_elem = [map_point_index[point] for point in map(tuple, points[elem])]
        ### Try to explain this line
        (identifier := tuple(sorted(new_elem))) in seen \
        or seen.add(identifier) or new_elems.append(new_elem)

new_elems = np.array(new_elems)
new_points = np.stack(list(map_point_index.keys()))

plot_meshes([new_elems], [new_points])

print('Number of points: ', len(new_points))
print('Number of elements: ', len(new_elems))
```

### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. All `Python` built-in types can be converted to a `bool`.
### 2. It avoids boilerplate.
### 3. With a little practice, it helps you avoid handling special cases separately
### (one function body handles all inputs equally gracefully).
### 
<hr style="border:1px solid blue">

### 
### We come to this session's most important
# <u>Lesson 5</u>: Advanced iteration and the use of itertools.

In [None]:
%reset -f

### 
### Python's most important iteration feature is the `zip` statement
### (Most of you have seen it).
### 
### What happens if the iterables we iterate over have unequal length ?

In [None]:
from itertools import repeat

iter0 = [0, 1, 2, 3, 4, 5]  # length 6
iter1 = range(20, 40)  # consumed after a max of 20 iterations
iter2 = repeat("Connie")  # this guy is repeated an infinite number of times

for elem0, elem1, elem2 in zip(iter0, iter1, iter2):
    print(elem0, elem1, elem2)

### 
### <u>We conclude</u>: the `zip(...)` loop terminates after the shortest iterable has been consumed.
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 5.1</u>:
### Given two polynomials represented by their weights
### `pol0 = [a0, a1, ..., aN]`
### `pol1 = [b0, b1, ..., bM]`
### potentially `N != M`.
### Write a function that returns their sum.
### Note that you have to be able to handle the case `len(pol0) != len(pol1)`.
### Hint: (in case you didn't know) $(a_0 + a_1 x + \ldots + a_N x^N) + (b_0 + b_1 x + \ldots + b_M x^M) = (a_0 + b_0) + (a_1 + b_1) x + \ldots$

In [None]:
def add_two_polynomials(pol0, pol1):
    # your code here
    pass

# 1 + 2 x + 3 x^2
pol0 = [1, 2, 3]

# 3 x + 2 x^2 + x^3
pol1 = (0, 3, 2, 1)  # >>note that pol1 is a tuple, not a list<<

# should be [1, 5, 5, 1]
print(add_two_polynomials(pol0, pol1))

### 
### solution (gosh, it's ugly, but it gets the job done):

In [None]:
def add_two_polynomials(pol0, pol1) -> list:
    # we convert both to lists, just to be sure
    pol0 = list(pol0)
    pol1 = list(pol1)
    # if pol0 is longer than pol1, reverse order
    if len(pol0) > len(pol1):
        pol1, pol0 = pol0, pol1
    # from here on out, we assume len(pol0) <= len(pol1)
    diff = len(pol1) - len(pol0)
    
    # add a tail of zeros. Note that pol0 is a list, we made sure of that
    pol0 = pol0 + [0] * diff
    return [a + b for a, b in zip(pol0, pol1)]

# 1 + 2 x + 3 x^2
pol0 = [1, 2, 3]

# 3 x + 2 x^2 + x^3, note that pol1 is a tuple, not a list
pol1 = [0, 3, 2, 1]

print(add_two_polynomials(pol0, pol1))

### 
<hr style="border:1px solid blue">

### 
### Let us see what this does:

In [None]:
from itertools import zip_longest

pol0 = [1, 2, 3]
pol1 = (0, 3, 2, 1)

for elem0, elem1 in zip_longest(pol0, pol1, fillvalue=0):
    print(elem0, elem1)

### I probably won't have to explain why this is handy now ;)
#### Note that `zip_longest = zip_longest(*iterables, fillvalue=None)` 
#### (accepts as many iterables as we want, same as `zip(...)`)
### 

<hr style="border:1px solid blue">

### 
## <u>Exercise 5.2</u>:
### Rewrite the sum of polynomials using zip_longest.
### To challenge yourself, write a function `add_polynomials(*polynomials)` that accepts an arbitrary number of polynomials for addition.

In [None]:
def add_polynomials(*polynomials):
    # your code here
    pass

# note that some are lists, others tuples ...
pol0 = (1, 2, 3)
pol1 = [0, 3, 2, 1]
pol2 = [-1, 2]
pol3 = [-1, -2, 5, 7, 10]

# should give [-1, 5, 10, 8, 10]
print(add_polynomials(pol0, pol1, pol2, pol3))

### 
### solution(s):

In [None]:
# one line of code
def add_polynomials(*polynomials):
    return [sum(weights) for weights in zip_longest(*polynomials, fillvalue=0)]

# note that some are lists, others tuples ...
pol0 = (1, 2, 3)
pol1 = [0, 3, 2, 1]
pol2 = [-1, 2]
pol3 = [-1, -2, 5, 7, 10]

# should give [-1, 5, 10, 8, 10]
print(add_polynomials(pol0, pol1, pol2, pol3))


# for good measure
def add_polynomials_map(*polynomials):
    return list(map(sum, zip_longest(*polynomials, fillvalue=0)))

# should give [-1, 5, 10, 8, 10]
print(add_polynomials_map(pol0, pol1, pol2, pol3))

### 
<hr style="border:1px solid blue">

### 
### The difference between an ugly implementation and one that is orders of magnitude simpler
### often just amounts to using the right function from the `itertools` module.
### 
<hr style="border:1px solid blue">

### 
### A very useful iterator from the `itertools` module is `itertools.count`.
### 

In [None]:
from itertools import count

# standard implementation
i = 0
while i < 10:
    # do something
    print(i)
    i += 1
    

print('\n')
    

# A more pythonic solution:
for i in count():
    # do something
    print(i)
    if i == 9: break

### 
### <u>Task</u>:
### given a set of unique edges `(i0, i1), (i1, i2), ...`, 
### create a dictionary that assigns an index to each edge.

In [1]:
# we make a few edges first

from itertools import pairwise

edges = tuple(map(tuple, pairwise(range(11))))  # the map statement will come in handy later
print(edges)

((0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10))


### 
### The naive solution

In [None]:
index = 0
map_edge_index = {}

for edge in edges:
    map_edge_index[edge] = index
    index += 1
    
print(map_edge_index)

### 
### A more pythonic solution

In [None]:
from itertools import count
map_edge_index = {}
counter = count()

for edge in edges:
    map_edge_index[edge] = next(counter)  # return the current index, starting at 0 and immediately increase it by one
    
print(map_edge_index)

### 
### The optimal solution

In [None]:
from itertools import count

# one line of code !
map_edge_index = dict(zip(edges, count()))

print(map_edge_index)

### (Another option is a dict comprehension)
### 
### Let's break down what happened in the last example:
* We can instantiate a `dict` via `(key, value)` pairs. For instance `dict([(1, 2), (2, 3)]) == {1: 2, 2: 3}`.
* The `zip` statement can be used to produce these pairs.
* We remember from before that `zip` terminates once the first iterator has been consumed. 
* Here: `edges` has finite length (unlike count) and is therefore consumed first.
* The `zip(edges, count())` statement produces the pairs `(edge0, 0), (edge1, 1), ...`
### 

<hr style="border:1px solid blue">

### 
### We have also seen the `map(...)` statement several times now, let's see what it does.

In [None]:
# the map statement can be used to apply a function to all inputs

# print the squares of 0...9 using `map(...)`.
for val in map(lambda x: x**2, range(10)):
    print(val)

### The `map(func, arguments)` statement creates an iterator that applies `func` to each `arg in arguments`.
### 
## <u>Exercise 5.3</u>:
### Suppose you are given the edges as before but this time, they are given as a list of lists.
### Your task is to assign to each edge an index, as before. However, since lists are not immutable (they can be changed),
### you are not allowed to use them as keys in a dictionary. Therefore, write a one-liner that uses dict, zip, map and count
### to assign an index to each edge converted to a tuple.

In [None]:
# make a bunch of edges
from itertools import pairwise
edges = list(map(list, pairwise(range(11))))

print(edges)

### 
### The wrong solution first:

In [None]:
from itertools import count

map_edge_index = {}
counter = count()

for edge in edges:
    map_edge_index[tuple(edge)] = next(counter)
    
print(map_edge_index)

### 
### do it better:

In [None]:
from itertools import count

# your one liner here
# map_edge_index = ???

print(map_edge_index)

### solution:

In [None]:
from itertools import count

map_edge_index = dict(zip(map(tuple, edges), count()))
# or 
map_edge_index = {tuple(edge): i for i, edge in enumerate(edges)}

print(map_edge_index)

### 
<hr style="border:1px solid blue">

### 
### The `map(...)` statement also accepts more than one input.
### For instance, `map(func, iter0, iter1, ...)` calls `func(object0, object1, ...)` where `object_i` is generated by `iter_i`.
### Example:

In [None]:
import numpy as np

def abc_formula(a, b, c):
    "Return the roots of a quadratic equation of the form ax^2 + bx + c"
    discriminant = (b ** 2 - 4 * a * c)**.5
    return (-b - discriminant) / (2 * a), (-b + discriminant) / (2 * a)

N = 10

# Three random complex arrays containing 10 entries each. 
a, b, c = np.abs(np.random.randn(3, 10)).astype(np.complex_)

print("Finding the complex roots of: \n{}\n".
      format('\n'.join([f'{i}: {a0:.5g} x^2 + {b0:.5g} x + {c0:.5g}'
                        for i, (a0, b0, c0) in enumerate(zip(*map(np.real, (a, b, c))))])))

print('The complex roots are: \n')


# pythonic solution:
for i, (root0, root1) in enumerate( map(abc_formula, a, b, c) ):
    # equivalent to abc_formulat(a[0], b[0], c[0])
    # abc_formula(a[1], b[1], c[1])
    # ...
    print(f"{i}:", f'({root0.real:.5g} + {root0.imag:.5g}j) and ({root1.real:.5g} + {root1.imag:.5g}j)')

### 
<hr style="border:1px solid blue">

### 
### 
### <u>A little intermezzo</u>:
### 
### A highly useful but quite unknown Python feature is the `for-else` clause.
### 
### <u>Task</u>:
### Given a list of words `list_of_words` write a function that finds the first word
### that starts on an `'a'` (case-insensitive) and prints it.
### If no such word is found, raise an error.

In [None]:
# we create our own exception that indicates no word was found
class NoWordFoundException(Exception):
    pass


# first we do a very bad implementation
def find_word_that_starts_with_a_non_pythonic(list_of_words):
    """Given a iterator `list_of_words` containing strings, print the first
       word that starts with an `a` or raise an Error if not found. Case-insensitive."""
    found = False
    for word in list_of_words:
        if word != '':  # if empty string, I can't index into the string, i.e., ''[0] gives an error.
            if word[0] == 'a' or word[0] == 'A':
                print('Found the word `{}`.'.format(word))
                found = True
                break
    if found is False:  # use truthiness ;)
        raise NoWordFoundException("Couldn't find a word that starts with `a`.")


words_with_a = ['Connie', 'is', 'a', 'bear']
words_without_a = ['This', 'is', 'no', 'suitcase', 'curseword']


for words in (words_with_a, words_without_a):
    try:
        find_word_that_starts_with_a_non_pythonic(words)
    except NoWordFoundException:
        print('The list of words {} contains no word that starts on an `a`.'.format(words))

### Now the pythonic implementation

In [None]:
# now the pythonic implementation

class NoWordFoundException(Exception):
    pass


def find_word_that_starts_with_a_pythonic(list_of_words):
    for word in list_of_words:
        if word[:1].lower().startswith('a'):
            print('Found the word `{}`.'.format(word))
            break
    else:  # no break occured
        raise NoWordFoundException("Couldn't find a word that starts with `a`.")
        
        
words_with_a = ['Connie', 'is', 'a', 'bear']
words_without_a = ['This', 'is', 'no', 'suitcase', 'curseword']

for words in (words_with_a, words_without_a):
    try:
        find_word_that_starts_with_a_pythonic(words)
    except NoWordFoundException:
        print('The list of words {} contains no word that starts on an `a`.'.format(words))

### 
### The `else` part of the `for-else` clause is only entered if the for loop
### is completed (not terminated early)
### **Ways to terminate early**: the use of `break` (break out of the loop) or `return`.
### 
<hr style="border:1px solid blue">

### 
## <u>Exercise 5.4</u>:
### Use the for-else construct to write a method `quasi_newton(func, x0, maxiter=10, h=1e-6, tol=1e-5)` that
### uses a quasi-newton method with finite-difference size `eps > 0` to find a function's root and return it.
### If `maxiter` is exceeded, we throw an error.
### **FYI**: $\quad x_{n+1} = x_n - \tfrac{f(x_n)}{f^\prime(x_n)}$, $\quad$ with $\quad f^\prime(x_n) \approx \frac{f(x_n + \varepsilon) - f(x_n)}{\varepsilon}$.

In [None]:
from typing import Callable
from numbers import Number


# throw this error if convergence is not reached
class FailedToConvergeError(Exception):
    pass


def quasi_newton(func: Callable, x0: Number, maxiter: int = 10, eps: float = 1e-6, tol: float = 1e-5):
    assert (eps := float(eps)) > 0
    assert (tol := float(tol)) > 0
    assert (maxiter := int(maxiter)) > 0
    
    # your code here
    
    
maxiter = 20
    
# clearly, one root of   f(x) = 10x^3 -5x^2 + 6x   is x = 0
func = lambda x: 10 * x**3 - 5 * x**2 + 6 * x

for x0 in (5, 10000):
    try:
        print('Found the root: ', quasi_newton(func, x0, maxiter=maxiter))
    except FailedToConvergeError:
        print('The quasi-Newton method failed to converge after {} iterations.'.format(maxiter))
    except Exception as ex:
        raise Exception from ex

### solution:

In [None]:
from typing import Callable


# throw this error if convergence is not reached
class FailedToConvergeError(Exception):
    pass


def quasi_newton(func: Callable, x0, maxiter: int = 10, eps: float = 1e-6, tol: float = 1e-5):
    assert (eps := float(eps)) > 0
    assert (tol := float(tol)) > 0
    assert (maxiter := int(maxiter)) > 0
    
    x, y = x0, func(x0)
    
    for _ in range(maxiter):
        fprime = (func(x + eps) - y) / eps  # approximate derivative
        x -= y / fprime
        if abs(( y := func(x) )) < tol:  # define and use y := func(x) in the same line
            return x
    else:
        raise FailedToConvergeError(f"Failed to converge after {maxiter} iterations.")
    
    
maxiter = 20
    
# clearly, one root of   f(x) = 10x^3 -5x^2 + 6x   is x = 0
func = lambda x: 10 * x**3 - 5 * x**2 + 6 * x

# The first one should converge, the second one will not
for x0 in (5, 10000):
    try:
        print('Found the root: ', quasi_newton(func, x0, maxiter=maxiter))
    except FailedToConvergeError:
        print('The quasi-Newton method failed to converge after {} iterations.'.format(maxiter))
    except Exception as ex:
        raise Exception from ex

### 
<hr style="border:1px solid blue">

### 
### We have seen how the `itertools` module provides tools for substantially reducing the required number of lines of code.
### Besides that, it also provides an infrastructure for writing code that is more scalable in the number of inputs.
### 
### We will now discuss the `itertools.product` iterator which can make it laughably easy to make a 
### function scale to arbitrary dimensionality.

### 
### <u>Task</u>:
### Given a tuple of arbitrary length $N$ where each element is itself a tuple of arbitrary length
### containing numbers. Write a function that yields all length $N$ combinations of elements from the tuples.
### For instance: `(1, 2), (3, 4) -> (1, 3), (1, 4), (2, 3), (2, 4)`

### If we knew the number $N$ in advance, we could approach this problem like so:

In [None]:
elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4, 5), (7, 8)

def generate_combinations0(*elements):
    yield ()

def generate_combinations1(*elements):
    elems0, = elements
    for elem in elems0:
        yield (elem,)

def generate_combinations2(*elements):
    assert len(elements) == 2
    for elem0 in elements[0]:
        for elem1 in elements[1]:
            yield (elem0, elem1)
            
def generate_combinations3(*elements):
    assert len(elements) == 3
    for elem0 in elements[0]:
        for elem1 in elements[1]:
            for elem2 in elements[2]:
                yield (elem0, elem1, elem2)
                
# and so on ...

# you can dereference iterators. They'll be consumed.
print(*generate_combinations0(*elements0))
print(*generate_combinations1(*elements1))
print(*generate_combinations2(*elements2))
print(*generate_combinations3(*elements3))

### We can see how it would be kinda cumbersome to write a separate function for handling each number $N$ of input tuples.
### Let us solve this using a (pythonic) C-style implementation that calls itself.

In [None]:
elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4, 5), (7, 8)

# cumbersome but at least pythonic ;)
# you don't need to understand this function (unless you're interested)
def generate_combinations(*elements):
    if len(elements) == 0:
        yield ()
        return
    head, *tail = elements
    for elem0 in head:
        for rest in generate_combinations(*tail):
            yield (elem0,) + rest


print(*generate_combinations(*elements0))
print(*generate_combinations(*elements1))
print(*generate_combinations(*elements2))
print(*generate_combinations(*elements3))

### This one is indeed capable of handling all input lengths. However, is it very readable ?
### Maybe this example is still manageable but as soon as you go to more complex stuff, you'll easily lose track.
### 
### Instead, we may generate all combinations using `itertools.product(*elements)`:

In [None]:
from itertools import product

elements0 = ()
elements1 = (1, 2),
elements2 = (1, 2), (3, 4)
elements3 = (1, 2, 3), (3, 4, 5), (7, 8)

# we can simply replace generate_combinations -> product

print(*product(*elements0))
print(*product(*elements1))
print(*product(*elements2))
print(*product(*elements3))

### Laughably easy ;-)
### `itertools.product` generates the elements in the order they would be generated by the use
### of nested for loops.
### 
<hr style="border:1px solid blue">


### 
## <u>Exercise 5.5</u>:
### You are given $N$ lists containing strings. Each string in the list represents a basis function (in mathematical typesetting).
### Create a list containing the basis functions resulting from a tensor product of all basis functions.
### Note that if you have `str0` and `str1`, the product is `'{} * {}'.format(str0, str1)` (don't forget the ` * `).

In [None]:
import numpy as np


# First 5 complex Fourier basis functions over x0 = [0, 2]
fourier_basis0 = ['exp({} * pi * 1j * x0)'.format(n) if n else '1' for n in range(5)]

# First 3 canonical polynomial basis functions over x1
pol_basis1 = ['x1**{}'.format(n) if n else '1' for n in range(3)]

# gaussian basis functions exp(-(x - a)^2) for various a over x2
gauss_basis2 = ['exp(-(x2 - {:.5g})**2)'.format(a) for a in np.linspace(0, 1, 3) ]

univariate_bases = [fourier_basis0, pol_basis1, gauss_basis2]
for i, basis in enumerate(univariate_bases):
    print("univariate basis number {}: \n\n{}\n\n".format(i, '\n \n'.join(basis)))

In [None]:
# YOUR ONE LINER HERE
# trivariate_basis = ???

print('Trivariate basis function \n')
for i, func in enumerate(trivariate_basis):
    print(f'v{i}: {func}', '\n')

### solution:

In [None]:
from itertools import product

trivariate_basis = list(map(' * '.join, product(*univariate_bases)))

print('Trivariate basis function \n')
for i, func in enumerate(trivariate_basis):
    print(f'v{i}: {func}', '\n')

### This is actually useful in the symbolic math library `sympy`:

In [None]:
from sympy import lambdify, symbols
x0, x1, x2 = symbols('x0 x1 x2')
func = trivariate_basis[35]

# convert string to an actual function
f = lambdify([x0, x1, x2], func)

# we can actually evaluate that function.
print('v35 in (x0, x1, x2) = (1/7, .5, 0): ', f(1/7, .5, 0))

### 
<hr style="border:1px solid blue">

### 
### <u>One last thing</u>:
### Suppose you have a function of the form `f(*args)` and you would like to pass
### dereferenced lists of varying length to `f` in a for loop using the `map(...)` function.

In [None]:
# example

def square(*args):
    return [i**2 for i in args]

inputs = [1, 2], [1, 2, 3], [0], [7, 8, 9, 10]

for squares in map(lambda x: square(*x), inputs):
    print(squares)

### We can use the `itertools.starmap` function which dereferences automatically:

In [None]:
from itertools import starmap

for squares in starmap(square, inputs):
    print(squares)

### 
<hr style="border:1px solid blue">

### 

# What we have learned:
### 1. For / while loops have an optional else clause ;) (you can use it to impress your friends).
### 2. `itertools` helps you making loops more concise and readable.
### It can make generalisations to arbitrary input numbers laughably easy and is worth checking out.

### 
<hr style="border:1px solid blue">

### 

# <u>Lesson 6</u>: The use of `functools` and decorators


### In this lession, we learn how to modify the behavior of functions without using too much boilerplate.
### 
### We have seen how we can write functions that accept lists, tuples maybe even np.ndarrays and
### handle them all equally well. This is called **type agnosticism**, i.e., one function with the same name
### that handles all kinds of different input types
### (returning differing output types from that function is considered bad design though).
### It can be very useful.
### 
### A good example is the `add_polynomials` function.

In [None]:
from itertools import zip_longest
import numpy as np


# It doesn't matter what type a polynomial
# has as long as we can iterate over it.
def add_polynomials(*polynomials):
    return tuple(map(sum, zip_longest(*polynomials, fillvalue=0)))


# np.array, list, tuple are all fine
pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)

print(add_polynomials(pol0, pol1, pol2))

### 
### In more general cases, it may be useful to handle only one type of input inside of the function.
### Type agnosticism is then easily achieved by converting the inputs to the desired type.
### If an input can't be converted, too bad, we get a runtime error early on. But that's **desired behaviour**.
### 
<hr style="border:1px solid blue">

### 
### For reasons that will become apparent shortly, in many cases it is desirable to convert the input(s)
### of a function to tuples **before** they are passed to the function.

In [None]:
from more_itertools import convolve  # multiplying two polynomials is the same as convolving their weights
import numpy as np


def convert_to_tuples(*polynomials):
    "([1, 2, 3], (2, 3, 4)) -> ((1, 2, 3), (2, 3, 4))"
    return tuple(map(tuple, polynomials))


def multiply_polynomials(*polynomials):
    ### special cases first
    if not polynomials:
        return ()
    if len(polynomials) == 1:
        return polynomials[0]
    ###
    if len(polynomials) == 2:
        return tuple(convolve(*polynomials))
    return multiply_polynomials(multiply_polynomials(*polynomials[:2]), *polynomials[2:])


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


# we have to convert first (annyoing)
polynomials = convert_to_tuples(pol0, pol1, pol2)
print("The product of all polynomials is given by: \n", multiply_polynomials(*polynomials))

### 
<hr style="border:1px solid blue">

### 
### The `multiply_polynomials` function can be significantly simplified using the `functools.reduce` function:
### `reduce(lambda x, y: x + y, [1, 2, 3, 4, 5])` $\implies$ `((((1 + 2) + 3) + 4) + 5)`
### 

In [None]:
from more_itertools import convolve  # multiplying two polynomials is the same as convolving their weights
from functools import reduce
import numpy as np


def convert_to_tuples(*polynomials):
    "([1, 2, 3], (2, 3, 4)) -> ((1, 2, 3), (2, 3, 4))"
    return tuple(map(tuple, polynomials))


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)  # apply `_multiply_pols` recursively to entire tuple



pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


# we have to convert first (annyoing)
polynomials = convert_to_tuples(pol0, pol1, pol2)
print("The product of all polynomials is given by: \n", multiply_polynomials(*polynomials))

### 
<hr style="border:1px solid blue">

### 

### That's better !
### However, it's annoying to always have to pass the tuple conversion into the function.
### Can we do it smarter ?
### Maybe like this ?

In [None]:
from more_itertools import convolve
from functools import reduce
import numpy as np


def convert_to_tuples(*polynomials):
    return tuple(map(tuple, polynomials))


def convert_to_tuples(*polynomials):
    "([1, 2, 3], (2, 3, 4)) -> ((1, 2, 3), (2, 3, 4))"
    return tuple(map(tuple, polynomials))


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


def _multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


multiply_polynomials = lambda *polynomials: _multiply_polynomials(*convert_to_tuples(*polynomials))


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)

# we do no longer have to take the composition
print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### Now we have have created a function that automatically converts the input to tuples
### which is then forwarded to the actual function which then does the work.
### It's better but I still see a lot of boilerplate.
### 
<hr style="border:1px solid blue">

### 
### Let's try something else. How about we define a **function that takes a function**
### and returns a **new function** ?
### Check this out:

In [None]:
from more_itertools import convolve
from functools import reduce
import numpy as np
from typing import Callable


# function that takes a function and returns a new function
def convert_input_to_tuples(f: Callable) -> Callable:
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))  # convert into tuple and drop into `f`
    return converted_function


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


multiply_polynomials = convert_input_to_tuples(multiply_polynomials)


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)

# we do no longer have to take the composition
print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### 
### The line `multiply_polynomials = convert_input_to_tuples(multiply_polynomials)`
### is equivalent to the following **decorator**:
### 

In [None]:
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable


def convert_input_to_tuples(f: Callable) -> Callable:
    def converted_function(*polynomials):
        return f(*map(tuple, polynomials))
    return converted_function


### This is what I call clean code


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


@convert_input_to_tuples
def multiply_polynomials(*polynomials):
    # bool(polynomials) False => return polynomials = ()
    return polynomials and reduce(_multiply_pols, polynomials)


@convert_input_to_tuples
def other_heavy_computation(*polynomials):
    pass


@convert_input_to_tuples
def yet_another_heavy_computation(*polynomials):
    pass


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### 
<hr style="border:1px solid blue">

### 

## <u>Exercise 6.1</u>:
### Extend the above to not only convert the polynomials to tuples but also sort the tuples
### to make the function agnostic to the order in which arguments are passed.
### This will come in handy very soon.
### 
### Do not write one decorator but two separate ones.
### 
### The `sorted` function can sort a list / tuple of tuples containing numbers
### `sorted( [(1, 2, 3), (0, 1, 2)] ) = [(0, 1, 2), (1, 2, 3)]`

In [None]:
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable


# function that takes a function and returns a new function
def convert_input_to_tuples(f: Callable) -> Callable:
    def wrapper(*polynomials):
        return f(*map(tuple, polynomials))
    return wrapper


# your code here
def sort_input(f: Callable) -> Callable:
    # your code here
    pass


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


# which decorators to add ?
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### solution:

In [None]:
from more_itertools import convolve
from functools import reduce
import numpy as np
from time import sleep
from typing import Callable
from functools import wraps


def convert_input_to_tuples(f: Callable) -> Callable:
    
    @wraps(f)  # wrapper will retain f's docstring etc
    def wrapper(*polynomials):
        return f(*map(tuple, polynomials))
    
    return wrapper


def sort_input(f: Callable) -> Callable:
    
    @wraps(f)
    def wrapper(*polynomials):
        return f(*sorted(polynomials))
    
    return wrapper


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


@convert_input_to_tuples  # first convert to tuples
@sort_input  # then sort those tuples and forward to function
def multiply_polynomials(*polynomials):
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


pol0 = np.array([1.0, 2.0, 4.0, -1.0])
pol1 = [1, 2, 3]
pol2 = (0, 0, 0, 0, 0, 1)


print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### 
<hr style="border:1px solid blue">

### 
### A popular application of decorators is function caching.
### Suppose you have a function that performs a heavy computation.
### It is possible that the same computation has to be performed more than once.
### It would be nice if the function "remembered" the output given its inputs.
### 
### <u>Task</u>: Write a decorator that does exactly that.

In [None]:
from more_itertools import convolve
from functools import reduce, wraps
from typing import Callable
from time import sleep
import numpy as np


def convert_input_to_tuples(f: Callable) -> Callable:

    @wraps(f)  # make sure the new function has the same name, docstring, ... as `f`
    def wrapper(*polynomials):
        return f(*map(tuple, polynomials))

    return wrapper


def sort_input(f: Callable) -> Callable:

    @wraps(f)
    def wrapper(*polynomials):
        return f(*sorted(polynomials))

    return wrapper


# Functions are first class citizens in python.
# we can edit them by adding attributes etc just like other objects.
def cache_input(f: Callable) -> Callable:
    f._cache = {}  # add an empty dictionary to the function

    @wraps(f)
    def wrapper(*polynomials):
        # try to return existing entry from `_cache` and if missing,
        # add the entry using dict.setdefault and return it.
        try:
            return f._cache[polynomials]
        except KeyError:
            return f._cache.setdefault(polynomials, f(*polynomials))

    return wrapper


def _multiply_pols(pol0, pol1):
    return tuple(convolve(pol0, pol1))


@convert_input_to_tuples  # convert input to tuples so we can use input as keys in a dict
@sort_input  # sort input for commutative operations
@cache_input  # remember outputs for inputs that have been seen before.
def multiply_polynomials(*polynomials):
    sleep(3)  # fake a heavy computation by freezing everything for 3 seconds
    if not polynomials:
        return ()
    return reduce(_multiply_pols, polynomials)


pol0 = [1, 2, 3]
pol1 = (0, 0, 0, 0, 0, 1)
pol2 = np.array([1.0, 2.0, 4.0, -1.0])

# we have to take the function composition
print("The product of all polynomials is given by: \n", multiply_polynomials(pol0, pol1, pol2))

### This took a few seconds to compute.
### To prove to you that the function remembers its input, we call it again.
### The output should be immediate.
### Since we are sorting the input, the order in which we pass the polynomials should not matter.

In [None]:
# call the function again but with pol1, pol2, pol0 rather than pol0, pol1, pol2
print( multiply_polynomials(pol1, pol2, pol0) )

### 
<hr style="border:1px solid blue">

### 

### The `cache_input` decorator we wrote is a simplistic representation of
### the `functools.lru_cache` decorator.
### The `functools.lru_cache` decorator optionally takes a `maxsize` keyword argument
### which specifies the number of input / output pairs to memoize.
### As in our self-written decorator, the input arguments need to be **hashable**.
### Example:

In [None]:
from functools import lru_cache

# Compute the n-th term in the Fibonacci sequence
# without and with caching. To see how often the function
# is called with a specific input, use a print statement.

def fibonacci(n):
    """
        Compute the n-th term in the Fibonacci sequence.
        Starting on (n0, n1) = (1, 1), n_{i+1} = n_{i} + n_{i-1}.
    """
    print('Computing the fibonacci term with n={}'.format(n))
    assert n >= 0
    if n < 2:
        return n
    return fibonacci(n-2) + fibonacci(n-1)


@lru_cache(maxsize=2)
def fibonacci_cached(n):
    """
        Compute the n-th term in the Fibonacci sequence.
        Starting on (n0, n1) = (1, 1), n_{i+1} = n_{i} + n_{i-1}.
        Cache the input / output pairs.
    """
    print('Computing the fibonacci term with n={}'.format(n))
    assert n >= 0
    if n < 2:
        return n
    return fibonacci_cached(n-2) + fibonacci_cached(n-1)


print('Computing the eigth term in the Fibonacci sequence without caching. \n')
print('Eight Fibonacci term:', fibonacci(8))
print('\n')

print('Computing the eigth term in the Fibonacci sequence with caching. \n')
print('Eighth Fibonacci term:', fibonacci_cached(8))

### 
<hr style="border:1px solid blue">

### 
# What we have learned:
### 1. Decoration `@decorator` is syntactic sugar for `decorator(fn)`, i.e., a simple function call.
### 2. `Python` functions are `first class citizens`. You can add attributes to them. They can be passed to functions.
### 3. `functools` provides tools for virtually all functional programming concepts `reduce, partial, wraps, ...`
### 4. `caching` is done by adding a `hashmap` (`dict`) to a function.
### 
<hr style="border:1px solid blue">

### 
# <u>Lesson 7</u>:
## The concept of `hashing`.

### 
### We have seen extensive use of `Python` `dictionaries`, so-called hashmaps.
### The question is: 1. what types can we use as keys and 2. what values can we put in a hashmap ?
### <u>Answer to 2. </u>: everything.
### <u>Answer to 1. </u>: all built-in types that are `immutable` (cannot be changed during their lifetime).
### 
### a. Numeric types: `int`, `float`, `complex`, `bool` (`bool` is a subclass of `int`).
### b. Collections: `tuple` (as long as the items are themselves hashable), `frozenset`.
### c. `Nonetype`.
### ... and a few more.
### 
### Characteristic for `immutable` built-ins is that they have a so-called `hash` value.
### A `hash` is an easy-to-compute `int` value we assign to an immutable type.
### Assigning a `hash` value to a type is a science on its own ...
### The `Python` built-in function `hash` does exactly that:

In [None]:
print(f'Hash of `5`: {hash(5)} (not surprising).\n')
print(f'Hash of `3.14`: {hash(3.14)}.\n')
print(f'The hash of (3.14, 5, 6.1): {hash((3.14, 5, 6.1))}. \n')

### The `Python` `hash` function is designed to avoid so-called **hash collisions**
### of assigning the same `hash` to two different types. However, collisions are
### not completely avoidable.

In [None]:
print(f'The hash of `5.0` is the same as the hash of `5` ? {hash(5.0) == hash(5)}.\n')

### 
### When we invoke `dict[key]`, here is what happens under the hood:
### 
### 1. `Python` computes the `hash` of `key`.
### 
### 2. The `hash` value is used to determine the index of the `dict`'s
### underlying data structure (an `array`), which contains the key/value pair.
### The `dict` data structure is such that the lookup time is $\mathcal{O}(1)$,
### regardless of the number of `dict` entries.
### 
### 3. It is tested if `key` is indeed equal to the `key` that `Python` stores at the location
### corresponding to `hash(key)`. This is necessary to avoid `hash` collisions.
### 
### 4. In case of equality, the value stored at the array location is returned.
### Else a `KeyError` is raised.
### 
### To see why Python discourages the use of `mutable` types in `hashmap`'s,
### here a minimal example:
### 

In [None]:
# class that acts like a list but has the added functionality
# of possessing a `hash`.
class HashableList(list):
    def __hash__(self):
        return hash(tuple(self))
    

# make a hashable list and an empty hashmap
myhashablelist = HashableList([1, 2, 3])
myhashmap = {}

# map `myhashablelist` to some value
myhashmap[myhashablelist] = 5

# mutate the hashable list inplace
myhashablelist.append(4)

# try to retrieve the hash value
myhashmap[myhashablelist]

### Here is what happens: 
### 1. `myhashmap` refers `hash(HashableList([1, 2, 3]))`
### to a reference pointing to the `key` `HashableList([1, 2, 3])`.
### 2. `key` is mutated inplace. Its `hash` changes.
### 3. The old `hash` now points to the mutated `HashableList([1, 2, 3, 4])`.
### 4. Putting `HashableList([1, 2, 3])` into `myhashmap` raises a `KeyError`
### because `HashableList([1, 2, 3]) == HashableList([1, 2, 3, 4])` is `False`.
### 
### Et voila, you've broken your `hashmap`.
### 
<hr style="border:1px solid blue">

### 
### Can we still hash `mutable` types ?
### Indirectly, yes, and we have done it before.
### 
### One way to circumvent this issue is by converting a mutable type to an immutable
### one before using it as a `key` in a `hashmap` (or cached function, for that matter).
### 

In [None]:
import numpy as np
from typing import Tuple
from functools import lru_cache, wraps
from time import sleep


def serialize_array(arr: np.ndarray) -> Tuple[bytes, Tuple[int, ...]]:  # convert array to hashable type
    arr = np.asarray(arr, dtype=float)
    return (arr.tobytes(), arr.shape)  # a tuple containing two hashable types


def deserialize_array(serialized_array: Tuple[bytes, Tuple[int, ...]]):  # undo the conversion
    byte_array, shape = serialized_array
    return np.frombuffer(byte_array).reshape(shape)  # you can uniquely recreate the array from the input


### decorators for converting back and forth
def serialize_inputs(fn):
    @wraps(fn)
    def wrapper(*args):
        return fn(*map(serialize_array, args))
    return wrapper


def deserialize_inputs(fn):
    @wraps(fn)
    def wrapper(*args):
        return fn(*map(deserialize_array, args))
    return wrapper
###


@serialize_inputs  # convert to hashable type
@lru_cache  # invoke cashing on the hashable type
@deserialize_inputs  # convert back to original array to be used in the function as usual.
def heavy_computation(arr0: Tuple[bytes, Tuple[int, ...]], arr1: Tuple[bytes, Tuple[int, ...]]):
    # do a heavy computation
    sleep(1)
    print(f'Performing a heavy computation with inputs {arr0} and {arr1} ....\n')
    return arr0 ** 3 + arr1 ** 4


arr0 = np.arange(3)
arr1 = np.arange(3, 6)

print(f'result: {heavy_computation(arr0, arr1)}.\n\n')

arr0[-1] = 10
print(f'Changed `arr0` inplace, which is now {arr0}.\n')

print(f'result: {heavy_computation(arr0, arr1)}.\n\n')

arr0[-1] = 2
print(f'Changed arr0 back to the original array: {arr0}.\n')

print(f'result: {heavy_computation(arr0, arr1)} (I used cached result).\n\n')

### Convince yourself that this implementation is `safe`.
### 
### Note that there is a slight overhead resulting from converting
### the arrays back and forth.
### Make sure that your computation is sufficiently heavy for this to be worth it.
### 
<hr style="border:1px solid blue">

### 
### Topics we couldn't cover due to time constraints:
### 1. Use of collections: namedtuple, ChainMap, defaultdict (useful alternative to `dict.setdefault`) <- this is definitely worth checking out.
### 2. dataclass - a nice gimmick. Helpful in some cases.
### 3. Iterating over `numpy.ndarray`'s - not super important because you should avoid loops in numpy ;-)
### 4. Writing functions that can handle several input types (achieving `multiple dispatch`).
### <- the `type agnosticism` we built into some of our functions is similar to traditional `multiple dispatch`. Have a look if you're interested.