In [1]:
#!apt-get -y install openjdk-8-jre-headless
#!pip install pyspark

## 建立Spark物件

In [2]:
#from google.colab import drive
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, rand, split, create_map
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator 
from itertools import chain
import seaborn as sns
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix
#drive.mount('/content/gdrive')

In [3]:
spark = SparkSession.builder.appName("final").getOrCreate()

23/05/29 14:20:16 WARN Utils: Your hostname, Shihs-PC.local resolves to a loopback address: 127.0.0.1; using 172.20.10.8 instead (on interface en0)
23/05/29 14:20:16 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/05/29 14:20:17 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## 讀檔與合併檔案
新增 holiday欄位 0表示該日要上班,2表示該日放假

In [4]:
traffic_data = spark.read.parquet("data_parquet/*/*",header = True)

calendar = spark.read.csv("calendars", header=True, encoding='utf8')

calendar = calendar.drop(calendar.columns[1],calendar.columns[3]) \
              .withColumnRenamed(calendar.columns[0],"Date") \
              .withColumnRenamed(calendar.columns[2],"holiday")

data = traffic_data.join(calendar, traffic_data.發生日期 == calendar.Date, "inner") \
      .drop(calendar.Date)

#data.printSchema()

                                                                                

In [5]:
traffic_data.groupBy("事故類別名稱").count().show()



+------------+-------+
|事故類別名稱|  count|
+------------+-------+
|          A2|3024862|
|          A1|  15026|
+------------+-------+



                                                                                

## Data prepocessing

In [6]:
data = data.drop("死亡受傷人數","死亡人數","受傷人數")

### 定義數值特徵



In [7]:
num_list = ["發生日期","發生時間","速限-第1當事者","當事者事故發生時年齡"]
# num_list = ["發生日期","發生時間","速限-第1當事者","當事者事故發生時年齡","死亡人數","受傷人數"]

In [8]:
for i in data.columns:
  if i in num_list:
    data = data.withColumn(i,col(i).cast("integer"))

In [9]:
#data.printSchema()

### 僅保留當事者順位為一的資料

In [10]:
data = data.filter(data.當事者順位 == "1")
data = data.drop("當事者順位")

In [11]:
data = data.drop("發生地點","道路型態大類別名稱","事故位置大類別名稱",
          "道路障礙-視距名稱","車道劃分設施-分向設施子類別名稱","當事者行動狀態大類別名稱",
          "車輛撞擊部位大類別名稱-其他","車輛撞擊部位子類別名稱-其他","肇因研判大類別名稱-個別",
          "肇因研判子類別名稱-個別","經度","緯度","處理單位名稱警局層")


In [12]:
data.show(5)

23/05/29 14:20:31 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+--------+--------+------------+--------+------------+-----------------------+--------------+------------------+------------------------+---------------------+---------------------+---------------------+-------------------+---------------------+------------------------------+-----------------+-------------------------------+--------------------------------------------+------------------------------------+----------------------------------+------------------------+------------------------+-----------------------+----------------------------+-------------------------------+-------------------------------+------+----+--------------------+----------------------------------+------------------------------------+------------------------+---------------------------+---------------------------+-------------------------+------------+-------+
|發生日期|發生時間|事故類別名稱|天候名稱|    光線名稱|道路類別-第1當事者-名稱|速限-第1當事者|道路型態子類別名稱|      事故位置子類別名稱|路面狀況-路面鋪裝名稱|路面狀況-路面狀態名稱|路面狀況-路面缺陷名稱|道路障礙-障礙物名稱|道路障礙-視距品質名稱|             號誌-

### 將boolean轉成字串

In [13]:
data = data.withColumn("無或物",col("無或物").cast("string"))
data = data.withColumn("holiday",col("holiday").cast("string"))


### 將類別特徵轉換成index

In [14]:
cat_list = [i for i in data.columns if i not in num_list]
for i in cat_list:
  indexer = StringIndexer(inputCol=i, outputCol=f"{i}_numeric").setHandleInvalid("keep").fit(data)
  data = indexer.transform(data).drop(i)
cat_list = list(map(lambda x: x+"_numeric",cat_list))

                                                                                

In [15]:
data.printSchema()

root
 |-- 發生日期: integer (nullable = true)
 |-- 發生時間: integer (nullable = true)
 |-- 速限-第1當事者: integer (nullable = true)
 |-- 當事者事故發生時年齡: integer (nullable = true)
 |-- 事故類別名稱_numeric: double (nullable = false)
 |-- 天候名稱_numeric: double (nullable = false)
 |-- 光線名稱_numeric: double (nullable = false)
 |-- 道路類別-第1當事者-名稱_numeric: double (nullable = false)
 |-- 道路型態子類別名稱_numeric: double (nullable = false)
 |-- 事故位置子類別名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面鋪裝名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面狀態名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面缺陷名稱_numeric: double (nullable = false)
 |-- 道路障礙-障礙物名稱_numeric: double (nullable = false)
 |-- 道路障礙-視距品質名稱_numeric: double (nullable = false)
 |-- 號誌-號誌種類名稱_numeric: double (nullable = false)
 |-- 號誌-號誌動作名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分向設施大類別名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分道設施-快車道或一般車道間名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分道設施-快慢車道間名稱_numeric: double (nullable = false)
 |-- 車

### Imputation Missing Value


In [16]:
data = data.fillna(value=0)

### 將特徵轉換成vector

In [17]:
features = [i for i in data.columns if i != "事故類別名稱_numeric"]
vector = VectorAssembler(inputCols=features, outputCol='features')
transformed_data = vector.transform(data)

In [18]:
print(features)

['發生日期', '發生時間', '速限-第1當事者', '當事者事故發生時年齡', '天候名稱_numeric', '光線名稱_numeric', '道路類別-第1當事者-名稱_numeric', '道路型態子類別名稱_numeric', '事故位置子類別名稱_numeric', '路面狀況-路面鋪裝名稱_numeric', '路面狀況-路面狀態名稱_numeric', '路面狀況-路面缺陷名稱_numeric', '道路障礙-障礙物名稱_numeric', '道路障礙-視距品質名稱_numeric', '號誌-號誌種類名稱_numeric', '號誌-號誌動作名稱_numeric', '車道劃分設施-分向設施大類別名稱_numeric', '車道劃分設施-分道設施-快車道或一般車道間名稱_numeric', '車道劃分設施-分道設施-快慢車道間名稱_numeric', '車道劃分設施-分道設施-路面邊線名稱_numeric', '事故類型及型態大類別名稱_numeric', '事故類型及型態子類別名稱_numeric', '肇因研判大類別名稱-主要_numeric', '肇因研判子類別名稱-主要_numeric', '當事者區分-類別-大類別名稱-車種_numeric', '當事者區分-類別-子類別名稱-車種_numeric', '無或物_numeric', '性別_numeric', '保護裝備名稱_numeric', '行動電話或電腦或其他相類功能裝置名稱_numeric', '當事者行動狀態子類別名稱_numeric', '車輛撞擊部位大類別名稱-最初_numeric', '車輛撞擊部位子類別名稱-最初_numeric', '肇事逃逸類別名稱-是否肇逃_numeric', 'site_id_numeric', 'holiday_numeric']


In [19]:
transformed_data.printSchema()

root
 |-- 發生日期: integer (nullable = true)
 |-- 發生時間: integer (nullable = true)
 |-- 速限-第1當事者: integer (nullable = true)
 |-- 當事者事故發生時年齡: integer (nullable = true)
 |-- 事故類別名稱_numeric: double (nullable = false)
 |-- 天候名稱_numeric: double (nullable = false)
 |-- 光線名稱_numeric: double (nullable = false)
 |-- 道路類別-第1當事者-名稱_numeric: double (nullable = false)
 |-- 道路型態子類別名稱_numeric: double (nullable = false)
 |-- 事故位置子類別名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面鋪裝名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面狀態名稱_numeric: double (nullable = false)
 |-- 路面狀況-路面缺陷名稱_numeric: double (nullable = false)
 |-- 道路障礙-障礙物名稱_numeric: double (nullable = false)
 |-- 道路障礙-視距品質名稱_numeric: double (nullable = false)
 |-- 號誌-號誌種類名稱_numeric: double (nullable = false)
 |-- 號誌-號誌動作名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分向設施大類別名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分道設施-快車道或一般車道間名稱_numeric: double (nullable = false)
 |-- 車道劃分設施-分道設施-快慢車道間名稱_numeric: double (nullable = false)
 |-- 車

In [20]:
transformed_data.groupBy("事故類別名稱_numeric").count().show()



+--------------------+-------+
|事故類別名稱_numeric|  count|
+--------------------+-------+
|                 0.0|1415429|
|                 1.0|   6991|
+--------------------+-------+



                                                                                

### 切分訓練集以及測試集

In [21]:
train,test = transformed_data.randomSplit([0.7, 0.3], 24)

### 儲存處理好的資料

In [22]:
train.write.format("parquet")\
        .option("encoding", "UTF-8")\
        .option("charset", "UTF-8")\
        .save("train")
test.write.format("parquet")\
        .option("encoding", "UTF-8")\
        .option("charset", "UTF-8")\
        .save("test")

23/05/29 14:21:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [23]:
spark.stop()

In [24]:
# https://towardsdatascience.com/machine-learning-on-a-large-scale-2eef3bb749ee