# Pre-training and fine-tuning an LLM on CPU on AG News with ThirdAI's UDT

In this notebook, we will pre-train an LLM from scratch on the popular AG News Dataset (https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset) using ThirdAI's Universal Deep Transformer (UDT). We will demonstrate how UDT can just pre-train on a small dataset and outperform the Semantic Search offering of OpenAI. 

This demo shows that one-model for all is sub-optimal and pre-training/fine-tuning on specific downstream datasets is required to get the best results.

While most LLMs cannot be fine-tuned even on a powerful GPU, ThirdAI's UDT can train a billion parameter model on just a moderate CPU in few minutes.

### Import thirdai and activate license

In [None]:
from thirdai import bolt,licensing

licensing.activate('71FC4B-F20E8F-D7C39E-4E936C-404BC9-V3')

### Download and process the dataset into a csv file.

In [None]:
from thirdai.demos import download_agnews_dataset

corpus_file = './datasets/agnews.csv'
n_target_classes = download_agnews_dataset(corpus_file)

In the above step, *corpus_file* refers to the corpus file with document id and text. We can have even more columns with other metadata for each row. Pre-training with UDT supports two types of columns, strong and weak. For the purpose of this demo, we choose *text* to be the strong column and leave the weak column list to be empty.

A couple of sample rows of the *corpus_file* are shown below.

PLEASE NOTE: Currently, UDT's cold_start function requires the *id* to be an integer. We will add support for other formats in a future release.

In [2]:
import pandas as pd

pd.options.display.max_colwidth = 700
pd.read_csv(corpus_file, nrows=2)

Unnamed: 0,id,text
0,0,wall st. bears claw back into the black (reuters) reuters - short-sellers wall street's dwindling\band of ultra-cynics are seeing green again.
1,1,carlyle looks toward commercial aerospace (reuters) reuters - private investment firm carlyle group \which has a reputation for making well-timed and occasionally\controversial plays in the defense industry has quietly placed\its bets on another part of the market.


### Define a UDT model

The column name *query* can be anything of your choice.
The column name *id* should match with the one in the header of the *corpus_file*.

In [None]:
model = bolt.UniversalDeepTransformer(
    data_types={
        "query": bolt.types.text(),
        "id": bolt.types.categorical(delimiter=':'),
    },
    target="id",
    n_target_classes=n_target_classes,
    integer_target=True,
    model_config='./configs/embeddings_and_cold_start_0.005.config',
)

### Pre-train (Cold Start) on the *corpus_file*

In the following step, we do the pre-training by specifying the strong and weak columns. For this demo, we use *text* as the strong column and leave the weak columns to be an emplty list. We can have more columns in either of the lists. The training time and the accuracies are shown below.

In [5]:
model.cold_start(
    filename=corpus_file,
    strong_column_names=["text"],
    weak_column_names=[],
    learning_rate=0.001,
    epochs=5,
    metrics=['categorical_accuracy'],
)

loading data | source './datasets/agnews.csv'
loaded data | source './datasets/agnews.csv' | vectors 120000 | batches 59 | time 0s | complete

train | epoch 0 | train_steps 59 | {categorical_accuracy: 0.00035} | train_batches 59 | time 103s | complete

train | epoch 1 | train_steps 118 | {categorical_accuracy: 0.068775} | train_batches 59 | time 79s | complete

train | epoch 2 | train_steps 177 | {categorical_accuracy: 0.350508} | train_batches 59 | time 85s | complete

train | epoch 3 | train_steps 236 | {categorical_accuracy: 0.656625} | train_batches 59 | time 86s | complete

train | epoch 4 | train_steps 295 | {categorical_accuracy: 0.825342} | train_batches 59 | time 86s | complete



{'epoch_times': [79.0, 85.0, 86.0, 86.0],
 'categorical_accuracy': [0.068775,
  0.3505083333333333,
  0.656625,
  0.8253416666666666]}

### Save and load the model

In [6]:
model.save('./agnews.model')

model = bolt.UniversalDeepTransformer.load('./agnews.model')

### Make Predictions

In [9]:
import numpy as np
import pandas as pd

activations = model.predict({'query':'british prime minister'})
top_preds = np.argsort(-activations)[:5]

df = pd.read_csv(corpus_file)
df.iloc[top_preds]

Unnamed: 0,id,text
117040,117040,top uk minister blunkett quits ahead of election (reuters) reuters - senior british government minister david\blunkett resigned on wednesday ripping a hole in prime\minister tony blair's team months before an expected general\election.
3372,3372,bomb is defused near a villa where berlusconi met with blair ome aug. 18 - the police defused a bomb early wednesday morning in porto rotondo sardinia the town where hours earlier prime minister silvio berlusconi had entertained the british prime minister tony blair and his wife cherie.
11371,11371,vote for pakistan prime minister pakistan's parliament is expected on friday to vote for outgoing finance minister shaukat aziz as prime minister.
2061,2061,pm #39;a trustworthy invidual #39; prime minister john howard was a trustworthy individual and most australians knew it deputy prime minister john anderson said.
106055,106055,myanmar overshadows trade deals at summit australian prime minister john howard (2nd l) sits next to (lr round table) thai prime minister thaksin shiwanatra singaporean pm lee hsien loong philippine president gloria arroyo myanmar prime minister soe win and malaysian prime minister abullah
