/
_validation.py
606 lines (511 loc) · 21.8 KB
/
_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
"""Utilities for input validation
"""
# Adapted from imbalanced-learn
# Authors: Guillaume Lemaitre
# Zhining Liu <zhining.liu@outlook.com>
# License: MIT
# %%
LOCAL_DEBUG = False
if not LOCAL_DEBUG:
from ..exceptions import raise_isinstance_error
else: # pragma: no cover
import sys # For local test
sys.path.append("..")
from exceptions import raise_isinstance_error
import warnings
from collections import OrderedDict
from functools import wraps
from inspect import Parameter, signature
from numbers import Integral, Real
import numpy as np
from sklearn.base import clone
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors._base import KNeighborsMixin
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import type_of_target
SAMPLING_KIND = (
"over-sampling",
"under-sampling",
"clean-sampling",
"ensemble",
"bypass",
)
TARGET_KIND = ("binary", "multiclass", "multilabel-indicator")
class ArraysTransformer:
"""A class to convert sampler output arrays to their original types."""
def __init__(self, X, y):
self.x_props = self._gets_props(X)
self.y_props = self._gets_props(y)
def transform(self, X, y):
X = self._transfrom_one(X, self.x_props)
y = self._transfrom_one(y, self.y_props)
return X, y
def _gets_props(self, array):
props = {}
props["type"] = array.__class__.__name__
props["columns"] = getattr(array, "columns", None)
props["name"] = getattr(array, "name", None)
props["dtypes"] = getattr(array, "dtypes", None)
return props
def _transfrom_one(self, array, props):
type_ = props["type"].lower()
if type_ == "list":
ret = array.tolist()
elif type_ == "dataframe":
import pandas as pd
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])
elif type_ == "series":
import pandas as pd
ret = pd.Series(array, dtype=props["dtypes"], name=props["name"])
else:
ret = array
return ret
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
"""Check the objects is consistent to be a NN.
Several methods in imbens.sampler relies on NN.
Only KNeighborsMixin will be accepted. This utility allows for type
checking and raise if the type is wrong.
Parameters
----------
nn_name : str
The name associated to the object to raise an error if needed.
nn_object : int or KNeighborsMixin,
The object to be checked.
additional_neighbor : int, default=0
Sometimes, some algorithm need an additional neighbors.
Returns
-------
nn_object : KNeighborsMixin
The k-NN object.
"""
if isinstance(nn_object, Integral):
return NearestNeighbors(n_neighbors=nn_object + additional_neighbor)
elif isinstance(nn_object, KNeighborsMixin):
return clone(nn_object)
else:
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)
def _count_class_sample(y):
unique, counts = np.unique(y, return_counts=True)
return dict(zip(unique, counts))
def check_target_type(y, indicate_one_vs_all=False):
"""Check the target types to be conform to the current samplers.
The current samplers should be compatible with ``'binary'``,
``'multilabel-indicator'`` and ``'multiclass'`` targets only.
Parameters
----------
y : ndarray
The array containing the target.
indicate_one_vs_all : bool, default=False
Either to indicate if the targets are encoded in a one-vs-all fashion.
Returns
-------
y : ndarray
The returned target.
is_one_vs_all : bool, optional
Indicate if the target was originally encoded in a one-vs-all fashion.
Only returned if ``indicate_multilabel=True``.
"""
type_y = type_of_target(y)
if type_y == "multilabel-indicator":
if np.any(y.sum(axis=1) > 1):
raise ValueError(
"Imbalanced-learn currently supports binary, multiclass and "
"binarized encoded multiclasss targets. Multilabel and "
"multioutput targets are not supported."
)
y = y.argmax(axis=1)
else:
y = column_or_1d(y)
return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y
def _sampling_strategy_all(y, sampling_type):
"""Returns sampling target by targeting all classes."""
target_stats = _count_class_sample(y)
if sampling_type == "over-sampling":
n_sample_majority = max(target_stats.values())
sampling_strategy = {
key: n_sample_majority - value for (key, value) in target_stats.items()
}
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
n_sample_minority = min(target_stats.values())
sampling_strategy = {key: n_sample_minority for key in target_stats.keys()}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_majority(y, sampling_type):
"""Returns sampling target by targeting the majority class only."""
if sampling_type == "over-sampling":
raise ValueError(
"'sampling_strategy'='majority' cannot be used with over-sampler."
)
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
target_stats = _count_class_sample(y)
class_majority = max(target_stats, key=target_stats.get)
n_sample_minority = min(target_stats.values())
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys()
if key == class_majority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_not_majority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
majority."""
target_stats = _count_class_sample(y)
if sampling_type == "over-sampling":
n_sample_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items()
if key != class_majority
}
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
n_sample_minority = min(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys()
if key != class_majority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_not_minority(y, sampling_type):
"""Returns sampling target by targeting all classes but not the
minority."""
target_stats = _count_class_sample(y)
if sampling_type == "over-sampling":
n_sample_majority = max(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items()
if key != class_minority
}
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
n_sample_minority = min(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_minority
for key in target_stats.keys()
if key != class_minority
}
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_minority(y, sampling_type):
"""Returns sampling target by targeting the minority class only."""
target_stats = _count_class_sample(y)
if sampling_type == "over-sampling":
n_sample_majority = max(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy = {
key: n_sample_majority - value
for (key, value) in target_stats.items()
if key == class_minority
}
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
raise ValueError(
"'sampling_strategy'='minority' cannot be used with"
" under-sampler and clean-sampler."
)
else:
raise NotImplementedError
return sampling_strategy
def _sampling_strategy_auto(y, sampling_type):
"""Returns sampling target auto for over-sampling and not-minority for
under-sampling."""
if sampling_type == "over-sampling":
return _sampling_strategy_not_majority(y, sampling_type)
elif sampling_type == "under-sampling" or sampling_type == "clean-sampling":
return _sampling_strategy_not_minority(y, sampling_type)
def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
"""Returns sampling target by converting the dictionary depending of the
sampling."""
target_stats = _count_class_sample(y)
# check that all keys in sampling_strategy are also in y
set_diff_sampling_strategy_target = set(sampling_strategy.keys()) - set(
target_stats.keys()
)
if len(set_diff_sampling_strategy_target) > 0:
raise ValueError(
f"The {set_diff_sampling_strategy_target} target class is/are not "
f"present in the data."
)
# check that there is no negative number
if any(n_samples < 0 for n_samples in sampling_strategy.values()):
raise ValueError(
f"The number of samples in a class cannot be negative."
f"'sampling_strategy' contains some negative value: {sampling_strategy}"
)
sampling_strategy_ = {}
if sampling_type == "over-sampling":
n_samples_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
for class_sample, n_samples in sampling_strategy.items():
if n_samples < target_stats[class_sample]:
raise ValueError(
f"With over-sampling methods, the number"
f" of samples in a class should be greater"
f" or equal to the original number of samples."
f" Originally, there is {target_stats[class_sample]} "
f"samples and {n_samples} samples are asked."
)
if n_samples > n_samples_majority:
warnings.warn(
f"After over-sampling, the number of samples ({n_samples})"
f" in class {class_sample} will be larger than the number of"
f" samples in the majority class (class #{class_majority} ->"
f" {n_samples_majority})"
)
sampling_strategy_[class_sample] = n_samples - target_stats[class_sample]
elif sampling_type == "under-sampling":
for class_sample, n_samples in sampling_strategy.items():
if n_samples > target_stats[class_sample]:
raise ValueError(
f"With under-sampling methods, the number of"
f" samples in a class should be less or equal"
f" to the original number of samples."
f" Originally, there is {target_stats[class_sample]} "
f"samples and {n_samples} samples are asked."
)
sampling_strategy_[class_sample] = n_samples
elif sampling_type == "clean-sampling":
raise ValueError(
"'sampling_strategy' as a dict for cleaning methods is "
"not supported. Please give a list of the classes to be "
"targeted by the sampling."
)
else:
raise NotImplementedError
return sampling_strategy_
def _sampling_strategy_list(sampling_strategy, y, sampling_type):
"""With cleaning methods, sampling_strategy can be a list to target the
class of interest."""
if sampling_type != "clean-sampling":
raise ValueError(
"'sampling_strategy' cannot be a list for samplers "
"which are not cleaning methods."
)
target_stats = _count_class_sample(y)
# check that all keys in sampling_strategy are also in y
set_diff_sampling_strategy_target = set(sampling_strategy) - set(
target_stats.keys()
)
if len(set_diff_sampling_strategy_target) > 0:
raise ValueError(
f"The {set_diff_sampling_strategy_target} target class is/are not "
f"present in the data."
)
return {
class_sample: min(target_stats.values()) for class_sample in sampling_strategy
}
def _sampling_strategy_float(sampling_strategy, y, sampling_type):
"""Take a proportion of the majority (over-sampling) or minority
(under-sampling) class in binary classification."""
type_y = type_of_target(y)
if type_y != "binary":
raise ValueError(
'"sampling_strategy" can be a float only when the type '
"of target is binary. For multi-class, use a dict."
)
target_stats = _count_class_sample(y)
if sampling_type == "over-sampling":
n_sample_majority = max(target_stats.values())
class_majority = max(target_stats, key=target_stats.get)
sampling_strategy_ = {
key: int(n_sample_majority * sampling_strategy - value)
for (key, value) in target_stats.items()
if key != class_majority
}
if any([n_samples <= 0 for n_samples in sampling_strategy_.values()]):
raise ValueError(
"The specified ratio required to remove samples "
"from the minority class while trying to "
"generate new samples. Please increase the "
"ratio."
)
elif sampling_type == "under-sampling":
n_sample_minority = min(target_stats.values())
class_minority = min(target_stats, key=target_stats.get)
sampling_strategy_ = {
key: int(n_sample_minority / sampling_strategy)
for (key, value) in target_stats.items()
if key != class_minority
}
if any(
[
n_samples > target_stats[target]
for target, n_samples in sampling_strategy_.items()
]
):
raise ValueError(
"The specified ratio required to generate new "
"sample in the majority class while trying to "
"remove samples. Please increase the ratio."
)
else:
raise ValueError(
"'clean-sampling' methods do let the user specify the sampling ratio."
)
return sampling_strategy_
def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs):
"""Sampling target validation for samplers.
Checks that ``sampling_strategy`` is of consistent type and return a
dictionary containing each targeted class with its corresponding
number of sample. It is used in :class:`~imbens.base.BaseSampler`.
Parameters
----------
sampling_strategy : float, str, dict, list or callable,
Sampling information to sample the data set.
- When ``float``:
For **under-sampling methods**, it corresponds to the ratio
:math:`\\alpha_{us}` defined by :math:`N_{rM} = \\alpha_{us}
\\times N_{m}` where :math:`N_{rM}` and :math:`N_{m}` are the
number of samples in the majority class after resampling and the
number of samples in the minority class, respectively;
For **over-sampling methods**, it correspond to the ratio
:math:`\\alpha_{os}` defined by :math:`N_{rm} = \\alpha_{os}
\\times N_{m}` where :math:`N_{rm}` and :math:`N_{M}` are the
number of samples in the minority class after resampling and the
number of samples in the majority class, respectively.
.. warning::
``float`` is only available for **binary** classification. An
error is raised for multi-class classification and with cleaning
samplers.
- When ``str``, specify the class targeted by the resampling. For
**under- and over-sampling methods**, the number of samples in the
different classes will be equalized. For **cleaning methods**, the
number of samples will not be equal. Possible choices are:
``'minority'``: resample only the minority class;
``'majority'``: resample only the majority class;
``'not minority'``: resample all classes but the minority class;
``'not majority'``: resample all classes but the majority class;
``'all'``: resample all classes;
``'auto'``: for under-sampling methods, equivalent to ``'not
minority'`` and for over-sampling methods, equivalent to ``'not
majority'``.
- When ``dict``, the keys correspond to the targeted classes. The
values correspond to the desired number of samples for each targeted
class.
.. warning::
``dict`` is available for both **under- and over-sampling
methods**. An error is raised with **cleaning methods**. Use a
``list`` instead.
- When ``list``, the list contains the targeted classes. It used only
for **cleaning methods**.
.. warning::
``list`` is available for **cleaning methods**. An error is raised
with **under- and over-sampling methods**.
- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.
y : ndarray of shape (n_samples,)
The target array.
sampling_type : {{'over-sampling', 'under-sampling', 'clean-sampling'}}
The type of sampling. Can be either ``'over-sampling'``,
``'under-sampling'``, or ``'clean-sampling'``.
kwargs : dict
Dictionary of additional keyword arguments to pass to
``sampling_strategy`` when this is a callable.
Returns
-------
sampling_strategy_converted : dict
The converted and validated sampling target. Returns a dictionary with
the key being the class target and the value being the desired
number of samples.
"""
if sampling_type not in SAMPLING_KIND:
raise ValueError(
f"'sampling_type' should be one of {SAMPLING_KIND}. "
f"Got '{sampling_type} instead."
)
if np.unique(y).size <= 1:
raise ValueError(
f"The target 'y' needs to have more than 1 class. "
f"Got {np.unique(y).size} class instead"
)
if sampling_type in ("ensemble", "bypass"):
return sampling_strategy
if isinstance(sampling_strategy, str):
if sampling_strategy not in SAMPLING_TARGET_KIND.keys():
raise ValueError(
f"When 'sampling_strategy' is a string, it needs"
f" to be one of {SAMPLING_TARGET_KIND}. Got '{sampling_strategy}' "
f"instead."
)
return OrderedDict(
sorted(SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items())
)
elif isinstance(sampling_strategy, dict):
return OrderedDict(
sorted(_sampling_strategy_dict(sampling_strategy, y, sampling_type).items())
)
elif isinstance(sampling_strategy, list):
return OrderedDict(
sorted(_sampling_strategy_list(sampling_strategy, y, sampling_type).items())
)
elif isinstance(sampling_strategy, Real):
if sampling_strategy <= 0 or sampling_strategy > 1:
raise ValueError(
f"When 'sampling_strategy' is a float, it should be "
f"in the range (0, 1]. Got {sampling_strategy} instead."
)
return OrderedDict(
sorted(
_sampling_strategy_float(sampling_strategy, y, sampling_type).items()
)
)
elif callable(sampling_strategy):
sampling_strategy_ = sampling_strategy(y, **kwargs)
return OrderedDict(
sorted(
_sampling_strategy_dict(sampling_strategy_, y, sampling_type).items()
)
)
SAMPLING_TARGET_KIND = {
"minority": _sampling_strategy_minority,
"majority": _sampling_strategy_majority,
"not minority": _sampling_strategy_not_minority,
"not majority": _sampling_strategy_not_majority,
"all": _sampling_strategy_all,
"auto": _sampling_strategy_auto,
}
def _deprecate_positional_args(f):
"""Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the
* will issue a warning when passed as a positional argument.
Parameters
----------
f : function
function to check arguments on.
"""
sig = signature(f)
kwonly_args = []
all_args = []
for name, param in sig.parameters.items():
if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
all_args.append(name)
elif param.kind == Parameter.KEYWORD_ONLY:
kwonly_args.append(name)
@wraps(f)
def inner_f(*args, **kwargs):
extra_args = len(args) - len(all_args)
if extra_args > 0:
# ignore first 'self' argument for instance methods
args_msg = [
f"{name}={arg}"
for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
]
warnings.warn(
f"Pass {', '.join(args_msg)} as keyword args. From version 0.9 "
f"passing these as positional arguments will "
f"result in an error",
FutureWarning,
)
kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
return f(**kwargs)
return inner_f