# Lesson 08 - Map and FlatMap

In [0]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

In this lesson, we will introduce the following two RDD methods: `map()` and `flatMap()`. Both of these tranformations allow us to apply a custom function to each element of an RDD. 

* The `map()` transformation is used to apply a one-to-one function in which each element of the original RDD is mapped to exactly one element of the new RDD. 
* The `flatMap()` transformation is used to apply a one-to-many function in which each element of the original RDD is mapped to one or more elements of the new RDD. 

Before discussing either of these transformations in depth, we need to review the Python concept of a lambda function.

### Python Review: Lambda Functions

A **lambda** function in Python is an anonymous (nameless) function that can be defined using a single line of code. They are useful for defining functions that are intended to be passed in as the argument for some other function. 

We typically define functions in Python by using the `def` keyword. In the following cell, we use `def` to define a function `g()` that accepts a single input and then returns the square of that input.

In [0]:
def g(x):
    return x**2

print(g(4))

In additional to `def`, we can also use the `lambda` keyword to define a function in Python. The syntax for using `lambda` to define a function is as follows:

   `lambda (parameters) : (formula for return value)`

The expression above will return the described function. 

In the next cell, we will define the same function `g()` as we did above using `def`, but this time using the `lambda` keyword instead. Notice that the function returned by the `lambda` keyword is stored in the variable `g`.

In [0]:
g = lambda x : x**2

print(g(4))

Note that we did not have to assign the returned function to a variable in order to use it. In the cell below, we define the function, and then immediately pass the argument 4 to the function, getting an output of 16.

In [0]:
(lambda x : x**2)(4)

Since we did not assign the function created in the cell above to a variable, it will be forgotten once the line that defines and uses it finishes executing. If we need to re-use a function, we need to give it a name by assigning it to a variable. 

As shown in the cell below, lambda functions can be defined using more than one parameter.

In [0]:
h = lambda x, y : x + y**2
print(h(3,2))

Lambda functions are useful when you need to provide a simple function as an argument for another function. We will now illustrate this concept using the `map()` transformation.

## The `map()` Transformation

The `map()` method of an RDD is a transformation that accepts a function `f` as a parameter. The method applies `f` to every element of the original RDD, returning an RDD containing the transformed values.

The cell below illustrates the `map()` transformation by squaring every element of an RDD.

In [0]:
#-------------------------------------------------
# Example: Using map to square elements in an RDD
#-------------------------------------------------

num_rdd = sc.parallelize([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5])

sq_rdd = num_rdd.map(lambda x : x**2)

print(sq_rdd.collect())

The return value of the function `f` passed to `map()` is not limited to any particular data type. When working with text data, it is often useful to use `map()` along with the `split()` method for strings to tokenize lines of text into lists of individual words. We illustrate this in the next example.

In [0]:
#----------------------------------------------------------
# Example: Using map to tokenize string elements in an RDD
#----------------------------------------------------------

pres_rdd = sc.parallelize([
    'George Washington', 'John Adams', 'Thomas Jefferson', 
    'James Madison', 'John Quincy Adams', 'Andrew Jackson'
])

pres_rdd_tokenized = pres_rdd.map(lambda x : x.split(' '))

for pres in pres_rdd_tokenized.collect():
    print(pres)

We can modify the example above to extract only the last word out of each line.

In [0]:
#-----------------------------------------------------
# Example: Using map to select last word in a string
#-----------------------------------------------------

last_names = pres_rdd.map(lambda x : x.split(' ')[-1])

for name in last_names.collect():
    print(name)

We are also able to use pre-defined Python functions with the `map()` transformation. In the following example, we pass the `len()` function to `map()` to determine the length of each string in `pres_rdd`.

In [0]:
#---------------------------------------------------------------
# Example: Using map to determine the length of string elements
#---------------------------------------------------------------

len_rdd = pres_rdd.map(len)
print(len_rdd.collect())

## Using `map()` to Process File Input

We have seen before that when we read in an RDD from a file, the contents are read in as strings with one RDD element per line of the file. In some cases the lines of the file will represent numerical information, or individual lines might contain a mix of text and numerical information. We can use `map()` to process the information into a more desirable format.

### Diamonds Data

In the next example, we will use `map()` to process tablular data stored in data file `diamonds.txt`. This data file contains information about 53,940 diamonds sold in the United States. You can find more information about this dataset here: [Diamonds Dataset](https://ggplot2.tidyverse.org/reference/diamonds.html). 

We will start by loading the file into an RDD and then counting the number of lines in the resulting RDD.

In [0]:
diamonds_pre = sc.textFile('/FileStore/tables/diamonds.txt')
print(diamonds_pre.count())

The data in this dataset is stored in a tab-delimited format. We can see that by looking at the strings representing the first five lines of the data file.

In [0]:
diamonds_pre.take(5)

We can display this information in a more readable format by printing each line separately.

In [0]:
for row in diamonds_pre.take(5):
    print(row)

We will use `map()` to process this data in the following way:

1. We will use the `split()` method to tokenize the strings into lists of individual values by spliting at the tab characters. 
2. We will convert each token into the appropriate data type. 

Each element in our resulting RDD will contain a list of values representing a single row in the dataset. 

Notice that the first element of the RDD contains header information providing the names of each of the columns in the dataset. This record will need to be processed in a different way than the others. 

We now write a function to process each line of text.

In [0]:
def process_row(row):
    items = row.split('\t')
  
    if 'carat' in row:
        return items
  
    return [float(items[0]), items[1], items[2], items[3], 
            float(items[4]), int(items[5]), int(items[6]), 
            float(items[7]), float(items[8]), float(items[9])]

In the next cell, we will use map to apply the `process_row` function to each element of the RDD. We will then display the output for the first five rows.

In [0]:
diamonds = diamonds_pre.map(process_row)

for row in diamonds.take(5):
    print(row)

Suppose we wanted to add a new value to each list in our RDD to record the price per carat for the diamond. The following example illustrates how we might do that. Note that the element at index 0 in each list records the carat size for the diamond and the element at index 6 indicates the price.

In [0]:
def price_per_carat(row):
    if 'carat' in row: 
        return row + ['price_per_carat'] # we did't use .append because it will change the orginal list "row" too...
    ppc = round(row[6] / row[0], 2)
    return row + [ppc]

diamonds_ppc = diamonds.map(price_per_carat)

for row in diamonds_ppc.take(5):
    print(row)

We will see other interesting applications of the `map()` transformation later.

## The `flatMap()` Transformation

When we use `map()` to tokenize string elements within an RDD, we get back an RDD whose elements are lists of tokens. Occasionally, we will want a new RDD that contains not lists of tokens, but the individual tokens themselves.This can be accomplished using the `flatMap()` transformation. The `flatMap()` method is similar to `map()`, but if the return value for the function being supplied to `flatMap()` is a list, then the elements of that list will each become new elements of the resulting RDD. As a result, an RDD produced by `flatMap()` will likely have more elements than the source RDD. We describe this scenario by saying that `flatMap()` is used to apply a **one-to-many transformation**. 

We will illustrate the behavior of `flatMap()` using `pres_rdd`.

In [0]:
#------------------------------------------
# Example: Using flatMap for tokenization
#------------------------------------------

pres_tokens = pres_rdd.flatMap(lambda x : x.split(' '))

print(pres_tokens.collect())

print()

print('Neat Print in line:')
print('=' *20)
for token in pres_tokens.collect():
    print(token)

In [0]:
print(pres_rdd.count())
print(pres_tokens.count())

In the next example, we will use `flatMap()` to count the number of words in a file. For this example, we will use a text file containing the contents of the novel "The War of the Worlds". We will begin by viewing the first 20 lines of this text file.

In [0]:
wotw_lines = sc.textFile('/FileStore/tables/war_of_the_worlds.txt')

for row in wotw_lines.take(20):
    print(row)

We will now use `flatMap()` to tokenize the elements of this RDD producing an RDD of individual words. We will then determine the number of lines and the number of words in this file.

In [0]:
#--------------------------------------------------------
# Example: Counting number of lines and words in a file
#--------------------------------------------------------

wotw_words = wotw_lines.flatMap(lambda x : x.split(' '))
print('Number of lines:', wotw_lines.count())
print('Number of words:', wotw_words.count())

We can use the `distinct()` method to count the number of unique words in the novel. To ensure that two words are considered the same regardless of capitalization, we will use `map()` to first convert all of the words to lowercase.

In [0]:
#--------------------------------------------------------
# Example: Counting number of distinct words in a file
#--------------------------------------------------------

wotw_lower = wotw_words.map(lambda x : x.lower())
print(wotw_lower.distinct().count())