In [1]:
import numpy as np
import itertools
import functools

In [2]:
np.set_printoptions(
    edgeitems=30, 
    linewidth=100000, 
    precision=3,
    suppress=True
)

rng = np.random.default_rng(12345)    

# Functional Programming APIs

In addition to the functions demo'ed here also check out the [more-itertools](https://more-itertools.readthedocs.io/en/stable/api.html) library. I'll use the following type annotations to avoid verbosity.
  * A sequence of type `T` is written as `seq[T]`
  * An iterable of type `T` is written as `[T]`  
  * A tuple is written as `(T, U, V)`
  
The difference between a sequence type and an iterable type is that the former implements `__len__` and `__getitem__` methods while the latter only implements the `__next__` method.
  
## Summary
```python
all :: [bool] -> bool
any :: [bool] -> bool
filter :: (a -> bool) -> [a] ->  [a]
filterfalse :: (a -> bool) -> [a] -> [a]
takewhile :: (a -> bool) -> [a] -> [a]
dropwhile :: (a -> bool) -> [a] -> [a]
compress :: [a] -> [bool] -> [a]
groupby :: [a] -> (a -> b) -> [(b, [a])]
product :: [t1] -> [t2] -> [t3] .. [tn] -> [(t1, t2, t3, ..., tn)]
accumulate :: [a] -> (b -> a -> b) -> b -> [b]
reduce :: (b -> a -> b) -> [a] -> b -> b
map :: (t1 -> t2 -> ... tn -> a) -> [t1] -> [t2] ... [tn] -> a
(simplified) map :: (a -> b) -> [a] -> [b]
starmap: ((T1, T2, ..., Tn)->U, [(T1, T2, ..., Tn)]) -> [U]
max: TODO
reversed :: [a] -> [a]
sorted :: [a] -> (a -> b) -> bool -> [a]
zip :: [t1] -> [t2] -> ... [tn] -> bool -> [(t1, t2, ..., tn)]
(simplified) zip :: [a] -> [b] -> bool -> [(a, b)]
tee: ([T], int) -> seq[[T]]
pairwise :: [a] -> [(a, a)]
permutations/combinations: ([T], int) -> [(T, T, ..., T)]
count: (Number, Number) -> [Number]
cycle: [T] -> [T]
repeat: (Any, int) -> [Any]
chain: ([T1], [T2], ..., [Tn]) -> [Tx]
```

### `all(it)` 
Will return `True` if all the elements in the iterator are true. The elements in the iterator are bool-like.

```
all: [bool] -> bool
```

In [3]:
nums = rng.integers(10, 1000, 5)
print(nums)
all(map(lambda x: x > 500, nums))

[702 235 790 323 212]


False

In [4]:
assert(all([1, 2, 3, 4]))
assert(not all([0, 1, 2, 3]))

### `any(it)` 
Will return `True` if any of the elements in the iterator is true.

```
any: [bool] -> bool
```

In [5]:
nums = rng.integers(10, 1000, 5)
print(nums)
any(map(lambda x: x > 500, nums))

[799 646 679 988 397]


True

### `filter(predicate, it)`
Will only return those elements in the iterator that are able to pass through the predicate function. **This returns a new iterator** containing only the filtered elements.

```
filter: (T->bool, [T]) -> [T] 
```

In [6]:
nums = rng.integers(10, 1000, 5)
print(nums)
list(filter(lambda x: x > 500, nums))

[841 339 571 602 221]


[841, 571, 602]

### `itertools.filterfalse(predicate, it)`
Will return those elements from the sequence that **do not** pass the predicate, i.e., the false values are filtered out. **This will return an iterator**.

```
filterfalse: (T->bool, [T]) -> [T]
```

In [7]:
list(itertools.filterfalse(lambda x: x%2 == 0, range(10)))

[1, 3, 5, 7, 9]

### `itertools.takewhile(predicate, it)`
Will keep selecting elements from the sequence as long as they keep passing the predicate. After a non-passing element is hit, no more elements are selected even if there are passing elements later on in the sequence.

```
takewhile: (T->bool, [T]) -> [T]
```

In [8]:
list(itertools.takewhile(lambda x: x<5, [1,4,6,4,1]))

[1, 4]

### `itertools.dropwhile(predicate, it)`
Will keep dropping elements from the iterable as long as they keep passing the filter. After a non-passing element is hit, all the remaining elements are returned even if they are passing.

```
takewhile: (T->bool, [T]) -> [T]
```

In [9]:
list(itertools.dropwhile(lambda x: x<5, [1,4,6,4,1]))

[6, 4, 1]

### `itertools.compress(data, selectors)`
Where `data` is an iterable that yeilds the data element and `selectors` is a bool-like iterable that determines whether the corresponding data element should be let through the gate or not.

```
compress: ([T], [bool]) -> [T]
```

In [10]:
# selector = [0, 1, 0, 1, 0]
selector = [False, True, False, True, False]
list(itertools.compress("ABCDE", selector))

['B', 'D']

### `itertools.groupby(it, key=lambda x: x)`
The main input argument is the iterator. When the `key` is not provided, this function behaves likes the `uniq` Unix command where consecutive same elements are grouped together. `key` is an optional group-by function that takes in an element from the passed in iterator and returns a value that will be used to group the elements together. In this sense, the default `key` function is the identity function. Even when a custom key function is provided, it only works for consecutive elements. See example with names below.

The returned iterator is a weird one - each element of the iterator comprises of a tuple `(key, it)`. The first element is the group-by key. In case of the default key function this is simply the value of the element. The second element is another iterator with all the elements in the group.

```
groupby: ([T], T->K) -> [(K, [T])]
```

In [11]:
grps = itertools.groupby([1, 1, 1, 2, 2, 3, 3, 3])
print(type(grps))
for key, grp in grps:
    print(type(grp))
    print(key, list(grp))

<class 'itertools.groupby'>
<class 'itertools._grouper'>
1 [1, 1, 1]
<class 'itertools._grouper'>
2 [2, 2]
<class 'itertools._grouper'>
3 [3, 3, 3]


In [12]:
names = [
    "avilay",
    "parekh",
    "anika",
    "manjit"
]
grps = itertools.groupby(names, key=lambda n: len(n))
for key, grp in grps:
    print(key, list(grp))

6 ['avilay', 'parekh']
5 ['anika']
6 ['manjit']


### `itertools.product(p, q, ...)`
Returns a cartesian product of all the iterators. This is equivalent to doing a nested for-loop.

```
product: ([T1], [T2], ..., [Tn]) -> [(T1, T2, ..., Tn)]
```

In [13]:
list(itertools.product([1, 2, 3], "ABCD"))

[(1, 'A'),
 (1, 'B'),
 (1, 'C'),
 (1, 'D'),
 (2, 'A'),
 (2, 'B'),
 (2, 'C'),
 (2, 'D'),
 (3, 'A'),
 (3, 'B'),
 (3, 'C'),
 (3, 'D')]

### `sum(it, /, start=0)`

### `itertools.accumulate(it, func=sum, *, initial=None)`
Returns the cumulative sum of the elements of the input iterator. `func` is a binary function that can be used to return the cumulative `func`. The default `func` is `sum`.

```
accumulate: ([T], (U,T->U), U) -> [U]
```

In [14]:
print(list(itertools.accumulate([1, 2, 3, 4, 5])))
print(list(itertools.accumulate([1, 2, 3, 4, 5], lambda x, y: x * y)))

[1, 3, 6, 10, 15]
[1, 2, 6, 24, 120]


In [15]:
def func(x: str, y: int) -> str:
    return x + str(y)

x = itertools.accumulate([1, 2, 3], func=func, initial="begin")
list(x)

['begin', 'begin1', 'begin12', 'begin123']

### `functools.reduce(func, it, init=None)`
Will return the final reduced value of applying the binary function successiely to each element of the iterator (starting from the left) and the previously accumulated value. If the initial value is supplied it is placed at the front of the iterator, otherwise the first element is used as the initial value.

```
reduce: ((U,T->U), [T], U) -> U
```

In [16]:
print(functools.reduce(lambda x, y: x + y, [10, 11, 12]))
print(functools.reduce(lambda x, y: x + y, [10, 11, 12], 100))

33
133


In [17]:
def func(x: str, y: int) -> str:
    print("x: ", type(x), x)
    print("y: ", type(y), y)
    return x + str(y)

x = functools.reduce(func, [1, 2, 3], "begin")
print(x)

x:  <class 'str'> begin
y:  <class 'int'> 1
x:  <class 'str'> begin1
y:  <class 'int'> 2
x:  <class 'str'> begin12
y:  <class 'int'> 3
begin123


### `map(fn, it1, it2, it3, ...)`
Will take the $n^{th}$ element from each iterator and ask the passed in function to create a new element that is then returned. **This returns a new iterator** with the new elements. This is done for each element of the shortest iterator, the rest of the elements of the other iterators are ignored.

Lets say 2 iterators are passed to `map`.

```
map: ((T1, T2, ..., Tn)->U, [T1], [T2], ..., [Tn]) -> [U]
```

In most cases I'll use the following simplified map:
```
map: (T->U, [T]) -> [U]
```

In [18]:
ints = rng.integers(10, 1000, 5)
floats = rng.random(3)
print(ints, floats)
list(map(lambda x, y: x + y, ints, floats))

[194 236 676 617 942] [0.248 0.949 0.667]


[194.24824571462958, 236.94888115183332, 676.6672374531004]

In [19]:
def fn(x: int, y: bool) -> str:
    return str(10*x) if y else str(10/x)

xs = [1, 2, 3]
ys = [True, False, True]
x = map(fn, xs, ys)
print(list(x))

['10', '5.0', '30']


### `itertools.starmap(func, it)`
Will unroll each element of the sequence and pass the unrolled (sub) elements as arguments to the function. For a function accepting $n$ arguments, each element of the sequence must itself be a sequence of $n$ elements.

Lets say `func` takes in two params, then `it` can be thought of as an iterable of 2-element tuples.

```
starmap: ((T1, T2, ..., Tn)->U, [(T1, T2, ..., Tn)]) -> [U]
```

In [20]:
ints = rng.integers(10, 1000, 5)
floats = rng.random(3)
print(list(zip(ints, floats)))
args = zip(ints, floats)
list(itertools.starmap(lambda x, y: x + y, args))

[(139, 0.6974534998820221), (104, 0.3264728640701121), (273, 0.7339281633300665)]


[139.69745349988202, 104.32647286407011, 273.73392816333006]

### `max/min`
  * `max(it, *[, key, default])`
  * `max(arg1, arg2, *args[, key])` 
  * `min(it, *[, key, default])`
  * `min(arg1, arg2, *args[, key])`
  
Will return the max/min element in the provided iterable or the list of args. `key` is a one argument ordering function.

In [21]:
strings = ["Anu", "Anika", "Anuchiku", "Baboodi"]
print(max(strings))
print(max(strings, key=lambda s: len(s)))

Baboodi
Anuchiku


### `reversed(seq)`
Will reverse the sequence iterator. `seq` must implement the `__reversed__()` method or implement the sequence protocol which is to implement `__len__()` and `__getitem__()` methods. **This returns a new iterator**.

```
reversed: seq[T] -> [T]
```

In [22]:
nums = rng.integers(10, 1000, 5)
print(nums)
rnums = reversed(nums)
for rnum in rnums:
    print(rnum, end=" ")


[887 776 227 717  90]
90 717 227 776 887 

In [23]:
x = reversed(nums)
print(type(x))

<class 'reversed'>


### `sorted(it, /, *, key=None, reverse=False)`
Will return a new sorted list from the iterable provided. `key` is a one argument ordering function. **This returns a list**.

```
sorted: ([T], T->K, bool) -> [T]
```

Most common form is:
```
sorted: [T] -> [T]
```

In [24]:
nums = rng.integers(10, 1000, 5)
print(nums)
sorted(nums)

[397 168 746 346 478]


[168, 346, 397, 478, 746]

### `zip(*its, strict=False)`
`zip` is a simple function, it "transposes" the values of the input iterators. I can pass as many iterators as I want, say I pass in $m$ iterators, then the output will be a bunch of iterators, where each iterator has $m$ elements, one from each of the input iterator. Because it is just a simple transpose, two successive zip invocations will undo each other.

![zip](./imgs/zip.png)

In [1]:
it1 = zip(
    [702, 235, 790, 323, 212],
    ["Anu", "Anika", "Anuchiku", "Baboodi"]
)
for row in it1:
    print(row)

(702, 'Anu')
(235, 'Anika')
(790, 'Anuchiku')
(323, 'Baboodi')


It does not matter, whether the input iterators are lists, tuples, ranges, etc. As long as they are iterators, zip will work.

In [2]:
it2 = zip(
    (702, 235, 790, 323, 212),
    ("Anu", "Anika", "Anuchiku", "Baboodi")
)
for row in it2:
    print(row)

(702, 'Anu')
(235, 'Anika')
(790, 'Anuchiku')
(323, 'Baboodi')


In [3]:
it3 = zip(
    [702, "Anu"],
    [235, "Anika"],
    [790, "Anuchiku"],
    [323, "Baboodi"]
)
for row in it3:
    print(row)

(702, 235, 790, 323)
('Anu', 'Anika', 'Anuchiku', 'Baboodi')


In [4]:
m = [
    ["00", "01", "02", "03"],
    ["10", "11", "12", "13"],
    ["20", "21", "22", "23"]
]
for row in zip(*m):
    print(row)

('00', '10', '20')
('01', '11', '21')
('02', '12', '22')
('03', '13', '23')


In [5]:
mT = [
    ["00", "10", "20"],
    ["01", "11", "21"],
    ["02", "12", "22"],
    ["03", "13", "23"]
]
for row in zip(*mT):
    print(row)

('00', '01', '02', '03')
('10', '11', '12', '13')
('20', '21', '22', '23')


Because the left-to-right evaluation order of the passed in iterators is preserved and the iterables are evaluated in a lazy manner, there is a clever way to implement chunking an iterable into chunks of $n$ elements each, as follows.

Consider the simple example below where I zip the same iterator 3 times. Internally `zip` will zip together `(next(arg1), next(arg2), next(arg3))`. In this case all the args are the same iterator so it translates to `(next(it), next(it), next(it))` which will yeild `(0, 1, 2)`, `(3, 4, 5)`, and so on.

In [31]:
it = iter(range(10))
for chunk in zip(it, it, it):
    print(chunk)

(0, 1, 2)
(3, 4, 5)
(6, 7, 8)


I can write this more cleverly as follows:

In [32]:
it = iter(range(10))
its = [it, it, it]
for chunk in zip(*its):
    print(chunk)

(0, 1, 2)
(3, 4, 5)
(6, 7, 8)


And now let me make it even more compact -

In [33]:
it = iter(range(10))
its = [it] * 3
for chunk in zip(*its):
    print(chunk)

(0, 1, 2)
(3, 4, 5)
(6, 7, 8)


Now all together - `zip(*[it]*3)` $\equiv$ `zip(it, it, it)`

In [34]:
it = iter(range(10))
for chunk in zip(*[it]*3):
    print(chunk)

(0, 1, 2)
(3, 4, 5)
(6, 7, 8)


Now see this in action with a generator function.

In [1]:
def gen_fibs(cap):
    i, j = 0, 1
    curr = 0
    while curr < cap:
        x = i + j
        yield x
        i, j = j, x
        curr += 1
print(list(gen_fibs(10)))  
chunks =  zip(*[gen_fibs(10)]*3)
list(chunks)

[1, 2, 3, 5, 8, 13, 21, 34, 55, 89]


[(1, 2, 3), (5, 8, 13), (21, 34, 55)]

Here is how it works, first `gen_fibs(10)` creates a generator which just an iterator over the `gen_fibs` function. Next the `[...]*3` copies this **same** iterator three times in an array, so when I call `next(it[0])` and `next(it[1])` it is the same underlying iterator that is yielding values. Next the `*[...]` just unrolls the values in the list into varargs. Now when zip starts iterating over the list of iterators, it will yield values in chunks of 3 from the same underlying iterator.

In [36]:
chunks_of = lambda n, seq: zip(*[iter(seq)]*n)
chunks = lambda n, seq: zip(*[iter(seq)]*(len(seq)//n))

In [37]:
# More readable version of the above code

def chunks(seq, n_chunks):
    """Partition seq into n_chunks.
    Will partition the provided sequence into equally sized n_chunks, where each 
    chunk will have len(seq) // n_chunks number of elements. Will drop any last
    elements in the sequence.
    """
    it = iter(seq)    
    chunk_len = len(list(seq))//n_chunks
    its = [it] * chunk_len
    return zip(*its)


def chunks_of(seq, chunk_len):
    """Partition into chunks of given length.
    Will partition the provided sequence into some as yet unknown number of chunks
    where each chunk has chunk_len number of elements. Will drop any last elements
    in the sequence.
    """
    it = iter(seq)
    its = [it] * chunk_len
    return zip(*its)    

In [38]:
nums = rng.integers(10, 1000, 20)
print(nums)
print(list(chunks_of(nums, 3)))
print(list(chunks(nums, 6)))

[470 480 273 562 817 503 201  34 138  90 100 131 602 809 856 656 605 338 932 644]
[(470, 480, 273), (562, 817, 503), (201, 34, 138), (90, 100, 131), (602, 809, 856), (656, 605, 338)]
[(470, 480, 273), (562, 817, 503), (201, 34, 138), (90, 100, 131), (602, 809, 856), (656, 605, 338)]


### `itertools.zip_longest(*its, fillvalue=None)`
Instead of cutting the zip at the shortest iterable, this goes on till the longest iterable, filling in the tail end of the short iterables with `None` or any other fillvalue.

In [39]:
list(itertools.zip_longest("ABC", range(5)))

[('A', 0), ('B', 1), ('C', 2), (None, 3), (None, 4)]

### `itertools.tee(it, n)`
Returns n independant iterators copied from the original iterator.

```
tee: ([T], int) -> seq[[T]]
```

The difference between `tee(it, 3)` and `[it]*3` is that `tee` will give independant iterators while the list will just make copies of the same iterator. This means that if I exhaust the first iterator in the tee'ed list, the other iterators are unaffected, but that is not the case for the copied iterators.

In [40]:
def nums(upper_bound):
    for i in range(upper_bound):
        yield i

it = nums(10)
its = list(itertools.tee(it, 3))
print(next(its[0]))
print(next(its[1]))
print(next(its[2]))

0
0
0


In [41]:
def nums(upper_bound):
    for i in range(upper_bound):
        yield i

it = nums(10)
its = [it] * 3
print(next(its[0]))
print(next(its[1]))
print(next(its[2]))

0
1
2


### `itertools.pairwise(it)`
Returns successive overlapping pairs taken from the input iterable.

```
pairwise: [T] -> [(T, T)]
```

In [42]:
list(itertools.pairwise("ABCDE"))

[('A', 'B'), ('B', 'C'), ('C', 'D'), ('D', 'E')]

## Permutations and Combinations
  * `itertools.permutations(it, r=None)` returns `r` or full length permutations of elements in the iterator.
  * `itertools.combnations(it, r)` returns `r` length combinations in lexicographic sorted order. These combinations are without replacement.
  * `itertools.combinations_with_replacement(it, r)` same as above but with repeated elements.

```
permutations: ([T], int) -> [(T, T, ..., T)]
```

In [43]:
perms = itertools.permutations(range(3))
print(type(perms))
for perm in perms:
    print(type(perm))
    print(perm)

<class 'itertools.permutations'>
<class 'tuple'>
(0, 1, 2)
<class 'tuple'>
(0, 2, 1)
<class 'tuple'>
(1, 0, 2)
<class 'tuple'>
(1, 2, 0)
<class 'tuple'>
(2, 0, 1)
<class 'tuple'>
(2, 1, 0)


In [44]:
list(itertools.combinations("ABCD", 3))

[('A', 'B', 'C'), ('A', 'B', 'D'), ('A', 'C', 'D'), ('B', 'C', 'D')]

In [45]:
list(itertools.combinations_with_replacement("ABC", 2))

[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'B'), ('B', 'C'), ('C', 'C')]

### `itertools.count(start=0, step=1)`
Start from the specified value and keeps yielding evenly spaced values as specified by step.

```
count: (Number, Number) -> [Number]
```

In [46]:
counter = itertools.count(start=10, step=2.5)
print(type(counter))
for i, val in enumerate(counter):
    if i > 5: break
    print(val)

<class 'itertools.count'>
10
12.5
15.0
17.5
20.0
22.5


### `itertools.cycle(it)`
Cycles throug the elements in the given iterable infinitely.

```
cycle: [T] -> [T]
```

In [47]:
it = itertools.cycle("ABCD")
print(type(it))
for i, x in enumerate(it):
    if i > 5: break
    print(x)
    

<class 'itertools.cycle'>
A
B
C
D
A
B


### `itertools.repeat(obj, times=$\infty$)`
Creates an iterator with `obj` repeated `times`.

```
repeat: (Any, int) -> [Any]
```

In [48]:
it = itertools.repeat("hello", 4)
print(type(it))
for x in it:
    print(x)

<class 'itertools.repeat'>
hello
hello
hello
hello


### `itertools.chain(*its)`
Flattens all the input iterables into a single iterable.
Lets say two iterables are provided to `chain`. This can be used to implement `flatmap` which is not available in Python's standard library. There is an alternate way to create a flattened iterator using `itertools.chain.from_itertable`.

```
chain: ([T1], [T2], ..., [Tn]) -> [Tx]
```

In [49]:
it = itertools.chain(range(5), "ABCDE")
print(type(it))
for x in it:
    print(x)

<class 'itertools.chain'>
0
1
2
3
4
A
B
C
D
E


### `functools.partial(func, /, *args, **kwargs)`
Freezes the input function with the provided arguments so it can be used with just the remaining arguments.

In [52]:
from functools import partial

In [56]:
def process_1(a: int, b: int, c: int) -> int:
    print(f"a = {a}, b = {b}, c = {c}")
    return 0

In [60]:
proc = partial(process_1, 10, 20)
proc(30)

a = 10, b = 20, c = 30


0

In [59]:
proc = partial(process_1, c=30)
proc(10, 20)

a = 10, b = 20, c = 30


0

In [62]:
try:
    proc = partial(process_1, b=20)
    proc(10, 30)
except TypeError as te:
    print(te)

process_1() got multiple values for argument 'b'
