In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.stat import Correlation
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.feature import VectorAssembler, StandardScaler, OneHotEncoder, StringIndexer
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import pyspark.sql.types as Types
from pyspark.sql.functions import col

In [2]:
spark = SparkSession.builder.appName('test').getOrCreate()


In [3]:
spark


In [4]:
country_df = spark.read.csv("./Datasets/covid_19_india.csv", header=True, inferSchema=True)

In [5]:
country_df.show()

+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+
|Sno|      Date|   Time|State/UnionTerritory|ConfirmedIndianNational|ConfirmedForeignNational|Cured|Deaths|Confirmed|
+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+
|  1|2020-01-30|6:00 PM|              Kerala|                      1|                       0|    0|     0|        1|
|  2|2020-01-31|6:00 PM|              Kerala|                      1|                       0|    0|     0|        1|
|  3|2020-02-01|6:00 PM|              Kerala|                      2|                       0|    0|     0|        2|
|  4|2020-02-02|6:00 PM|              Kerala|                      3|                       0|    0|     0|        3|
|  5|2020-02-03|6:00 PM|              Kerala|                      3|                       0|    0|     0|        3|
|  6|2020-02-04|6:00 PM|              Kerala|           

In [6]:
impCols = ['Sno', 'Date', 'Time', 'State/UnionTerritory', 'ConfirmedIndianNational', 'ConfirmedForeignNational', 'Cured', 'Deaths', 

        'Confirmed']

 

cleanedData = country_df.select(impCols)

cleanedData = cleanedData.dropna()


In [150]:
fig = px.scatter(cleanedData.toPandas(), x='Deaths', y='Confirmed')
fig.show()

In [116]:


filteredData = cleanedData.filter((col("Cured") > 0) & (col("Confirmed") > 0))


In [118]:
assembler = VectorAssembler(inputCols=['Deaths'], outputCol='features')

In [119]:
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=False)

In [120]:
pipeline = Pipeline(stages=[assembler, scaler]).fit(filteredData)

In [121]:
data = pipeline.transform(filteredData)

In [122]:
data.show(5)

+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+--------+---------------+
|Sno|      Date|   Time|State/UnionTerritory|ConfirmedIndianNational|ConfirmedForeignNational|Cured|Deaths|Confirmed|features|scaled_features|
+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+--------+---------------+
| 38|2020-03-03|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|   [0.0]|          [0.0]|
| 41|2020-03-04|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|   [0.0]|          [0.0]|
| 48|2020-03-05|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|   [0.0]|          [0.0]|
| 54|2020-03-06|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|   [0.0]|          [0.0]|

In [139]:
finalDf = data.select("scaled_features", 'Cured')

finalDf.show(5)

+---------------+-----+
|scaled_features|Cured|
+---------------+-----+
|          [0.0]|    3|
|          [0.0]|    3|
|          [0.0]|    3|
|          [0.0]|    3|
|          [0.0]|    3|
+---------------+-----+
only showing top 5 rows



In [140]:
train_data, test_data = finalDf.randomSplit([0.7, 0.3], seed=123)



print("Train Shape : ", (train_data.count(), len(train_data.columns)))

print("Test Shape : ", (test_data.count(), len(test_data.columns)))

Train Shape :  (9712, 2)
Test Shape :  (4251, 2)


In [143]:
lr = LinearRegression(featuresCol='scaled_features', labelCol='Cured')

In [144]:
lrModel = lr.fit(train_data)

lrModel


LinearRegressionModel: uid=LinearRegression_97ef90a3138f, numFeatures=1

In [145]:
print("Coefficients: {}".format(lrModel.coefficients))

print('Intercept: {}'.format(lrModel.intercept))

Coefficients: [306959.18670548004]
Intercept: 53552.11268814639


In [146]:

preds = lrModel.evaluate(test_data)

preds.residuals.show(5)

+------------------+
|         residuals|
+------------------+
|-53551.11268814639|
|-53551.11268814639|
|-53551.11268814639|
|-53551.11268814639|
|-53551.11268814639|
+------------------+
only showing top 5 rows



In [147]:
# show regression metrics

print(f"R2-Score               :    {preds.r2}")

print(f"Mean Squared Error     :    {preds.meanSquaredError}")

print(f"Root MSE               :    {preds.rootMeanSquaredError}")

print(f"Mean Absolute Error    :    {preds.meanAbsoluteError}")

R2-Score               :    0.8305907512121876
Mean Squared Error     :    21221719528.282024
Root MSE               :    145676.7638584892
Mean Absolute Error    :    90268.21269870624


In [148]:
# plot the regression line using plotly express

fig = px.scatter(filteredData.toPandas(), x='Confirmed', y='Deaths', trendline="ols")

In [29]:
fig.show()


In [149]:
##random forrest

filteredData.show(5)

+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+
|Sno|      Date|   Time|State/UnionTerritory|ConfirmedIndianNational|ConfirmedForeignNational|Cured|Deaths|Confirmed|
+---+----------+-------+--------------------+-----------------------+------------------------+-----+------+---------+
| 38|2020-03-03|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|
| 41|2020-03-04|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|
| 48|2020-03-05|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|
| 54|2020-03-06|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|
| 58|2020-03-07|6:00 PM|              Kerala|                      3|                       0|    3|     0|        3|
+---+----------+-------+--------------------+-----------