-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathshared.ex
More file actions
208 lines (188 loc) · 6.79 KB
/
shared.ex
File metadata and controls
208 lines (188 loc) · 6.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
defmodule Mighty.Preprocessing.Shared do
import Exterval
@moduledoc false
vectorizer_schema_opts = [
ngram_range: [
type: {:custom, __MODULE__, :validate_ngram_range, []},
default: {1, 1},
doc: """
The lower and upper boundary of the range of n-values for different n-grams to be extracted.
All values of $n$ such that $min_n <= n <= max_n$ will be used.
For example an `ngram_range` of `{1, 1}` means only unigrams, `{1, 2}` means unigrams and bigrams,
and `{2, 2}` means only bigrams.
"""
],
max_features: [
type: {:or, [:pos_integer, nil]},
default: nil,
doc: """
If not `nil`, build a vocabulary that only consider the top `max_features` ordered by term frequency across the corpus.
This parameter is ignored if `vocabulary` is not `nil`.
"""
],
min_df: [
type: {:or, [:non_neg_integer, {:in, ~i<[0,1]>}]},
default: 1,
doc: """
When building the vocabulary ignore terms that have a document frequency strictly lower than the given threshold.
This value is also called cut-off in the literature.
If float, the parameter represents a proportion of documents, integer absolute counts.
This parameter is ignored if `vocabulary` is not `nil`.
"""
],
max_df: [
type: {:or, [:non_neg_integer, {:in, ~i<[0,1]>}]},
default: 1.0,
doc: """
When building the vocabulary ignore terms that have a document frequency strictly higher than the given threshold.
This value is also called cut-off in the literature.
If float, the parameter represents a proportion of documents, integer absolute counts.
This parameter is ignored if `vocabulary` is not `nil`.
"""
],
stop_words: [
type: {:custom, __MODULE__, :validate_stop_words, []},
default: [],
doc: """
If `stop_words` is `nil`, no stop words will be used.
If `stop_words` is `:english`, a built-in stop word list for English is used.
If `stop_words` is a list, that list is assumed to contain stop words, all of which will be removed from the resulting tokens.
Only applies if `analyzer` is not callable.
"""
],
binary: [
type: :boolean,
default: false,
doc: """
If `true`, all non-zero counts are set to 1.
This is useful for discrete probabilistic models that model binary events rather than integer counts.
"""
],
dtype: [
type:
{:in,
[:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :f16, :f32, :f64, :bf16, :c64, :c128]},
default: :f64,
doc: """
Type of the matrix returned by `fit_transform` or `transform`.
"""
],
tokenizer: [
type: :mfa,
default: {String, :split, []},
doc: """
Provide the tokenization function to use on the corpus.
The n-grams generated will use the tokens produced by the tokenizer.
Must be in MFA format (e.g. `{Module, :function, arity}`).
If `tokenizer` is `nil`, `String.split/2` is used.
"""
],
preprocessor: [
type: :mfa,
default: {__MODULE__, :default_preprocessor, []},
doc: """
Override the preprocessing (string transformation) stage while preserving the tokenizing and n-grams generation steps.
Must be in MFA format (e.g. `{Module, :function, arity}`).
Default performs `String.downcase/1` |> `String.normalize(:nfkd)`.
"""
],
vocabulary: [
type: {:custom, __MODULE__, :validate_vocabulary, []},
default: nil,
doc: """
Either a map where keys are terms and values are indices in the feature matrix, or a list of terms.
If `vocabulary` is `nil`, a vocabulary is determined from the input documents.
Indices in the vocabulary are expected to be unique.
"""
]
]
tfidf_schema_opts = [
norm: [
type: {:in, [nil, :euclidean, :manhattan, :chebyshev]},
default: :euclidean,
doc: """
Norm used to normalize term vectors. If `nil`, no normalization is applied.
Valid options are the same as Scholar.Preprocessing.normalize/2.
See https://hexdocs.pm/scholar/Scholar.Preprocessing.html#normalize/2
"""
],
use_idf: [
type: :boolean,
default: true,
doc: ~S"""
Enable inverse-document-frequency reweighting.
If false, $idf(t) = 1$.
"""
],
smooth_idf: [
type: :boolean,
default: true,
doc: """
Smooth idf weights by adding one to document frequencies, as if an extra document was seen containing every term in the collection exactly once.
Prevents zero divisions.
"""
],
sublinear_tf: [
type: :boolean,
default: false,
doc: ~S"""
Apply sublinear tf scaling, i.e. replace tf with $1 + \log(tf)$.
"""
]
]
@vectorizer_schema NimbleOptions.new!(vectorizer_schema_opts)
@tfidf_schema NimbleOptions.new!(vectorizer_schema_opts ++ tfidf_schema_opts)
def get_vectorizer_schema() do
@vectorizer_schema.schema
end
def get_tfidf_schema() do
@tfidf_schema.schema
end
def validate_vocabulary(vocabulary) do
case vocabulary do
nil ->
{:ok, nil}
%MapSet{} ->
{:ok, vocabulary |> Enum.sort() |> Enum.with_index() |> Enum.into(%{})}
_ when is_list(vocabulary) ->
{:ok, vocabulary |> Enum.sort() |> Enum.with_index() |> Enum.into(%{})}
_ when is_map(vocabulary) ->
indices = vocabulary |> Map.values() |> MapSet.new()
unless Enum.count(indices) == Enum.count(vocabulary),
do: {:error, "vocabulary indices must be unique"}
for i <- 0..(Enum.count(vocabulary) - 1) do
unless Map.has_key?(vocabulary, i),
do: {:error, "Vocabulary of size #{Enum.count(vocabulary)} missing index #{i}"}
end
{:ok, vocabulary}
_ ->
{:error, "vocabulary must be of type Map, MapSet, or List"}
end
end
def default_preprocessor(text) do
text
|> String.downcase()
|> String.normalize(:nfkd)
end
def validate_shared!(opts) do
NimbleOptions.validate!(opts, @vectorizer_schema)
end
def validate_tfidf!(opts) do
NimbleOptions.validate!(opts, @tfidf_schema)
end
def validate_ngram_range(value = {min, max}) do
if min <= max, do: {:ok, value}, else: {:error, "min must be less than or equal to max"}
end
def validate_ngram_range(value) do
unless is_tuple(value) and tuple_size(value) == 2,
do: {:error, "ngram_range must be a tuple of length 2"}
end
def validate_stop_words(stop_words) when is_list(stop_words) do
unless Enum.all?(stop_words, &is_binary/1),
do: {:error, "stop_words must only consist of strings"}
{:ok, MapSet.new(stop_words)}
end
def validate_stop_words(stop_words) do
{:error, "stop_words must be a list, got #{inspect(stop_words)}"}
end
end