In [1]:
import numpy as np
import jax
import jax.numpy as jnp

from typing import List, Callable
from functools import partial

In [2]:
def str_to_jax_array(word: str) -> jax.Array:
    return jnp.array([ord(c) for c in word])

In [3]:
def jax_array_to_str(arr: jax.Array) -> str:
    return "".join([chr(i) for i in arr.tolist()])

In [4]:
def list_str_to_jax_array(words: List[str]):
    max_length = len(max(words, key=lambda x: len(x)))
    arr = jnp.empty(shape=(len(words), max_length), dtype=jnp.int32)
    for i, word in enumerate(words):
        arr = arr.at[i, : len(word)].set(str_to_jax_array(word))
    return arr

In [5]:
def jax_array_to_list_str(arr: jax.Array) -> List[str]:
    words = []
    for row in arr:
        if len(a := jnp.where(row == 0)[0]):
            row = row[: a[0]]
        words.append(jax_array_to_str(row))
    return words

# Ranking Word


## Initial Requirement :

Requirements for the word ranking
1.  The score of a given word is calculated by giving one point for 
each letter that is not an 'a'
2. For a given list of words, return a sorted list that starts with the 
highest-scoring word

In [6]:
def j_score(word: jax.Array):
    return jnp.sum(jnp.where((word != ord("a")) & (word != 0), 1, 0))


j_score(str_to_jax_array("rust"))

Array(4, dtype=int32)

In [7]:
def py_score(word: str) -> int:
    return len(word.replace("a", ""))


py_score("rust")

4

In [8]:
py_words = ["ada", "haskell", "scala", "java", "rust"]
jax_words = list_str_to_jax_array(py_words)

In [9]:
def j_rankedWords(words: jax.Array, key: Callable[[jax.Array], jax.Array]) -> jax.Array:
    return words[jnp.flip(jnp.apply_along_axis(key, axis=1, arr=words).argsort())]


res = j_rankedWords(jax_words, j_score)
jax_array_to_list_str(res)

['haskell', 'rust', 'scala', 'java', 'ada']

In [10]:
def py_rankedWords(words: List[str], key: Callable[[str], int]) -> List[str]:
    return sorted(words, key=key, reverse=True)


py_rankedWords(py_words, py_score)

['haskell', 'rust', 'scala', 'java', 'ada']

## Additional Requirement

1. A bonus score of 5 needs to be added to the score if the word contains a c'.
2. The old way of scoring (without the bonus) should still be supported in the code.

In [11]:
def j_score_with_bonus(word: jax.Array) -> jax.Array:
    base = j_score(word)
    if jnp.any(word == ord("c")):
        base += 5
    return base

In [12]:
def j_score_with_bonus_use_lax(word: jax.Array) -> jax.Array:
    base = j_score(word)
    bonus = jax.lax.cond(jnp.any(word == ord("c")), lambda: 5, lambda: 0)
    return base + bonus

In [13]:
res = j_rankedWords(jax_words, j_score_with_bonus_use_lax)
jax_array_to_list_str(res)

['scala', 'haskell', 'rust', 'java', 'ada']

In [14]:
def j_bonus(word: jax.Array) -> jax.Array:
    return jax.lax.cond(jnp.any(word == ord("c")), lambda: 5, lambda: 0)

In [15]:
res = j_rankedWords(jax_words, lambda w: j_score(w) + j_bonus(w))
jax_array_to_list_str(res)

['scala', 'haskell', 'rust', 'java', 'ada']

In [16]:
def py_score_with_bonus(word: str) -> int:
    base = py_score(word)
    bonus = 5 if "c" in word else 0
    return base + bonus


py_rankedWords(py_words, py_score_with_bonus)

['scala', 'haskell', 'rust', 'java', 'ada']

In [17]:
def py_bonus(word: str) -> int:
    return 5 if "c" in word else 0


py_rankedWords(py_words, lambda w: py_score(w) + py_bonus(w))

['scala', 'haskell', 'rust', 'java', 'ada']

## New requirement: Possibility of a penalty
1. A penalty score of 7 needs to be subtracted from the score if the word contains an 's'.
2. Old ways of scoring (with and without the bonus) should still be supported in the code.

In [18]:
def j_penalty(word: jax.Array):
    return jax.lax.cond(jnp.any(word == ord("s")), lambda: 7, lambda: 0)

In [19]:
j_penalty(str_to_jax_array("ada"))

Array(0, dtype=int32, weak_type=True)

In [20]:
res = j_rankedWords(jax_words, lambda w: j_score(w) + j_bonus(w) - j_penalty(w))
jax_array_to_list_str(res)
# ? scala and ada has same score

['java', 'scala', 'ada', 'haskell', 'rust']

In [21]:
def py_penalty(word: str) -> int:
    return 7 if "s" in word else 0

In [22]:
py_rankedWords(py_words, lambda w: py_score(w) + py_bonus(w) - py_penalty(w))

['java', 'ada', 'scala', 'haskell', 'rust']

In [25]:
py_words

['ada', 'haskell', 'scala', 'java', 'rust']

## New requirement: Get the scores
1. We need to know the score of each word in the list of words.
2. The function responsible for ranking should still work the same (we cannot change any existing function)

In [28]:
def j_word_score(
    score_fn: Callable[[jax.Array], jax.Array], words: jax.Array
) -> jax.Array:
    return jax.lax.map(score_fn, words)

In [29]:
j_word_score(lambda w: j_score(w) + j_bonus(w) - j_penalty(w), jax_words)

Array([ 1, -1,  1,  2, -3], dtype=int32)

In [39]:
def py_get_score(score_fn: Callable[[str], int], words: List[str]) -> List[int]:
    return list(map(score_fn, words))

In [40]:
word_score_fn = lambda w: py_score(w) + py_bonus(w) - py_penalty(w)

In [41]:
py_get_score(word_score_fn, py_words)

[1, -1, 1, 2, -3]

## New requirement: Return high-scoring words
1. We need to return a list of words that have a score higher than 1 (i.e., high score).
2. Functionalities implemented so far should still work the same (we cannot change 
any existing function).

filter is not possible in jax since jit need the static shaped array

In [51]:
def py_high_scoring_word(score_fn: Callable[[str], int], words: List[str]) -> List[str]:
    return list(filter(lambda w: score_fn(w) > 1, words))


py_high_scoring_word(word_score_fn, py_words)

['java']

## New Requirement : Different threshold
The high score threshold is 1, but there will be several game modes, each having a different threshold. For  now there will be three game modes  with high score thresholds defined at 1, 0 and 5, respectively

### Solution 1 : Add the new parameter

Problem : Lot repetition in caller side

In [52]:
def py_high_scoring_word(
    score_fn: Callable[[str], int], words: List[str], threshold: int
) -> List[str]:
    return list(filter(lambda w: score_fn(w) > threshold, words))

In [53]:
py_high_scoring_word(word_score_fn, py_words, 0)

['ada', 'scala', 'java']

In [54]:
py_high_scoring_word(word_score_fn, py_words, 1)

['java']

### Return the function

In [69]:
def py_high_scoring_word(
    score_fn: Callable[[str], int], words: List[str]
) -> Callable[[int], List[str]]:
    return lambda threshold: list(filter(lambda w: score_fn(w) > threshold, words))

In [72]:
word_with_score_higher_than = py_high_scoring_word(word_score_fn, py_words)

In [74]:
word_with_score_higher_than(1)

['java']

In [75]:
word_with_score_higher_than(0)

['ada', 'scala', 'java']

## New requirement: Return total score
1. We need to return a cumulative score of words provided as an input list.
2. Functionalities implemented so far should still work the same (we cannot 
change any existing function)

In [89]:
def j_total_score(
    score_fn: Callable[[jax.Array], jax.Array], words: jax.Array
) -> jax.Array:
    return jax.lax.reduce(jax.lax.map(score_fn, words), 0, jax.lax.add, (0,))

In [93]:
j_total_score(lambda w: j_score(w) + j_bonus(w) - j_penalty(w), jax_words)

Array(0, dtype=int32)

In [95]:
from functools import reduce
def py_total_score(score_fn:Callable[[str],int],words:List[str])->int:
    return reduce(lambda x,y:x+y,map(score_fn,words),0)

py_total_score(word_score_fn,py_words)

0

### Exercise:

#### Return a sum of all integers in the given lis

In [96]:
reduce(lambda x,y : x+y,[5, 1, 2, 4, 100],0)

112

#### Return the total length of all the words in the given list.

In [98]:
reduce(lambda acc,word:acc+len(word),("scala", "rust", "ada") ,0)

12

#### Return the number of the letter 's' found in all the words in the given list

In [102]:
reduce(lambda acc,word:acc+word.count("s"),("scala", "haskell", "rust", "ada") ,0)

3

#### Return the maximum of all integers in the given list.

In [103]:
reduce(lambda cur_max,val:max(cur_max,val),(5, 1, 2, 4, 15),0)

15