## Data Parallel Training for ThirdAI's UDT

This notebook shows how to run Data Parallel Training for ThirdAI's UDT. We will be using CLINC 150 small dataset for training and evaluation for this demo. But, you can easily replace this with your workload. 

ThirdAI's Distributed Data Parallel Training assumes that you already have a ray cluster running. For this demo, we would be using the ray mock cluster to simulate the ray cluster. For seetting up a ray cluster, see here: https://docs.ray.io/en/latest/cluster/getting-started.html

In [None]:
!pip3 install thirdai --upgrade
!pip3 install pyarrow
!pip3 install 'ray>=2.7.0'
!pip3 install torch

import os
import thirdai
from thirdai import bolt
import thirdai.distributed_bolt as dist     

## Ray Cluster Initialization
For the purpose of this demo, we will be initializing a mock ray cluster of 2 nodes here.

In [None]:
import ray
from ray import train
from ray.train import ScalingConfig, RunConfig

cpus_per_node = (dist.get_num_cpus() - 1) // 2

ray.init(ignore_reinit_error=True, runtime_env={"env_vars": {"OMP_NUM_THREADS": f"{cpus_per_node}"}})
scaling_config = ScalingConfig(
    num_workers=2,
    use_gpu=False,
    trainer_resources={"CPU": 1},
    resources_per_worker={"CPU": cpus_per_node},
    placement_strategy="PACK",
)

# We need to specify `storage_path` in `RunConfig` which must be a networked file system 
# or cloud storage path accessible by all workers. (Ray 2.7.0 onwards)
run_config = RunConfig(
    name= "distributed_clinc",
    storage_path= "~/ray_results", # For the purpose of this demo, this `storage_path` will work fine since both workers are run on same machine.
)
thirdai.licensing.activate("WUAT-V7FP-TXLJ-97KR-3MCV-H4UC-7ERL-JYAF") 

# Dataset Download

We will use the demos module in the thirdai package to download the CLINC 150 small dataset. You can replace theis step and the next step with a download method and a UDT initialization that is specific to your dataset.

In [None]:
from thirdai.demos import download_clinc_dataset

train_filenames, test_filename , _ = download_clinc_dataset(num_training_files=2, clinc_small=True)

# UDT Initialization
We can now create a UDT model by passing in the types of each column in the dataset and the target column we want to be able to predict.

In [None]:
def get_udt_model():
    model = bolt.UniversalDeepTransformer(
        data_types={
            "text": bolt.types.text(),
            "category": bolt.types.categorical(),
        },
        target="category",
        n_target_classes=151,
        integer_target=True,
    )
    return model
    
def train_loop_per_worker(config):
    # thirdai.licensing.deactivate()
    thirdai.licensing.activate("WUAT-V7FP-TXLJ-97KR-3MCV-H4UC-7ERL-JYAF") 
    thirdai.logging.setup(log_to_stderr=False, path="log.txt", level="info")
    
    model = get_udt_model()
    model = dist.prepare_model(model)

    metrics = model.train_distributed(
        filename=os.path.join(config["curr_dir"], train_filenames[train.get_context().get_world_rank()]),
        learning_rate=0.02,
        epochs=1,
        batch_size=256,
        metrics=["categorical_accuracy"],
        verbose=True,
    )

    train.report(
        metrics=metrics,
        checkpoint=dist.UDTCheckPoint.from_model(model),
    )



## Distributed Training

We will now train a UDT model in distributed data parallel fashion. Feel free to customize the number of epochs and the learning rate; we have chosen values that give good convergence. 

In [None]:
import os
from ray.train.torch import TorchConfig

trainer = dist.BoltTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={
        "curr_dir": os.path.abspath(os.getcwd()),
    },
    scaling_config=scaling_config,
    backend_config=TorchConfig(backend="gloo"),
)

result_checkpoint_and_history = trainer.fit()


# Evaluation
Evaluating the performance of the UDT model is just two line!

In [None]:
model = dist.UDTCheckPoint.get_model(result_checkpoint_and_history.checkpoint)
model.evaluate(test_filename, metrics=["categorical_accuracy"])

# Ray Cluster Teardown

In [None]:
ray.shutdown()