In [1]:
# Activate Spark in our Colab notebook.
import os
# Find the latest version of spark 3.0  from http://www.apache.org/dist/spark/ and enter as the spark version
# For example: 'spark-3.2.2'
spark_version = 'spark-3.2.2'
# spark_version = 'spark-3.<enter version>'
os.environ['SPARK_VERSION']=spark_version

# Install Spark and Java
!apt-get update
!apt-get install openjdk-11-jdk-headless -qq > /dev/null
!wget -q http://www.apache.org/dist/spark/$SPARK_VERSION/$SPARK_VERSION-bin-hadoop3.2.tgz
!tar xf $SPARK_VERSION-bin-hadoop3.2.tgz
!pip install -q findspark

# Set Environment Variables
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = f"/content/{spark_version}-bin-hadoop3.2"

0% [Working]            Hit:1 http://archive.ubuntu.com/ubuntu bionic InRelease
0% [Waiting for headers] [Connected to cloud.r-project.org (108.157.162.43)] [C                                                                               Get:2 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]
                                                                               Get:3 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]
                                                                               Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease [1,581 B]
0% [2 InRelease 15.6 kB/88.7 kB 18%] [3 InRelease 20.0 kB/88.7 kB 23%] [Waiting0% [1 InRelease gpgv 242 kB] [2 InRelease 15.6 kB/88.7 kB 18%] [3 InRelease 20.                                                                               Get:5 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]
Get:6 http://pp

In [2]:
# Install pytest and pytest-sugar to make our output look nice.
!pip install -q pytest pytest-sugar

In [3]:
# Create and navigate to the tdd directory.
from pathlib import Path
if Path.cwd().name != 'tests':
    %mkdir tests
    %cd tests
# Show the current working directory.  
%pwd

/content/tests


'/content/tests'

In [4]:
# Create a  __init__.py file that will contain that will be used to run our functions. 
# This file will be stored in our pwd (/content/tests)
%%file __init__.py
pass

Writing __init__.py


In [13]:
# Create a heart_health.py file that will contain our functions.
# This file will be stored in our pwd (/content/tests).
%%file heart_health.py

# Import findspark() and initialize. 
import findspark
findspark.init()

# Import other dependencies. 
from pyspark import SparkFiles
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("sparkHealthData").getOrCreate()

def import_data():
    url = "https://2u-data-curriculum-team.s3.amazonaws.com/nflx-data-science-adv/week-6/heart_health.csv"
    spark.sparkContext.addFile(url)
    df = spark.read.csv(SparkFiles.get("heart_health.csv"), sep=",", header=True)
    df.createOrReplaceTempView('HEART_HEALTH')

    return df

# 1. Write a function called "get_states()" that selects the distinct states from the "HEART_HEALTH" view. 
def get_states():
    transformed_df = spark.sql("""
    SELECT DISTINCT
        STATE
    FROM HEART_HEALTH
    """)

    return transformed_df

# 2. Write a function called "get_states_with_above_avg_death_rate" that selects the states, average death rate from the "HEART_HEALTH" view grouped by state 
# where the average death rate is greater than 400. 
def get_states_with_above_avg_death_rate():
    transformed_df = spark.sql("""
    SELECT
        STATE,
        AVG(Death_Rate)
    FROM HEART_HEALTH
    GROUP BY STATE
    HAVING AVG(Death_Rate) > 400
    """)

    return transformed_df

Overwriting heart_health.py


Test Suite

In [16]:
# Create a test_heart_health.py file that will contain the the test functions.
# This file will be stored in our pwd (/content/tests).
%%file test_heart_health.py

# Import the functions to test: import_data get_states, get_states_with_above_avg_death_rate
from heart_health import import_data, get_states, get_states_with_above_avg_death_rate

# 1. Write a test that returns 799 rows from the import_data function. 
def test_row_count_source():
    df = import_data()
    assert df.count() == 799

# 2. Write a test that returns 9 columns from the import_data function.
def test_column_count_source():
    df = import_data()
    assert len(df.columns) == 9

# 3. Write a test that ensures that we are only getting the "STATE" column returned from the get_states() function.
def test_get_states():
    df = get_states()
    assert df.schema.names == ['STATE']

# 4. Write a test that ensures that we are only getting five distinct states 
# after the the transformation in the get_states() function.
def test_row_count_get_states():
    df = get_states()
    assert df.count() == 5

# 5. Write a test that to ensure that we only get 5 rows from the get_states_with_above_avg_death_rate() function. 
def test_row_count_avg_death_rate():
    df = get_states_with_above_avg_death_rate()
    assert df.count() == 5

# 6. Write a test that to ensure that we only get 2 columns from the get_states_with_above_avg_death_rate() function. 
def test_column_count_avg_death_rate():
    df = get_states_with_above_avg_death_rate()
    assert len(df.columns) == 2

# 7. Write a test to see if any states were removed because they had death rate lower than 400.
def test_get_states_vs_avg_death_rate_count():
    df_get_states = get_states()
    df_avg_death_rate = get_states_with_above_avg_death_rate()
    assert df_get_states.count() == df_avg_death_rate.count()

Overwriting test_heart_health.py


In [17]:
# Run the test_heart_health.py file with pytest. 
!python -m pytest test_heart_health.py

[1mTest session starts (platform: linux, Python 3.7.13, pytest 3.6.4, pytest-sugar 0.9.5)[0m
rootdir: /content/tests, inifile:
plugins: typeguard-2.7.1, sugar-0.9.5

 [36m[0mtest_heart_health.py[0m [32m✓[0m[32m✓[0m[32m✓[0m[32m✓[0m[32m✓[0m[32m✓[0m[32m✓[0m                                    [32m100% [0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m[40m[32m█[0m

Results (26.76s):
[32m       7 passed[0m
