In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, isnull, when, count
from pyspark.ml.feature import VectorAssembler, StandardScaler, StringIndexer
from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 创建Spark会话
spark = SparkSession.builder.appName("ForestFirePrediction").getOrCreate()

# 读取CSV文件
df1 = spark.read.csv('weather.csv', header=True, inferSchema=True)
df2 = spark.read.csv('fires.csv', header=True, inferSchema=True)

# 打印数据类型和缺失值统计
df1.printSchema()
df2.printSchema()

# 合并数据集
df3 = df1.join(df2, on=['X', 'Y', 'num', 'country'], how='outer')

# 按国家统计缺失行数
missing_ratio_per_country = df3.select('country').withColumn('missing_count', count(when(isnull(col('country')), 1))).groupBy('country').count().orderBy('count', ascending=False)
missing_ratio_per_country.show()

# 将国家和月份的字符串值转换为数字
df3 = df3.replace(['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'], list(range(1, 13)), 'month')
df3 = df3.replace(['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun'], list(range(1, 8)), 'day')
df3 = df3.replace(['Portugal', 'Brazil'], [1, 2], 'country')

# 填补缺失值
df3 = df3.fillna(df3.agg(*[when(isnull(col(c)), col(c).median()).otherwise(col(c)).alias(c) for c in df3.columns]))

# 删除巴西的数据
df3 = df3.filter(df3['country'] != 2)

# 删除不需要的列
df5 = df3.drop('num', 'country', 'passenger')

# 转换为pandas数据框以进行可视化
df5_pd = df5.toPandas()

# 绘制缺失值柱状图
missing_values = df5_pd.isnull().sum()
plt.figure(figsize=(10, 6))
missing_values.plot(kind='bar')
plt.title('Missing Values in Each Column')
plt.xlabel('Columns')
plt.ylabel('Number of Missing Values')
plt.xticks(rotation=45)
plt.show()

# 绘制密度图
sns.kdeplot(data=df5_pd['area'], fill=True)
plt.show()

# 绘制月度记录数
monthly_counts = df5_pd.groupby('month')['area'].count().reset_index()
plt.figure(figsize=(10, 6))
plt.bar(monthly_counts['month'], monthly_counts['area'], color='skyblue')
plt.xlabel('Month')
plt.ylabel('Number of Records')
plt.title('Relationship Between the Number of Area Records and Month')
plt.show()

# 数据准备
feature_columns = [c for c in df5.columns if c not in ['area']]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
df5 = assembler.transform(df5)

# 数据标准化
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")
scalerModel = scaler.fit(df5)
df5 = scalerModel.transform(df5)

# 划分数据集
(train_data, test_data) = df5.randomSplit([0.8, 0.2], seed=42)

# 定义模型
lr = LinearRegression(featuresCol='scaledFeatures', labelCol='area')

# 网格搜索和交叉验证
paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.1, 0.01]).build()
crossval = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=RegressionEvaluator(labelCol="area"), numFolds=5)
cvModel = crossval.fit(train_data)

# 预测和评估
predictions = cvModel.transform(test_data)
evaluator = RegressionEvaluator(labelCol="area", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) on test data = {rmse}")

# 随机森林模型
rf = RandomForestRegressor(featuresCol='scaledFeatures', labelCol='area')
rfModel = rf.fit(train_data)
rf_predictions = rfModel.transform(test_data)
rf_rmse = evaluator.evaluate(rf_predictions)
print(f"Random Forest RMSE on test data = {rf_rmse}")

spark.stop()


Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/23 19:15:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

root
 |-- num: integer (nullable = true)
 |-- X: integer (nullable = true)
 |-- Y: integer (nullable = true)
 |-- DC: double (nullable = true)
 |-- ISI: double (nullable = true)
 |-- country: string (nullable = true)

root
 |-- num: integer (nullable = true)
 |-- X: integer (nullable = true)
 |-- Y: integer (nullable = true)
 |-- month: string (nullable = true)
 |-- day: string (nullable = true)
 |-- FFMC: double (nullable = true)
 |-- DMC: double (nullable = true)
 |-- temp: double (nullable = true)
 |-- RH: integer (nullable = true)
 |-- wind: double (nullable = true)
 |-- rain: double (nullable = true)
 |-- area: double (nullable = true)
 |-- passenger: integer (nullable = true)
 |-- country: string (nullable = true)

