In [1]:
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark=SparkSession.builder.master("local[*]").appName('anxiety').getOrCreate()
sc=spark.sparkContext

In [2]:
df = spark.read.csv("anxiety.csv", header = True, inferSchema = True, sep = ",")

In [3]:
# df1: 2020年及以后outcome为重症的病例
df1 = df.filter((df["YEAR"] >= 2020) & ((df["OUTC_COD"] == "DE") | (df["OUTC_COD"] == "HO") | (df["OUTC_COD"] == "RI"))).select("*")
# df2: str -> index
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
indexers = [StringIndexer(inputCol = column, outputCol = column+"_index")
            .fit(df1) for column in ['GNDR_COD', 'COVID', 'PT', 'DRUGNAME', 'OUTC_COD']]
pipeline = Pipeline(stages = indexers)
df2 = pipeline.fit(df1).transform(df1)
df2 = df2.drop('GNDR_COD', 'COVID', 'OUTC_COD', 'PT', 'DRUGNAME', 'INDI_PT')
df2.printSchema()

root
 |-- ISR: long (nullable = true)
 |-- CASE: double (nullable = true)
 |-- AGE: double (nullable = true)
 |-- WT: double (nullable = true)
 |-- YEAR: integer (nullable = true)
 |-- GNDR_COD_index: double (nullable = false)
 |-- COVID_index: double (nullable = false)
 |-- PT_index: double (nullable = false)
 |-- DRUGNAME_index: double (nullable = false)
 |-- OUTC_COD_index: double (nullable = false)



In [4]:
# feature_vector1: 看所有可能的自变量对因变量（重症outc_cod）的影响
feature1 = VectorAssembler(inputCols = df2.columns[2:9], outputCol = "features")
feature_vector1 = feature1.transform(df2)

In [5]:
train_data_1, test_data_1 = feature_vector1.randomSplit([0.5,0.5], seed = 27)

In [6]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
outc_glr_1 = LogisticRegression(labelCol="OUTC_COD_index", featuresCol="features", family="multinomial")
trained_outc_glr_Model_1 = outc_glr_1.fit(train_data_1)
outc_glr_prediction_1 = trained_outc_glr_Model_1.transform(test_data_1)
outc_glr_prediction_1.select("prediction", "OUTC_COD_index", "features").show(10)
evaluator = MulticlassClassificationEvaluator(labelCol="OUTC_COD_index", predictionCol="prediction", metricName="accuracy")
accuracy_1 = evaluator.evaluate(outc_glr_prediction_1)
print("Test Accuracy = %g" % accuracy_1)
print("Test Error = %g " % (1.0 - accuracy_1))

+----------+--------------+--------------------+
|prediction|OUTC_COD_index|            features|
+----------+--------------+--------------------+
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
+----------+--------------+--------------------+
only showing top 10 rows

Test Accuracy = 0.863749
Test Error = 0.136251 


In [7]:
outc_glr_prediction_1.select("prediction").distinct().show()

+----------+
|prediction|
+----------+
|       0.0|
+----------+



In [9]:
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(labelCol="OUTC_COD_index", featuresCol="features", maxBins=4538)
rf_model = rf.fit(train_data_1)
rf_prediction = rf_model.transform(test_data_1)
rf_prediction.select("prediction", "OUTC_COD_index", "features").show(10)
rf_accuracy = evaluator.evaluate(rf_prediction)
print("Accuracy of RandomForestClassifier is = %g"% (rf_accuracy))
print("Test Error of RandomForestClassifier  = %g " % (1.0 - rf_accuracy))

+----------+--------------+--------------------+
|prediction|OUTC_COD_index|            features|
+----------+--------------+--------------------+
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[47.8246575342465...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
|       0.0|           0.0|[50.0,54.0,2020.0...|
+----------+--------------+--------------------+
only showing top 10 rows

Accuracy of RandomForestClassifier is = 0.867177
Test Error of RandomForestClassifier  = 0.132823 


In [10]:
rf_prediction.select("prediction").distinct().show()

+----------+
|prediction|
+----------+
|       0.0|
|       1.0|
+----------+



In [11]:
rf_prediction.filter("prediction == 1.0").select("*").toPandas()

Unnamed: 0,ISR,CASE,AGE,WT,YEAR,GNDR_COD_index,COVID_index,PT_index,DRUGNAME_index,OUTC_COD_index,features,rawPrediction,probability,prediction
0,148320975,14832097.0,54.236592,72.641947,2020,0.0,0.0,4343.0,265.0,1.0,"[54.236591895403784, 72.6419470689881, 2020.0,...","[8.412909752954993, 11.439928046186399, 0.1471...","[0.42064548764774967, 0.5719964023093199, 0.00...",1.0
1,165986994,16598699.0,54.236592,72.641947,2020,0.0,0.0,1003.0,2.0,1.0,"[54.236591895403784, 72.6419470689881, 2020.0,...","[9.62070296198814, 10.28460846275694, 0.094688...","[0.4810351480994071, 0.5142304231378472, 0.004...",1.0
2,167695694,16769569.0,54.236592,72.641947,2020,0.0,0.0,87.0,2.0,1.0,"[54.236591895403784, 72.6419470689881, 2020.0,...","[9.071982214276357, 10.823144279883621, 0.1048...","[0.45359911071381787, 0.541157213994181, 0.005...",1.0
3,171721592,17172159.0,54.236592,54.420000,2020,1.0,0.0,34.0,23.0,1.0,"[54.236591895403784, 54.42, 2020.0, 1.0, 0.0, ...","[9.296412609938931, 10.573802169190042, 0.1297...","[0.46482063049694655, 0.5286901084595022, 0.00...",1.0
4,171721592,17172159.0,54.236592,54.420000,2020,1.0,0.0,3365.0,23.0,1.0,"[54.236591895403784, 54.42, 2020.0, 1.0, 0.0, ...","[9.718282100270862, 10.143366695602584, 0.1383...","[0.48591410501354304, 0.5071683347801291, 0.00...",1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
259,2061330224,20613302.0,54.236592,97.000000,2023,0.0,0.0,415.0,0.0,1.0,"[54.236591895403784, 97.0, 2023.0, 0.0, 0.0, 4...","[8.061915318280105, 11.877555955449854, 0.0605...","[0.40309576591400526, 0.5938777977724927, 0.00...",1.0
260,2061330224,20613302.0,54.236592,97.000000,2023,0.0,0.0,466.0,0.0,1.0,"[54.236591895403784, 97.0, 2023.0, 0.0, 0.0, 4...","[7.858475503029059, 12.087089061507339, 0.0544...","[0.39292377515145294, 0.604354453075367, 0.002...",1.0
261,2061330224,20613302.0,54.236592,97.000000,2023,0.0,0.0,510.0,0.0,1.0,"[54.236591895403784, 97.0, 2023.0, 0.0, 0.0, 5...","[7.487519643148482, 12.441870692254913, 0.0706...","[0.37437598215742407, 0.6220935346127456, 0.00...",1.0
262,2061330224,20613302.0,54.236592,97.000000,2023,0.0,0.0,544.0,0.0,1.0,"[54.236591895403784, 97.0, 2023.0, 0.0, 0.0, 5...","[7.797490335037921, 12.148839873045095, 0.0536...","[0.389874516751896, 0.6074419936522546, 0.0026...",1.0
