Skip to content

Commit

Permalink
[SPARK-45434][ML][CONNECT] LogisticRegression checks the training labels
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

- checks the training labels
- get `num_features` together with `num_rows`

### Why are the changes needed?
training labels should be in [0, numClasses)

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #43246 from zhengruifeng/ml_lr_nit.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Oct 6, 2023
1 parent bb56226 commit 39d43e0
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/pyspark/ml/connect/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from pyspark.ml.connect.base import Predictor, PredictionModel
from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite
from pyspark.sql.functions import lit, count, countDistinct
from pyspark.sql import functions as sf

import torch
import torch.nn as torch_nn
Expand Down Expand Up @@ -232,18 +232,20 @@ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegressionMo
num_train_workers
)

# TODO: check label values are in range of [0, num_classes)
num_rows, num_classes = dataset.agg(
count(lit(1)), countDistinct(self.getLabelCol())
num_rows, num_features, classes = dataset.select(
sf.count(sf.lit(1)),
sf.first(sf.array_size(self.getFeaturesCol())),
sf.collect_set(self.getLabelCol()),
).head() # type: ignore[misc]

num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
num_samples_per_worker = num_batches_per_worker * batch_size

num_features = len(dataset.select(self.getFeaturesCol()).head()[0]) # type: ignore[index]

num_classes = len(classes)
if num_classes < 2:
raise ValueError("Training dataset distinct labels must >= 2.")
if any(c not in range(0, num_classes) for c in classes):
raise ValueError("Training labels must be integers in [0, numClasses).")

num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size)
num_samples_per_worker = num_batches_per_worker * batch_size

# TODO: support GPU.
distributor = TorchDistributor(
Expand Down

0 comments on commit 39d43e0

Please sign in to comment.