# Training a classifier using model2vec

Model2Vec supports built-in classifier training with an easy, scikit-learn-based syntax. Just give the model your data in `.fit`, and you'll have a trained model!

How it works:
* We load a base `StaticModel` using as a torch module. By default we use [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M).
* We add a one-layer MLP with 512 hidden units and `ReLU` activation as a head.
* We train the model using cross-entropy, using [`pytorch-lightning`](https://lightning.ai/docs/pytorch/stable/) as a training framework.

After training, you can export the model using regular torch tools, such as `torch.save` and `torch.load`, or you can export the model to a `scikit-learn` pipeline. The latter option leads to a really small footprint during inference, as there is no longer a need to use `torch`.

In [None]:
# Install the necessary libraries
!uv pip install "model2vec[train,inference]"
!uv pip install "datasets"
!uv pip install "scikit-learn"

# Import the necessary libraries
from model2vec.train import StaticModelForClassification
from model2vec.inference import StaticModelPipeline

To demonstrate how to train a model, we'll be using the `subjectivity` dataset, which contains short utterances and whether they are subjective or objective.

In [None]:
from datasets import load_dataset

dataset = load_dataset("setfit/20_newsgroups")
print(dataset)

Let's take a look at the first five training samples:

In [None]:
# First 5 training samples:
for record in dataset["train"].to_list()[:5]:
    print(f"TEXT: {record['text']} LABEL: {record['label_text']}")

In [None]:
# Define the staticmodel
model = StaticModelForClassification.from_pretrained()
# Optional arguments:
# model_name: the name of the base model (defaults to potion-base-8m)
# n_layers: the number of layers in the MLP (defaults to 1)
# hidden_dim: the number of hidden units (defaults to 512)
print(model)

Now let's train the model on a subset of examples. We pick the first 1000 examples to train on.

In [None]:
import time
# Fit the model on the first 1000 records
subset = dataset["train"].select(range(1000))
s = time.time()
model = model.fit(subset["text"], subset["label_text"])
print(f"training took {time.time() - s} seconds")
# Fit takes many many arguments, check them out!

We have trained a classifier in five seconds. Nice!

Let's take a look at how good it is.

In [None]:
from sklearn.metrics import classification_report

predictions = model.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

Our model scores 0.55 accuracy. But what does this mean? Let's compare it to a `tf-idf` pipeline from `scikit-learn`.

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline

sklearn_pipeline = make_pipeline(TfidfVectorizer(), LogisticRegression())
sklearn_pipeline.fit(subset["text"], subset["label_text"])
predictions = sklearn_pipeline.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

Pretty good! We outperform the tf-idf pipeline by a wide margin.

We can now export the model to scikit-learn, and push it to the hub. But first, let's verify whether the predictions of this model and the original model match.

In [None]:
pipeline = model.to_pipeline()

predictions = pipeline.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

Ok, so let's save the model locally, or push it to the hub!

In [None]:
pipeline.save_pretrained("my_cool_model")
# Fill in your own org
# pipeline.push_to_hub("my_org/my_model")

This saves a model to a local folder. The model can then be loaded as follows:

In [None]:
new_model = StaticModelPipeline.from_pretrained("my_cool_model")
# Or from the hub
# model = StaticModelPipeline.from_pretrained("my_org/my_model")

One reason to work like this is that the `StaticModelPipeline` does not require torch to be installed at all, leading to really fast cold start predictions, smaller images, and a lot less hassle overall.

And that's it! Super fast, super small, super good classifiers.