# [Distributed] Pre-training and fine-tuning an LLM on CPU on AG News with ThirdAI's UDT

This training script shows how to run data parallel pre-training of an LLM from scratch on the popular AG News Dataset (https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset) using ThirdAI's Universal Deep Transformer (UDT).

## 1. Import thirdai and activate license

In [1]:
!pip3 install datasets
# !pip3 install thirdai --upgrade
!pip3 install ray
!pip3 install torch

import thirdai
from thirdai import bolt
import thirdai.distributed_bolt as dist  
# thirdai.licensing.activate('71FC4B-F20E8F-D7C39E-4E936C-404BC9-V3')



  from .autonotebook import tqdm as notebook_tqdm
2023-08-08 08:17:23,292	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-08-08 08:17:27,610	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-08-08 08:17:27,874	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


## 2. Ray Cluster Initialization
 For the purpose of this demo, we will be initializing a mock ray cluster of 2 nodes here. Change num_cpus accordingly.

In [2]:
import ray
from ray.air import ScalingConfig, session

cpus_per_node = (dist.get_num_cpus() - 1) // 2

ray.init(ignore_reinit_error=True)
scaling_config = ScalingConfig(
    num_workers=2,
    use_gpu=False,
    trainer_resources={"CPU": 1},
    resources_per_worker={"CPU": cpus_per_node},
    placement_strategy="PACK",
)

2023-08-08 08:17:31,019	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


## 3. Download and process the dataset into a csv file. <br>
### 3.1 We divide our dataset into 2 datasets for purpose of data-parallel training. 

In [3]:
from datasets import load_dataset

file_1 = open('agnews_train_0.csv', 'w')
file_2 = open('agnews_train_1.csv', 'w')

corpus = load_dataset("ag_news")["train"]["text"]
num_datapoints = len(corpus)

file_1.write("id,text\n")
file_2.write("id,text\n")

idx = 0
for line in corpus:
    if idx < num_datapoints//2:
        nothing = file_1.write(str(idx) + "," + line.replace(",", " ").lower() + "\n")
    else:
        nothing = file_2.write(str(idx) + "," + line.replace(",", " ").lower() + "\n")

    idx += 1

file_1.close()
file_2.close()

train_filenames = ['agnews_train_0.csv', 'agnews_train_1.csv']

### 3.2 Looking at the dataset
In the above step, *agnews_train_[0/1].csv* files refers to the corpus file with document id and text. We can have even more columns with other metadata for each row. <br>

A couple of sample rows of the *corpus_file* are shown below.

In [4]:
import pandas as pd

pd.options.display.max_colwidth = 700
pd.read_csv(train_filenames[0], nrows=4)

Unnamed: 0,id,text
0,0,wall st. bears claw back into the black (reuters) reuters - short-sellers wall street's dwindling\band of ultra-cynics are seeing green again.
1,1,carlyle looks toward commercial aerospace (reuters) reuters - private investment firm carlyle group \which has a reputation for making well-timed and occasionally\controversial plays in the defense industry has quietly placed\its bets on another part of the market.
2,2,oil and economy cloud stocks' outlook (reuters) reuters - soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.
3,3,iraq halts oil exports from main southern pipeline (reuters) reuters - authorities have halted oil export\flows from the main pipeline in southern iraq after\intelligence showed a rebel militia could strike\infrastructure an oil official said on saturday.


## 4. Define a UDT model and training loop

In the UDT model, <i>data_type</i> *query* can be anything of your choice but *id* should match with the one in the header of the *corpus_file*.

The <b><i>train_loop_per_worker</i></b> defines the training that will run on each worker node.

Pre-training with UDT supports two types of columns, strong and weak. For the purpose of this demo, we choose *text* to be the strong column and leave the weak column list to be empty.<br>
We will now train a UDT model in distributed data parallel fashion. Feel free to customize the number of epochs and the learning rate. <br>
<u><b>PLEASE NOTE :</b></u> Currently, UDT's cold_start function requires the *id* to be an integer. We will add support for other formats in a future release.

In [5]:
def get_udt_model(model_config_path):
    model = bolt.UniversalDeepTransformer(
        data_types={
            "query": bolt.types.text(),
            "id": bolt.types.categorical(delimiter=':'),
        },
        target="id",
        n_target_classes=num_datapoints,
        integer_target=True,
        model_config=model_config_path,
    )
    return model
    
def train_loop_per_worker(config):
    # thirdai.licensing.deactivate()
    # thirdai.licensing.activate("HN7J-W79C-KN9U-WTKE-9PNM-4PVR-CNPJ-WTWE") 

    model_config_path = os.path.join(config["curr_dir"], '../configs/embeddings_and_cold_start_0.005.config')
    model = get_udt_model(model_config_path)
    model = dist.prepare_model(model)

    metrics = model.coldstart_distributed_v2(
        filename=os.path.join(config["curr_dir"], train_filenames[session.get_world_rank()]),
        strong_column_names=["text"],
        weak_column_names=[],
        learning_rate=0.001,
        epochs=5,
        metrics=["categorical_accuracy"],
        verbose=True,
    )

    session.report(
        metrics=metrics,
        checkpoint=dist.UDTCheckPoint.from_model(model),
    )


## 5. Distributed Training

Now, we start the training using <b>ThirdAI</b> <i>BoltTrainer</i> which runs the <i>train_loop_per_worker</i> function on different worker nodes.

In [6]:
import os
from ray.train.torch import TorchConfig

# thirdai.licensing.activate("HN7J-W79C-KN9U-WTKE-9PNM-4PVR-CNPJ-WTWE") 
licensing_lambda = None
# if hasattr(thirdai._thirdai, "licensing"):
#     license_state = thirdai._thirdai.licensing._get_license_state()
#     licensing_lambda = lambda: thirdai._thirdai.licensing._set_license_state(license_state)

trainer = dist.BoltTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={
        "curr_dir": os.path.abspath(os.getcwd()),
    },
    scaling_config=scaling_config,
    backend_config=TorchConfig(backend="gloo"),
)

result_checkpoint_and_history = trainer.fit()

0,1
Current time:,2023-08-08 08:17:52
Running for:,00:00:00.14
Memory:,64.4/251.7 GiB

Trial name,status,loc
BoltTrainer_11b15_00000,PENDING,


[2m[36m(pid=863909)[0m NCCL seems unavailable. Please install Cupy following the guide at: https://docs.cupy.dev/en/stable/install.html.
[2m[36m(BoltTrainer pid=863909)[0m Starting distributed worker processes: ['863982 (192.168.1.5)', '863983 (192.168.1.5)']
[2m[36m(RayTrainWorker pid=863982)[0m Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(RayTrainWorker pid=863982)[0m NCCL seems unavailable. Please install Cupy following the guide at: https://docs.cupy.dev/en/stable/install.html.


[2m[36m(RayTrainWorker pid=863982)[0m loading data | source '/home/mritunjay/Demos/distributed/agnews_train_0.csv'
[2m[36m(RayTrainWorker pid=863982)[0m loaded data | source '/home/mritunjay/Demos/distributed/agnews_train_0.csv' | vectors 60000 | batches 30 | time 0s | complete
[2m[36m(RayTrainWorker pid=863982)[0m 
train: [                                                  ] 0%          
[2m[36m(RayTrainWorker pid=863983)[0m loading data | source '/home/mritunjay/Demos/distributed/agnews_train_1.csv'
train: [==                                                ] 3%
[2m[36m(RayTrainWorker pid=863983)[0m loaded data | source '/home/mritunjay/Demos/distributed/agnews_train_1.csv' | vectors 60000 | batches 30 | time 0s | complete
[2m[36m(RayTrainWorker pid=863983)[0m 
train: [                                                  ] 0%          
train: [===                                               ] 6%
train: [==                                                ] 3%
train: [===

## 6. Save and load the model

In [None]:
model = result_checkpoint_and_history.checkpoint.get_model()
model.save('./agnews.model')

model = bolt.UniversalDeepTransformer.load('./agnews.model')

## 7. Make Predictions

### Example 1

In [None]:
import numpy as np
import pandas as pd

from thirdai.demos import download_agnews_dataset

corpus_file = './agnews.csv'
download_agnews_dataset(corpus_file)

df = pd.read_csv(corpus_file)

activations = model.predict({'query':'BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the  quot;alarming quot; growth of greenhouse gases'})
top_preds = np.argsort(-activations)[:5]

df.iloc[top_preds]

For the same example, here are the top-5 results that OpenAI's Search and Recommendation notebook (https://github.com/openai/openai-cookbook/blob/main/examples/Recommendation_using_embeddings.ipynb) gets.

| text |
| --- |
| THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today |
| LONDON, England -- A US scientist is reported to have observed a surprising jump in the amount of carbon dioxide, the main greenhouse gas. |
| The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war. |
| Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed |
| AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair. |


### Example 2

In [None]:
activations = model.predict({'query':'PC World - Upcoming chip set will include built-in security features for your PC'})
top_preds = np.argsort(-activations)[:5]

df.iloc[top_preds]

For the same example, here are the top-5 results that OpenAI's Search and Recommendation notebook (https://github.com/openai/openai-cookbook/blob/main/examples/Recommendation_using_embeddings.ipynb) gets.

| text |
| --- |
| PC World - Updated antivirus software for businesses adds intrusion prevention features. |
| PC World - The one-time World Class Product of the Year PDA gets a much-needed upgrade. |
| PC World - Send your video throughout your house--wirelessly--with new gateways and media adapters. |
| PC World - Symantec, McAfee hope raising virus-definition fees will move users to\  suites. |
| Gateway computers will be more widely available at Office Depot, in the PC maker #39;s latest move to broaden distribution at retail stores since acquiring rival eMachines this year. |