# Training a classifier using model2vec

Model2Vec supports built-in classifier training with an easy, scikit-learn-based syntax. Just give the model your data in `.fit`, and you'll have a trained model!

How it works:
* We load a base `StaticModel` using as a torch module. By default we use [potion-base-8m](https://huggingface.co/minishlab/potion-base-8M).
* We add a one-layer MLP with 512 hidden units and `ReLU` activation as a head.
* We train the model using cross-entropy, using [`pytorch-lightning`](https://lightning.ai/docs/pytorch/stable/) as a training framework.

After training, you can export the model using regular torch tools, such as `torch.save` and `torch.load`, or you can export the model to a `scikit-learn` pipeline. The latter option leads to a really small footprint during inference, as there is no longer a need to use `torch`.

In [1]:
# Install the necessary libraries
!uv pip install "model2vec[train,inference]"
!uv pip install "datasets"
!uv pip install "scikit-learn"

# Import the necessary libraries
from model2vec.train import StaticModelForClassification
from model2vec.inference import StaticModelPipeline

[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv[0m
[2mAudited [1m1 package[0m [2min 4ms[0m[0m
[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv[0m
[2mAudited [1m1 package[0m [2min 8ms[0m[0m
[2mUsing Python 3.11.4 environment at: /Users/stephantulkens/Documents/GitHub/model2vec/.venv[0m
[2mAudited [1m1 package[0m [2min 3ms[0m[0m


To demonstrate how to train a model, we'll be using the `20_newsgroups` dataset, which contains posts from 1 of 20 newsgroups.

In [2]:
from datasets import load_dataset

dataset = load_dataset("setfit/20_newsgroups")
print(dataset)

Repo card metadata block was not found. Setting CardData to empty.


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 11314
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 7532
    })
})


Let's take a look at the first five training samples:

In [3]:
# First 5 training samples:
for record in dataset["train"].to_list()[:5]:
    print(f"TEXT: {record['text']} LABEL: {record['label_text']}")

TEXT: I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail. LABEL: rec.autos
TEXT: A fair number of brave souls who upgraded their SI clock oscillator have
shared their experiences for this poll. Please send a brief message detailing
your experiences with the procedure. Top speed attained, CPU rated speed,
add on cards and adapters, heat sinks, hour of usage per day, floppy disk
functionality with 800 and 1.4 m floppies are especially requested.

I will be summarizing in the next two days, so please add to the network
knowledge base if you have done the clock upgrade and 

In [4]:
# Define the staticmodel
model = StaticModelForClassification.from_pretrained()
# Optional arguments:
# model_name: the name of the base model (defaults to potion-base-8m)
# n_layers: the number of layers in the MLP (defaults to 1)
# hidden_dim: the number of hidden units (defaults to 512)
print(model)

StaticModelForClassification(
  (embeddings): Embedding(29528, 256, padding_idx=0)
  (head): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=2, bias=True)
  )
)


Now let's train the model on a subset of examples. We pick the first 1000 examples to train on.

In [5]:
import time
# Fit the model on the first 1000 records
subset = dataset["train"].select(range(1000))
s = time.time()
model = model.fit(subset["text"], subset["label_text"])
print(f"training took {time.time() - s} seconds")
# Fit takes many many arguments, check them out!

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name  | Type                         | Params | Mode 
---------------------------------------------------------------
0 | model | StaticModelForClassification | 7.7 M  | train
---------------------------------------------------------------
7.7 M     Trainable params
0         Non-trainable params

Sanity Checking: |                                                                             | 0/? [00:00<?,…

/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/stephantulkens/Documents/GitHub/model2vec/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (29) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training 

Training: |                                                                                    | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

training took 8.715388059616089 seconds


We have trained a classifier in 9 seconds. Nice!

Let's take a look at how good it is.

In [6]:
from sklearn.metrics import classification_report

predictions = model.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

                          precision    recall  f1-score   support

             alt.atheism       0.27      0.38      0.31       319
           comp.graphics       0.62      0.55      0.58       389
 comp.os.ms-windows.misc       0.47      0.50      0.48       394
comp.sys.ibm.pc.hardware       0.50      0.49      0.49       392
   comp.sys.mac.hardware       0.47      0.47      0.47       385
          comp.windows.x       0.75      0.57      0.65       395
            misc.forsale       0.69      0.75      0.72       390
               rec.autos       0.46      0.67      0.54       396
         rec.motorcycles       0.69      0.56      0.61       398
      rec.sport.baseball       0.73      0.72      0.72       397
        rec.sport.hockey       0.82      0.76      0.79       399
               sci.crypt       0.60      0.62      0.61       396
         sci.electronics       0.42      0.47      0.44       393
                 sci.med       0.68      0.75      0.71       396
         

Our model scores 0.57 accuracy. But what does this mean? Let's compare it to a `tf-idf` pipeline from `scikit-learn`.

In [7]:
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline

sklearn_pipeline = make_pipeline(TfidfVectorizer(), LogisticRegression())
sklearn_pipeline.fit(subset["text"], subset["label_text"])
predictions = sklearn_pipeline.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

                          precision    recall  f1-score   support

             alt.atheism       0.21      0.17      0.19       319
           comp.graphics       0.73      0.30      0.43       389
 comp.os.ms-windows.misc       0.44      0.57      0.50       394
comp.sys.ibm.pc.hardware       0.67      0.18      0.28       392
   comp.sys.mac.hardware       0.29      0.67      0.41       385
          comp.windows.x       0.83      0.41      0.55       395
            misc.forsale       0.49      0.79      0.61       390
               rec.autos       0.61      0.57      0.59       396
         rec.motorcycles       0.80      0.40      0.53       398
      rec.sport.baseball       0.29      0.64      0.40       397
        rec.sport.hockey       0.77      0.61      0.68       399
               sci.crypt       0.71      0.48      0.57       396
         sci.electronics       0.32      0.31      0.31       393
                 sci.med       0.66      0.31      0.42       396
         

Pretty good! We outperform the tf-idf pipeline by a wide margin.

We can now export the model to scikit-learn, and push it to the hub. But first, let's verify whether the predictions of this model and the original model match.

In [8]:
pipeline = model.to_pipeline()

predictions = pipeline.predict(dataset["test"]["text"])

print(classification_report(dataset["test"]["label_text"], predictions))

                          precision    recall  f1-score   support

             alt.atheism       0.27      0.38      0.31       319
           comp.graphics       0.62      0.55      0.58       389
 comp.os.ms-windows.misc       0.47      0.50      0.48       394
comp.sys.ibm.pc.hardware       0.50      0.49      0.49       392
   comp.sys.mac.hardware       0.47      0.47      0.47       385
          comp.windows.x       0.75      0.57      0.65       395
            misc.forsale       0.69      0.75      0.72       390
               rec.autos       0.46      0.67      0.54       396
         rec.motorcycles       0.69      0.56      0.61       398
      rec.sport.baseball       0.73      0.72      0.72       397
        rec.sport.hockey       0.82      0.76      0.79       399
               sci.crypt       0.60      0.62      0.61       396
         sci.electronics       0.42      0.47      0.44       393
                 sci.med       0.68      0.75      0.71       396
         

Ok, so let's save the model locally, or push it to the hub!

In [9]:
pipeline.save_pretrained("my_cool_model")
# Fill in your own org
# pipeline.push_to_hub("my_org/my_model")

This saves a model to a local folder. The model can then be loaded as follows:

In [10]:
new_model = StaticModelPipeline.from_pretrained("my_cool_model")
# Or from the hub
# model = StaticModelPipeline.from_pretrained("my_org/my_model")

One reason to work like this is that the `StaticModelPipeline` does not require torch to be installed at all, leading to really fast cold start predictions, smaller images, and a lot less hassle overall.

And that's it! Super fast, super small, super good classifiers.