/
impute_normal.py
142 lines (112 loc) · 4.82 KB
/
impute_normal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 27 10:25:10 2018
@author: JohnBauer
"""
import os
os.environ['JAVA_HOME'] = "/Library/Java/JavaVirtualMachines/jdk1.8.0_191.jdk/Contents/Home"
os.environ['SPARK_HOME'] = "/Users/john.h.bauer/spark"
os.environ['PYTHONPATH'] = "$SPARK_HOME/python:$SPARK_HOME/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, randn
from pyspark.sql.functions import avg, stddev_samp
from pyspark import keyword_only
from pyspark.ml import Estimator, Model
# from pyspark.ml.feature import SQLTransformer
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
spark = SparkSession \
.builder \
.appName("ImputeNormal") \
.getOrCreate()
class ImputeNormal(Estimator,
HasInputCol,
HasOutputCol,
DefaultParamsReadable,
DefaultParamsWritable,
):
@keyword_only
def __init__(self, inputCol="inputCol", outputCol="outputCol"):
super(ImputeNormal, self).__init__()
self._setDefault(inputCol="inputCol", outputCol="outputCol")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCol="inputCol", outputCol="outputCol"):
"""
setParams(self, inputCol="inputCol", outputCol="outputCol")
"""
kwargs = self._input_kwargs
self._set(**kwargs)
return self
def _fit(self, data):
inputCol = self.getInputCol()
outputCol = self.getOutputCol()
mean, stddev = data.agg(avg(inputCol), stddev_samp(inputCol)).first()
return ImputeNormalModel(mean=float(mean),
stddev=float(stddev),
inputCol=inputCol,
outputCol=outputCol,
)
# FOR A TRULY MINIMAL BUT LESS DIDACTICALLY EFFECTIVE DEMO, DO INSTEAD:
# sql_text = "SELECT *, IF({inputCol} IS NULL, {stddev} * randn() + {mean}, {inputCol}) AS {outputCol} FROM __THIS__"
#
# return SQLTransformer(statement=sql_text.format(stddev=stddev, mean=mean, inputCol=inputCol, outputCol=outputCol))
class ImputeNormalModel(Model,
HasInputCol,
HasOutputCol,
DefaultParamsReadable,
DefaultParamsWritable,
):
mean = Param(Params._dummy(), "mean", "Mean value of imputations. Calculated by fit method.",
typeConverter=TypeConverters.toFloat)
stddev = Param(Params._dummy(), "stddev", "Standard deviation of imputations. Calculated by fit method.",
typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol"):
super(ImputeNormalModel, self).__init__()
self._setDefault(mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol"):
"""
setParams(self, mean=0.0, stddev=1.0, inputCol="inputCol", outputCol="outputCol")
"""
kwargs = self._input_kwargs
self._set(**kwargs)
return self
def getMean(self):
return self.getOrDefault(self.mean)
def setMean(self, mean):
self._set(mean=mean)
def getStddev(self):
return self.getOrDefault(self.stddev)
def setStddev(self, stddev):
self._set(stddev=stddev)
def _transform(self, data):
mean = self.getMean()
stddev = self.getStddev()
inputCol = self.getInputCol()
outputCol = self.getOutputCol()
df = data.withColumn(outputCol,
when(col(inputCol).isNull(),
stddev * randn() + mean). \
otherwise(col(inputCol)))
return df
if __name__ == "__main__":
train = spark.createDataFrame([[0], [1], [2]] + [[None]] * 100, ['input'])
impute = ImputeNormal(inputCol='input', outputCol='output')
impute_model = impute.fit(train)
print("Input column: {}".format(impute_model.getInputCol()))
print("Output column: {}".format(impute_model.getOutputCol()))
print("Mean: {}".format(impute_model.getMean()))
print("Standard Deviation: {}".format(impute_model.getStddev()))
test = impute_model.transform(train)
test.show(10)
test.describe().show()
print("mean and stddev for outputCol should be close to those of inputCol")