In [0]:
#In case the python environment does not contain pyspark package, pyspark will be installed through pip
%pip install pyspark

In [0]:
import pyspark

####Importing Dataset

In [0]:
#This snippet might have changes depending on where or how the code is being executed.
df1 = spark.read.format("csv").option("header", "true").load("dbfs:/FileStore/shared_uploads/akurra4@gmu.edu/global_power_plant_database.csv", inferSchema = "true")
print (df1.count())

####Removing Unnecessary Columns And Handling Missing Data

Removed the columns:<br>
country_long<br>
owner<br>
source<br>
url<br>
wepp_id<br>
geolocation_source<br>
generation_data_source<br>
estimated_generation_note_2013<br>
estimated_generation_note_2014<br>
estimated_generation_note_2015<br>
estimated_generation_note_2016<br>
estimated_generation_note_2017<br>
<br>
Also dealt whith missing data within Categorical columns by removing the records with null values<br>
As for the numeric columns, a value of 0 has been attached to the columns where there is no record with a value 0 for the feature.<br>

In [0]:
from pyspark.sql.functions import col
from pyspark.sql.functions import coalesce, lit

df = df1.drop('country_long','owner', 'source', 'url', 'wepp_id','geolocation_source','generation_data_source','estimated_generation_note_2013', 'estimated_generation_note_2014', 'estimated_generation_note_2015', 'estimated_generation_note_2016', 'estimated_generation_note_2017')
df = df.withColumnRenamed('latitude', 'LAT')
df = df.withColumnRenamed('longitude', 'LONG')
df = df.withColumn("capacity_mw", coalesce("capacity_mw", lit(0)))
df = df.withColumn("LAT", coalesce("LAT", lit(0)))
df = df.withColumn("LONG", coalesce("LONG", lit(0)))
df = df.withColumn("commissioning_year", coalesce("commissioning_year", lit(0)))
df = df.withColumn("year_of_capacity_data", coalesce("year_of_capacity_data", lit(0)))
df = df.withColumn("generation_gwh_2013", coalesce("generation_gwh_2013", lit(0)))
df = df.withColumn("generation_gwh_2014", coalesce("generation_gwh_2014", lit(0)))
df = df.withColumn("generation_gwh_2015", coalesce("generation_gwh_2015", lit(0)))
df = df.withColumn("generation_gwh_2016", coalesce("generation_gwh_2016", lit(0)))
df = df.withColumn("generation_gwh_2017", coalesce("generation_gwh_2017", lit(0)))
df = df.withColumn("generation_gwh_2018", coalesce("generation_gwh_2018", lit(0)))
df = df.withColumn("generation_gwh_2019", coalesce("generation_gwh_2019", lit(0)))
df = df.withColumn("estimated_generation_gwh_2013", coalesce("estimated_generation_gwh_2013", lit(0)))
df = df.withColumn("estimated_generation_gwh_2014", coalesce("estimated_generation_gwh_2014", lit(0)))
df = df.withColumn("estimated_generation_gwh_2015", coalesce("estimated_generation_gwh_2015", lit(0)))
df = df.withColumn("estimated_generation_gwh_2016", coalesce("estimated_generation_gwh_2016", lit(0)))
df = df.withColumn("estimated_generation_gwh_2017", coalesce("estimated_generation_gwh_2017", lit(0)))
#setting missing data in categorical columns to a string "N/A"
df = df.na.fill ("N/A")
df = df.dropna()
df.dtypes

####Splitting Test and Training Data

In [0]:
trainDF, testDF = df.randomSplit([0.8, 0.2], seed=100)
print(trainDF.cache().count()) # Cache because accessing training data multiple times
print(testDF.count())

####Feature Engineering for Categorical predictors along with the output feature

In [0]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

categoricalCols = ['country','primary_fuel','other_fuel1','other_fuel2','other_fuel3']

# The following two lines are estimators. They return functions that we will later apply to transform the dataset.
stringIndexer = StringIndexer(inputCols=categoricalCols, outputCols=[x + "Index" for x in categoricalCols]) 
stringIndexer.setHandleInvalid("keep")
encoder = OneHotEncoder(inputCols=stringIndexer.getOutputCols(), outputCols=[x + "OHE" for x in categoricalCols]) 
# Convert it to a numeric value using StringIndexer.
labelToIndex = StringIndexer(inputCol="primary_fuel", outputCol="label")

In [0]:
stringIndexerModel = stringIndexer.fit(trainDF)
display(stringIndexerModel.transform(trainDF))

country,name,gppd_idnr,capacity_mw,LAT,LONG,primary_fuel,other_fuel1,other_fuel2,other_fuel3,commissioning_year,year_of_capacity_data,generation_gwh_2013,generation_gwh_2014,generation_gwh_2015,generation_gwh_2016,generation_gwh_2017,generation_gwh_2018,generation_gwh_2019,estimated_generation_gwh_2013,estimated_generation_gwh_2014,estimated_generation_gwh_2015,estimated_generation_gwh_2016,estimated_generation_gwh_2017,primary_fuelIndex,other_fuel1Index,other_fuel2Index,other_fuel3Index,countryIndex
AFG,Kajaki Hydroelectric Power Plant Afghanistan,GEODB0040538,33.0,32.322,65.119,Hydro,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,123.77,162.9,97.39,137.76,119.5,1.0,0.0,0.0,0.0,105.0
AFG,Kandahar DOG,WKS0070144,10.0,31.67,65.795,Solar,,,,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18.43,17.48,18.25,17.7,18.29,0.0,0.0,0.0,0.0,105.0
AFG,Kandahar JOL,WKS0071196,10.0,31.623,65.792,Solar,,,,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,18.64,17.58,19.1,17.62,18.72,0.0,0.0,0.0,0.0,105.0
AFG,Naghlu Dam Hydroelectric Power Plant Afghanistan,GEODB0040534,100.0,34.641,69.717,Hydro,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,406.16,357.22,270.99,395.38,350.8,1.0,0.0,0.0,0.0,105.0
AFG,Nangarhar (Darunta) Hydroelectric Power Plant Afghanistan,GEODB0040536,11.55,34.4847,70.3633,Hydro,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,58.77,54.42,42.71,59.72,46.12,1.0,0.0,0.0,0.0,105.0
AFG,Northwest Kabul Power Plant Afghanistan,GEODB0040540,42.0,34.5638,69.1134,Gas,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,105.0
AFG,Pul-e-Khumri Hydroelectric Power Plant Afghanistan,GEODB0040537,6.0,35.9416,68.71,Hydro,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,21.99,21.19,18.4,25.34,19.74,1.0,0.0,0.0,0.0,105.0
AFG,Sarobi Dam Hydroelectric Power Plant Afghanistan,GEODB0040535,22.0,34.5865,69.7757,Hydro,,,,0.0,2017,0.0,0.0,0.0,0.0,0.0,0.0,0.0,123.23,82.87,69.15,93.83,80.0,1.0,0.0,0.0,0.0,105.0
AGO,Biopio,WRI1023002,22.8,-12.4706,13.7319,Oil,,,,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,64.92,4.0,0.0,0.0,0.0,92.0
AGO,Cambambe,WRI1023003,180.0,-9.7523,14.4809,Hydro,,,,0.0,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,573.81,732.53,709.84,604.31,678.19,1.0,0.0,0.0,0.0,92.0


####Feature Engineering on Numeric Predictors

In [0]:
from pyspark.ml.feature import VectorAssembler

# This includes both the numeric columns and the one-hot encoded binary vector columns in our dataset.
numericCols = ['capacity_mw','LAT','LONG','commissioning_year', 'generation_gwh_2014','generation_gwh_2015','generation_gwh_2016','generation_gwh_2017','generation_gwh_2018','generation_gwh_2019','estimated_generation_gwh_2014','estimated_generation_gwh_2015','estimated_generation_gwh_2016','estimated_generation_gwh_2017']
assemblerInputs = [c + "OHE" for c in categoricalCols] + numericCols
vecAssembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")

####Defing A Logistic Regression Model

In [0]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(featuresCol="features", labelCol="label", regParam=1.0)

####Setting Up a Pipeline for All The Feature Engineering and The Model, Initial Fit and Testing The Fit

In [0]:
from pyspark.ml import Pipeline

# Define the pipeline based on the stages created in previous steps.
pipeline = Pipeline(stages=[stringIndexer, encoder, labelToIndex, vecAssembler, lr])

# Define the pipeline model.
pipelineModel = pipeline.fit(trainDF)

# Apply the pipeline model to the test dataset.
predDF = pipelineModel.transform(testDF)

In [0]:
display(predDF.select("features", "label", "prediction", "probability"))

features,label,prediction,probability
"List(0, 226, List(105, 164, 178, 191, 203, 212, 213, 214, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 66.0, 34.556, 69.4787, 203.55, 146.9, 230.18, 174.91))",1.0,1.0,"List(1, 15, List(), List(0.19800127348641994, 0.45766284703697313, 0.10101689394355508, 0.0784965320711648, 0.04733988466557552, 0.04584153993880453, 0.03212035496317223, 0.024192009887372433, 0.004732209726488129, 0.004534656567371402, 0.003411733809477237, 0.0010527146370588592, 0.0010244940808143471, 2.8058976678628327E-4, 2.922654189662053E-4))"
"List(0, 226, List(92, 164, 178, 191, 203, 212, 213, 214, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 14.6, -12.4706, 13.7319, 50.79, 49.87, 80.7, 48.53))",1.0,1.0,"List(1, 15, List(), List(0.17458514247472962, 0.4172530245874325, 0.10637670248187349, 0.09076366942317234, 0.0824289639860038, 0.0482080032356935, 0.03700237376021427, 0.02646972779589083, 0.005188433585353905, 0.005025483629031493, 0.0037684791566260085, 0.0011634872879003762, 0.0011331524794417228, 3.1033186438334513E-4, 3.2302425225311574E-4))"
"List(0, 226, List(92, 167, 178, 191, 203, 212, 213, 214, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 16.26, -12.76, 15.75, 46.3))",4.0,4.0,"List(1, 15, List(), List(0.20145879777165887, 0.18962398851333587, 0.12195558140698945, 0.10289825131077084, 0.23708306669426468, 0.055259273220050024, 0.04218011957374226, 0.030204764604456594, 0.00593266020369375, 0.005745488196043348, 0.0043079643067554544, 0.001330285557037113, 0.0012955797467424186, 3.5483187142204586E-4, 3.6934702303741547E-4))"
"List(0, 226, List(92, 166, 178, 191, 203, 212, 213, 214), List(1.0, 1.0, 1.0, 1.0, 1.0, 11.68, -15.1961, 12.1522))",3.0,3.0,"List(1, 15, List(), List(0.20004038627982226, 0.19055473294385517, 0.12141309026893676, 0.24939104479486762, 0.09370854090424742, 0.05390082268010205, 0.04199562206366177, 0.02988366297295005, 0.005857076149700006, 0.005682618874455066, 0.004259971841090747, 0.0013153291123860916, 0.0012810750202534753, 3.508438379728007E-4, 3.651822556988545E-4))"
"List(0, 226, List(116, 164, 178, 191, 203, 212, 213, 214, 215, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 600.0, 42.1033, 19.8224, 1985.0, 1618.73, 1805.63, 2434.84, 1982.72))",1.0,1.0,"List(1, 15, List(), List(0.11865910249612333, 0.5838838955897625, 0.08238799492960891, 0.0657863932388768, 0.041498371238743594, 0.045356972609050604, 0.027055243258027703, 0.021285414015706546, 0.004422991830526683, 0.004058056305563312, 0.00306106327456669, 0.0011129291294286533, 9.185748088619767E-4, 2.514804604193118E-4, 2.615168147335122E-4))"
"List(0, 226, List(116, 164, 178, 191, 203, 212, 213, 214, 215, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 250.0, 42.0137, 19.6359, 1971.0, 561.94, 614.47, 897.47, 703.64))",1.0,1.0,"List(1, 15, List(), List(0.14413754142008411, 0.526487292958357, 0.09387006892796591, 0.07225754098830993, 0.04681910590560735, 0.04713928884318949, 0.030355019079479566, 0.02371709469501733, 0.004670536800911528, 0.0044326713877708, 0.003342667018724681, 0.0012121716744267711, 0.001000588997405717, 2.7372695441207657E-4, 2.846843483375606E-4))"
"List(0, 226, List(19, 164, 178, 191, 203, 212, 213, 214, 215, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 1050.0, -40.58, -70.7489, 1987.0, 3621.29, 3575.78, 3416.89, 2896.84))",1.0,1.0,"List(1, 15, List(), List(0.09783309381298033, 0.5589424097271987, 0.0782401011730874, 0.0833554635455221, 0.06817043926793888, 0.04762203815317508, 0.029584791648668105, 0.021426968937666154, 0.0048032060918002184, 0.004287854186690525, 0.0032156067538519294, 0.001005477713310006, 9.704530215453612E-4, 2.659254064271217E-4, 2.7617056013787435E-4))"
"List(0, 226, List(19, 167, 178, 191, 203, 212, 213, 214, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 25.0, -34.8386, -58.4033, 113.26))",4.0,4.0,"List(1, 15, List(), List(0.20111322288625957, 0.17581067548902385, 0.12385393956194893, 0.10916449571662018, 0.24256371410976463, 0.054095295198729174, 0.043502818974848464, 0.03038246541457843, 0.005993903582157127, 0.005789479268699182, 0.004342717862412869, 0.0013532761459631257, 0.0013050764832978306, 3.572576756833537E-4, 3.71661630013078E-4))"
"List(0, 226, List(19, 167, 188, 191, 203, 212, 213, 214, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.904, -39.2145, -70.9157, 8.62))",4.0,4.0,"List(1, 15, List(), List(0.13399032612766548, 0.12573433520156785, 0.0921123158877669, 0.11670881596105133, 0.39674889743661707, 0.04967414643468256, 0.039176161890679066, 0.027881098985215576, 0.00568209348304086, 0.005238043089574781, 0.003955177077813426, 0.0012381011208657134, 0.001193566978660669, 3.2764180417132566E-4, 3.392785206273709E-4))"
"List(0, 226, List(19, 164, 178, 191, 203, 212, 213, 214, 215, 222, 223, 224, 225), List(1.0, 1.0, 1.0, 1.0, 1.0, 50.7, -33.0452, -69.0516, 2003.0, 150.75, 188.01, 190.95, 163.07))",1.0,1.0,"List(1, 15, List(), List(0.15883875015450197, 0.40679915269714345, 0.10716347305263589, 0.10316572880460717, 0.089087688953516, 0.05141416819794953, 0.03839107149067796, 0.02757262564550994, 0.005382833179028382, 0.00522306157028618, 0.00391703790927497, 0.001216010090516172, 0.0011738866541185292, 3.210096135993691E-4, 3.3350198663430766E-4))"


####Initial Fit Evaluation

In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

bcEvaluator = BinaryClassificationEvaluator(metricName="areaUnderROC")
print(f"Area under ROC curve: {bcEvaluator.evaluate(predDF)}")

mcEvaluator = MulticlassClassificationEvaluator(metricName="accuracy")
print(f"Accuracy: {mcEvaluator.evaluate(predDF)}")

####HyperParameter Tuning

In [0]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01, 0.25, 1.0])
             .addGrid(lr.elasticNetParam, [0.05, 0.5, 0.5])
             .build())

In [0]:
# Create a 3-fold CrossValidator
cv = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=bcEvaluator, numFolds=3, parallelism = 4)

# Run cross validations. This step takes a few minutes and returns the best model found from the cross validation.
cvModel = cv.fit(trainDF)

####Making Predictions

In [0]:
# Use the model identified by the cross-validation to make predictions on the test dataset
cvPredDF = cvModel.transform(testDF)

# Evaluate the model's performance based on area under the ROC curve and accuracy 
print(f"Area under ROC curve: {bcEvaluator.evaluate(cvPredDF)}")
print(f"Accuracy: {mcEvaluator.evaluate(cvPredDF)}")

In [0]:
cvPredDF.createOrReplaceTempView("finalPreds")

In [0]:
%sql
SELECT primary_fuel, prediction, count(*) AS count
FROM finalPreds
GROUP BY primary_fuel, prediction
ORDER BY primary_fuel

primary_fuel,prediction,count
Biomass,0.0,284
Coal,0.0,434
Coal,5.0,15
Cogeneration,0.0,8
Gas,0.0,535
Gas,3.0,231
Geothermal,0.0,43
Hydro,5.0,1
Hydro,1.0,1452
Nuclear,0.0,31


Output can only be rendered in Databricks