Skip to content

Commit dcad0da

Browse files
Migrate to modern NumPy interface (#2479)
* Migrate to modern NumPy interface Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> * Update onedal/utils/validation.py Co-authored-by: david-cortes-intel <david.cortes@intel.com> * Update sklearnex/svm/_common.py Co-authored-by: david-cortes-intel <david.cortes@intel.com> --------- Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com> Co-authored-by: david-cortes-intel <david.cortes@intel.com>
1 parent 8e3c8e7 commit dcad0da

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

onedal/utils/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _compute_class_weight(class_weight, classes, y):
9393

9494
le = LabelEncoder()
9595
y_ind = le.fit_transform(y_)
96-
if not all(np.in1d(classes, le.classes_)):
96+
if not np.isin(classes, le.classes_).all():
9797
raise ValueError("classes should have valid labels that are in y")
9898

9999
y_bin = np.bincount(y_ind).astype(np.float64)

sklearnex/svm/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _compute_balanced_class_weight(self, y):
233233

234234
le = LabelEncoder()
235235
y_ind = le.fit_transform(y_)
236-
if not all(np.in1d(classes, le.classes_)):
236+
if not np.isin(classes, le.classes_).all():
237237
raise ValueError("classes should have valid labels that are in y")
238238

239239
recip_freq = len(y_) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64))

0 commit comments

Comments
 (0)