forked from zero-one-group/geni
/
foreign_idioms.clj
216 lines (194 loc) · 8.28 KB
/
foreign_idioms.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
;; Docstring Sources:
;; https://numpy.org/doc/
;; https://pandas.pydata.org/docs/
(ns zero-one.geni.core.foreign-idioms
(:refer-clojure :exclude [replace])
(:require
[clojure.string :as string]
[potemkin :refer [import-fn]]
[zero-one.geni.core.column :as column]
[zero-one.geni.core.data-sources :as data-sources]
[zero-one.geni.core.dataset :as dataset]
[zero-one.geni.core.dataset-creation :as dataset-creation]
[zero-one.geni.core.functions :as sql]
[zero-one.geni.core.polymorphic :as polymorphic]
[zero-one.geni.core.window :as window]
[zero-one.geni.utils :as utils])
(:import
(org.apache.spark.sql Column functions)))
;; NumPy
(defn clip
"Returns a new Column where values outside `[low, high]` are clipped to the interval edges."
[expr low high]
(let [col (column/->column expr)]
(-> (polymorphic/coalesce
(sql/when (column/<= col low) low)
(sql/when (column/<= high col) high)
col)
(polymorphic/as (format "clip(%s, %s, %s)"
(.toString col)
(str low)
(str high))))))
(defn random-uniform
"Returns a new Column of draws from a uniform distribution."
([] (random-uniform 0.0 1.0))
([low high] (random-uniform low high (rand-int Integer/MAX_VALUE)))
([low high seed]
(let [length (Math/abs (- high low))
base (min high low)]
(column/+ base (column/* length (sql/rand seed))))))
(import-fn random-uniform runiform)
(import-fn random-uniform runif)
(defn random-norm
"Returns a new Column of draws from a normal distribution."
([] (random-norm 0.0 1.0))
([mu sigma] (random-norm mu sigma (rand-int Integer/MAX_VALUE)))
([mu sigma seed] (column/+ mu (column/* sigma (sql/randn seed)))))
(import-fn random-norm rnorm)
(defn random-exp
"Returns a new Column of draws from an exponential distribution."
([] (random-exp 1.0))
([rate] (random-exp rate (rand-int Integer/MAX_VALUE)))
([rate seed] (-> (sql/rand seed)
sql/log
(column/* -1.0)
(column// rate))))
(import-fn random-exp rexp)
(defn random-int
"Returns a new Column of random integers from `low` (inclusive) to `high` (exclusive)."
([] (random-int 0 (dec Integer/MAX_VALUE)))
([low high] (random-int low high (rand-int Integer/MAX_VALUE)))
([low high seed]
(let [length (Math/abs (- high low))
base (min high low)
->long #(column/cast % "long")]
(column/+ (->long base) (->long (column/* length (sql/rand seed)))))))
(defn random-choice
"Returns a new Column of a random sample from a given collection of `choices`."
([choices]
(let [n-choices (count choices)]
(random-choice choices (take n-choices (repeat (/ 1.0 n-choices))))))
([choices probs] (random-choice choices probs (rand-int Integer/MAX_VALUE)))
([choices probs seed]
(assert (and (= (count choices) (count probs))
(every? pos? probs))
"random-choice args must have same lengths.")
(assert (< (Math/abs (- (apply + probs) 1.0)) 1e-4)
"random-choice probs must some to one.")
(let [rand-col (column/->column (sql/rand seed))
cum-probs (reductions + probs)
choice-cols (map (fn [choice prob]
(sql/when (column/< rand-col (+ prob 1e-6))
(column/->column choice)))
choices
cum-probs)]
(.as (apply polymorphic/coalesce choice-cols)
(format "choice(%s, %s)" (str choices) (str probs))))))
(import-fn random-choice rchoice)
;; Pandas
(defn value-counts
"Returns a Dataset containing counts of unique rows.
The resulting object will be in descending order so that the
first element is the most frequently-occurring element."
[dataframe]
(-> dataframe
(dataset/group-by (dataset/columns dataframe))
(dataset/agg {:count (functions/count "*")})
(dataset/order-by (.desc (column/->column :count)))))
(defn shape
"Returns a vector representing the dimensionality of the Dataset."
[dataframe]
[(.count dataframe) (count (.columns dataframe))])
(defn nlargest
"Return the Dataset with the first `n-rows` rows ordered by `expr` in descending order."
[dataframe n-rows expr]
(-> dataframe
(dataset/order-by (.desc (column/->column expr)))
(dataset/limit n-rows)))
(defn nsmallest
"Return the Dataset with the first `n-rows` rows ordered by `expr` in ascending order."
[dataframe n-rows expr]
(-> dataframe
(dataset/order-by (column/->column expr))
(dataset/limit n-rows)))
(defn nunique
"Count distinct observations over all columns in the Dataset."
[dataframe]
(dataset/agg-all dataframe #(functions/countDistinct
(column/->column %)
(into-array Column []))))
(defn- resolve-probs [num-buckets-or-probs]
(if (coll? num-buckets-or-probs)
(do
(assert (and (apply < num-buckets-or-probs)
(every? #(< 0.0 % 1.0) num-buckets-or-probs))
"Probs array must be increasing and in the unit interval.")
num-buckets-or-probs)
(map #(/ (inc %) (double num-buckets-or-probs)) (range (dec num-buckets-or-probs)))))
(defn qcut
"Returns a new Column of discretised `expr` into equal-sized buckets based
on rank or based on sample quantiles."
[expr num-buckets-or-probs]
(let [probs (resolve-probs num-buckets-or-probs)
col (column/->column expr)
rank-col (window/windowed {:window-col (sql/percent-rank) :order-by col})
qcut-cols (map (fn [low high]
(sql/when (column/<= low rank-col high)
(column/lit (format "%s[%s, %s]"
(.toString col)
(str low)
(str high)))))
(concat [0.0] probs)
(concat probs [1.0]))]
(.as (apply polymorphic/coalesce qcut-cols)
(format "qcut(%s, %s)" (.toString col) (str probs)))))
(defn cut
"Returns a new Column of discretised `expr` into the intervals of bins."
[expr bins]
(assert (apply < bins))
(let [col (column/->column expr)
cut-cols (map (fn [low high]
(sql/when (column/<= low col high)
(column/lit (format "%s[%s, %s]"
(.toString col)
(str low)
(str high)))))
(concat [Double/NEGATIVE_INFINITY] bins)
(concat bins [Double/POSITIVE_INFINITY]))]
(.as (apply polymorphic/coalesce cut-cols)
(format "cut(%s, %s)" (.toString col) (str bins)))))
(defn replace
"Returns a new Column where `from-value-or-values` is replaced with `to-value`."
([expr lookup-map]
(reduce-kv (fn [column from to] (replace column from to)) expr lookup-map))
([expr from-value-or-values to-value]
(let [from-values (utils/ensure-coll from-value-or-values)]
(sql/when
(column/isin expr from-values)
(column/lit to-value)
expr))))
;; Tech ML
(defn- apply-options [dataset options]
(-> dataset
(cond-> (:column-whitelist options)
(dataset/select (map name (:column-whitelist options))))
(cond-> (:n-records options)
(dataset/limit (:n-records options)))))
(defmulti ->dataset
"Create a Dataset from a path or a collection of records."
(fn [head & _] (class head)))
;; TODO: support excel files
(defmethod ->dataset java.lang.String
([path]
(cond
(string/includes? path ".avro") (data-sources/read-avro! path)
(string/includes? path ".csv") (data-sources/read-csv! path)
(string/includes? path ".json") (data-sources/read-json! path)
(string/includes? path ".parquet") (data-sources/read-parquet! path)
:else (throw (Exception. "Unsupported file format."))))
([path options] (apply-options (->dataset path) options)))
(defmethod ->dataset :default
([records] (dataset-creation/records->dataset records))
([records options] (apply-options (->dataset records) options)))
(import-fn dataset-creation/map->dataset name-value-seq->dataset)
(import-fn dataset/select select-columns)