diff --git a/Makefile b/Makefile index ef0f2d3e..be38ccb6 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ install: uv run pre-commit install install-no-pre-commit: - uv pip install ".[dev,distill]" + uv pip install ".[dev,distill,inference,train]" uv pip install "torch<2.5.0" install-base: diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 44d0b5e9..a24fdd31 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -60,6 +60,7 @@ def _create_model_card( license: str = "mit", language: list[str] | None = None, model_name: str | None = None, + template_path: str = "modelcards/model_card_template.md", **kwargs: Any, ) -> None: """ @@ -70,11 +71,12 @@ def _create_model_card( :param license: The license to use. :param language: The language of the model. :param model_name: The name of the model to use in the Model Card. + :param template_path: The path to the template. :param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.). """ folder_path = Path(folder_path) model_name = model_name or folder_path.name - template_path = Path(__file__).parent / "model_card_template.md" + full_path = Path(__file__).parent / template_path model_card_data = ModelCardData( model_name=model_name, @@ -85,7 +87,7 @@ def _create_model_card( library_name="model2vec", **kwargs, ) - model_card = ModelCard.from_template(model_card_data, template_path=template_path) + model_card = ModelCard.from_template(model_card_data, template_path=full_path) model_card.save(folder_path / "README.md") diff --git a/model2vec/inference/README.md b/model2vec/inference/README.md new file mode 100644 index 00000000..a14c3e1d --- /dev/null +++ b/model2vec/inference/README.md @@ -0,0 +1,18 @@ +# Inference + +This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines. + +If you're looking for information on how to train a model, see [here](../train/README.md). + +# Usage + +Let's assume you're using our [potion-edu classifier](https://huggingface.co/minishlab/potion-8m-edu-classifier). + +```python +from model2vec.inference import StaticModelPipeline + +classifier = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier") +label = classifier.predict("Attitudes towards cattle in the Alps: a study in letting go.") +``` + +This should just work. diff --git a/model2vec/inference/__init__.py b/model2vec/inference/__init__.py new file mode 100644 index 00000000..de94e18d --- /dev/null +++ b/model2vec/inference/__init__.py @@ -0,0 +1,10 @@ +from model2vec.utils import get_package_extras, importable + +_REQUIRED_EXTRA = "inference" + +for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): + importable(extra_dependency, _REQUIRED_EXTRA) + +from model2vec.inference.model import StaticModelPipeline + +__all__ = ["StaticModelPipeline"] diff --git a/model2vec/inference/model.py b/model2vec/inference/model.py new file mode 100644 index 00000000..5b08dad8 --- /dev/null +++ b/model2vec/inference/model.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import re +from pathlib import Path +from tempfile import TemporaryDirectory + +import huggingface_hub +import numpy as np +import skops.io +from sklearn.pipeline import Pipeline + +from model2vec.hf_utils import _create_model_card +from model2vec.model import PathLike, StaticModel + +_DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+") +_DEFAULT_MODEL_FILENAME = "pipeline.skops" + + +class StaticModelPipeline: + def __init__(self, model: StaticModel, head: Pipeline) -> None: + """Create a pipeline with a StaticModel encoder.""" + self.model = model + self.head = head + + @classmethod + def from_pretrained( + cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False + ) -> StaticModelPipeline: + """ + Load a StaticModel from a local path or huggingface hub path. + + NOTE: if you load a private model from the huggingface hub, you need to pass a token. + + :param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub + :param token: The token to use to download the pipeline from the hub. + :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `sklearn`. + :return: The loaded pipeline. + """ + model, head = _load_pipeline(path, token, trust_remote_code) + model.embedding = np.nan_to_num(model.embedding) + + return cls(model, head) + + def save_pretrained(self, path: str) -> None: + """Save the model to a folder.""" + save_pipeline(self, path) + + def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = False) -> None: + """ + Save a model to a folder, and then push that folder to the hf hub. + + :param repo_id: The id of the repository to push to. + :param token: The token to use to push to the hub. + :param private: Whether the repository should be private. + """ + from model2vec.hf_utils import push_folder_to_hub + + with TemporaryDirectory() as temp_dir: + save_pipeline(self, temp_dir) + self.model.save_pretrained(temp_dir) + push_folder_to_hub(Path(temp_dir), repo_id, private, token) + + def _predict_and_coerce_to_2d( + self, + X: list[str] | str, + show_progress_bar: bool, + max_length: int | None, + batch_size: int, + use_multiprocessing: bool, + multiprocessing_threshold: int, + ) -> np.ndarray: + """Predict the labels of the input and coerce the output to a matrix.""" + encoded = self.model.encode( + X, + show_progress_bar=show_progress_bar, + max_length=max_length, + batch_size=batch_size, + use_multiprocessing=use_multiprocessing, + multiprocessing_threshold=multiprocessing_threshold, + ) + if np.ndim(encoded) == 1: + encoded = encoded[None, :] + + return encoded + + def predict( + self, + X: list[str] | str, + show_progress_bar: bool = False, + max_length: int | None = 512, + batch_size: int = 1024, + use_multiprocessing: bool = True, + multiprocessing_threshold: int = 10_000, + ) -> np.ndarray: + """Predict the labels of the input.""" + encoded = self._predict_and_coerce_to_2d( + X, + show_progress_bar=show_progress_bar, + max_length=max_length, + batch_size=batch_size, + use_multiprocessing=use_multiprocessing, + multiprocessing_threshold=multiprocessing_threshold, + ) + + return self.head.predict(encoded) + + def predict_proba( + self, + X: list[str] | str, + show_progress_bar: bool = False, + max_length: int | None = 512, + batch_size: int = 1024, + use_multiprocessing: bool = True, + multiprocessing_threshold: int = 10_000, + ) -> np.ndarray: + """Predict the probabilities of the labels of the input.""" + encoded = self._predict_and_coerce_to_2d( + X, + show_progress_bar=show_progress_bar, + max_length=max_length, + batch_size=batch_size, + use_multiprocessing=use_multiprocessing, + multiprocessing_threshold=multiprocessing_threshold, + ) + + return self.head.predict_proba(encoded) + + +def _load_pipeline( + folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False +) -> tuple[StaticModel, Pipeline]: + """ + Load a model and an sklearn pipeline. + + This assumes the following files are present in the repo: + - `pipeline.skops`: The head of the pipeline. + - `config.json`: The configuration of the model. + - `model.safetensors`: The weights of the model. + - `tokenizer.json`: The tokenizer of the model. + + :param folder_or_repo_path: The path to the folder containing the pipeline. + :param token: The token to use to download the pipeline from the hub. If this is None, you will only + be able to load the pipeline from a local folder, public repository, or a repository that you have access to + because you are logged in. + :param trust_remote_code: Whether to trust the remote code. If this is False, + we will only load components coming from `sklearn`. If this is True, we will load all components. + If you set this to True, you are responsible for whatever happens. + :return: The encoder model and the loaded head + :raises FileNotFoundError: If the pipeline file does not exist in the folder. + :raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False. + """ + folder_or_repo_path = Path(folder_or_repo_path) + model_filename = _DEFAULT_MODEL_FILENAME + if folder_or_repo_path.exists(): + head_pipeline_path = folder_or_repo_path / model_filename + if not head_pipeline_path.exists(): + raise FileNotFoundError(f"Pipeline file does not exist in {folder_or_repo_path}") + else: + head_pipeline_path = huggingface_hub.hf_hub_download( + folder_or_repo_path.as_posix(), model_filename, token=token + ) + + model = StaticModel.from_pretrained(folder_or_repo_path) + + unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path) + # If the user does not trust remote code, we should check that the unknown types are trusted. + # By default, we trust everything coming from scikit-learn. + if not trust_remote_code: + for t in unknown_types: + if not _DEFAULT_TRUST_PATTERN.match(t): + raise ValueError(f"Untrusted type {t}.") + head = skops.io.load(head_pipeline_path, trusted=unknown_types) + + return model, head + + +def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None: + """ + Save a pipeline to a folder. + + :param pipeline: The pipeline to save. + :param folder_path: The path to the folder to save the pipeline to. + """ + folder_path = Path(folder_path) + folder_path.mkdir(parents=True, exist_ok=True) + model_filename = _DEFAULT_MODEL_FILENAME + head_pipeline_path = folder_path / model_filename + skops.io.dump(pipeline.head, head_pipeline_path) + pipeline.model.save_pretrained(folder_path) + base_model_name = pipeline.model.base_model_name + if isinstance(base_model_name, list) and base_model_name: + name = base_model_name[0] + elif isinstance(base_model_name, str): + name = base_model_name + else: + name = "unknown" + _create_model_card( + folder_path, + base_model_name=name, + language=pipeline.model.language, + template_path="modelcards/classifier_template.md", + ) diff --git a/model2vec/model.py b/model2vec/model.py index 39ca64a1..62255b8c 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -87,7 +87,7 @@ def normalize(self) -> bool: @normalize.setter def normalize(self, value: bool) -> None: """Update the config if the value of normalize changes.""" - config_normalize = self.config.get("normalize", False) + config_normalize = self.config.get("normalize") self._normalize = value if config_normalize is not None and value != config_normalize: logger.warning( diff --git a/model2vec/modelcards/classifier_template.md b/model2vec/modelcards/classifier_template.md new file mode 100644 index 00000000..7f38d272 --- /dev/null +++ b/model2vec/modelcards/classifier_template.md @@ -0,0 +1,49 @@ +--- +{{ card_data }} +--- + +# {{ model_name }} Model Card + +This [Model2Vec](https://github.com/MinishLab/model2vec) model is a fine-tuned version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Model2Vec model. It also includes a classifier head on top. + +## Installation + +Install model2vec using pip: +``` +pip install model2vec[inference] +``` + +## Usage +Load this model using the `from_pretrained` method: +```python +from model2vec.inference import StaticModelPipeline + +# Load a pretrained Model2Vec model +model = StaticModelPipeline.from_pretrained("{{ model_name }}") + +# Predict labels +predicted = model.predict(["Example sentence"]) +``` + +## Additional Resources + +- [All Model2Vec models on the hub](https://huggingface.co/models?library=model2vec) +- [Model2Vec Repo](https://github.com/MinishLab/model2vec) +- [Model2Vec Results](https://github.com/MinishLab/model2vec?tab=readme-ov-file#results) +- [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials) + +## Library Authors + +Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). + +## Citation + +Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work. +``` +@software{minishlab2024model2vec, + authors = {Stephan Tulkens, Thomas van Dongen}, + title = {Model2Vec: Turn any Sentence Transformer into a Small Fast Model}, + year = {2024}, + url = {https://github.com/MinishLab/model2vec}, +} +``` diff --git a/model2vec/model_card_template.md b/model2vec/modelcards/model_card_template.md similarity index 85% rename from model2vec/model_card_template.md rename to model2vec/modelcards/model_card_template.md index bfd80e41..b304d7ee 100644 --- a/model2vec/model_card_template.md +++ b/model2vec/modelcards/model_card_template.md @@ -4,7 +4,7 @@ # {{ model_name }} Model Card -This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. +This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the {{ base_model }}(https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. ## Installation diff --git a/model2vec/train/README.md b/model2vec/train/README.md new file mode 100644 index 00000000..e4c9d1ad --- /dev/null +++ b/model2vec/train/README.md @@ -0,0 +1,137 @@ +# Training + +Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). + +# Installation + +To train, make sure you install the training extra: + +``` +pip install model2vec[training] +``` + +# Quickstart + +To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows: + +```python +from model2vec.distill import distill +from model2vec.train import StaticModelForClassification + +# From a distilled model +distilled_model = distill("baai/bge-base-en-v1.5") +classifier = StaticModelForClassification.from_static_model(distilled_model) + +# From a pre-trained model: potion is the default +classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-8m") +``` + +This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M), our best model to date. This is our recommended path if you're working with general English data. + +Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed. + +```python +import numpy as np +from datasets import load_dataset + +# Load the subj dataset +ds = load_dataset("setfit/subj") +train = ds["train"] +test = ds["test"] + +s = perf_counter() +classifier = classifier.fit(train["text"], train["label"]) + +predicted = classifier.predict(test["text"]) +print(f"Training took {int(perf_counter() - s)} seconds.") +# Training took 81 seconds +accuracy = np.mean([x == y for x, y in zip(predicted, test["label"])]) * 100 +print(f"Achieved {accuracy} test accuracy") +# Achieved 91.0 test accuracy +``` + +As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training. + +The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5. + +Note that this model is as fast as you're used to from us: + +```python +from time import perf_counter + +s = perf_counter() +classifier.predict(test["text"]) +print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.") +# Took 67 milliseconds for 2000 instances on CPU. +``` + +# Persistence + +You can turn a classifier into a scikit-learn compatible pipeline, as follows: + +```python +pipeline = classifier.to_pipeline() +``` + +This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization. + +If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions: + +```python +pipeline.save_pretrained(path) +pipeline.push_to_hub("my_cool/project") +``` + +Later, you can load these as follows: + +```python +from model2vec.inference import StaticModelPipeline + +pipeline = StaticModelPipeline.from_pretrained("my_cool/project") +``` + +Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk. + +# Results + +The main results are detailed in our training blogpost, but we'll do a comparison with vanilla model2vec here. In a vanilla model2vec classifier, you just put a scikit-learn `LogisticRegressionCV` on top of the model encoder. In contrast, training a `StaticModelForClassification` fine-tunes the full model, including the `StaticModel` weights. + +We use 14 classification datasets, using 1000 examples from the train set, and the full test set. No parameters were tuned on any validation set. All datasets were taken from the [Setfit organization on Hugging Face](https://huggingface.co/datasets/SetFit). + +| dataset_name | model2vec logreg | setfit | model2vec full finetune | +|:---------------------------|---------------------------------------------:|-------------------------------------------------:|--------------------------------------:| +| 20_newgroups | 0.545312 | 0.595426 | 0.555459 | +| ade | 0.715725 | 0.788789 | 0.740307 | +| ag_news | 0.860154 | 0.880142 | 0.858304 | +| amazon_counterfactual | 0.637754 | 0.873249 | 0.744288 | +| bbc | 0.955719 | 0.965823 | 0.965018 | +| emotion | 0.516267 | 0.598852 | 0.586328 | +| enron_spam | 0.951975 | 0.974498 | 0.964994 | +| hatespeech_offensive | 0.543758 | 0.659873 | 0.592587 | +| imdb | 0.839002 | 0.860037 | 0.846198 | +| massive_scenario | 0.797779 | 0.814601 | 0.822825 | +| senteval_cr | 0.743436 | 0.8526 | 0.745863 | +| sst5 | 0.290249 | 0.393179 | 0.363071 | +| student | 0.806069 | 0.889399 | 0.837581 | +| subj | 0.878394 | 0.937955 | 0.88941 | +| tweet_sentiment_extraction | 0.638664 | 0.755296 | 0.632009 | + +| | logreg | full finetune | +|:---------------------------|-----------:|---------------:| +| average | 0.714 | 0.742 | + +As you can see, full fine-tuning brings modest performance improvements in some cases, but very large ones in other cases, leading to a pretty large increase in average score. Our advice is to test both if you can use `potion-base-8m`, and to use full fine-tuning if you are starting from another base model. + +# Bring your own architecture + +Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc. + +The core functionality of the `StaticModelForClassification` is contained in a couple of functions: + +* `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior. +* `train_test_split`: governs the train test split before classification. +* `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training. +* `_encode`: The encoding function used in the model. +* `fit`: contains all the lightning-related fitting logic. + +The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. diff --git a/model2vec/train/__init__.py b/model2vec/train/__init__.py new file mode 100644 index 00000000..c70f8039 --- /dev/null +++ b/model2vec/train/__init__.py @@ -0,0 +1,10 @@ +from model2vec.utils import get_package_extras, importable + +_REQUIRED_EXTRA = "train" + +for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): + importable(extra_dependency, _REQUIRED_EXTRA) + +from model2vec.train.classifier import StaticModelForClassification + +__all__ = ["StaticModelForClassification"] diff --git a/model2vec/train/base.py b/model2vec/train/base.py new file mode 100644 index 00000000..65f4b45e --- /dev/null +++ b/model2vec/train/base.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any, TypeVar + +import numpy as np +import torch +from tokenizers import Encoding, Tokenizer +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset + +from model2vec import StaticModel + + +class FinetunableStaticModel(nn.Module): + def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None: + """ + Initialize a trainable StaticModel from a StaticModel. + + :param vectors: The embeddings of the staticmodel. + :param tokenizer: The tokenizer. + :param out_dim: The output dimension of the head. + :param pad_id: The padding id. This is set to 0 in almost all model2vec models + """ + super().__init__() + self.pad_id = pad_id + self.out_dim = out_dim + self.embed_dim = vectors.shape[1] + self.vectors = vectors + + self.embeddings = nn.Embedding.from_pretrained(vectors.clone().float(), freeze=False, padding_idx=pad_id) + self.head = self.construct_head() + self.w = self.construct_weights() + self.tokenizer = tokenizer + + def construct_weights(self) -> nn.Parameter: + """Construct the weights for the model.""" + weights = torch.zeros(len(self.vectors)) + weights[self.pad_id] = -10_000 + return nn.Parameter(weights) + + def construct_head(self) -> nn.Sequential: + """Method should be overridden for various other classes.""" + return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) + + @classmethod + def from_pretrained( + cls: type[ModelType], out_dim: int = 2, model_name: str = "minishlab/potion-base-8m", **kwargs: Any + ) -> ModelType: + """Load the model from a pretrained model2vec model.""" + model = StaticModel.from_pretrained(model_name) + return cls.from_static_model(model, out_dim, **kwargs) + + @classmethod + def from_static_model(cls: type[ModelType], model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType: + """Load the model from a static model.""" + model.embedding = np.nan_to_num(model.embedding) + embeddings_converted = torch.from_numpy(model.embedding) + return cls( + vectors=embeddings_converted, + pad_id=model.tokenizer.token_to_id("[PAD]"), + out_dim=out_dim, + tokenizer=model.tokenizer, + **kwargs, + ) + + def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + A forward pass and mean pooling. + + This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients + to pass through. + + :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds. + :return: The mean over the input ids, weighted by token weights. + """ + w = self.w[input_ids] + w = torch.sigmoid(w) + zeros = (input_ids != self.pad_id).float() + w = w * zeros + # Add a small epsilon to avoid division by zero + length = zeros.sum(1) + 1e-16 + embedded = self.embeddings(input_ids) + # Simulate actual mean + # Zero out the padding + embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) + # embedded = embedded.sum(1) + embedded = embedded / length[:, None] + + return nn.functional.normalize(embedded) + + def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass through the mean, and a classifier layer after.""" + encoded = self._encode(input_ids) + return self.head(encoded), encoded + + def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor: + """ + Tokenize a bunch of strings into a single padded 2D tensor. + + Note that this is not used during training. + + :param texts: The texts to tokenize. + :param max_length: If this is None, the sequence lengths are truncated to 512. + :return: A 2D padded tensor + """ + encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False) + encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded] + return pad_sequence(encoded_ids, batch_first=True) + + @property + def device(self) -> str: + """Get the device of the model.""" + return self.embeddings.weight.device + + def to_static_model(self) -> StaticModel: + """Convert the model to a static model.""" + emb = self.embeddings.weight.detach().cpu().numpy() + w = torch.sigmoid(self.w).detach().cpu().numpy() + + return StaticModel(emb * w[:, None], self.tokenizer, normalize=True) + + +class TextDataset(Dataset): + def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None: + """ + A dataset of texts. + + :param tokenized_texts: The tokenized texts. Each text is a list of token ids. + :param targets: The targets. + :raises ValueError: If the number of labels does not match the number of texts. + """ + if len(targets) != len(tokenized_texts): + raise ValueError("Number of labels does not match number of texts.") + self.tokenized_texts = tokenized_texts + self.targets = targets + + def __len__(self) -> int: + """Return the length of the dataset.""" + return len(self.tokenized_texts) + + def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: + """Gets an item.""" + return self.tokenized_texts[index], self.targets[index] + + @staticmethod + def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: + """Collate function.""" + texts, targets = zip(*batch) + + tensors = [torch.LongTensor(x) for x in texts] + padded = pad_sequence(tensors, batch_first=True, padding_value=0) + + return padded, torch.stack(targets) + + def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: + """Convert the dataset to a DataLoader.""" + return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) + + +ModelType = TypeVar("ModelType", bound=FinetunableStaticModel) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py new file mode 100644 index 00000000..9ec30ee0 --- /dev/null +++ b/model2vec/train/classifier.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +import logging +from collections import Counter +from tempfile import TemporaryDirectory + +import lightning as pl +import numpy as np +import torch +from lightning.pytorch.callbacks import Callback, EarlyStopping +from lightning.pytorch.utilities.types import OptimizerLRScheduler +from sklearn.model_selection import train_test_split +from sklearn.neural_network import MLPClassifier +from sklearn.pipeline import make_pipeline +from tokenizers import Tokenizer +from torch import nn +from tqdm import trange + +from model2vec.inference import StaticModelPipeline +from model2vec.train.base import FinetunableStaticModel, TextDataset + +logger = logging.getLogger(__name__) + +_RANDOM_SEED = 42 + + +class StaticModelForClassification(FinetunableStaticModel): + def __init__( + self, + *, + vectors: torch.Tensor, + tokenizer: Tokenizer, + n_layers: int = 1, + hidden_dim: int = 512, + out_dim: int = 2, + pad_id: int = 0, + ) -> None: + """Initialize a standard classifier model.""" + self.n_layers = n_layers + self.hidden_dim = hidden_dim + # Alias: Follows scikit-learn. Set to dummy classes + self.classes_: list[str] = [str(x) for x in range(out_dim)] + super().__init__(vectors=vectors, out_dim=out_dim, pad_id=pad_id, tokenizer=tokenizer) + + @property + def classes(self) -> list[str]: + """Return all clasess in the correct order.""" + return self.classes_ + + def construct_head(self) -> nn.Sequential: + """Constructs a simple classifier head.""" + if self.n_layers == 0: + return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) + modules = [ + nn.Linear(self.embed_dim, self.hidden_dim), + nn.ReLU(), + ] + for _ in range(self.n_layers - 1): + modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()]) + modules.extend([nn.Linear(self.hidden_dim, self.out_dim)]) + + for module in modules: + if isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight) + nn.init.zeros_(module.bias) + + return nn.Sequential(*modules) + + def predict(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: + """Predict a class for a set of texts.""" + pred: list[str] = [] + for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): + logits = self._predict_single_batch(X[batch : batch + batch_size]) + pred.extend([self.classes[idx] for idx in logits.argmax(1)]) + + return np.asarray(pred) + + @torch.no_grad() + def _predict_single_batch(self, X: list[str]) -> torch.Tensor: + input_ids = self.tokenize(X) + vectors, _ = self.forward(input_ids) + return vectors + + def predict_proba(self, X: list[str], show_progress_bar: bool = False, batch_size: int = 1024) -> np.ndarray: + """Predict the probability of each class.""" + pred: list[np.ndarray] = [] + for batch in trange(0, len(X), batch_size, disable=not show_progress_bar): + logits = self._predict_single_batch(X[batch : batch + batch_size]) + pred.append(torch.softmax(logits, dim=1).numpy()) + + return np.concatenate(pred) + + def fit( + self, + X: list[str], + y: list[str], + learning_rate: float = 1e-3, + batch_size: int | None = None, + early_stopping_patience: int | None = 5, + test_size: float = 0.1, + device: str = "auto", + ) -> StaticModelForClassification: + """ + Fit a model. + + This function creates a Lightning Trainer object and fits the model to the data. + We use early stopping. After training, the weigths of the best model are loaded back into the model. + + This function seeds everything with a seed of 42, so the results are reproducible. + It also splits the data into a train and validation set, again with a random seed. + + :param X: The texts to train on. + :param y: The labels to train on. + :param learning_rate: The learning rate. + :param batch_size: The batch size. + If this is None, a good batch size is chosen automatically. + :param early_stopping_patience: The patience for early stopping. + If this is None, early stopping is disabled. + :param test_size: The test size for the train-test split. + :param device: The device to train on. If this is "auto", the device is chosen automatically. + :return: The fitted model. + """ + pl.seed_everything(_RANDOM_SEED) + logger.info("Re-initializing model.") + self._initialize(y) + + train_texts, validation_texts, train_labels, validation_labels = self._train_test_split( + X, y, test_size=test_size + ) + + if batch_size is None: + # Set to a multiple of 32 + base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16)) + batch_size = int(base_number * 32) + logger.info("Batch size automatically set to %d.", batch_size) + + logger.info("Preparing train dataset.") + train_dataset = self._prepare_dataset(train_texts, train_labels) + logger.info("Preparing validation dataset.") + val_dataset = self._prepare_dataset(validation_texts, validation_labels) + + c = _ClassifierLightningModule(self, learning_rate=learning_rate) + + n_train_batches = len(train_dataset) // batch_size + callbacks: list[Callback] = [] + if early_stopping_patience is not None: + callback = EarlyStopping(monitor="val_accuracy", mode="max", patience=early_stopping_patience) + callbacks.append(callback) + + # If the dataset is small, we check the validation set every epoch. + # If the dataset is large, we check the validation set every 250 batches. + if n_train_batches < 250: + val_check_interval = None + check_val_every_epoch = 1 + else: + val_check_interval = max(250, 2 * len(val_dataset) // batch_size) + check_val_every_epoch = None + + with TemporaryDirectory() as tempdir: + trainer = pl.Trainer( + max_epochs=500, + callbacks=callbacks, + val_check_interval=val_check_interval, + check_val_every_n_epoch=check_val_every_epoch, + accelerator=device, + default_root_dir=tempdir, + ) + + trainer.fit( + c, + train_dataloaders=train_dataset.to_dataloader(shuffle=True, batch_size=batch_size), + val_dataloaders=val_dataset.to_dataloader(shuffle=False, batch_size=batch_size), + ) + best_model_path = trainer.checkpoint_callback.best_model_path # type: ignore + best_model_weights = torch.load(best_model_path, weights_only=True) + + state_dict = {} + for weight_name, weight in best_model_weights["state_dict"].items(): + state_dict[weight_name.removeprefix("model.")] = weight + + self.load_state_dict(state_dict) + self.eval() + + return self + + def _initialize(self, y: list[str]) -> None: + """Sets the out dimensionality, the classes and initializes the head.""" + classes = sorted(set(y)) + self.classes_ = classes + + if len(self.classes) != self.out_dim: + self.out_dim = len(self.classes) + + self.head = self.construct_head() + self.embeddings = nn.Embedding.from_pretrained(self.vectors.clone(), freeze=False, padding_idx=self.pad_id) + self.w = self.construct_weights() + self.train() + + def _prepare_dataset(self, X: list[str], y: list[str], max_length: int = 512) -> TextDataset: + """Prepare a dataset.""" + # This is a speed optimization. + # assumes a mean token length of 10, which is really high, so safe. + truncate_length = max_length * 10 + X = [x[:truncate_length] for x in X] + tokenized: list[list[int]] = [ + encoding.ids[:max_length] for encoding in self.tokenizer.encode_batch_fast(X, add_special_tokens=False) + ] + labels_tensor = torch.Tensor([self.classes.index(label) for label in y]).long() + + return TextDataset(tokenized, labels_tensor) + + @staticmethod + def _train_test_split( + X: list[str], y: list[str], test_size: float + ) -> tuple[list[str], list[str], list[str], list[str]]: + """Split the data.""" + label_counts = Counter(y) + if min(label_counts.values()) < 2: + logger.info("Some classes have less than 2 samples. Stratification is disabled.") + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True) + return train_test_split(X, y, test_size=test_size, random_state=42, shuffle=True, stratify=y) + + def to_pipeline(self) -> StaticModelPipeline: + """Convert the model to an sklearn pipeline.""" + static_model = self.to_static_model() + + random_state = np.random.RandomState(_RANDOM_SEED) + n_items = len(self.classes) + X = random_state.randn(n_items, static_model.dim) + y = self.classes + + converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers)) + converted.fit(X, y) + mlp_head: MLPClassifier = converted[-1] + + for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]): + mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T + mlp_head.intercepts_[index] = layer.bias.detach().cpu().numpy() + # Below is necessary to ensure that the converted model works correctly. + # In scikit-learn, a binary classifier only has a single vector of output coefficients + # and a single intercept. We use two output vectors. + # To convert correctly, we need to set the outputs correctly, and fix the activation function. + # Make sure n_outputs is set to > 1. + mlp_head.n_outputs_ = self.out_dim + # Set to softmax + mlp_head.out_activation_ = "softmax" + + return StaticModelPipeline(static_model, converted) + + +class _ClassifierLightningModule(pl.LightningModule): + def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None: + """Initialize the lightningmodule.""" + super().__init__() + self.model = model + self.learning_rate = learning_rate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Simple forward pass.""" + return self.model(x) + + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Simple training step using cross entropy loss.""" + x, y = batch + head_out, _ = self.model(x) + loss = nn.functional.cross_entropy(head_out, y).mean() + + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Simple validation step using cross entropy loss and accuracy.""" + x, y = batch + head_out, _ = self.model(x) + loss = nn.functional.cross_entropy(head_out, y).mean() + accuracy = (head_out.argmax(1) == y).float().mean() + + self.log("val_loss", loss) + self.log("val_accuracy", accuracy, prog_bar=True) + + return loss + + def configure_optimizers(self) -> OptimizerLRScheduler: + """Simple Adam optimizer.""" + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=0.5, + patience=3, + verbose=True, + min_lr=1e-6, + threshold=0.03, + threshold_mode="rel", + ) + + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} diff --git a/model2vec/utils.py b/model2vec/utils.py index 235c3d91..f11f079b 100644 --- a/model2vec/utils.py +++ b/model2vec/utils.py @@ -88,7 +88,7 @@ def importable(module: str, extra: str) -> None: import_module(module) except ImportError: raise ImportError( - f"`{module}`, is required. Please reinstall model2vec with the `distill` extra. `pip install model2vec[{extra}]`" + f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`" ) diff --git a/pyproject.toml b/pyproject.toml index fbf11576..c1c56d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ packages = ["model2vec"] include-package-data = true [tool.setuptools.package-data] -model2vec = ["assets/model_card_template.md"] +model2vec = ["assets/modelcards/model_card_template.md", "assets/modelcards/classifier_template.md"] [project.optional-dependencies] dev = [ @@ -54,8 +54,12 @@ dev = [ "pytest-cov", "ruff", ] + distill = ["torch", "transformers", "scikit-learn"] onnx = ["onnx", "torch"] +# train also installs inference +train = ["torch", "lightning", "scikit-learn", "skops"] +inference = ["scikit-learn", "skops"] [project.urls] "Homepage" = "https://github.com/MinishLab" diff --git a/tests/conftest.py b/tests/conftest.py index ced1abca..62203292 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,16 +5,22 @@ import numpy as np import pytest import torch +from sklearn.neural_network import MLPClassifier +from sklearn.pipeline import make_pipeline from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from transformers import AutoModel, AutoTokenizer +from model2vec.inference import StaticModelPipeline +from model2vec.model import StaticModel +from model2vec.train import StaticModelForClassification -@pytest.fixture + +@pytest.fixture(scope="session") def mock_tokenizer() -> Tokenizer: """Create a mock tokenizer.""" - vocab = ["word1", "word2", "word3", "[UNK]", "[PAD]"] + vocab = ["[PAD]", "word1", "word2", "word3", "[UNK]"] unk_token = "[UNK]" model = WordLevel(vocab={word: idx for idx, word in enumerate(vocab)}, unk_token=unk_token) @@ -62,7 +68,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return MockPreTrainedModel() -@pytest.fixture +@pytest.fixture(scope="session") def mock_vectors() -> np.ndarray: """Create mock vectors.""" return np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.0, 0.0], [0.0, 0.0]]) @@ -72,3 +78,21 @@ def mock_vectors() -> np.ndarray: def mock_config() -> dict[str, str]: """Create a mock config.""" return {"some_config": "value"} + + +@pytest.fixture(scope="session") +def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification) -> StaticModelPipeline: + """Mock pipeline.""" + return mock_trained_pipeline.to_pipeline() + + +@pytest.fixture(scope="session") +def mock_trained_pipeline() -> StaticModelForClassification: + """Mock staticmodelforclassification.""" + tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer + torch.random.manual_seed(42) + vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) + s = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") + s.fit(["dog", "cat"], ["a", "b"], device="cpu") + + return s diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..9f4618df --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,50 @@ +import os +import re +from tempfile import TemporaryDirectory +from unittest.mock import patch + +import pytest + +from model2vec.inference import StaticModelPipeline + + +def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: + """Test successful initialization of StaticModelPipeline.""" + assert mock_inference_pipeline.predict("dog").tolist() == ["b"] + assert mock_inference_pipeline.predict(["dog"]).tolist() == ["b"] + + +def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None: + """Test successful initialization of StaticModelPipeline.""" + assert mock_inference_pipeline.predict_proba("dog").argmax() == 1 + assert mock_inference_pipeline.predict_proba(["dog"]).argmax(1).tolist() == [1] + + +def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None: + """Test saving and loading the pipeline.""" + with TemporaryDirectory() as temp_dir: + mock_inference_pipeline.save_pretrained(temp_dir) + loaded = StaticModelPipeline.from_pretrained(temp_dir) + assert loaded.predict("dog") == ["b"] + assert loaded.predict(["dog"]) == ["b"] + assert loaded.predict_proba("dog").argmax() == 1 + assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1] + + +@patch("model2vec.inference.model._DEFAULT_TRUST_PATTERN", re.compile("torch")) +def test_roundtrip_save_mock_trust_pattern(mock_inference_pipeline: StaticModelPipeline) -> None: + """Test saving and loading the pipeline.""" + with TemporaryDirectory() as temp_dir: + mock_inference_pipeline.save_pretrained(temp_dir) + with pytest.raises(ValueError): + StaticModelPipeline.from_pretrained(temp_dir) + + +def test_roundtrip_save_file_gone(mock_inference_pipeline: StaticModelPipeline) -> None: + """Test saving and loading the pipeline.""" + with TemporaryDirectory() as temp_dir: + mock_inference_pipeline.save_pretrained(temp_dir) + # Rename the file to abc.pipeline, so that it looks like it was downloaded from the hub + os.unlink(os.path.join(temp_dir, "pipeline.skops")) + with pytest.raises(FileNotFoundError): + StaticModelPipeline.from_pretrained(temp_dir) diff --git a/tests/test_trainable.py b/tests/test_trainable.py new file mode 100644 index 00000000..dc9bb811 --- /dev/null +++ b/tests/test_trainable.py @@ -0,0 +1,145 @@ +from tempfile import TemporaryDirectory + +import numpy as np +import pytest +import torch +from tokenizers import Tokenizer + +from model2vec.model import StaticModel +from model2vec.train import StaticModelForClassification +from model2vec.train.base import FinetunableStaticModel, TextDataset + + +@pytest.mark.parametrize("n_layers", [0, 1, 2, 3]) +def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test successful initialization of StaticModelForClassification.""" + vectors_torched = torch.from_numpy(mock_vectors) + s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + assert s.classes == s.classes_ + assert s.classes == ["0", "1"] + + head = s.construct_head() + assert head[0].in_features == mock_vectors.shape[1] + head = s.construct_head() + assert head[0].in_features == mock_vectors.shape[1] + assert head[-1].out_features == 2 + + +def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test successful initialization of the base class.""" + vectors_torched = torch.from_numpy(mock_vectors) + s = FinetunableStaticModel(vectors=vectors_torched, tokenizer=mock_tokenizer) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + head = s.construct_head() + assert head[0].in_features == mock_vectors.shape[1] + + +def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test initializion from a static model.""" + model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) + s = FinetunableStaticModel.from_static_model(model) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + with TemporaryDirectory() as temp_dir: + model.save_pretrained(temp_dir) + s = FinetunableStaticModel.from_pretrained(model_name=temp_dir) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + +def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test initializion from a static model.""" + model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) + s = StaticModelForClassification.from_static_model(model) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + with TemporaryDirectory() as temp_dir: + model.save_pretrained(temp_dir) + s = StaticModelForClassification.from_pretrained(model_name=temp_dir) + assert s.vectors.shape == mock_vectors.shape + assert s.w.shape[0] == mock_vectors.shape[0] + + +def test_encode(mock_trained_pipeline: StaticModelForClassification) -> None: + """Test the encode function.""" + result = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()) + assert result.shape == (2, 12) + assert torch.allclose(result[0], result[1]) + + +def test_tokenize(mock_trained_pipeline: StaticModelForClassification) -> None: + """Test the encode function.""" + result = mock_trained_pipeline.tokenize(["dog dog", "cat"]) + assert result.shape == torch.Size([2, 2]) + assert result[1, 1] == 0 + + +def test_device(mock_trained_pipeline: StaticModelForClassification) -> None: + """Get the device.""" + assert mock_trained_pipeline.device == torch.device(type="cpu") # type: ignore # False positive + assert mock_trained_pipeline.device == mock_trained_pipeline.w.device + + +def test_conversion(mock_trained_pipeline: StaticModelForClassification) -> None: + """Test the conversion to numpy.""" + staticmodel = mock_trained_pipeline.to_static_model() + with torch.no_grad(): + result_1 = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()).numpy() + result_2 = staticmodel.embedding[[[0, 1], [1, 0]]].mean(0) + result_2 /= np.linalg.norm(result_2, axis=1, keepdims=True) + + assert np.allclose(result_1, result_2) + + +def test_textdataset_init() -> None: + """Test the textdataset init.""" + dataset = TextDataset([[0], [1]], torch.arange(2)) + assert len(dataset) == 2 + + +def test_textdataset_init_incorrect() -> None: + """Test the textdataset init.""" + with pytest.raises(ValueError): + TextDataset([[0]], torch.arange(2)) + + +def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None: + """Test the predict function.""" + result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() + assert result == ["b", "b"] + + +def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None: + """Test the predict function.""" + result = mock_trained_pipeline.predict_proba(["dog cat", "dog"]) + assert result.shape == (2, 2) + + +def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification) -> None: + """Convert a model to a pipeline.""" + mock_trained_pipeline.eval() + pipeline = mock_trained_pipeline.to_pipeline() + encoded_pipeline = pipeline.model.encode(["dog cat", "dog"]) + encoded_model = mock_trained_pipeline(mock_trained_pipeline.tokenize(["dog cat", "dog"]))[1].detach().numpy() + assert np.allclose(encoded_pipeline, encoded_model) + a = pipeline.predict(["dog cat", "dog"]).tolist() + b = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() + assert a == b + p1 = pipeline.predict_proba(["dog cat", "dog"]) + p2 = mock_trained_pipeline.predict_proba(["dog cat", "dog"]) + assert np.allclose(p1, p2) + + +def test_train_test_split() -> None: + """Test the train test split function.""" + a, b, c, d = StaticModelForClassification._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) + assert len(a) == 2 + assert len(b) == 2 + assert len(c) == len(a) + assert len(d) == len(b) diff --git a/tutorials/README.md b/tutorials/README.md index 6d9c6805..2874196a 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -13,3 +13,4 @@ This is a list of all our tutorials. They are all self-contained ipython noteboo |--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| | **Recipe search** 🍝 | Learn how to do lightning-fast semantic search by distilling a small model. Compare a really tiny model to a larger with one with a better vocabulary. Learn what Fattoush is (delicious). | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/recipe_search.ipynb) | | **Semantic chunking** 🧩 | Learn how to chunk your text into meaningful segments with [Chonkie](https://github.com/bhavnicksm/chonkie) at lightning-speed. Efficiently query your chunks with [Vicinity](https://github.com/MinishLab/vicinity). | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/semantic_chunking.ipynb) | +| **Training a classifier** 🧩 | Learn how to train a classifier using model2vec. Lightning fast, great performance, especially on small datasets | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/train_classifier.ipynb) | diff --git a/tutorials/train_classifier.ipynb b/tutorials/train_classifier.ipynb new file mode 100644 index 00000000..988007d6 --- /dev/null +++ b/tutorials/train_classifier.ipynb @@ -0,0 +1,806 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a classifier using model2vec\n", + "\n", + "Model2Vec supports built-in classifier training with an easy, scikit-learn-based syntax. Just give the model your data in `.fit`, and you'll have a trained model!\n", + "\n", + "How it works:\n", + "* We load a base `StaticModel` using as a torch module. By default we use [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M).\n", + "* We add a one-layer MLP with 512 hidden units and `ReLU` activation as a head.\n", + "* We train the model using cross-entropy, using [`pytorch-lightning`](https://lightning.ai/docs/pytorch/stable/) as a training framework.\n", + "\n", + "After training, you can export the model using regular torch tools, such as `torch.save` and `torch.load`, or you can export the model to a `scikit-learn` pipeline. The latter option leads to a really small footprint during inference, as there is no longer a need to use `torch`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n", + "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 4ms\u001b[0m\u001b[0m\n", + "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n", + "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 8ms\u001b[0m\u001b[0m\n", + "\u001b[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv\u001b[0m\n", + "\u001b[2mAudited \u001b[1m1 package\u001b[0m \u001b[2min 3ms\u001b[0m\u001b[0m\n" + ] + } + ], + "source": [ + "# Install the necessary libraries\n", + "!uv pip install \"model2vec[train,inference]\"\n", + "!uv pip install \"datasets\"\n", + "!uv pip install \"scikit-learn\"\n", + "\n", + "# Import the necessary libraries\n", + "from model2vec.train import StaticModelForClassification\n", + "from model2vec.inference import StaticModelPipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To demonstrate how to train a model, we'll be using the `20_newsgroups` dataset, which contains posts from 1 of 20 newsgroups." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Repo card metadata block was not found. Setting CardData to empty.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text', 'label', 'label_text'],\n", + " num_rows: 11314\n", + " })\n", + " test: Dataset({\n", + " features: ['text', 'label', 'label_text'],\n", + " num_rows: 7532\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"setfit/20_newsgroups\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's take a look at the first five training samples:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TEXT: I was wondering if anyone out there could enlighten me on this car I saw\n", + "the other day. It was a 2-door sports car, looked to be from the late 60s/\n", + "early 70s. It was called a Bricklin. The doors were really small. In addition,\n", + "the front bumper was separate from the rest of the body. This is \n", + "all I know. If anyone can tellme a model name, engine specs, years\n", + "of production, where this car is made, history, or whatever info you\n", + "have on this funky looking car, please e-mail. LABEL: rec.autos\n", + "TEXT: A fair number of brave souls who upgraded their SI clock oscillator have\n", + "shared their experiences for this poll. Please send a brief message detailing\n", + "your experiences with the procedure. Top speed attained, CPU rated speed,\n", + "add on cards and adapters, heat sinks, hour of usage per day, floppy disk\n", + "functionality with 800 and 1.4 m floppies are especially requested.\n", + "\n", + "I will be summarizing in the next two days, so please add to the network\n", + "knowledge base if you have done the clock upgrade and haven't answered this\n", + "poll. Thanks. LABEL: comp.sys.mac.hardware\n", + "TEXT: well folks, my mac plus finally gave up the ghost this weekend after\n", + "starting life as a 512k way back in 1985. sooo, i'm in the market for a\n", + "new machine a bit sooner than i intended to be...\n", + "\n", + "i'm looking into picking up a powerbook 160 or maybe 180 and have a bunch\n", + "of questions that (hopefully) somebody can answer:\n", + "\n", + "* does anybody know any dirt on when the next round of powerbook\n", + "introductions are expected? i'd heard the 185c was supposed to make an\n", + "appearence \"this summer\" but haven't heard anymore on it - and since i\n", + "don't have access to macleak, i was wondering if anybody out there had\n", + "more info...\n", + "\n", + "* has anybody heard rumors about price drops to the powerbook line like the\n", + "ones the duo's just went through recently?\n", + "\n", + "* what's the impression of the display on the 180? i could probably swing\n", + "a 180 if i got the 80Mb disk rather than the 120, but i don't really have\n", + "a feel for how much \"better\" the display is (yea, it looks great in the\n", + "store, but is that all \"wow\" or is it really that good?). could i solicit\n", + "some opinions of people who use the 160 and 180 day-to-day on if its worth\n", + "taking the disk size and money hit to get the active display? (i realize\n", + "this is a real subjective question, but i've only played around with the\n", + "machines in a computer store breifly and figured the opinions of somebody\n", + "who actually uses the machine daily might prove helpful).\n", + "\n", + "* how well does hellcats perform? ;)\n", + "\n", + "thanks a bunch in advance for any info - if you could email, i'll post a\n", + "summary (news reading time is at a premium with finals just around the\n", + "corner... :( )\n", + "--\n", + "Tom Willis \\ twillis@ecn.purdue.edu \\ Purdue Electrical Engineering LABEL: comp.sys.mac.hardware\n", + "TEXT: \n", + "Do you have Weitek's address/phone number? I'd like to get some information\n", + "about this chip.\n", + " LABEL: comp.graphics\n", + "TEXT: From article , by tombaker@world.std.com (Tom A Baker):\n", + "\n", + "\n", + "My understanding is that the 'expected errors' are basically\n", + "known bugs in the warning system software - things are checked\n", + "that don't have the right values in yet because they aren't\n", + "set till after launch, and suchlike. Rather than fix the code\n", + "and possibly introduce new bugs, they just tell the crew\n", + "'ok, if you see a warning no. 213 before liftoff, ignore it'. LABEL: sci.space\n" + ] + } + ], + "source": [ + "# First 5 training samples:\n", + "for record in dataset[\"train\"].to_list()[:5]:\n", + " print(f\"TEXT: {record['text']} LABEL: {record['label_text']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StaticModelForClassification(\n", + " (embeddings): Embedding(29528, 256, padding_idx=0)\n", + " (head): Sequential(\n", + " (0): Linear(in_features=256, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=512, out_features=2, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "# Define the staticmodel\n", + "model = StaticModelForClassification.from_pretrained()\n", + "# Optional arguments:\n", + "# model_name: the name of the base model (defaults to potion-base-8m)\n", + "# n_layers: the number of layers in the MLP (defaults to 1)\n", + "# hidden_dim: the number of hidden units (defaults to 512)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's train the model on a subset of examples. We pick the first 1000 examples to train on." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n", + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/torch/optim/lr_scheduler.py:60: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n", + " warnings.warn(\n", + "\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------\n", + "0 | model | StaticModelForClassification | 7.7 M | train\n", + "---------------------------------------------------------------\n", + "7.7 M Trainable params\n", + "0 Non-trainable params\n", + "7.7 M Total params\n", + "30.922 Total estimated model params size (MB)\n", + "6 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2351ba8c0b53458fb680e8d29e0f0a6c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00=0.20" }, { name = "torch", marker = "extra == 'distill'" }, { name = "torch", marker = "extra == 'onnx'" }, + { name = "torch", marker = "extra == 'train'" }, { name = "tqdm" }, { name = "transformers", marker = "extra == 'distill'" }, ] @@ -604,6 +875,93 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, ] +[[package]] +name = "multidict" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/be/504b89a5e9ca731cd47487e91c469064f8ae5af93b7259758dcfc2b9c848/multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a", size = 64002 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/68/259dee7fd14cf56a17c554125e534f6274c2860159692a414d0b402b9a6d/multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60", size = 48628 }, + { url = "https://files.pythonhosted.org/packages/50/79/53ba256069fe5386a4a9e80d4e12857ced9de295baf3e20c68cdda746e04/multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1", size = 29327 }, + { url = "https://files.pythonhosted.org/packages/ff/10/71f1379b05b196dae749b5ac062e87273e3f11634f447ebac12a571d90ae/multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53", size = 29689 }, + { url = "https://files.pythonhosted.org/packages/71/45/70bac4f87438ded36ad4793793c0095de6572d433d98575a5752629ef549/multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5", size = 126639 }, + { url = "https://files.pythonhosted.org/packages/80/cf/17f35b3b9509b4959303c05379c4bfb0d7dd05c3306039fc79cf035bbac0/multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581", size = 134315 }, + { url = "https://files.pythonhosted.org/packages/ef/1f/652d70ab5effb33c031510a3503d4d6efc5ec93153562f1ee0acdc895a57/multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56", size = 129471 }, + { url = "https://files.pythonhosted.org/packages/a6/64/2dd6c4c681688c0165dea3975a6a4eab4944ea30f35000f8b8af1df3148c/multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429", size = 124585 }, + { url = "https://files.pythonhosted.org/packages/87/56/e6ee5459894c7e554b57ba88f7257dc3c3d2d379cb15baaa1e265b8c6165/multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748", size = 116957 }, + { url = "https://files.pythonhosted.org/packages/36/9e/616ce5e8d375c24b84f14fc263c7ef1d8d5e8ef529dbc0f1df8ce71bb5b8/multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db", size = 128609 }, + { url = "https://files.pythonhosted.org/packages/8c/4f/4783e48a38495d000f2124020dc96bacc806a4340345211b1ab6175a6cb4/multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056", size = 123016 }, + { url = "https://files.pythonhosted.org/packages/3e/b3/4950551ab8fc39862ba5e9907dc821f896aa829b4524b4deefd3e12945ab/multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76", size = 133542 }, + { url = "https://files.pythonhosted.org/packages/96/4d/f0ce6ac9914168a2a71df117935bb1f1781916acdecbb43285e225b484b8/multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160", size = 130163 }, + { url = "https://files.pythonhosted.org/packages/be/72/17c9f67e7542a49dd252c5ae50248607dfb780bcc03035907dafefb067e3/multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7", size = 126832 }, + { url = "https://files.pythonhosted.org/packages/71/9f/72d719e248cbd755c8736c6d14780533a1606ffb3fbb0fbd77da9f0372da/multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0", size = 26402 }, + { url = "https://files.pythonhosted.org/packages/04/5a/d88cd5d00a184e1ddffc82aa2e6e915164a6d2641ed3606e766b5d2f275a/multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d", size = 28800 }, + { url = "https://files.pythonhosted.org/packages/93/13/df3505a46d0cd08428e4c8169a196131d1b0c4b515c3649829258843dde6/multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6", size = 48570 }, + { url = "https://files.pythonhosted.org/packages/f0/e1/a215908bfae1343cdb72f805366592bdd60487b4232d039c437fe8f5013d/multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156", size = 29316 }, + { url = "https://files.pythonhosted.org/packages/70/0f/6dc70ddf5d442702ed74f298d69977f904960b82368532c88e854b79f72b/multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb", size = 29640 }, + { url = "https://files.pythonhosted.org/packages/d8/6d/9c87b73a13d1cdea30b321ef4b3824449866bd7f7127eceed066ccb9b9ff/multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b", size = 131067 }, + { url = "https://files.pythonhosted.org/packages/cc/1e/1b34154fef373371fd6c65125b3d42ff5f56c7ccc6bfff91b9b3c60ae9e0/multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72", size = 138507 }, + { url = "https://files.pythonhosted.org/packages/fb/e0/0bc6b2bac6e461822b5f575eae85da6aae76d0e2a79b6665d6206b8e2e48/multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304", size = 133905 }, + { url = "https://files.pythonhosted.org/packages/ba/af/73d13b918071ff9b2205fcf773d316e0f8fefb4ec65354bbcf0b10908cc6/multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351", size = 129004 }, + { url = "https://files.pythonhosted.org/packages/74/21/23960627b00ed39643302d81bcda44c9444ebcdc04ee5bedd0757513f259/multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb", size = 121308 }, + { url = "https://files.pythonhosted.org/packages/8b/5c/cf282263ffce4a596ed0bb2aa1a1dddfe1996d6a62d08842a8d4b33dca13/multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3", size = 132608 }, + { url = "https://files.pythonhosted.org/packages/d7/3e/97e778c041c72063f42b290888daff008d3ab1427f5b09b714f5a8eff294/multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399", size = 127029 }, + { url = "https://files.pythonhosted.org/packages/47/ac/3efb7bfe2f3aefcf8d103e9a7162572f01936155ab2f7ebcc7c255a23212/multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423", size = 137594 }, + { url = "https://files.pythonhosted.org/packages/42/9b/6c6e9e8dc4f915fc90a9b7798c44a30773dea2995fdcb619870e705afe2b/multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3", size = 134556 }, + { url = "https://files.pythonhosted.org/packages/1d/10/8e881743b26aaf718379a14ac58572a240e8293a1c9d68e1418fb11c0f90/multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753", size = 130993 }, + { url = "https://files.pythonhosted.org/packages/45/84/3eb91b4b557442802d058a7579e864b329968c8d0ea57d907e7023c677f2/multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80", size = 26405 }, + { url = "https://files.pythonhosted.org/packages/9f/0b/ad879847ecbf6d27e90a6eabb7eff6b62c129eefe617ea45eae7c1f0aead/multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926", size = 28795 }, + { url = "https://files.pythonhosted.org/packages/fd/16/92057c74ba3b96d5e211b553895cd6dc7cc4d1e43d9ab8fafc727681ef71/multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa", size = 48713 }, + { url = "https://files.pythonhosted.org/packages/94/3d/37d1b8893ae79716179540b89fc6a0ee56b4a65fcc0d63535c6f5d96f217/multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436", size = 29516 }, + { url = "https://files.pythonhosted.org/packages/a2/12/adb6b3200c363062f805275b4c1e656be2b3681aada66c80129932ff0bae/multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761", size = 29557 }, + { url = "https://files.pythonhosted.org/packages/47/e9/604bb05e6e5bce1e6a5cf80a474e0f072e80d8ac105f1b994a53e0b28c42/multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e", size = 130170 }, + { url = "https://files.pythonhosted.org/packages/7e/13/9efa50801785eccbf7086b3c83b71a4fb501a4d43549c2f2f80b8787d69f/multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef", size = 134836 }, + { url = "https://files.pythonhosted.org/packages/bf/0f/93808b765192780d117814a6dfcc2e75de6dcc610009ad408b8814dca3ba/multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95", size = 133475 }, + { url = "https://files.pythonhosted.org/packages/d3/c8/529101d7176fe7dfe1d99604e48d69c5dfdcadb4f06561f465c8ef12b4df/multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925", size = 131049 }, + { url = "https://files.pythonhosted.org/packages/ca/0c/fc85b439014d5a58063e19c3a158a889deec399d47b5269a0f3b6a2e28bc/multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966", size = 120370 }, + { url = "https://files.pythonhosted.org/packages/db/46/d4416eb20176492d2258fbd47b4abe729ff3b6e9c829ea4236f93c865089/multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305", size = 125178 }, + { url = "https://files.pythonhosted.org/packages/5b/46/73697ad7ec521df7de5531a32780bbfd908ded0643cbe457f981a701457c/multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2", size = 119567 }, + { url = "https://files.pythonhosted.org/packages/cd/ed/51f060e2cb0e7635329fa6ff930aa5cffa17f4c7f5c6c3ddc3500708e2f2/multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2", size = 129822 }, + { url = "https://files.pythonhosted.org/packages/df/9e/ee7d1954b1331da3eddea0c4e08d9142da5f14b1321c7301f5014f49d492/multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6", size = 128656 }, + { url = "https://files.pythonhosted.org/packages/77/00/8538f11e3356b5d95fa4b024aa566cde7a38aa7a5f08f4912b32a037c5dc/multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3", size = 125360 }, + { url = "https://files.pythonhosted.org/packages/be/05/5d334c1f2462d43fec2363cd00b1c44c93a78c3925d952e9a71caf662e96/multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133", size = 26382 }, + { url = "https://files.pythonhosted.org/packages/a3/bf/f332a13486b1ed0496d624bcc7e8357bb8053823e8cd4b9a18edc1d97e73/multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1", size = 28529 }, + { url = "https://files.pythonhosted.org/packages/22/67/1c7c0f39fe069aa4e5d794f323be24bf4d33d62d2a348acdb7991f8f30db/multidict-6.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d569388c381b24671589335a3be6e1d45546c2988c2ebe30fdcada8457a31008", size = 48771 }, + { url = "https://files.pythonhosted.org/packages/3c/25/c186ee7b212bdf0df2519eacfb1981a017bda34392c67542c274651daf23/multidict-6.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:052e10d2d37810b99cc170b785945421141bf7bb7d2f8799d431e7db229c385f", size = 29533 }, + { url = "https://files.pythonhosted.org/packages/67/5e/04575fd837e0958e324ca035b339cea174554f6f641d3fb2b4f2e7ff44a2/multidict-6.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f90c822a402cb865e396a504f9fc8173ef34212a342d92e362ca498cad308e28", size = 29595 }, + { url = "https://files.pythonhosted.org/packages/d3/b2/e56388f86663810c07cfe4a3c3d87227f3811eeb2d08450b9e5d19d78876/multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b225d95519a5bf73860323e633a664b0d85ad3d5bede6d30d95b35d4dfe8805b", size = 130094 }, + { url = "https://files.pythonhosted.org/packages/6c/ee/30ae9b4186a644d284543d55d491fbd4239b015d36b23fea43b4c94f7052/multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23bfd518810af7de1116313ebd9092cb9aa629beb12f6ed631ad53356ed6b86c", size = 134876 }, + { url = "https://files.pythonhosted.org/packages/84/c7/70461c13ba8ce3c779503c70ec9d0345ae84de04521c1f45a04d5f48943d/multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c09fcfdccdd0b57867577b719c69e347a436b86cd83747f179dbf0cc0d4c1f3", size = 133500 }, + { url = "https://files.pythonhosted.org/packages/4a/9f/002af221253f10f99959561123fae676148dd730e2daa2cd053846a58507/multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf6bea52ec97e95560af5ae576bdac3aa3aae0b6758c6efa115236d9e07dae44", size = 131099 }, + { url = "https://files.pythonhosted.org/packages/82/42/d1c7a7301d52af79d88548a97e297f9d99c961ad76bbe6f67442bb77f097/multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57feec87371dbb3520da6192213c7d6fc892d5589a93db548331954de8248fd2", size = 120403 }, + { url = "https://files.pythonhosted.org/packages/68/f3/471985c2c7ac707547553e8f37cff5158030d36bdec4414cb825fbaa5327/multidict-6.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0c3f390dc53279cbc8ba976e5f8035eab997829066756d811616b652b00a23a3", size = 125348 }, + { url = "https://files.pythonhosted.org/packages/67/2c/e6df05c77e0e433c214ec1d21ddd203d9a4770a1f2866a8ca40a545869a0/multidict-6.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:59bfeae4b25ec05b34f1956eaa1cb38032282cd4dfabc5056d0a1ec4d696d3aa", size = 119673 }, + { url = "https://files.pythonhosted.org/packages/c5/cd/bc8608fff06239c9fb333f9db7743a1b2eafe98c2666c9a196e867a3a0a4/multidict-6.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b2f59caeaf7632cc633b5cf6fc449372b83bbdf0da4ae04d5be36118e46cc0aa", size = 129927 }, + { url = "https://files.pythonhosted.org/packages/44/8e/281b69b7bc84fc963a44dc6e0bbcc7150e517b91df368a27834299a526ac/multidict-6.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:37bb93b2178e02b7b618893990941900fd25b6b9ac0fa49931a40aecdf083fe4", size = 128711 }, + { url = "https://files.pythonhosted.org/packages/12/a4/63e7cd38ed29dd9f1881d5119f272c898ca92536cdb53ffe0843197f6c85/multidict-6.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4e9f48f58c2c523d5a06faea47866cd35b32655c46b443f163d08c6d0ddb17d6", size = 125519 }, + { url = "https://files.pythonhosted.org/packages/38/e0/4f5855037a72cd8a7a2f60a3952d9aa45feedb37ae7831642102604e8a37/multidict-6.1.0-cp313-cp313-win32.whl", hash = "sha256:3a37ffb35399029b45c6cc33640a92bef403c9fd388acce75cdc88f58bd19a81", size = 26426 }, + { url = "https://files.pythonhosted.org/packages/7e/a5/17ee3a4db1e310b7405f5d25834460073a8ccd86198ce044dfaf69eac073/multidict-6.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e9aa71e15d9d9beaad2c6b9319edcdc0a49a43ef5c0a4c8265ca9ee7d6c67774", size = 28531 }, + { url = "https://files.pythonhosted.org/packages/e7/c9/9e153a6572b38ac5ff4434113af38acf8d5e9957897cdb1f513b3d6614ed/multidict-6.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4e18b656c5e844539d506a0a06432274d7bd52a7487e6828c63a63d69185626c", size = 48550 }, + { url = "https://files.pythonhosted.org/packages/76/f5/79565ddb629eba6c7f704f09a09df085c8dc04643b12506f10f718cee37a/multidict-6.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a185f876e69897a6f3325c3f19f26a297fa058c5e456bfcff8015e9a27e83ae1", size = 29298 }, + { url = "https://files.pythonhosted.org/packages/60/1b/9851878b704bc98e641a3e0bce49382ae9e05743dac6d97748feb5b7baba/multidict-6.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab7c4ceb38d91570a650dba194e1ca87c2b543488fe9309b4212694174fd539c", size = 29641 }, + { url = "https://files.pythonhosted.org/packages/89/87/d451d45aab9e422cb0fb2f7720c31a4c1d3012c740483c37f642eba568fb/multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e617fb6b0b6953fffd762669610c1c4ffd05632c138d61ac7e14ad187870669c", size = 126202 }, + { url = "https://files.pythonhosted.org/packages/fa/b4/27cbe9f3e2e469359887653f2e45470272eef7295139916cc21107c6b48c/multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5f4bf4e603eb1fdd5d8180f1a25f30056f22e55ce51fb3d6ad4ab29f7d96f", size = 133925 }, + { url = "https://files.pythonhosted.org/packages/4d/a3/afc841899face8adfd004235ce759a37619f6ec99eafd959650c5ce4df57/multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c035da3f544b1882bac24115f3e2e8760f10a0107614fc9839fd232200b875", size = 129039 }, + { url = "https://files.pythonhosted.org/packages/5e/41/0d0fb18c1ad574f807196f5f3d99164edf9de3e169a58c6dc2d6ed5742b9/multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:957cf8e4b6e123a9eea554fa7ebc85674674b713551de587eb318a2df3e00255", size = 124072 }, + { url = "https://files.pythonhosted.org/packages/00/22/defd7a2e71a44e6e5b9a5428f972e5b572e7fe28e404dfa6519bbf057c93/multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:483a6aea59cb89904e1ceabd2b47368b5600fb7de78a6e4a2c2987b2d256cf30", size = 116532 }, + { url = "https://files.pythonhosted.org/packages/91/25/f7545102def0b1d456ab6449388eed2dfd822debba1d65af60194904a23a/multidict-6.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:87701f25a2352e5bf7454caa64757642734da9f6b11384c1f9d1a8e699758057", size = 128173 }, + { url = "https://files.pythonhosted.org/packages/45/79/3dbe8d35fc99f5ea610813a72ab55f426cb9cf482f860fa8496e5409be11/multidict-6.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:682b987361e5fd7a139ed565e30d81fd81e9629acc7d925a205366877d8c8657", size = 122654 }, + { url = "https://files.pythonhosted.org/packages/97/cb/209e735eeab96e1b160825b5d0b36c56d3862abff828fc43999bb957dcad/multidict-6.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce2186a7df133a9c895dea3331ddc5ddad42cdd0d1ea2f0a51e5d161e4762f28", size = 133197 }, + { url = "https://files.pythonhosted.org/packages/e4/3a/a13808a7ada62808afccea67837a79d00ad6581440015ef00f726d064c2d/multidict-6.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9f636b730f7e8cb19feb87094949ba54ee5357440b9658b2a32a5ce4bce53972", size = 129754 }, + { url = "https://files.pythonhosted.org/packages/77/dd/8540e139eafb240079242da8f8ffdf9d3f4b4ad1aac5a786cd4050923783/multidict-6.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:73eae06aa53af2ea5270cc066dcaf02cc60d2994bbb2c4ef5764949257d10f43", size = 126402 }, + { url = "https://files.pythonhosted.org/packages/86/99/e82e1a275d8b1ea16d3a251474262258dbbe41c05cce0c01bceda1fc8ea5/multidict-6.1.0-cp39-cp39-win32.whl", hash = "sha256:1ca0083e80e791cffc6efce7660ad24af66c8d4079d2a750b29001b53ff59ada", size = 26421 }, + { url = "https://files.pythonhosted.org/packages/86/1c/9fa630272355af7e4446a2c7550c259f11ee422ab2d30ff90a0a71cf3d9e/multidict-6.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:aa466da5b15ccea564bdab9c89175c762bc12825f4659c11227f515cee76fa4a", size = 28791 }, + { url = "https://files.pythonhosted.org/packages/99/b7/b9e70fde2c0f0c9af4cc5277782a89b66d35948ea3369ec9f598358c3ac5/multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506", size = 10051 }, +] + [[package]] name = "mypy" version = "1.14.0" @@ -1034,6 +1392,95 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/6a/fd08d94654f7e67c52ca30523a178b3f8ccc4237fce4be90d39c938a831a/prompt_toolkit-3.0.48-py3-none-any.whl", hash = "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e", size = 386595 }, ] +[[package]] +name = "propcache" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/c8/2a13f78d82211490855b2fb303b6721348d0787fdd9a12ac46d99d3acde1/propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64", size = 41735 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/a5/0ea64c9426959ef145a938e38c832fc551843481d356713ececa9a8a64e8/propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6", size = 79296 }, + { url = "https://files.pythonhosted.org/packages/76/5a/916db1aba735f55e5eca4733eea4d1973845cf77dfe67c2381a2ca3ce52d/propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2", size = 45622 }, + { url = "https://files.pythonhosted.org/packages/2d/62/685d3cf268b8401ec12b250b925b21d152b9d193b7bffa5fdc4815c392c2/propcache-0.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6445804cf4ec763dc70de65a3b0d9954e868609e83850a47ca4f0cb64bd79fea", size = 45133 }, + { url = "https://files.pythonhosted.org/packages/4d/3d/31c9c29ee7192defc05aa4d01624fd85a41cf98e5922aaed206017329944/propcache-0.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9479aa06a793c5aeba49ce5c5692ffb51fcd9a7016e017d555d5e2b0045d212", size = 204809 }, + { url = "https://files.pythonhosted.org/packages/10/a1/e4050776f4797fc86140ac9a480d5dc069fbfa9d499fe5c5d2fa1ae71f07/propcache-0.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9631c5e8b5b3a0fda99cb0d29c18133bca1e18aea9effe55adb3da1adef80d3", size = 219109 }, + { url = "https://files.pythonhosted.org/packages/c9/c0/e7ae0df76343d5e107d81e59acc085cea5fd36a48aa53ef09add7503e888/propcache-0.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3156628250f46a0895f1f36e1d4fbe062a1af8718ec3ebeb746f1d23f0c5dc4d", size = 217368 }, + { url = "https://files.pythonhosted.org/packages/fc/e1/e0a2ed6394b5772508868a977d3238f4afb2eebaf9976f0b44a8d347ad63/propcache-0.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6fb63ae352e13748289f04f37868099e69dba4c2b3e271c46061e82c745634", size = 205124 }, + { url = "https://files.pythonhosted.org/packages/50/c1/e388c232d15ca10f233c778bbdc1034ba53ede14c207a72008de45b2db2e/propcache-0.2.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:887d9b0a65404929641a9fabb6452b07fe4572b269d901d622d8a34a4e9043b2", size = 195463 }, + { url = "https://files.pythonhosted.org/packages/0a/fd/71b349b9def426cc73813dbd0f33e266de77305e337c8c12bfb0a2a82bfb/propcache-0.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a96dc1fa45bd8c407a0af03b2d5218392729e1822b0c32e62c5bf7eeb5fb3958", size = 198358 }, + { url = "https://files.pythonhosted.org/packages/02/f2/d7c497cd148ebfc5b0ae32808e6c1af5922215fe38c7a06e4e722fe937c8/propcache-0.2.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a7e65eb5c003a303b94aa2c3852ef130230ec79e349632d030e9571b87c4698c", size = 195560 }, + { url = "https://files.pythonhosted.org/packages/bb/57/f37041bbe5e0dfed80a3f6be2612a3a75b9cfe2652abf2c99bef3455bbad/propcache-0.2.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:999779addc413181912e984b942fbcc951be1f5b3663cd80b2687758f434c583", size = 196895 }, + { url = "https://files.pythonhosted.org/packages/83/36/ae3cc3e4f310bff2f064e3d2ed5558935cc7778d6f827dce74dcfa125304/propcache-0.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:19a0f89a7bb9d8048d9c4370c9c543c396e894c76be5525f5e1ad287f1750ddf", size = 207124 }, + { url = "https://files.pythonhosted.org/packages/8c/c4/811b9f311f10ce9d31a32ff14ce58500458443627e4df4ae9c264defba7f/propcache-0.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1ac2f5fe02fa75f56e1ad473f1175e11f475606ec9bd0be2e78e4734ad575034", size = 210442 }, + { url = "https://files.pythonhosted.org/packages/18/dd/a1670d483a61ecac0d7fc4305d91caaac7a8fc1b200ea3965a01cf03bced/propcache-0.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:574faa3b79e8ebac7cb1d7930f51184ba1ccf69adfdec53a12f319a06030a68b", size = 203219 }, + { url = "https://files.pythonhosted.org/packages/f9/2d/30ced5afde41b099b2dc0c6573b66b45d16d73090e85655f1a30c5a24e07/propcache-0.2.1-cp310-cp310-win32.whl", hash = "sha256:03ff9d3f665769b2a85e6157ac8b439644f2d7fd17615a82fa55739bc97863f4", size = 40313 }, + { url = "https://files.pythonhosted.org/packages/23/84/bd9b207ac80da237af77aa6e153b08ffa83264b1c7882495984fcbfcf85c/propcache-0.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:2d3af2e79991102678f53e0dbf4c35de99b6b8b58f29a27ca0325816364caaba", size = 44428 }, + { url = "https://files.pythonhosted.org/packages/bc/0f/2913b6791ebefb2b25b4efd4bb2299c985e09786b9f5b19184a88e5778dd/propcache-0.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1ffc3cca89bb438fb9c95c13fc874012f7b9466b89328c3c8b1aa93cdcfadd16", size = 79297 }, + { url = "https://files.pythonhosted.org/packages/cf/73/af2053aeccd40b05d6e19058419ac77674daecdd32478088b79375b9ab54/propcache-0.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f174bbd484294ed9fdf09437f889f95807e5f229d5d93588d34e92106fbf6717", size = 45611 }, + { url = "https://files.pythonhosted.org/packages/3c/09/8386115ba7775ea3b9537730e8cf718d83bbf95bffe30757ccf37ec4e5da/propcache-0.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:70693319e0b8fd35dd863e3e29513875eb15c51945bf32519ef52927ca883bc3", size = 45146 }, + { url = "https://files.pythonhosted.org/packages/03/7a/793aa12f0537b2e520bf09f4c6833706b63170a211ad042ca71cbf79d9cb/propcache-0.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b480c6a4e1138e1aa137c0079b9b6305ec6dcc1098a8ca5196283e8a49df95a9", size = 232136 }, + { url = "https://files.pythonhosted.org/packages/f1/38/b921b3168d72111769f648314100558c2ea1d52eb3d1ba7ea5c4aa6f9848/propcache-0.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d27b84d5880f6d8aa9ae3edb253c59d9f6642ffbb2c889b78b60361eed449787", size = 239706 }, + { url = "https://files.pythonhosted.org/packages/14/29/4636f500c69b5edea7786db3c34eb6166f3384b905665ce312a6e42c720c/propcache-0.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:857112b22acd417c40fa4595db2fe28ab900c8c5fe4670c7989b1c0230955465", size = 238531 }, + { url = "https://files.pythonhosted.org/packages/85/14/01fe53580a8e1734ebb704a3482b7829a0ef4ea68d356141cf0994d9659b/propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf6c4150f8c0e32d241436526f3c3f9cbd34429492abddbada2ffcff506c51af", size = 231063 }, + { url = "https://files.pythonhosted.org/packages/33/5c/1d961299f3c3b8438301ccfbff0143b69afcc30c05fa28673cface692305/propcache-0.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d4cfda1d8ed687daa4bc0274fcfd5267873db9a5bc0418c2da19273040eeb7", size = 220134 }, + { url = "https://files.pythonhosted.org/packages/00/d0/ed735e76db279ba67a7d3b45ba4c654e7b02bc2f8050671ec365d8665e21/propcache-0.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c2f992c07c0fca81655066705beae35fc95a2fa7366467366db627d9f2ee097f", size = 220009 }, + { url = "https://files.pythonhosted.org/packages/75/90/ee8fab7304ad6533872fee982cfff5a53b63d095d78140827d93de22e2d4/propcache-0.2.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:4a571d97dbe66ef38e472703067021b1467025ec85707d57e78711c085984e54", size = 212199 }, + { url = "https://files.pythonhosted.org/packages/eb/ec/977ffaf1664f82e90737275873461695d4c9407d52abc2f3c3e24716da13/propcache-0.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bb6178c241278d5fe853b3de743087be7f5f4c6f7d6d22a3b524d323eecec505", size = 214827 }, + { url = "https://files.pythonhosted.org/packages/57/48/031fb87ab6081764054821a71b71942161619549396224cbb242922525e8/propcache-0.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ad1af54a62ffe39cf34db1aa6ed1a1873bd548f6401db39d8e7cd060b9211f82", size = 228009 }, + { url = "https://files.pythonhosted.org/packages/1a/06/ef1390f2524850838f2390421b23a8b298f6ce3396a7cc6d39dedd4047b0/propcache-0.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e7048abd75fe40712005bcfc06bb44b9dfcd8e101dda2ecf2f5aa46115ad07ca", size = 231638 }, + { url = "https://files.pythonhosted.org/packages/38/2a/101e6386d5a93358395da1d41642b79c1ee0f3b12e31727932b069282b1d/propcache-0.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:160291c60081f23ee43d44b08a7e5fb76681221a8e10b3139618c5a9a291b84e", size = 222788 }, + { url = "https://files.pythonhosted.org/packages/db/81/786f687951d0979007e05ad9346cd357e50e3d0b0f1a1d6074df334b1bbb/propcache-0.2.1-cp311-cp311-win32.whl", hash = "sha256:819ce3b883b7576ca28da3861c7e1a88afd08cc8c96908e08a3f4dd64a228034", size = 40170 }, + { url = "https://files.pythonhosted.org/packages/cf/59/7cc7037b295d5772eceb426358bb1b86e6cab4616d971bd74275395d100d/propcache-0.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:edc9fc7051e3350643ad929df55c451899bb9ae6d24998a949d2e4c87fb596d3", size = 44404 }, + { url = "https://files.pythonhosted.org/packages/4c/28/1d205fe49be8b1b4df4c50024e62480a442b1a7b818e734308bb0d17e7fb/propcache-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:081a430aa8d5e8876c6909b67bd2d937bfd531b0382d3fdedb82612c618bc41a", size = 79588 }, + { url = "https://files.pythonhosted.org/packages/21/ee/fc4d893f8d81cd4971affef2a6cb542b36617cd1d8ce56b406112cb80bf7/propcache-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2ccec9ac47cf4e04897619c0e0c1a48c54a71bdf045117d3a26f80d38ab1fb0", size = 45825 }, + { url = "https://files.pythonhosted.org/packages/4a/de/bbe712f94d088da1d237c35d735f675e494a816fd6f54e9db2f61ef4d03f/propcache-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:14d86fe14b7e04fa306e0c43cdbeebe6b2c2156a0c9ce56b815faacc193e320d", size = 45357 }, + { url = "https://files.pythonhosted.org/packages/7f/14/7ae06a6cf2a2f1cb382586d5a99efe66b0b3d0c6f9ac2f759e6f7af9d7cf/propcache-0.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:049324ee97bb67285b49632132db351b41e77833678432be52bdd0289c0e05e4", size = 241869 }, + { url = "https://files.pythonhosted.org/packages/cc/59/227a78be960b54a41124e639e2c39e8807ac0c751c735a900e21315f8c2b/propcache-0.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cd9a1d071158de1cc1c71a26014dcdfa7dd3d5f4f88c298c7f90ad6f27bb46d", size = 247884 }, + { url = "https://files.pythonhosted.org/packages/84/58/f62b4ffaedf88dc1b17f04d57d8536601e4e030feb26617228ef930c3279/propcache-0.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98110aa363f1bb4c073e8dcfaefd3a5cea0f0834c2aab23dda657e4dab2f53b5", size = 248486 }, + { url = "https://files.pythonhosted.org/packages/1c/07/ebe102777a830bca91bbb93e3479cd34c2ca5d0361b83be9dbd93104865e/propcache-0.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:647894f5ae99c4cf6bb82a1bb3a796f6e06af3caa3d32e26d2350d0e3e3faf24", size = 243649 }, + { url = "https://files.pythonhosted.org/packages/ed/bc/4f7aba7f08f520376c4bb6a20b9a981a581b7f2e385fa0ec9f789bb2d362/propcache-0.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfd3223c15bebe26518d58ccf9a39b93948d3dcb3e57a20480dfdd315356baff", size = 229103 }, + { url = "https://files.pythonhosted.org/packages/fe/d5/04ac9cd4e51a57a96f78795e03c5a0ddb8f23ec098b86f92de028d7f2a6b/propcache-0.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d71264a80f3fcf512eb4f18f59423fe82d6e346ee97b90625f283df56aee103f", size = 226607 }, + { url = "https://files.pythonhosted.org/packages/e3/f0/24060d959ea41d7a7cc7fdbf68b31852331aabda914a0c63bdb0e22e96d6/propcache-0.2.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e73091191e4280403bde6c9a52a6999d69cdfde498f1fdf629105247599b57ec", size = 221153 }, + { url = "https://files.pythonhosted.org/packages/77/a7/3ac76045a077b3e4de4859a0753010765e45749bdf53bd02bc4d372da1a0/propcache-0.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3935bfa5fede35fb202c4b569bb9c042f337ca4ff7bd540a0aa5e37131659348", size = 222151 }, + { url = "https://files.pythonhosted.org/packages/e7/af/5e29da6f80cebab3f5a4dcd2a3240e7f56f2c4abf51cbfcc99be34e17f0b/propcache-0.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f508b0491767bb1f2b87fdfacaba5f7eddc2f867740ec69ece6d1946d29029a6", size = 233812 }, + { url = "https://files.pythonhosted.org/packages/8c/89/ebe3ad52642cc5509eaa453e9f4b94b374d81bae3265c59d5c2d98efa1b4/propcache-0.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1672137af7c46662a1c2be1e8dc78cb6d224319aaa40271c9257d886be4363a6", size = 238829 }, + { url = "https://files.pythonhosted.org/packages/e9/2f/6b32f273fa02e978b7577159eae7471b3cfb88b48563b1c2578b2d7ca0bb/propcache-0.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b74c261802d3d2b85c9df2dfb2fa81b6f90deeef63c2db9f0e029a3cac50b518", size = 230704 }, + { url = "https://files.pythonhosted.org/packages/5c/2e/f40ae6ff5624a5f77edd7b8359b208b5455ea113f68309e2b00a2e1426b6/propcache-0.2.1-cp312-cp312-win32.whl", hash = "sha256:d09c333d36c1409d56a9d29b3a1b800a42c76a57a5a8907eacdbce3f18768246", size = 40050 }, + { url = "https://files.pythonhosted.org/packages/3b/77/a92c3ef994e47180862b9d7d11e37624fb1c00a16d61faf55115d970628b/propcache-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:c214999039d4f2a5b2073ac506bba279945233da8c786e490d411dfc30f855c1", size = 44117 }, + { url = "https://files.pythonhosted.org/packages/0f/2a/329e0547cf2def8857157f9477669043e75524cc3e6251cef332b3ff256f/propcache-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aca405706e0b0a44cc6bfd41fbe89919a6a56999157f6de7e182a990c36e37bc", size = 77002 }, + { url = "https://files.pythonhosted.org/packages/12/2d/c4df5415e2382f840dc2ecbca0eeb2293024bc28e57a80392f2012b4708c/propcache-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:12d1083f001ace206fe34b6bdc2cb94be66d57a850866f0b908972f90996b3e9", size = 44639 }, + { url = "https://files.pythonhosted.org/packages/d0/5a/21aaa4ea2f326edaa4e240959ac8b8386ea31dedfdaa636a3544d9e7a408/propcache-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d93f3307ad32a27bda2e88ec81134b823c240aa3abb55821a8da553eed8d9439", size = 44049 }, + { url = "https://files.pythonhosted.org/packages/4e/3e/021b6cd86c0acc90d74784ccbb66808b0bd36067a1bf3e2deb0f3845f618/propcache-0.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba278acf14471d36316159c94a802933d10b6a1e117b8554fe0d0d9b75c9d536", size = 224819 }, + { url = "https://files.pythonhosted.org/packages/3c/57/c2fdeed1b3b8918b1770a133ba5c43ad3d78e18285b0c06364861ef5cc38/propcache-0.2.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4e6281aedfca15301c41f74d7005e6e3f4ca143584ba696ac69df4f02f40d629", size = 229625 }, + { url = "https://files.pythonhosted.org/packages/9d/81/70d4ff57bf2877b5780b466471bebf5892f851a7e2ca0ae7ffd728220281/propcache-0.2.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b750a8e5a1262434fb1517ddf64b5de58327f1adc3524a5e44c2ca43305eb0b", size = 232934 }, + { url = "https://files.pythonhosted.org/packages/3c/b9/bb51ea95d73b3fb4100cb95adbd4e1acaf2cbb1fd1083f5468eeb4a099a8/propcache-0.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf72af5e0fb40e9babf594308911436c8efde3cb5e75b6f206c34ad18be5c052", size = 227361 }, + { url = "https://files.pythonhosted.org/packages/f1/20/3c6d696cd6fd70b29445960cc803b1851a1131e7a2e4ee261ee48e002bcd/propcache-0.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2d0a12018b04f4cb820781ec0dffb5f7c7c1d2a5cd22bff7fb055a2cb19ebce", size = 213904 }, + { url = "https://files.pythonhosted.org/packages/a1/cb/1593bfc5ac6d40c010fa823f128056d6bc25b667f5393781e37d62f12005/propcache-0.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e800776a79a5aabdb17dcc2346a7d66d0777e942e4cd251defeb084762ecd17d", size = 212632 }, + { url = "https://files.pythonhosted.org/packages/6d/5c/e95617e222be14a34c709442a0ec179f3207f8a2b900273720501a70ec5e/propcache-0.2.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4160d9283bd382fa6c0c2b5e017acc95bc183570cd70968b9202ad6d8fc48dce", size = 207897 }, + { url = "https://files.pythonhosted.org/packages/8e/3b/56c5ab3dc00f6375fbcdeefdede5adf9bee94f1fab04adc8db118f0f9e25/propcache-0.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:30b43e74f1359353341a7adb783c8f1b1c676367b011709f466f42fda2045e95", size = 208118 }, + { url = "https://files.pythonhosted.org/packages/86/25/d7ef738323fbc6ebcbce33eb2a19c5e07a89a3df2fded206065bd5e868a9/propcache-0.2.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:58791550b27d5488b1bb52bc96328456095d96206a250d28d874fafe11b3dfaf", size = 217851 }, + { url = "https://files.pythonhosted.org/packages/b3/77/763e6cef1852cf1ba740590364ec50309b89d1c818e3256d3929eb92fabf/propcache-0.2.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:0f022d381747f0dfe27e99d928e31bc51a18b65bb9e481ae0af1380a6725dd1f", size = 222630 }, + { url = "https://files.pythonhosted.org/packages/4f/e9/0f86be33602089c701696fbed8d8c4c07b6ee9605c5b7536fd27ed540c5b/propcache-0.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:297878dc9d0a334358f9b608b56d02e72899f3b8499fc6044133f0d319e2ec30", size = 216269 }, + { url = "https://files.pythonhosted.org/packages/cc/02/5ac83217d522394b6a2e81a2e888167e7ca629ef6569a3f09852d6dcb01a/propcache-0.2.1-cp313-cp313-win32.whl", hash = "sha256:ddfab44e4489bd79bda09d84c430677fc7f0a4939a73d2bba3073036f487a0a6", size = 39472 }, + { url = "https://files.pythonhosted.org/packages/f4/33/d6f5420252a36034bc8a3a01171bc55b4bff5df50d1c63d9caa50693662f/propcache-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:556fc6c10989f19a179e4321e5d678db8eb2924131e64652a51fe83e4c3db0e1", size = 43363 }, + { url = "https://files.pythonhosted.org/packages/0a/08/6ab7f65240a16fa01023125e65258acf7e4884f483f267cdd6fcc48f37db/propcache-0.2.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6a9a8c34fb7bb609419a211e59da8887eeca40d300b5ea8e56af98f6fbbb1541", size = 80403 }, + { url = "https://files.pythonhosted.org/packages/34/fe/e7180285e21b4e6dff7d311fdf22490c9146a09a02834b5232d6248c6004/propcache-0.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ae1aa1cd222c6d205853b3013c69cd04515f9d6ab6de4b0603e2e1c33221303e", size = 46152 }, + { url = "https://files.pythonhosted.org/packages/9c/36/aa74d884af826030ba9cee2ac109b0664beb7e9449c315c9c44db99efbb3/propcache-0.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:accb6150ce61c9c4b7738d45550806aa2b71c7668c6942f17b0ac182b6142fd4", size = 45674 }, + { url = "https://files.pythonhosted.org/packages/22/59/6fe80a3fe7720f715f2c0f6df250dacbd7cad42832410dbd84c719c52f78/propcache-0.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eee736daafa7af6d0a2dc15cc75e05c64f37fc37bafef2e00d77c14171c2097", size = 207792 }, + { url = "https://files.pythonhosted.org/packages/4a/68/584cd51dd8f4d0f5fff5b128ce0cdb257cde903898eecfb92156bbc2c780/propcache-0.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7a31fc1e1bd362874863fdeed71aed92d348f5336fd84f2197ba40c59f061bd", size = 223280 }, + { url = "https://files.pythonhosted.org/packages/85/cb/4c3528460c41e61b06ec3f970c0f89f87fa21f63acac8642ed81a886c164/propcache-0.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba4cfa1052819d16699e1d55d18c92b6e094d4517c41dd231a8b9f87b6fa681", size = 221293 }, + { url = "https://files.pythonhosted.org/packages/69/c0/560e050aa6d31eeece3490d1174da508f05ab27536dfc8474af88b97160a/propcache-0.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f089118d584e859c62b3da0892b88a83d611c2033ac410e929cb6754eec0ed16", size = 208259 }, + { url = "https://files.pythonhosted.org/packages/0c/87/d6c86a77632eb1ba86a328e3313159f246e7564cb5951e05ed77555826a0/propcache-0.2.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:781e65134efaf88feb447e8c97a51772aa75e48b794352f94cb7ea717dedda0d", size = 198632 }, + { url = "https://files.pythonhosted.org/packages/3a/2b/3690ea7b662dc762ab7af5f3ef0e2d7513c823d193d7b2a1b4cda472c2be/propcache-0.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31f5af773530fd3c658b32b6bdc2d0838543de70eb9a2156c03e410f7b0d3aae", size = 203516 }, + { url = "https://files.pythonhosted.org/packages/4d/b5/afe716c16c23c77657185c257a41918b83e03993b6ccdfa748e5e7d328e9/propcache-0.2.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a7a078f5d37bee6690959c813977da5291b24286e7b962e62a94cec31aa5188b", size = 199402 }, + { url = "https://files.pythonhosted.org/packages/a4/c0/2d2df3aa7f8660d0d4cc4f1e00490c48d5958da57082e70dea7af366f876/propcache-0.2.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:cea7daf9fc7ae6687cf1e2c049752f19f146fdc37c2cc376e7d0032cf4f25347", size = 200528 }, + { url = "https://files.pythonhosted.org/packages/21/c8/65ac9142f5e40c8497f7176e71d18826b09e06dd4eb401c9a4ee41aa9c74/propcache-0.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b3489ff1ed1e8315674d0775dc7d2195fb13ca17b3808721b54dbe9fd020faf", size = 211254 }, + { url = "https://files.pythonhosted.org/packages/09/e4/edb70b447a1d8142df51ec7511e84aa64d7f6ce0a0fdf5eb55363cdd0935/propcache-0.2.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9403db39be1393618dd80c746cb22ccda168efce239c73af13c3763ef56ffc04", size = 214589 }, + { url = "https://files.pythonhosted.org/packages/cb/02/817f309ec8d8883287781d6d9390f80b14db6e6de08bc659dfe798a825c2/propcache-0.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5d97151bc92d2b2578ff7ce779cdb9174337390a535953cbb9452fb65164c587", size = 207283 }, + { url = "https://files.pythonhosted.org/packages/d7/fe/2d18612096ed2212cfef821b6fccdba5d52efc1d64511c206c5c16be28fd/propcache-0.2.1-cp39-cp39-win32.whl", hash = "sha256:9caac6b54914bdf41bcc91e7eb9147d331d29235a7c967c150ef5df6464fd1bb", size = 40866 }, + { url = "https://files.pythonhosted.org/packages/24/2e/b5134802e7b57c403c7b73c7a39374e7a6b7f128d1968b4a4b4c0b700250/propcache-0.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:92fc4500fcb33899b05ba73276dfb684a20d31caa567b7cb5252d48f896a91b1", size = 44975 }, + { url = "https://files.pythonhosted.org/packages/41/b6/c5319caea262f4821995dca2107483b94a3345d4607ad797c76cb9c36bcc/propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54", size = 11818 }, +] + [[package]] name = "protobuf" version = "5.29.2" @@ -1107,6 +1554,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949 }, ] +[[package]] +name = "pytorch-lightning" +version = "2.5.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fsspec", extra = ["http"] }, + { name = "lightning-utilities" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "torch" }, + { name = "torchmetrics" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/b1/3c1d08db3feb1dfcd5be1b9f2406455f3740f7525d7ea1b9244f67b11cb5/pytorch_lightning-2.5.0.post0.tar.gz", hash = "sha256:347235bf8573b4ebcf507a0dd755fcb9ce58c420c77220a9756a6edca0418532", size = 631450 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/df/0c7e4582b74264fe2179e78fcdeb9313f680d40ffe1dd4b078da5a2cbf82/pytorch_lightning-2.5.0.post0-py3-none-any.whl", hash = "sha256:c86bf4fded58b386f312f75337696a9b2d57077b858b3b9524400a03a0179b3a", size = 819282 }, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -1518,6 +1984,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/55/21/47d163f615df1d30c094f6c8bbb353619274edccf0327b185cc2493c2c33/setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d", size = 1224032 }, ] +[[package]] +name = "skops" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "packaging" }, + { name = "scikit-learn" }, + { name = "tabulate" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/80/15/5718ee04c70425b083e070aa4da651292e1527144ea87b3ed07591d729a9/skops-0.11.0.tar.gz", hash = "sha256:229c867fbc5e669a1c6a88661c3883a14f3591abd9bfa6073df308d63ae1fa3a", size = 610701 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/5f/a3ec074e67b5dcce2de290bb5b2edb60f78c9304d86485dc1570bca22c2d/skops-0.11.0-py3-none-any.whl", hash = "sha256:8c6109e27e4d762948cad7d21de008034bd14e15f111e9405c7930e74a7fe8c1", size = 146956 }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1544,6 +2025,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -1627,19 +2117,19 @@ dependencies = [ { name = "jinja2" }, { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -1661,12 +2151,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/81/c05013695bfb3762f3c657a557407f152a0a0452b3ccec437a4a59848fb5/torch-2.4.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a38de2803ee6050309aac032676536c3d3b6a9804248537e38e098d0e14817ec", size = 62139344 }, ] +[[package]] +name = "torchmetrics" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lightning-utilities" }, + { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "numpy", version = "2.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "packaging" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/c5/8d916585d4d6eb158105c21b28cd4b0ed296d74e499bf8f104368de16619/torchmetrics-1.6.1.tar.gz", hash = "sha256:a5dc236694b392180949fdd0a0fcf2b57135c8b600e557c725e077eb41e53e64", size = 540022 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/e1/84066ff60a20dfa63f4d9d8ddc280d5ed323b7f06504dbb51c523b690116/torchmetrics-1.6.1-py3-none-any.whl", hash = "sha256:c3090aa2341129e994c0a659abb6d4140ae75169a6ebf45bffc16c5cb553b38e", size = 927305 }, +] + [[package]] name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ @@ -1758,3 +2264,97 @@ sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, ] + +[[package]] +name = "yarl" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/9d/4b94a8e6d2b51b599516a5cb88e5bc99b4d8d4583e468057eaa29d5f0918/yarl-1.18.3.tar.gz", hash = "sha256:ac1801c45cbf77b6c99242eeff4fffb5e4e73a800b5c4ad4fc0be5def634d2e1", size = 181062 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/98/e005bc608765a8a5569f58e650961314873c8469c333616eb40bff19ae97/yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34", size = 141458 }, + { url = "https://files.pythonhosted.org/packages/df/5d/f8106b263b8ae8a866b46d9be869ac01f9b3fb7f2325f3ecb3df8003f796/yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7", size = 94365 }, + { url = "https://files.pythonhosted.org/packages/56/3e/d8637ddb9ba69bf851f765a3ee288676f7cf64fb3be13760c18cbc9d10bd/yarl-1.18.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:602d98f2c2d929f8e697ed274fbadc09902c4025c5a9963bf4e9edfc3ab6f7ed", size = 92181 }, + { url = "https://files.pythonhosted.org/packages/76/f9/d616a5c2daae281171de10fba41e1c0e2d8207166fc3547252f7d469b4e1/yarl-1.18.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c654d5207c78e0bd6d749f6dae1dcbbfde3403ad3a4b11f3c5544d9906969dde", size = 315349 }, + { url = "https://files.pythonhosted.org/packages/bb/b4/3ea5e7b6f08f698b3769a06054783e434f6d59857181b5c4e145de83f59b/yarl-1.18.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5094d9206c64181d0f6e76ebd8fb2f8fe274950a63890ee9e0ebfd58bf9d787b", size = 330494 }, + { url = "https://files.pythonhosted.org/packages/55/f1/e0fc810554877b1b67420568afff51b967baed5b53bcc983ab164eebf9c9/yarl-1.18.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35098b24e0327fc4ebdc8ffe336cee0a87a700c24ffed13161af80124b7dc8e5", size = 326927 }, + { url = "https://files.pythonhosted.org/packages/a9/42/b1753949b327b36f210899f2dd0a0947c0c74e42a32de3f8eb5c7d93edca/yarl-1.18.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3236da9272872443f81fedc389bace88408f64f89f75d1bdb2256069a8730ccc", size = 319703 }, + { url = "https://files.pythonhosted.org/packages/f0/6d/e87c62dc9635daefb064b56f5c97df55a2e9cc947a2b3afd4fd2f3b841c7/yarl-1.18.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2c08cc9b16f4f4bc522771d96734c7901e7ebef70c6c5c35dd0f10845270bcd", size = 310246 }, + { url = "https://files.pythonhosted.org/packages/e3/ef/e2e8d1785cdcbd986f7622d7f0098205f3644546da7919c24b95790ec65a/yarl-1.18.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:80316a8bd5109320d38eef8833ccf5f89608c9107d02d2a7f985f98ed6876990", size = 319730 }, + { url = "https://files.pythonhosted.org/packages/fc/15/8723e22345bc160dfde68c4b3ae8b236e868f9963c74015f1bc8a614101c/yarl-1.18.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c1e1cc06da1491e6734f0ea1e6294ce00792193c463350626571c287c9a704db", size = 321681 }, + { url = "https://files.pythonhosted.org/packages/86/09/bf764e974f1516efa0ae2801494a5951e959f1610dd41edbfc07e5e0f978/yarl-1.18.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fea09ca13323376a2fdfb353a5fa2e59f90cd18d7ca4eaa1fd31f0a8b4f91e62", size = 324812 }, + { url = "https://files.pythonhosted.org/packages/f6/4c/20a0187e3b903c97d857cf0272d687c1b08b03438968ae8ffc50fe78b0d6/yarl-1.18.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e3b9fd71836999aad54084906f8663dffcd2a7fb5cdafd6c37713b2e72be1760", size = 337011 }, + { url = "https://files.pythonhosted.org/packages/c9/71/6244599a6e1cc4c9f73254a627234e0dad3883ece40cc33dce6265977461/yarl-1.18.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:757e81cae69244257d125ff31663249b3013b5dc0a8520d73694aed497fb195b", size = 338132 }, + { url = "https://files.pythonhosted.org/packages/af/f5/e0c3efaf74566c4b4a41cb76d27097df424052a064216beccae8d303c90f/yarl-1.18.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b1771de9944d875f1b98a745bc547e684b863abf8f8287da8466cf470ef52690", size = 331849 }, + { url = "https://files.pythonhosted.org/packages/8a/b8/3d16209c2014c2f98a8f658850a57b716efb97930aebf1ca0d9325933731/yarl-1.18.3-cp310-cp310-win32.whl", hash = "sha256:8874027a53e3aea659a6d62751800cf6e63314c160fd607489ba5c2edd753cf6", size = 84309 }, + { url = "https://files.pythonhosted.org/packages/fd/b7/2e9a5b18eb0fe24c3a0e8bae994e812ed9852ab4fd067c0107fadde0d5f0/yarl-1.18.3-cp310-cp310-win_amd64.whl", hash = "sha256:93b2e109287f93db79210f86deb6b9bbb81ac32fc97236b16f7433db7fc437d8", size = 90484 }, + { url = "https://files.pythonhosted.org/packages/40/93/282b5f4898d8e8efaf0790ba6d10e2245d2c9f30e199d1a85cae9356098c/yarl-1.18.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8503ad47387b8ebd39cbbbdf0bf113e17330ffd339ba1144074da24c545f0069", size = 141555 }, + { url = "https://files.pythonhosted.org/packages/6d/9c/0a49af78df099c283ca3444560f10718fadb8a18dc8b3edf8c7bd9fd7d89/yarl-1.18.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02ddb6756f8f4517a2d5e99d8b2f272488e18dd0bfbc802f31c16c6c20f22193", size = 94351 }, + { url = "https://files.pythonhosted.org/packages/5a/a1/205ab51e148fdcedad189ca8dd587794c6f119882437d04c33c01a75dece/yarl-1.18.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:67a283dd2882ac98cc6318384f565bffc751ab564605959df4752d42483ad889", size = 92286 }, + { url = "https://files.pythonhosted.org/packages/ed/fe/88b690b30f3f59275fb674f5f93ddd4a3ae796c2b62e5bb9ece8a4914b83/yarl-1.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d980e0325b6eddc81331d3f4551e2a333999fb176fd153e075c6d1c2530aa8a8", size = 340649 }, + { url = "https://files.pythonhosted.org/packages/07/eb/3b65499b568e01f36e847cebdc8d7ccb51fff716dbda1ae83c3cbb8ca1c9/yarl-1.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b643562c12680b01e17239be267bc306bbc6aac1f34f6444d1bded0c5ce438ca", size = 356623 }, + { url = "https://files.pythonhosted.org/packages/33/46/f559dc184280b745fc76ec6b1954de2c55595f0ec0a7614238b9ebf69618/yarl-1.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c017a3b6df3a1bd45b9fa49a0f54005e53fbcad16633870104b66fa1a30a29d8", size = 354007 }, + { url = "https://files.pythonhosted.org/packages/af/ba/1865d85212351ad160f19fb99808acf23aab9a0f8ff31c8c9f1b4d671fc9/yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75674776d96d7b851b6498f17824ba17849d790a44d282929c42dbb77d4f17ae", size = 344145 }, + { url = "https://files.pythonhosted.org/packages/94/cb/5c3e975d77755d7b3d5193e92056b19d83752ea2da7ab394e22260a7b824/yarl-1.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ccaa3a4b521b780a7e771cc336a2dba389a0861592bbce09a476190bb0c8b4b3", size = 336133 }, + { url = "https://files.pythonhosted.org/packages/19/89/b77d3fd249ab52a5c40859815765d35c91425b6bb82e7427ab2f78f5ff55/yarl-1.18.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d06d3005e668744e11ed80812e61efd77d70bb7f03e33c1598c301eea20efbb", size = 347967 }, + { url = "https://files.pythonhosted.org/packages/35/bd/f6b7630ba2cc06c319c3235634c582a6ab014d52311e7d7c22f9518189b5/yarl-1.18.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:9d41beda9dc97ca9ab0b9888cb71f7539124bc05df02c0cff6e5acc5a19dcc6e", size = 346397 }, + { url = "https://files.pythonhosted.org/packages/18/1a/0b4e367d5a72d1f095318344848e93ea70da728118221f84f1bf6c1e39e7/yarl-1.18.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ba23302c0c61a9999784e73809427c9dbedd79f66a13d84ad1b1943802eaaf59", size = 350206 }, + { url = "https://files.pythonhosted.org/packages/b5/cf/320fff4367341fb77809a2d8d7fe75b5d323a8e1b35710aafe41fdbf327b/yarl-1.18.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6748dbf9bfa5ba1afcc7556b71cda0d7ce5f24768043a02a58846e4a443d808d", size = 362089 }, + { url = "https://files.pythonhosted.org/packages/57/cf/aadba261d8b920253204085268bad5e8cdd86b50162fcb1b10c10834885a/yarl-1.18.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0b0cad37311123211dc91eadcb322ef4d4a66008d3e1bdc404808992260e1a0e", size = 366267 }, + { url = "https://files.pythonhosted.org/packages/54/58/fb4cadd81acdee6dafe14abeb258f876e4dd410518099ae9a35c88d8097c/yarl-1.18.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fb2171a4486bb075316ee754c6d8382ea6eb8b399d4ec62fde2b591f879778a", size = 359141 }, + { url = "https://files.pythonhosted.org/packages/9a/7a/4c571597589da4cd5c14ed2a0b17ac56ec9ee7ee615013f74653169e702d/yarl-1.18.3-cp311-cp311-win32.whl", hash = "sha256:61b1a825a13bef4a5f10b1885245377d3cd0bf87cba068e1d9a88c2ae36880e1", size = 84402 }, + { url = "https://files.pythonhosted.org/packages/ae/7b/8600250b3d89b625f1121d897062f629883c2f45339623b69b1747ec65fa/yarl-1.18.3-cp311-cp311-win_amd64.whl", hash = "sha256:b9d60031cf568c627d028239693fd718025719c02c9f55df0a53e587aab951b5", size = 91030 }, + { url = "https://files.pythonhosted.org/packages/33/85/bd2e2729752ff4c77338e0102914897512e92496375e079ce0150a6dc306/yarl-1.18.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1dd4bdd05407ced96fed3d7f25dbbf88d2ffb045a0db60dbc247f5b3c5c25d50", size = 142644 }, + { url = "https://files.pythonhosted.org/packages/ff/74/1178322cc0f10288d7eefa6e4a85d8d2e28187ccab13d5b844e8b5d7c88d/yarl-1.18.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7c33dd1931a95e5d9a772d0ac5e44cac8957eaf58e3c8da8c1414de7dd27c576", size = 94962 }, + { url = "https://files.pythonhosted.org/packages/be/75/79c6acc0261e2c2ae8a1c41cf12265e91628c8c58ae91f5ff59e29c0787f/yarl-1.18.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b411eddcfd56a2f0cd6a384e9f4f7aa3efee14b188de13048c25b5e91f1640", size = 92795 }, + { url = "https://files.pythonhosted.org/packages/6b/32/927b2d67a412c31199e83fefdce6e645247b4fb164aa1ecb35a0f9eb2058/yarl-1.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436c4fc0a4d66b2badc6c5fc5ef4e47bb10e4fd9bf0c79524ac719a01f3607c2", size = 332368 }, + { url = "https://files.pythonhosted.org/packages/19/e5/859fca07169d6eceeaa4fde1997c91d8abde4e9a7c018e371640c2da2b71/yarl-1.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e35ef8683211db69ffe129a25d5634319a677570ab6b2eba4afa860f54eeaf75", size = 342314 }, + { url = "https://files.pythonhosted.org/packages/08/75/76b63ccd91c9e03ab213ef27ae6add2e3400e77e5cdddf8ed2dbc36e3f21/yarl-1.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84b2deecba4a3f1a398df819151eb72d29bfeb3b69abb145a00ddc8d30094512", size = 341987 }, + { url = "https://files.pythonhosted.org/packages/1a/e1/a097d5755d3ea8479a42856f51d97eeff7a3a7160593332d98f2709b3580/yarl-1.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e5a1fea0fd4f5bfa7440a47eff01d9822a65b4488f7cff83155a0f31a2ecba", size = 336914 }, + { url = "https://files.pythonhosted.org/packages/0b/42/e1b4d0e396b7987feceebe565286c27bc085bf07d61a59508cdaf2d45e63/yarl-1.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0e883008013c0e4aef84dcfe2a0b172c4d23c2669412cf5b3371003941f72bb", size = 325765 }, + { url = "https://files.pythonhosted.org/packages/7e/18/03a5834ccc9177f97ca1bbb245b93c13e58e8225276f01eedc4cc98ab820/yarl-1.18.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5a3f356548e34a70b0172d8890006c37be92995f62d95a07b4a42e90fba54272", size = 344444 }, + { url = "https://files.pythonhosted.org/packages/c8/03/a713633bdde0640b0472aa197b5b86e90fbc4c5bc05b727b714cd8a40e6d/yarl-1.18.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ccd17349166b1bee6e529b4add61727d3f55edb7babbe4069b5764c9587a8cc6", size = 340760 }, + { url = "https://files.pythonhosted.org/packages/eb/99/f6567e3f3bbad8fd101886ea0276c68ecb86a2b58be0f64077396cd4b95e/yarl-1.18.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b958ddd075ddba5b09bb0be8a6d9906d2ce933aee81100db289badbeb966f54e", size = 346484 }, + { url = "https://files.pythonhosted.org/packages/8e/a9/84717c896b2fc6cb15bd4eecd64e34a2f0a9fd6669e69170c73a8b46795a/yarl-1.18.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c7d79f7d9aabd6011004e33b22bc13056a3e3fb54794d138af57f5ee9d9032cb", size = 359864 }, + { url = "https://files.pythonhosted.org/packages/1e/2e/d0f5f1bef7ee93ed17e739ec8dbcb47794af891f7d165fa6014517b48169/yarl-1.18.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4891ed92157e5430874dad17b15eb1fda57627710756c27422200c52d8a4e393", size = 364537 }, + { url = "https://files.pythonhosted.org/packages/97/8a/568d07c5d4964da5b02621a517532adb8ec5ba181ad1687191fffeda0ab6/yarl-1.18.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce1af883b94304f493698b00d0f006d56aea98aeb49d75ec7d98cd4a777e9285", size = 357861 }, + { url = "https://files.pythonhosted.org/packages/7d/e3/924c3f64b6b3077889df9a1ece1ed8947e7b61b0a933f2ec93041990a677/yarl-1.18.3-cp312-cp312-win32.whl", hash = "sha256:f91c4803173928a25e1a55b943c81f55b8872f0018be83e3ad4938adffb77dd2", size = 84097 }, + { url = "https://files.pythonhosted.org/packages/34/45/0e055320daaabfc169b21ff6174567b2c910c45617b0d79c68d7ab349b02/yarl-1.18.3-cp312-cp312-win_amd64.whl", hash = "sha256:7e2ee16578af3b52ac2f334c3b1f92262f47e02cc6193c598502bd46f5cd1477", size = 90399 }, + { url = "https://files.pythonhosted.org/packages/30/c7/c790513d5328a8390be8f47be5d52e141f78b66c6c48f48d241ca6bd5265/yarl-1.18.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:90adb47ad432332d4f0bc28f83a5963f426ce9a1a8809f5e584e704b82685dcb", size = 140789 }, + { url = "https://files.pythonhosted.org/packages/30/aa/a2f84e93554a578463e2edaaf2300faa61c8701f0898725842c704ba5444/yarl-1.18.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:913829534200eb0f789d45349e55203a091f45c37a2674678744ae52fae23efa", size = 94144 }, + { url = "https://files.pythonhosted.org/packages/c6/fc/d68d8f83714b221a85ce7866832cba36d7c04a68fa6a960b908c2c84f325/yarl-1.18.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ef9f7768395923c3039055c14334ba4d926f3baf7b776c923c93d80195624782", size = 91974 }, + { url = "https://files.pythonhosted.org/packages/56/4e/d2563d8323a7e9a414b5b25341b3942af5902a2263d36d20fb17c40411e2/yarl-1.18.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a19f62ff30117e706ebc9090b8ecc79aeb77d0b1f5ec10d2d27a12bc9f66d0", size = 333587 }, + { url = "https://files.pythonhosted.org/packages/25/c9/cfec0bc0cac8d054be223e9f2c7909d3e8442a856af9dbce7e3442a8ec8d/yarl-1.18.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e17c9361d46a4d5addf777c6dd5eab0715a7684c2f11b88c67ac37edfba6c482", size = 344386 }, + { url = "https://files.pythonhosted.org/packages/ab/5d/4c532190113b25f1364d25f4c319322e86232d69175b91f27e3ebc2caf9a/yarl-1.18.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a74a13a4c857a84a845505fd2d68e54826a2cd01935a96efb1e9d86c728e186", size = 345421 }, + { url = "https://files.pythonhosted.org/packages/23/d1/6cdd1632da013aa6ba18cee4d750d953104a5e7aac44e249d9410a972bf5/yarl-1.18.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41f7ce59d6ee7741af71d82020346af364949314ed3d87553763a2df1829cc58", size = 339384 }, + { url = "https://files.pythonhosted.org/packages/9a/c4/6b3c39bec352e441bd30f432cda6ba51681ab19bb8abe023f0d19777aad1/yarl-1.18.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f52a265001d830bc425f82ca9eabda94a64a4d753b07d623a9f2863fde532b53", size = 326689 }, + { url = "https://files.pythonhosted.org/packages/23/30/07fb088f2eefdc0aa4fc1af4e3ca4eb1a3aadd1ce7d866d74c0f124e6a85/yarl-1.18.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:82123d0c954dc58db301f5021a01854a85bf1f3bb7d12ae0c01afc414a882ca2", size = 345453 }, + { url = "https://files.pythonhosted.org/packages/63/09/d54befb48f9cd8eec43797f624ec37783a0266855f4930a91e3d5c7717f8/yarl-1.18.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:2ec9bbba33b2d00999af4631a3397d1fd78290c48e2a3e52d8dd72db3a067ac8", size = 341872 }, + { url = "https://files.pythonhosted.org/packages/91/26/fd0ef9bf29dd906a84b59f0cd1281e65b0c3e08c6aa94b57f7d11f593518/yarl-1.18.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fbd6748e8ab9b41171bb95c6142faf068f5ef1511935a0aa07025438dd9a9bc1", size = 347497 }, + { url = "https://files.pythonhosted.org/packages/d9/b5/14ac7a256d0511b2ac168d50d4b7d744aea1c1aa20c79f620d1059aab8b2/yarl-1.18.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:877d209b6aebeb5b16c42cbb377f5f94d9e556626b1bfff66d7b0d115be88d0a", size = 359981 }, + { url = "https://files.pythonhosted.org/packages/ca/b3/d493221ad5cbd18bc07e642894030437e405e1413c4236dd5db6e46bcec9/yarl-1.18.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b464c4ab4bfcb41e3bfd3f1c26600d038376c2de3297760dfe064d2cb7ea8e10", size = 366229 }, + { url = "https://files.pythonhosted.org/packages/04/56/6a3e2a5d9152c56c346df9b8fb8edd2c8888b1e03f96324d457e5cf06d34/yarl-1.18.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8d39d351e7faf01483cc7ff7c0213c412e38e5a340238826be7e0e4da450fdc8", size = 360383 }, + { url = "https://files.pythonhosted.org/packages/fd/b7/4b3c7c7913a278d445cc6284e59b2e62fa25e72758f888b7a7a39eb8423f/yarl-1.18.3-cp313-cp313-win32.whl", hash = "sha256:61ee62ead9b68b9123ec24bc866cbef297dd266175d53296e2db5e7f797f902d", size = 310152 }, + { url = "https://files.pythonhosted.org/packages/f5/d5/688db678e987c3e0fb17867970700b92603cadf36c56e5fb08f23e822a0c/yarl-1.18.3-cp313-cp313-win_amd64.whl", hash = "sha256:578e281c393af575879990861823ef19d66e2b1d0098414855dd367e234f5b3c", size = 315723 }, + { url = "https://files.pythonhosted.org/packages/6a/3b/fec4b08f5e88f68e56ee698a59284a73704df2e0e0b5bdf6536c86e76c76/yarl-1.18.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:61e5e68cb65ac8f547f6b5ef933f510134a6bf31bb178be428994b0cb46c2a04", size = 142780 }, + { url = "https://files.pythonhosted.org/packages/ed/85/796b0d6a22d536ec8e14bdbb86519250bad980cec450b6e299b1c2a9079e/yarl-1.18.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe57328fbc1bfd0bd0514470ac692630f3901c0ee39052ae47acd1d90a436719", size = 94981 }, + { url = "https://files.pythonhosted.org/packages/ee/0e/a830fd2238f7a29050f6dd0de748b3d6f33a7dbb67dbbc081a970b2bbbeb/yarl-1.18.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a440a2a624683108a1b454705ecd7afc1c3438a08e890a1513d468671d90a04e", size = 92789 }, + { url = "https://files.pythonhosted.org/packages/0f/4f/438c9fd668954779e48f08c0688ee25e0673380a21bb1e8ccc56de5b55d7/yarl-1.18.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c7907c8548bcd6ab860e5f513e727c53b4a714f459b084f6580b49fa1b9cee", size = 317327 }, + { url = "https://files.pythonhosted.org/packages/bd/79/a78066f06179b4ed4581186c136c12fcfb928c475cbeb23743e71a991935/yarl-1.18.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4f6450109834af88cb4cc5ecddfc5380ebb9c228695afc11915a0bf82116789", size = 336999 }, + { url = "https://files.pythonhosted.org/packages/55/02/527963cf65f34a06aed1e766ff9a3b3e7d0eaa1c90736b2948a62e528e1d/yarl-1.18.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9ca04806f3be0ac6d558fffc2fdf8fcef767e0489d2684a21912cc4ed0cd1b8", size = 331693 }, + { url = "https://files.pythonhosted.org/packages/a2/2a/167447ae39252ba624b98b8c13c0ba35994d40d9110e8a724c83dbbb5822/yarl-1.18.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77a6e85b90a7641d2e07184df5557132a337f136250caafc9ccaa4a2a998ca2c", size = 321473 }, + { url = "https://files.pythonhosted.org/packages/55/03/07955fabb20082373be311c91fd78abe458bc7ff9069d34385e8bddad20e/yarl-1.18.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6333c5a377c8e2f5fae35e7b8f145c617b02c939d04110c76f29ee3676b5f9a5", size = 313571 }, + { url = "https://files.pythonhosted.org/packages/95/e2/67c8d3ec58a8cd8ddb1d63bd06eb7e7b91c9f148707a3eeb5a7ed87df0ef/yarl-1.18.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0b3c92fa08759dbf12b3a59579a4096ba9af8dd344d9a813fc7f5070d86bbab1", size = 325004 }, + { url = "https://files.pythonhosted.org/packages/06/43/51ceb3e427368fe6ccd9eccd162be227fd082523e02bad1fd3063daf68da/yarl-1.18.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:4ac515b860c36becb81bb84b667466885096b5fc85596948548b667da3bf9f24", size = 322677 }, + { url = "https://files.pythonhosted.org/packages/e4/0e/7ef286bfb23267739a703f7b967a858e2128c10bea898de8fa027e962521/yarl-1.18.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:045b8482ce9483ada4f3f23b3774f4e1bf4f23a2d5c912ed5170f68efb053318", size = 332806 }, + { url = "https://files.pythonhosted.org/packages/c8/94/2d1f060f4bfa47c8bd0bcb652bfe71fba881564bcac06ebb6d8ced9ac3bc/yarl-1.18.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:a4bb030cf46a434ec0225bddbebd4b89e6471814ca851abb8696170adb163985", size = 339919 }, + { url = "https://files.pythonhosted.org/packages/8e/8d/73b5f9a6ab69acddf1ca1d5e7bc92f50b69124512e6c26b36844531d7f23/yarl-1.18.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:54d6921f07555713b9300bee9c50fb46e57e2e639027089b1d795ecd9f7fa910", size = 340960 }, + { url = "https://files.pythonhosted.org/packages/41/13/ce6bc32be4476b60f4f8694831f49590884b2c975afcffc8d533bf2be7ec/yarl-1.18.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1d407181cfa6e70077df3377938c08012d18893f9f20e92f7d2f314a437c30b1", size = 336592 }, + { url = "https://files.pythonhosted.org/packages/81/d5/6e0460292d6299ac3919945f912b16b104f4e81ab20bf53e0872a1296daf/yarl-1.18.3-cp39-cp39-win32.whl", hash = "sha256:ac36703a585e0929b032fbaab0707b75dc12703766d0b53486eabd5139ebadd5", size = 84833 }, + { url = "https://files.pythonhosted.org/packages/b2/fc/a8aef69156ad5508165d8ae956736d55c3a68890610834bd985540966008/yarl-1.18.3-cp39-cp39-win_amd64.whl", hash = "sha256:ba87babd629f8af77f557b61e49e7c7cac36f22f871156b91e10a6e9d4f829e9", size = 90968 }, + { url = "https://files.pythonhosted.org/packages/f5/4b/a06e0ec3d155924f77835ed2d167ebd3b211a7b0853da1cf8d8414d784ef/yarl-1.18.3-py3-none-any.whl", hash = "sha256:b57f4f58099328dfb26c6a771d09fb20dbbae81d20cfb66141251ea063bd101b", size = 45109 }, +]