In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import avg, year
from pyspark.sql.functions import desc
import matplotlib.pyplot as plt
spark = SparkSession.builder.appName("GovernmentPolicyAnalysis").getOrCreate()
df = spark.read.csv("/covid.csv", header=True, inferSchema=True)
df = df.dropDuplicates().na.drop()
df = df.withColumn("year", year(df.date))
avg_cases_deaths_df = df.groupBy("location", "year").agg(avg("total_cases").alias("avg_total_cases"), avg("total_deaths").alias("avg_total_deaths"), avg("stringency_index").alias("avg_stringency_index"))
df_with_avg_cases_deaths = df.join(avg_cases_deaths_df, ["location", "year"])
selected_columns = ["location", "year", "avg_total_cases", "avg_total_deaths", "avg_stringency_index"]
assembler = VectorAssembler(inputCols=selected_columns[2:], outputCol="features")
df_transformed = assembler.transform(df_with_avg_cases_deaths)
(trainingData, testData) = df_transformed.randomSplit([0.8, 0.2])
lr = LinearRegression(featuresCol="features", labelCol="avg_stringency_index")
model = lr.fit(trainingData)
predictions = model.transform(testData)
evaluator = RegressionEvaluator(labelCol="avg_stringency_index", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Square Error (RMSE):", rmse)
top_10_countries = df_with_avg_cases_deaths.select("location", "avg_stringency_index", "avg_total_cases", "avg_total_deaths") \
    .dropDuplicates(["location"]) \
    .orderBy("avg_total_cases", ascending=False) \
    .limit(10)

top_10_countries_pd = top_10_countries.toPandas()
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(top_10_countries_pd["location"], top_10_countries_pd["avg_stringency_index"], label="Average Stringency Index")
ax.bar(top_10_countries_pd["location"], top_10_countries_pd["avg_total_cases"], label="Average Total Cases")
ax.bar(top_10_countries_pd["location"], top_10_countries_pd["avg_total_deaths"], label="Average Total Deaths")
ax.set_xlabel("Country")
ax.set_ylabel("Average Values")
ax.set_title("Top 10 Unique Countries with Highest Averages")
ax.legend()
plt.xticks(rotation=45)
plt.show()