forked from zero-one-group/geni
-
Notifications
You must be signed in to change notification settings - Fork 0
/
polymorphic.clj
196 lines (174 loc) · 7.86 KB
/
polymorphic.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
(ns zero-one.geni.core.polymorphic
(:refer-clojure :exclude [alias
assoc
count
dissoc
filter
first
last
max
min
shuffle
update])
(:require
[clojure.string]
[zero-one.geni.core.column :refer [->col-array ->column]]
[zero-one.geni.core.dataset :as dataset]
[zero-one.geni.core.dataset-creation :as dataset-creation]
[zero-one.geni.core.functions :as sql]
[zero-one.geni.defaults]
[zero-one.geni.interop :as interop]
[zero-one.geni.utils :refer [->string-map arg-count ensure-coll]])
(:import
(org.apache.spark.ml.stat Correlation)
(org.apache.spark.sql Dataset
RelationalGroupedDataset
functions)))
(def default-spark zero-one.geni.defaults/spark)
(defmulti as (fn [head & _] (class head)))
(defmethod as :default [expr new-name] (.as (->column expr) (name new-name)))
(defmethod as Dataset [dataframe new-name] (.as dataframe (name new-name)))
(def alias as)
(defmulti count class)
(defmethod count :default [expr] (functions/count (->column expr)))
(defmethod count Dataset [dataset] (.count dataset))
(defmethod count RelationalGroupedDataset [grouped] (.count grouped))
(defmulti explain (fn [head & _] (class head)))
(defmethod explain :default [expr extended] (.explain (->column expr) extended))
(defmethod explain Dataset
([dataset] (.explain dataset))
([dataset extended] (.explain dataset extended)))
(defmulti mean (fn [head & _] (class head)))
(defmethod mean :default [expr & _] (functions/mean (->column expr)))
(defmethod mean RelationalGroupedDataset [grouped & col-names]
(.mean grouped (interop/->scala-seq (clojure.core/map name col-names))))
(def avg mean)
(defmulti max (fn [head & _] (class head)))
(defmethod max :default [expr] (functions/max (->column expr)))
(defmethod max RelationalGroupedDataset [grouped & col-names]
(.max grouped (interop/->scala-seq (clojure.core/map name col-names))))
(defmulti min (fn [head & _] (class head)))
(defmethod min :default [expr] (functions/min (->column expr)))
(defmethod min RelationalGroupedDataset [grouped & col-names]
(.min grouped (interop/->scala-seq (clojure.core/map name col-names))))
(defmulti sum (fn [head & _] (class head)))
(defmethod sum :default [expr] (functions/sum (->column expr)))
(defmethod sum RelationalGroupedDataset [grouped & col-names]
(.sum grouped (interop/->scala-seq (clojure.core/map name col-names))))
(defmulti coalesce (fn [head & _] (class head)))
(defmethod coalesce Dataset [dataframe n-partitions]
(.coalesce dataframe n-partitions))
(defmethod coalesce :default [& exprs]
(functions/coalesce (->col-array exprs)))
(defmulti shuffle class)
(defmethod shuffle :default [expr]
(functions/shuffle (->column expr)))
(defmethod shuffle Dataset [dataframe]
(dataset/sort dataframe (functions/randn)))
(defmulti first class)
(defmethod first Dataset [dataframe]
(-> dataframe (dataset/take 1) clojure.core/first))
(defmethod first :default [expr] (functions/first (->column expr)))
(defmulti last class)
(defmethod last Dataset [dataframe]
(-> dataframe (dataset/tail 1) clojure.core/first))
(defmethod last :default [expr] (functions/last (->column expr)))
(defmulti filter (fn [head & _] (class head)))
(defmethod filter Dataset [dataframe expr]
(.filter dataframe (.cast (->column expr) "boolean")))
(defmethod filter :default [expr predicate]
(let [scala-predicate (if (= (arg-count predicate) 2)
(interop/->scala-function2 predicate)
(interop/->scala-function1 predicate))]
(functions/filter (->column expr) scala-predicate)))
(def where filter)
(defmulti to-json (fn [head & _] (class head)))
(defmethod to-json Dataset [dataframe] (.toJSON dataframe))
(defmethod to-json :default
([expr] (functions/to_json (->column expr) {}))
([expr options]
(functions/to_json (->column expr) (->string-map options))))
(defmulti to-df (fn [head & _] (class head)))
(defmethod to-df :default
([table col-names]
(to-df @default-spark table col-names))
([spark table col-names]
(dataset-creation/table->dataset spark table col-names)))
(defmethod to-df Dataset
([dataframe] (.toDF dataframe))
([dataframe & col-names]
(.toDF dataframe (->> col-names
(mapcat ensure-coll)
(map name)
interop/->scala-seq))))
(defmulti corr (fn [head & _] (class head)))
(defmethod corr :default [l-expr r-expr]
(functions/corr (->column l-expr) (->column r-expr)))
(defmethod corr Dataset
([dataframe col-name]
(Correlation/corr dataframe (name col-name)))
([dataframe col-name1 col-name2]
(-> dataframe .stat (.corr (name col-name1) (name col-name2))))
([dataframe col-name1 col-name2 method]
(-> dataframe .stat (.corr (name col-name1) (name col-name2) method))))
;; Tech ML
(defmulti assoc (fn [head & _] (class head)))
(defmethod assoc :default
([expr k v] (sql/map-concat expr (sql/map k v)))
([expr k v & kvs]
(if (even? (clojure.core/count kvs))
(let [assoced (assoc expr k v)]
(reduce (fn [m [k v]] (assoc m k v)) assoced (partition 2 kvs)))
(throw (IllegalArgumentException. (str "assoc expects even number of arguments "
"after map/vector, found odd number"))))))
(defmethod assoc Dataset
([dataframe k v] (.withColumn dataframe (name k) (->column v)))
([dataframe k v & kvs]
(if (even? (clojure.core/count kvs))
(let [assoced (assoc dataframe k v)]
(reduce (fn [m [k v]] (assoc m k v)) assoced (partition 2 kvs)))
(throw (IllegalArgumentException. (str "assoc expects even number of arguments "
"after map/vector, found odd number"))))))
(defmulti dissoc (fn [head & _] (class head)))
(defmethod dissoc :default [expr & ks]
(sql/map-filter
expr
(fn [k _] (functions/not (.isin k (interop/->scala-seq ks))))))
(defmethod dissoc Dataset [dataframe & col-names]
(apply dataset/drop dataframe col-names))
(defmulti update (fn [head & _] (class head)))
(defmethod update :default [expr k f & args]
(sql/transform-values
expr
(fn [k' v] (sql/when (.equalTo (->column k') (->column k))
(apply f v args)
v))))
(defmethod update Dataset [dataframe k f & args]
(dataset/with-column dataframe k (apply f k args)))
;; Pandas
(defmulti quantile (fn [head & _] (class head)))
(defmethod quantile :default [col-name percs]
(let [percs-str (if (coll? percs)
(str "array(" (clojure.string/join ", " (map str percs)) ")")
(str percs))
median-expr (str "percentile_approx(" (name col-name) ", " percs-str ")")
median-name (str "quantile(" (name col-name) ", " percs-str ")")]
(as (sql/expr median-expr) median-name)))
(defmethod quantile RelationalGroupedDataset [grouped percs col-names]
(dataset/agg grouped (map #(quantile % percs) col-names)))
(defmulti iqr (fn [head & _] (class head)))
(defmethod iqr :default [col-name]
(as (.minus (quantile col-name 0.75) (quantile col-name 0.25))
(str "iqr(" (name col-name) ")")))
(defmethod iqr RelationalGroupedDataset [grouped & col-names]
(dataset/agg grouped (->> col-names
(mapcat ensure-coll)
(map iqr))))
(def interquartile-range iqr)
(defmulti median (fn [head & _] (class head)))
(defmethod median :default [col-name]
(let [median-expr (str "percentile_approx(" (name col-name) " , 0.5)")
median-name (str "median(" (name col-name) ")")]
(as (sql/expr median-expr) median-name)))
(defmethod median RelationalGroupedDataset [grouped & col-names]
(dataset/agg grouped (map median col-names)))