# Train your own classifier with our pre-trained model

In this notebook, we will show how to train a text classifier on your own data by starting from our pre-trained model. The pre-trained model that we will be using here is multi-lingual. Like with all of ThirdAI's classifiers, the inference latency would still be < 5 ms per sample on a single CPU thread, even when you pass in a text chunk with 5000 tokens!

For this notebook, we will be using the CureKart dataset hosted [here](https://github.com/hellohaptik/HINT3/tree/master/dataset/v2). CureKart has 20 unique labels with just 600 training samples.

Our model acheives a near SOTA accuracy of 84%, while being an order of magnitude faster than the SOTA models on inference latency.

In [None]:
!pip3 install thirdai --upgrade

### Activate ThirdAI's license key

In [3]:
from thirdai import bolt, licensing
import pandas as pd

import os
if "THIRDAI_KEY" in os.environ:
    licensing.activate(os.environ["THIRDAI_KEY"])
else:
    licensing.activate("")  # Enter your ThirdAI key here

### Download Pre-trained Model

In [4]:
import os

if not os.path.isdir("./models/"):
    os.system("mkdir ./models/")

if not os.path.exists("./models/pretrained_multilingual.model"):
    os.system("wget -nv -O ./models/pretrained_multilingual.model 'https://www.dropbox.com/scl/fi/qem5aqhsh5no6bdb4395a/pretrained_multilingual.model?rlkey=o4cegybi7xc06kj83mhbv8tru&st=8y4v7l9q&dl=0'")


### Load the Pre-trained Model

In [5]:
pretrained_model = bolt.PretrainedBase.load("./models/pretrained_multilingual.model")

In [8]:
train_file = "./datasets/curekart/curekart_train.csv"
test_file = "./datasets/curekart/curekart_test.csv"

In [9]:
df = pd.read_csv(train_file)
n_target_classes = df.label.nunique()

In [10]:
model = bolt.UniversalDeepTransformer(
        data_types={
            "query": bolt.types.text(),
            "label": bolt.types.categorical(),
        },
        target="label",
        n_target_classes=n_target_classes,
        integer_target=True,
        pretrained_model=pretrained_model,
        options={
            "embedding_dimension": 2000,
        },
    )

In [11]:
model.train(train_file, epochs=50, learning_rate=0.001, metrics=["precision@1"])

loading data | source './datasets/curekart/curekart_train.csv'
loading data | source './datasets/curekart/curekart_train.csv' | vectors 599 | batches 1 | time 0.072s | complete

train | epoch 0 | train_steps 1 | train_precision@1=0.0417362  | train_batches 1 | time 2.339s

train | epoch 1 | train_steps 2 | train_precision@1=0.732888  | train_batches 1 | time 0.082s

train | epoch 2 | train_steps 3 | train_precision@1=0.796327  | train_batches 1 | time 0.072s

train | epoch 3 | train_steps 4 | train_precision@1=0.792988  | train_batches 1 | time 0.062s

train | epoch 4 | train_steps 5 | train_precision@1=0.784641  | train_batches 1 | time 0.078s

train | epoch 5 | train_steps 6 | train_precision@1=0.784641  | train_batches 1 | time 0.064s

train | epoch 6 | train_steps 7 | train_precision@1=0.784641  | train_batches 1 | time 0.071s

train | epoch 7 | train_steps 8 | train_precision@1=0.78631  | train_batches 1 | time 0.068s

train | epoch 8 | train_steps 9 | train_precision@1=0.808013  

{'epoch_times': [2.3410000801086426,
  0.08399999886751175,
  0.07400000095367432,
  0.06800000369548798,
  0.07999999821186066,
  0.0689999982714653,
  0.07699999958276749,
  0.07400000095367432,
  0.0729999989271164,
  0.08299999684095383,
  0.057999998331069946,
  0.08500000089406967,
  0.08900000154972076,
  0.06599999964237213,
  0.07800000160932541,
  0.07599999755620956,
  0.07599999755620956,
  0.08299999684095383,
  0.07000000029802322,
  0.0820000022649765,
  0.07100000232458115,
  0.07999999821186066,
  0.08100000023841858,
  0.06300000101327896,
  0.0989999994635582,
  0.05700000002980232,
  0.0820000022649765,
  0.08299999684095383,
  0.06400000303983688,
  0.0729999989271164,
  0.06300000101327896,
  0.07000000029802322,
  0.08100000023841858,
  0.0689999982714653,
  0.08100000023841858,
  0.08100000023841858,
  0.06800000369548798,
  0.09700000286102295,
  0.07000000029802322,
  0.07999999821186066,
  0.07199999690055847,
  0.09099999815225601,
  0.07000000029802322,
  0

In [12]:
metrics = model.evaluate(test_file, metrics=["precision@1"])
print(metrics)

loading data | source './datasets/curekart/curekart_test.csv'
loading data | source './datasets/curekart/curekart_test.csv' | vectors 459 | batches 1 | time 0.018s | complete

validate | epoch 50 | train_steps 50 | val_precision@1=0.840959  | val_batches 1 | time 0.027s

{'val_times': [0.027000000700354576], 'val_precision@1': [0.8409585952758789]}


In [None]:
model.predict({"query":"test query"})