Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 67 additions & 57 deletions python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
#

import sys
from typing import Any, Dict, Optional, TYPE_CHECKING

from pyspark import keyword_only, since
from pyspark.sql import DataFrame
from pyspark.ml.util import JavaMLWritable, JavaMLReadable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject # type: ignore[import]

__all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan"]


Expand All @@ -33,26 +37,26 @@ class _FPGrowthParams(HasPredictionCol):
.. versionadded:: 3.0.0
"""

itemsCol = Param(
itemsCol: Param[str] = Param(
Params._dummy(), "itemsCol", "items column name", typeConverter=TypeConverters.toString
)
minSupport = Param(
minSupport: Param[float] = Param(
Params._dummy(),
"minSupport",
"Minimal support level of the frequent pattern. [0.0, 1.0]. "
+ "Any pattern that appears more than (minSupport * size-of-the-dataset) "
+ "times will be output in the frequent itemsets.",
typeConverter=TypeConverters.toFloat,
)
numPartitions = Param(
numPartitions: Param[int] = Param(
Params._dummy(),
"numPartitions",
"Number of partitions (at least 1) used by parallel FP-growth. "
+ "By default the param is not set, "
+ "and partition number of the input dataset is used.",
typeConverter=TypeConverters.toInt,
)
minConfidence = Param(
minConfidence: Param[float] = Param(
Params._dummy(),
"minConfidence",
"Minimal confidence for generating Association Rule. [0.0, 1.0]. "
Expand All @@ -61,78 +65,78 @@ class _FPGrowthParams(HasPredictionCol):
typeConverter=TypeConverters.toFloat,
)

def __init__(self, *args):
def __init__(self, *args: Any):
super(_FPGrowthParams, self).__init__(*args)
self._setDefault(
minSupport=0.3, minConfidence=0.8, itemsCol="items", predictionCol="prediction"
)

def getItemsCol(self):
def getItemsCol(self) -> str:
"""
Gets the value of itemsCol or its default value.
"""
return self.getOrDefault(self.itemsCol)

def getMinSupport(self):
def getMinSupport(self) -> float:
"""
Gets the value of minSupport or its default value.
"""
return self.getOrDefault(self.minSupport)

def getNumPartitions(self):
def getNumPartitions(self) -> int:
"""
Gets the value of :py:attr:`numPartitions` or its default value.
"""
return self.getOrDefault(self.numPartitions)

def getMinConfidence(self):
def getMinConfidence(self) -> float:
"""
Gets the value of minConfidence or its default value.
"""
return self.getOrDefault(self.minConfidence)


class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
class FPGrowthModel(JavaModel, _FPGrowthParams, JavaMLWritable, JavaMLReadable["FPGrowthModel"]):
"""
Model fitted by FPGrowth.

.. versionadded:: 2.2.0
"""

@since("3.0.0")
def setItemsCol(self, value):
def setItemsCol(self, value: str) -> "FPGrowthModel":
"""
Sets the value of :py:attr:`itemsCol`.
"""
return self._set(itemsCol=value)

@since("3.0.0")
def setMinConfidence(self, value):
def setMinConfidence(self, value: float) -> "FPGrowthModel":
"""
Sets the value of :py:attr:`minConfidence`.
"""
return self._set(minConfidence=value)

@since("3.0.0")
def setPredictionCol(self, value):
def setPredictionCol(self, value: str) -> "FPGrowthModel":
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

@property
@property # type: ignore[misc]
@since("2.2.0")
def freqItemsets(self):
def freqItemsets(self) -> DataFrame:
"""
DataFrame with two columns:
* `items` - Itemset of the same type as the input column.
* `freq` - Frequency of the itemset (`LongType`).
"""
return self._call_java("freqItemsets")

@property
@property # type: ignore[misc]
@since("2.2.0")
def associationRules(self):
def associationRules(self) -> DataFrame:
"""
DataFrame with four columns:
* `antecedent` - Array of the same type as the input column.
Expand All @@ -143,7 +147,9 @@ def associationRules(self):
return self._call_java("associationRules")


class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
class FPGrowth(
JavaEstimator[FPGrowthModel], _FPGrowthParams, JavaMLWritable, JavaMLReadable["FPGrowth"]
):
r"""
A parallel FP-growth algorithm to mine frequent itemsets.

Expand Down Expand Up @@ -229,16 +235,17 @@ class FPGrowth(JavaEstimator, _FPGrowthParams, JavaMLWritable, JavaMLReadable):
>>> fpm.transform(data).take(1) == model2.transform(data).take(1)
True
"""
_input_kwargs: Dict[str, Any]

@keyword_only
def __init__(
self,
*,
minSupport=0.3,
minConfidence=0.8,
itemsCol="items",
predictionCol="prediction",
numPartitions=None,
minSupport: float = 0.3,
minConfidence: float = 0.8,
itemsCol: str = "items",
predictionCol: str = "prediction",
numPartitions: Optional[int] = None,
):
"""
__init__(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
Expand All @@ -254,50 +261,50 @@ def __init__(
def setParams(
self,
*,
minSupport=0.3,
minConfidence=0.8,
itemsCol="items",
predictionCol="prediction",
numPartitions=None,
):
minSupport: float = 0.3,
minConfidence: float = 0.8,
itemsCol: str = "items",
predictionCol: str = "prediction",
numPartitions: Optional[int] = None,
) -> "FPGrowth":
"""
setParams(self, \\*, minSupport=0.3, minConfidence=0.8, itemsCol="items", \
predictionCol="prediction", numPartitions=None)
"""
kwargs = self._input_kwargs
return self._set(**kwargs)

def setItemsCol(self, value):
def setItemsCol(self, value: str) -> "FPGrowth":
"""
Sets the value of :py:attr:`itemsCol`.
"""
return self._set(itemsCol=value)

def setMinSupport(self, value):
def setMinSupport(self, value: float) -> "FPGrowth":
"""
Sets the value of :py:attr:`minSupport`.
"""
return self._set(minSupport=value)

def setNumPartitions(self, value):
def setNumPartitions(self, value: int) -> "FPGrowth":
"""
Sets the value of :py:attr:`numPartitions`.
"""
return self._set(numPartitions=value)

def setMinConfidence(self, value):
def setMinConfidence(self, value: float) -> "FPGrowth":
"""
Sets the value of :py:attr:`minConfidence`.
"""
return self._set(minConfidence=value)

def setPredictionCol(self, value):
def setPredictionCol(self, value: str) -> "FPGrowth":
"""
Sets the value of :py:attr:`predictionCol`.
"""
return self._set(predictionCol=value)

def _create_model(self, java_model):
def _create_model(self, java_model: "JavaObject") -> FPGrowthModel:
return FPGrowthModel(java_model)


Expand Down Expand Up @@ -347,7 +354,9 @@ class PrefixSpan(JavaParams):
...
"""

minSupport = Param(
_input_kwargs: Dict[str, Any]

minSupport: Param[float] = Param(
Params._dummy(),
"minSupport",
"The minimal support level of the "
Expand All @@ -356,14 +365,14 @@ class PrefixSpan(JavaParams):
typeConverter=TypeConverters.toFloat,
)

maxPatternLength = Param(
maxPatternLength: Param[int] = Param(
Params._dummy(),
"maxPatternLength",
"The maximal length of the sequential pattern. Must be > 0.",
typeConverter=TypeConverters.toInt,
)

maxLocalProjDBSize = Param(
maxLocalProjDBSize: Param[int] = Param(
Params._dummy(),
"maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the "
Expand All @@ -374,7 +383,7 @@ class PrefixSpan(JavaParams):
typeConverter=TypeConverters.toInt,
)

sequenceCol = Param(
sequenceCol: Param[str] = Param(
Params._dummy(),
"sequenceCol",
"The name of the sequence column in "
Expand All @@ -386,10 +395,10 @@ class PrefixSpan(JavaParams):
def __init__(
self,
*,
minSupport=0.1,
maxPatternLength=10,
maxLocalProjDBSize=32000000,
sequenceCol="sequence",
minSupport: float = 0.1,
maxPatternLength: int = 10,
maxLocalProjDBSize: int = 32000000,
sequenceCol: str = "sequence",
):
"""
__init__(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
Expand All @@ -408,11 +417,11 @@ def __init__(
def setParams(
self,
*,
minSupport=0.1,
maxPatternLength=10,
maxLocalProjDBSize=32000000,
sequenceCol="sequence",
):
minSupport: float = 0.1,
maxPatternLength: int = 10,
maxLocalProjDBSize: int = 32000000,
sequenceCol: str = "sequence",
) -> "PrefixSpan":
"""
setParams(self, \\*, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \
sequenceCol="sequence")
Expand All @@ -421,62 +430,62 @@ def setParams(
return self._set(**kwargs)

@since("3.0.0")
def setMinSupport(self, value):
def setMinSupport(self, value: float) -> "PrefixSpan":
"""
Sets the value of :py:attr:`minSupport`.
"""
return self._set(minSupport=value)

@since("3.0.0")
def getMinSupport(self):
def getMinSupport(self) -> float:
"""
Gets the value of minSupport or its default value.
"""
return self.getOrDefault(self.minSupport)

@since("3.0.0")
def setMaxPatternLength(self, value):
def setMaxPatternLength(self, value: int) -> "PrefixSpan":
"""
Sets the value of :py:attr:`maxPatternLength`.
"""
return self._set(maxPatternLength=value)

@since("3.0.0")
def getMaxPatternLength(self):
def getMaxPatternLength(self) -> int:
"""
Gets the value of maxPatternLength or its default value.
"""
return self.getOrDefault(self.maxPatternLength)

@since("3.0.0")
def setMaxLocalProjDBSize(self, value):
def setMaxLocalProjDBSize(self, value: int) -> "PrefixSpan":
"""
Sets the value of :py:attr:`maxLocalProjDBSize`.
"""
return self._set(maxLocalProjDBSize=value)

@since("3.0.0")
def getMaxLocalProjDBSize(self):
def getMaxLocalProjDBSize(self) -> int:
"""
Gets the value of maxLocalProjDBSize or its default value.
"""
return self.getOrDefault(self.maxLocalProjDBSize)

@since("3.0.0")
def setSequenceCol(self, value):
def setSequenceCol(self, value: str) -> "PrefixSpan":
"""
Sets the value of :py:attr:`sequenceCol`.
"""
return self._set(sequenceCol=value)

@since("3.0.0")
def getSequenceCol(self):
def getSequenceCol(self) -> str:
"""
Gets the value of sequenceCol or its default value.
"""
return self.getOrDefault(self.sequenceCol)

def findFrequentSequentialPatterns(self, dataset):
def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame:
"""
Finds the complete set of frequent sequential patterns in the input sequences of itemsets.

Expand All @@ -499,6 +508,7 @@ def findFrequentSequentialPatterns(self, dataset):
"""

self._transfer_params_to_java()
assert self._java_obj is not None
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
return DataFrame(jdf, dataset.sql_ctx)

Expand Down
Loading