In [None]:
# Activate Spark in our Colab notebook.
import os
# Find the latest version of spark 3.0  from http://www.apache.org/dist/spark/ and enter as the spark version
# For example: 'spark-3.2.2'
spark_version = 'spark-3.2.2'
# spark_version = 'spark-3.<enter version>'
os.environ['SPARK_VERSION']=spark_version

# Install Spark and Java
!apt-get update
!apt-get install openjdk-11-jdk-headless -qq > /dev/null
!wget -q http://www.apache.org/dist/spark/$SPARK_VERSION/$SPARK_VERSION-bin-hadoop3.2.tgz
!tar xf $SPARK_VERSION-bin-hadoop3.2.tgz
!pip install -q findspark

# Set Environment Variables
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = f"/content/{spark_version}-bin-hadoop3.2"

In [None]:
# Install pytest and pytest-sugar to make our output look nice.
!pip install -q pytest pytest-sugar

In [None]:
# Create and navigate to the tdd directory.
from pathlib import Path
if Path.cwd().name != 'tests':
    %mkdir tests
    %cd tests
# Show the current working directory.  
%pwd

In [None]:
# Initialize the __init__.py file. 
# This file will be stored in our pwd (/content/tests)
%%file __init__.py
pass

In [None]:
import findspark
findspark.init()

# Import other dependencies. 
from pyspark import SparkFiles
from pyspark.sql import SparkSession

import time
spark = SparkSession.builder.appName("sparkEnergyData").getOrCreate()

# url = "https://raw.githubusercontent.com/ahakobia/Group4_NFLX_MIDTERM/main/Resources/organised_Gen.csv?token=GHSAT0AAAAAABXH6XNLVR2KDPQKXAWUGPNQYZX5EYQ"
# spark.sparkContext.addFile(url)
# energy = spark.read.csv(SparkFiles.get("organised_Gen.csv"), header=True)

spark = SparkSession \
    .builder \
    .appName("energy file read in") \
    .getOrCreate()
energy = spark.read.csv('organised_Gen.csv', header = True)
energy = energy.withColumnRenamed('TYPE OF PRODUCER', 'producer')
energy = energy.withColumnRenamed('ENERGY SOURCE', 'source')
energy = energy.withColumnRenamed('GENERATION (Megawatthours)', 'generated')
energy.write.parquet('parquet_energy',mode='overwrite')
p_energy = spark.read.parquet('parquet_energy')
p_energy.createOrReplaceTempView('p_energy_data')

start_time = time.time()
transformed_df = spark.sql("""
    SELECT
        YEAR as year,
        MONTH as month,
        STATE as state, producer, source, generated
        FROM p_energy_data
        WHERE state == 'US-TOTAL' AND producer == 'Total Electric Power Industry'
        AND source == 'Total'
        """)
transformed_df

  
q_df = spark.sql("""SELECT YEAR, (SUM(generated)/1000000)
            FROM p_energy_data 
            WHERE producer = 'Total Electric Power Industry' 
                AND source = 'Total' 
                AND state = 'US-TOTAL'
                AND YEAR != 2022
            Group By YEAR, State
            Order By YEAR DESC
                """)
q_df


print("--- %s seconds ---" % (time.time() - start_time))


In [None]:
# Create a bank_data.py file and write the function to it. 
# This file will be stored in our pwd (/content/tests).
%%file total_energy.py

# Import findspark() and initialize. 
import findspark
findspark.init()

# Import other dependencies. 
from pyspark import SparkFiles
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("sparkEnergyData").getOrCreate()

# Create the import_data function. 
def import_data():
    spark = SparkSession \
    .builder \
    .appName("energy file read in") \
    .getOrCreate()
    energy = spark.read.csv('organised_Gen.csv', header = True)
    energy = energy.withColumnRenamed('TYPE OF PRODUCER', 'producer')
    energy = energy.withColumnRenamed('ENERGY SOURCE', 'source')
    energy = energy.withColumnRenamed('GENERATION (Megawatthours)', 'generated')
    energy.write.parquet('parquet_energy',mode='overwrite')
    p_energy = spark.read.parquet('parquet_energy')
    p_energy.createOrReplaceTempView('energy_data')
    return p_energy

def transform_data():
    transformed_df = spark.sql("""
    SELECT
        YEAR as year,
        MONTH as month,
        STATE as state, producer, source, generated
        FROM energy_data
        WHERE state == 'US-TOTAL' AND producer == 'Total Electric Power Industry'
        AND source == 'Total'
        """)
    return transformed_df

def query_data():
  
    q_df = spark.sql("""SELECT YEAR, (SUM(generated)/1000000)
            FROM energy_data 
            WHERE producer = 'Total Electric Power Industry' 
                AND source = 'Total' 
                AND state = 'US-TOTAL'
                AND YEAR != 2022
            Group By YEAR, State
            Order By YEAR DESC
                """)
    return q_df
    

# def transform_data_full():
#     transformed_df = spark.sql("""
#     SELECT
#         ZIPCODE,
#         ADDRESS
#     FROM ZIP_BANK_DATA
#     """)

#     return transformed_df

# def distinct_zip_codes():
#     distinct_zips = spark.sql("""
#     SELECT DISTINCT
#         ZIPCODE
#     FROM ZIP_BANK_DATA
#     """)

#     return distinct_zips

In [None]:
# Create a test_bank_data.py file and write the test functions to it. 
# This file will be stored in our pwd (/content/tests).
%%file test_total_energy.py

# From the bank_data.py file and import the import_data function. 
from total_energy import (import_data, transform_data, query_data) #transform_data_full, distinct_zip_codes

# Write the tests. 
def test_row_count_before_transform():
  df = import_data()
  assert df.count() == 496774

def test_column_count_before_transform():
  df = import_data()
  assert len(df.columns) == 7

def test_row_count_after_transform():
    df = transform_data()
    assert df.count() != 496774

def test_column_count_after_transform():
    df = transform_data()
    assert len(df.columns) != 7

def test_row_count_query():
    df = query_data()
    assert df.count() == 21

def test_column_count_query():
    df = query_data()
    assert len(df.columns) == 2

# def test_columns_in_transformed_df():
#     df = transform_data()
#     assert df.schema.names == ['year', 'month','state','producer','source','generated']

# def test_row_distinctness():
#     transformed_df = transform_data_full()
#     distinct_zips_df = distinct_zip_codes()
#     assert transformed_df.count() == distinct_zips_df.count()


In [None]:
# Run the test_import_data.py file with pytest. 
!python -m pytest test_total_energy.py