Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
453e5a9
Added multilabel option to training
Pringled Feb 14, 2025
0226494
Added multilabel option to training
Pringled Feb 14, 2025
a22d61a
Added multilabel option to training
Pringled Feb 14, 2025
68a4ae4
Added multilabel option to training
Pringled Feb 14, 2025
614069a
Added multilabel option to training
Pringled Feb 14, 2025
b50bc4a
Added multilabel option to training
Pringled Feb 14, 2025
6831bfe
Added threshold to predict
Pringled Feb 14, 2025
7bf46ea
Updated docs
Pringled Feb 14, 2025
d277e79
Updated docs
Pringled Feb 14, 2025
d28b895
Removed fallback logic
Pringled Feb 14, 2025
327ecb1
Updated docs
Pringled Feb 14, 2025
d38679f
Updated docs
Pringled Feb 14, 2025
6d80e90
Resolved feedback
Pringled Feb 14, 2025
ad8ea8d
Update model2vec/train/README.md
Pringled Feb 14, 2025
b3363ff
Resolved feedback
Pringled Feb 14, 2025
15f4873
Resolved feedback
Pringled Feb 14, 2025
06dc246
Resolved feedback
Pringled Feb 14, 2025
43de6da
Resolved feedback
Pringled Feb 14, 2025
8e944ab
add multilabel targets, fix tests (#194)
stephantul Feb 15, 2025
ff4043f
Merge branch 'main' of https://github.com/MinishLab/model2vec into ad…
Pringled Feb 15, 2025
5c9d397
Fixed bug with array conversion
Pringled Feb 15, 2025
6a4f89b
Optimized inference performance
Pringled Feb 15, 2025
3609e62
Changed classes to np array
Pringled Feb 15, 2025
b4df861
Added int as possible label type
Pringled Feb 16, 2025
ba29feb
Added int as possible label type
Pringled Feb 16, 2025
3dcddf5
Use previous logic
Pringled Feb 16, 2025
eccec80
Updated type check
Pringled Feb 16, 2025
f9037d9
Updated type check
Pringled Feb 16, 2025
2dc5b17
Updated type check logic
Pringled Feb 16, 2025
5003768
Fixed merge conflict
Pringled Feb 16, 2025
b6c00b8
Added evaluate function
Pringled Feb 16, 2025
a51f0bb
Updated evaluate, updated tests to also include int type labels
Pringled Feb 16, 2025
1c86d5e
Updated docs
Pringled Feb 16, 2025
f939695
Fixed inference tests
Pringled Feb 16, 2025
aa07183
Refactored evaluate. Made evaluate available for pipelines. Simplifie…
Pringled Feb 17, 2025
69d990a
Removed unused imports
Pringled Feb 17, 2025
065e04d
Updated classes logic
Pringled Feb 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ ds = load_dataset("setfit/subj")
classifier.fit(ds["train"]["text"], ds["train"]["label"])

# Evaluate the classifier
predictions = classifier.predict(ds["test"]["text"])
accuracy = np.mean(np.array(predictions) == np.array(ds["test"]["label"])) * 100
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
```

For advanced usage, please refer to our [usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md).
Expand Down
4 changes: 2 additions & 2 deletions model2vec/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.inference.model import StaticModelPipeline
from model2vec.inference.model import StaticModelPipeline, evaluate_single_or_multi_label

__all__ = ["StaticModelPipeline"]
__all__ = ["StaticModelPipeline", "evaluate_single_or_multi_label"]
61 changes: 61 additions & 0 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypeVar

import huggingface_hub
import numpy as np
import skops.io
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MultiLabelBinarizer

from model2vec.hf_utils import _create_model_card
from model2vec.model import PathLike, StaticModel

_DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+")
_DEFAULT_MODEL_FILENAME = "pipeline.skops"

LabelType = TypeVar("LabelType", list[str], list[list[str]])


class StaticModelPipeline:
def __init__(self, model: StaticModel, head: Pipeline) -> None:
Expand Down Expand Up @@ -169,6 +174,24 @@ def predict_proba(

return self.head.predict_proba(encoded)

def evaluate(
self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.

:param X: The texts to predict on.
:param y: The ground truth labels.
:param batch_size: The batch size.
:param threshold: The threshold for multilabel classification.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
"""
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

return report


def _load_pipeline(
folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False
Expand Down Expand Up @@ -244,3 +267,41 @@ def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> Non
language=pipeline.model.language,
template_path="modelcards/classifier_template.md",
)


def _is_multi_label_shaped(y: LabelType) -> bool:
"""Check if the labels are in a multi-label shape."""
return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))


def evaluate_single_or_multi_label(
predictions: np.ndarray,
y: LabelType,
output_dict: bool = False,
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.

:param predictions: The predictions.
:param y: The ground truth labels.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
"""
if _is_multi_label_shaped(y):
classes = sorted(set([label for labels in y for label in labels]))
mlb = MultiLabelBinarizer(classes=classes)
y = mlb.fit_transform(y)
predictions = mlb.transform(predictions)
elif isinstance(y[0], (str, int)):
classes = sorted(set(y))

report = classification_report(
y,
predictions,
labels=np.arange(len(classes)),
target_names=[str(c) for c in classes],
output_dict=output_dict,
zero_division=0,
)

return report
19 changes: 4 additions & 15 deletions model2vec/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ test = ds["test"]
s = perf_counter()
classifier = classifier.fit(train["text"], train["label"])

predicted = classifier.predict(test["text"])
print(f"Training took {int(perf_counter() - s)} seconds.")
# Training took 81 seconds
accuracy = np.mean([x == y for x, y in zip(predicted, test["label"])]) * 100
print(f"Achieved {accuracy} test accuracy")
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
print(classification_report)
# Achieved 91.0 test accuracy
```

Expand Down Expand Up @@ -95,18 +94,8 @@ Then, we can evaluate the classifier:
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer

# Make predictions on the test set with a threshold of 0.3
predictions = classifier.predict(ds["test"]["text"], threshold=0.3)

# Evaluate the classifier
mlb = MultiLabelBinarizer(classes=classifier.classes)
y_true = mlb.fit_transform(ds["test"]["labels"])
y_pred = mlb.transform(predictions)

print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}")
print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
print(classification_report)
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
Expand Down
21 changes: 20 additions & 1 deletion model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from tqdm import trange

from model2vec.inference import StaticModelPipeline
from model2vec.inference import StaticModelPipeline, evaluate_single_or_multi_label
from model2vec.train.base import FinetunableStaticModel, TextDataset

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -227,6 +227,25 @@ def fit(
self.eval()
return self

def evaluate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this function is hitting the wrong abstraction level. Here's some observations:

  • The function doesn't need to know what the classifier is, because multi_output labels look different, so you can leave out the check for self.multilabel.
  • There's no need to encode labels for non-multilabel output, you can just pass the un-encoded labels.
  • This function is not available to converted pipelines, but equally applicable.

So I would refactor this into a function that takes a bunch of labels, and then, based on the type and shape of the output, returns a report. This function is then called by this function.

So something like this:

def evaluate(self, ...):
    predictions = self.predict(...)
    return evaluate_single_or_multilabel(predictions, y)

The evaluate_single_or_multilabel then simplifies to:

def evaluate_single_or_multilabel(y, pred):
    if _is_multi_label_shaped(y):
        # Binarization etc.
        return classification_report(y_binarized, pred_binarized)
    return classification_report(y, pred)

That way you can also test these functions without needing to have models, and can also reuse them in other contexts. The consequence of all of this, however, is that evaluate simplifies to:

evaluate_single_or_multilabel(ds["label"], model.predict(ds["text"]))

So maybe having evaluate is not even necessary any more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code as per your suggestions.

  • In the inference model.py there's now evaluate_single_or_multilabel and _is_multi_label_shaped
  • both inference and train model.py have an evaluate function that calls evaluate_single_or_multilabel with the model predictions
  • single label case doesn't use a labelencoder anymore

This way evaluate is available to both trained models and pipeline converted models. I also updated the tests to reflect this.

As for your other comment: yes, this is a essentially a thin wrapper around MultiLabelBinarizer and classification_report. However, I think this is worth it. Consider the following example:

from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer

predictions = classifier.predict(X, y)

mlb = MultiLabelBinarizer(classes=classifier.classes)
y_true = mlb.fit_transform(ds["test"]["labels"])
y_pred = mlb.transform(predictions)

print(classification_report(y_true, y_pred, target_names=classifier.classes, zero_division=0))

Vs:

print(classifier.evaluate(X, y))

This is much easier to run and understand in my opinion, and fits in with the rest of our training code, which creates a wrapper around torch/lightning. While the function does not add much for the singelabel case, it does provide a unified interface and function, and even in that case it does give a slightly nicer way to evaluate IMO:

from sklearn.metrics import classification_report

predictions = classifier.predict(X, y)

print(classification_report(y, predictions, target_names=classifier.classes, zero_division=0))

Vs:

print(classifier.evaluate(X, y))

self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.

:param X: The texts to predict on.
:param y: The ground truth labels.
:param batch_size: The batch size.
:param threshold: The threshold for multilabel classification.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
"""
self.eval()
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

return report

def _initialize(self, y: LabelType) -> None:
"""
Sets the output dimensionality, the classes, and initializes the head.
Expand Down
24 changes: 16 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,30 @@ def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification)
return mock_trained_pipeline.to_pipeline()


@pytest.fixture(params=[False, True], ids=["single_label", "multilabel"], scope="session")
@pytest.fixture(
params=[
(False, "single_label", "str"),
(False, "single_label", "int"),
(True, "multilabel", "str"),
(True, "multilabel", "int"),
],
ids=lambda param: f"{param[1]}_{param[2]}",
scope="session",
)
def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification:
"""Mock staticmodelforclassification."""
"""Mock StaticModelForClassification with different label formats."""
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
torch.random.manual_seed(42)
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")

X = ["dog", "cat"]
y: list[str] | list[list[str]]
if request.param:
# Use multilabel targets.
y = [["a", "b"], ["a"]]
is_multilabel, label_type = request.param[0], request.param[2]

if label_type == "str":
y = [["a", "b"], ["a"]] if is_multilabel else ["a", "b"] # type: ignore
else:
# Use singlelabel targets.
y = ["a", "b"]
y = [[0, 1], [0]] if is_multilabel else [0, 1] # type: ignore

model.fit(X, y)

Expand Down
40 changes: 34 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,60 @@


def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test successful initialization of StaticModelPipeline."""
"""Test successful init and predict with StaticModelPipeline."""
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = [["a", "b"]]
else:
target = [[0, 1]] # type: ignore
else:
target = ["b"]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = ["b"]
else:
target = [1] # type: ignore
assert mock_inference_pipeline.predict("dog").tolist() == target
assert mock_inference_pipeline.predict(["dog"]).tolist() == target


def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test successful initialization of StaticModelPipeline."""
"""Test successful init and predict_proba with StaticModelPipeline."""
assert mock_inference_pipeline.predict_proba("dog").argmax() == 1
assert mock_inference_pipeline.predict_proba(["dog"]).argmax(1).tolist() == [1]


def test_init_evaluate(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test successful init and evaluate with StaticModelPipeline."""
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
if isinstance(mock_inference_pipeline.classes_[0], str):
target = [["a", "b"]]
else:
target = [[0, 1]] # type: ignore
else:
if isinstance(mock_inference_pipeline.classes_[0], str):
target = ["b"]
else:
target = [1] # type: ignore
mock_inference_pipeline.evaluate("dog", target) # type: ignore


def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test saving and loading the pipeline."""
with TemporaryDirectory() as temp_dir:
mock_inference_pipeline.save_pretrained(temp_dir)
loaded = StaticModelPipeline.from_pretrained(temp_dir)
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = [["a", "b"]]
else:
target = [[0, 1]] # type: ignore
else:
target = ["b"]
if isinstance(mock_inference_pipeline.classes_[0], str):
target = ["b"]
else:
target = [1] # type: ignore
assert loaded.predict("dog").tolist() == target
assert loaded.predict(["dog"]).tolist() == target
assert loaded.predict_proba("dog").argmax() == 1
Expand Down
26 changes: 24 additions & 2 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,15 @@ def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the predict function."""
result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist()
if mock_trained_pipeline.multilabel:
assert result == [["a", "b"], ["a", "b"]]
if type(mock_trained_pipeline.classes_[0]) == str:
assert result == [["a", "b"], ["a", "b"]]
else:
assert result == [[0, 1], [0, 1]]
else:
assert result == ["b", "b"]
if type(mock_trained_pipeline.classes_[0]) == str:
assert result == ["b", "b"]
else:
assert result == [1, 1]


def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None:
Expand Down Expand Up @@ -146,3 +152,19 @@ def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -
assert len(b) == 2
assert len(c) == len(a)
assert len(d) == len(b)


def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the evaluate function."""
if mock_trained_pipeline.multilabel:
if type(mock_trained_pipeline.classes_[0]) == str:
mock_trained_pipeline.evaluate(["dog cat", "dog"], [["a", "b"], ["a"]])
else:
# Ignore the type error since we don't support int labels in our typing, but the code does
mock_trained_pipeline.evaluate(["dog cat", "dog"], [[0, 1], [0]]) # type: ignore
else:
if type(mock_trained_pipeline.classes_[0]) == str:
mock_trained_pipeline.evaluate(["dog cat", "dog"], ["a", "a"])
else:
# Ignore the type error since we don't support int labels in our typing, but the code does
mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore