# Train your own classifier with our pre-trained model

In this notebook, we will show how to train a text classifier on your own data by starting from our pre-trained model. The pre-trained model that we will be using here is multi-lingual. Like with all of ThirdAI's classifiers, the inference latency would still be < 5 ms per sample on a single CPU thread, even when you pass in a text chunk with 5000 tokens!

For this notebook, we will be using the CureKart dataset hosted [here](https://github.com/hellohaptik/HINT3/tree/master/dataset/v2). CureKart has 20 unique labels with just 600 training samples.

Our model acheives a near SOTA accuracy of 84%, while being an order of magnitude faster than the SOTA models on inference latency.

In [None]:
!pip3 install thirdai --upgrade

### Activate ThirdAI's license key

In [None]:
from thirdai import bolt, licensing
import pandas as pd

import os
if "THIRDAI_KEY" in os.environ:
    licensing.activate(os.environ["THIRDAI_KEY"])
else:
    licensing.activate("")  # Enter your ThirdAI key here

### Download Pre-trained Model

In [None]:
import os

if not os.path.isdir("./models/"):
    os.system("mkdir ./models/")

if not os.path.exists("./models/pretrained_multilingual.model"):
    os.system("wget -nv -O ./models/pretrained_multilingual.model 'https://www.dropbox.com/scl/fi/qem5aqhsh5no6bdb4395a/pretrained_multilingual.model?rlkey=o4cegybi7xc06kj83mhbv8tru&st=8y4v7l9q&dl=0'")


### Load the Pre-trained Model

In [None]:
pretrained_model = bolt.PretrainedBase.load("./models/pretrained_multilingual.model")

### Load the Dataset
We have bundled the pre-processed train and test csv files in the repo. Like mentioned earlier, the original dataset can be found [here](https://github.com/hellohaptik/HINT3/tree/master/dataset/v2).

In [None]:
train_file = "./datasets/curekart/curekart_train.csv"
test_file = "./datasets/curekart/curekart_test.csv"

In [None]:
df = pd.read_csv(train_file)
n_target_classes = df.label.nunique()

### Define a UDT with the pre-trained model

In [None]:
model = bolt.UniversalDeepTransformer(
    data_types={
        "query": bolt.types.text(),
        "label": bolt.types.categorical(n_classes=n_target_classes, type="int"),
    },
    target="label",
    pretrained_model=pretrained_model,
)

### Train the model

In [None]:
model.train(train_file, epochs=50, learning_rate=0.001, metrics=["precision@1"])

### Evaluate the model

In [None]:
metrics = model.evaluate(test_file, metrics=["precision@1"])
print(metrics)

In [None]:
model.predict({"query":"test query"})