In [59]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

# Initialize Spark session
spark = SparkSession.builder.appName("COVID-19 Linear Regression").getOrCreate()

23/09/09 10:30:54 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [89]:
# Load COVID-19 data from Cloud Storage
covid_data = spark.read.csv("gs://6410381/covidData/covid19.csv", header=True, inferSchema=True)

In [90]:
covid_data

DataFrame[date: timestamp, new_confirmed: int]

In [91]:
covid_data.take(5)

[Row(date=datetime.datetime(2020, 1, 1, 0, 0), new_confirmed=0),
 Row(date=datetime.datetime(2020, 1, 2, 0, 0), new_confirmed=0),
 Row(date=datetime.datetime(2020, 1, 3, 0, 0), new_confirmed=0),
 Row(date=datetime.datetime(2020, 1, 4, 0, 0), new_confirmed=0),
 Row(date=datetime.datetime(2020, 1, 5, 0, 0), new_confirmed=0)]

In [92]:
covid_data.printSchema()

root
 |-- date: timestamp (nullable = true)
 |-- new_confirmed: integer (nullable = true)



In [93]:
covid_data.count()

991

In [94]:
covid_data.dropna()

DataFrame[date: timestamp, new_confirmed: int]

In [95]:
covid_data.count()

991

In [96]:
features_df = covid_data.select(['date', 'new_confirmed']).na.drop()

In [97]:
training_df, test_df = features_df.randomSplit([0.8, 0.2], seed=12)


In [98]:
featureAssembler = VectorAssembler(inputCols=['new_confirmed'], outputCol='features')


In [99]:
lr = LinearRegression(labelCol='new_confirmed', featuresCol='features')


In [100]:
from pyspark.ml import Pipeline

pipeline_lr = Pipeline(stages=[featureAssembler, lr])


In [101]:
lrModel = pipeline_lr.fit(training_df)


23/09/09 10:59:51 WARN Instrumentation: [822038ed] regParam is zero, which might cause numerical instability and overfitting.


In [102]:
predictions = lrModel.transform(test_df)
predictions.select(['new_confirmed', 'prediction']).show()


+-------------+------------------+
|new_confirmed|        prediction|
+-------------+------------------+
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            1|0.9999999999999998|
|            1|0.9999999999999998|
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            0|               0.0|
|            1|0.9999999999999998|
|            6| 5.999999999999998|
|           12|11.999999999999996|
|           75| 74.99999999999999|
|          152|151.99999999999997|
|          181|180.99999999999997|
|          116|115.99999999999997|
|          157|156.99999999999997|
+-------------+------------------+
only showing top 20 rows

