# A Practical Guide to Running NLP Models: BERT Use Case

This notebook serves as a practical guide to getting started running Natural Language Processing (NLP) models on the E75 and E150 AI accelerator hardware using the TT-BUDA compiler stack.

The tutorial will walk through an example of running the [BERT](https://en.wikipedia.org/wiki/BERT_(language_model)) model on Tenstorrent AI accelerator hardware. The model weights will be directly downloaded from the [HuggingFace library](https://huggingface.co/docs/transformers/model_doc/bert) and executed through the PyBUDA SDK.

**Note on terminology:**

While TT-BUDA is the official Tenstorrent AI/ML compiler stack, PyBUDA is the Python interface for TT-BUDA. TT-BUDA is the core technology; however, PyBUDA allows users to access and utilize TT-BUDA's features directly from Python. This includes directly importing model architectures and weights from PyTorch, TensorFlow, ONNX, and TFLite.

## Guide Overview

In this guide, we will talk through the steps for running the BERT model trained on the [SQuADv1.1](https://rajpurkar.github.io/SQuAD-explorer/explore/1.1/dev/) dataset for the **Question and Answering** task.

You will learn how to import the appropriate libraries, how to download model weights from popular site such as HuggingFace, utilize the PyBUDA API to initiate an inference experiment, and observe the results from running on Tenstorrent hardware.

## Step 1: Import libraries

Make sure that you have an activate Python environment with the latest version of PyBUDA installed.

In [1]:
# Start by importing the pybuda library and modules from HuggingFace's transformers library
import pybuda
from transformers import BertForQuestionAnswering, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Step 2: Download the model weights from HuggingFace

In [2]:
# Load BERT tokenizer and model from HuggingFace for Q&A task
model_ckpt = "bert-large-cased-whole-word-masking-finetuned-squad"
tokenizer = BertTokenizer.from_pretrained(model_ckpt)
model = BertForQuestionAnswering.from_pretrained(model_ckpt)

Some weights of the model checkpoint at bert-large-cased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Step 3: Set example input

In [3]:
# # Load data sample from SQuADv1.1
# context = """Super Bowl 50 was an American football game to determine the champion of the National Football League
# (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the
# National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title.
# The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.
# As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed
# initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals
# (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently
# feature the Arabic numerals 50."""

# question = "Which NFL team represented the AFC at Super Bowl 50?"

context = "Shakespear was a poet who wrote in old english. He used terms like Ye-old lady"
question = "What sort of terms did Shakespear use?"


## Step 4: Data Preprocessing

Data preprocessing is an important step in the AI inference pipeline. For NLP models, we want to make sure that the input text is converted to the appropriate tokenized representation that was used to train the model.

In [4]:
# Data preprocessing
input_tokens = tokenizer(
    question,  # pass question
    context,  # pass context
    max_length=384,  # set the maximum input context length
    padding="max_length",  # pad to max length for fixed input size
    truncation=True,  # truncate to max length
    return_tensors="pt",  # return PyTorch tensor
)

## Step 5: Configure PyBUDA Parameters

There are optional configurations that can be adjusted before compiling and running a model on Tenstorrent hardware. Sometimes, the configurations are necessary to compile the model and other times they are tuneable parameters that can be adjusted for performance enhancement.

For the BERT model, two key parameters are required for compilation:

* `default_df_override`
* `default_dram_parameters`

In [5]:
# Set PyBuda configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.default_df_override = pybuda._C.DataFormat.Float16_b
compiler_cfg.default_dram_parameters = False

## Step 6: Instantiate Tenstorrent device

The first time we use PyBUDA, we must initialize a `TTDevice` object which serves as the abstraction over the target hardware.

In [6]:
tt0 = pybuda.TTDevice(
    name="tt_device_0",  # here we can give our device any name we wish, for tracking purposes
    arch=pybuda.BackendDevice.Grayskull  # we set the target device architecture to compile for
)

## Step 7: Create a PyBUDA module from PyTorch model

Next, we must abstract the PyTorch model loaded from HuggingFace into a `pybuda.PyTorchModule` object. This will let the BUDA compiler know which model architecture and AI framework it has to compile.

We then "place" this module onto the previously initialized `TTDevice`.

In [7]:
# Create module
pybuda_module = pybuda.PyTorchModule(
    name = "pt_bert_question_answering",  # give the module a name, this will be used for tracking purposes
    module=model  # specify the model that is being targeted for compilation
)

# Place module on device
tt0.place_module(module=pybuda_module)

## Step 8: Push the (tokenized) inputs into the model input queue

In [8]:
# Push inputs
tt0.push_to_inputs(input_tokens)

## Step 9: Run inference on the targeted device

Running a model on a Tenstorrent device invovles two parts: compilation and runtime.

Compilation -- TT-BUDA is a compiler. Meaning that it will take a model architecture graph and compile it for the target hardware. Compilation can take anywhere from a few seconds to a few minutes, depending on the model. This only needs to happen once. When you execute the following block of code the compilation logs will be displayed.

Runtime -- once the model has been compiled and loaded onto the device, the user can push new inputs which will execute immediately.

The `run_inference` API can achieve both steps in a single call. If it's the first call, the model will compile. Any subsequent calls will execute runtime only.

Please refer to the documentation for alternative APIs such as `initialize_pipeline` and `run_forward`.

In [9]:
# Run inference on Tenstorrent device
output_q = pybuda.run_inference()  # executes compilation (if first time) + runtime
output = output_q.get()  # get last value from output queue



[32m2024-03-06 14:08:29.740[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - initialize_child_process called on pid 521050


  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
2024-03-06 14:09:01.247 | INFO     | tvm.relay.op.contrib.buda.buda:visit_call:817 - Adding: embedding to fallback
2024-03-06 14:09:01.248 | INFO     | tvm.relay.op.contrib.buda.buda:visit_call:817 - Adding: embedding to fallback
2024-03-06 14:09:01.250 | INFO     | tvm.relay.op.contrib.buda.buda:visit_call:817 - Adding: embedding to fallback
2024-03-06 14:09:14.504 | INFO     | tvm.relay.op.contrib.buda.buda:_cpu_eval:562 - cast will be executed on CPU
2024-03-06 14:09:14.506 | INFO     | tvm.relay.op.contrib.buda.buda:_cpu_eval:562 - cast will be executed on CPU
2024-03-06 14:09:14.507 | INFO     | tvm.relay.op.contrib.buda.buda:_cpu_eval:562 - add will be executed on CPU
2024-03-06 14:09:14.507 | INFO     | tvm.relay.op.contrib.buda.buda:_cpu_eval:562 - strided_slice will be executed on CPU
2024-03-06 14:09:14.508 | INFO     | tvm.relay.op.contrib.buda.buda:_cpu_eval:562 - cas

[32m2024-03-06 14:09:27.613[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - initialize_child_process called on pid 521357


2024-03-06 14:09:27.819 | INFO     | pybuda.device_connector:pusher_thread_main:144 - Pusher thread on <pybuda.device_connector.DirectPusherDeviceConnector object at 0x7f102c1e68e0> starting
2024-03-06 14:09:27.821 | DEBUG    | pybuda.device:run_next_command:452 - Received COMPILE command on TTDevice 'tt_device_0' / 521367
2024-03-06 14:09:27.821 | DEBUG    | pybuda.ttdevice:compile_for:770 - Compiling for Inference mode on TTDevice 'tt_device_0'
2024-03-06 14:09:27.821 | INFO     | pybuda.ci:initialize_output_build_directory:94 - Pybuda output build directory for compiled artifacts: /tmp/jonathan/c56c8a3963a1
2024-03-06 14:09:27.822 | INFO     | pybuda.ci:create_symlink:85 - Symlink created from /home/jonathan/Desktop/tenstorrent/tt-buda-demos/first_5_steps/tt_build/test_out to /tmp/jonathan/c56c8a3963a1
2024-03-06 14:09:27.882 | INFO     | pybuda.compile:pybuda_compile:220 - Device grid size: r = 10, c = 12
2024-03-06 14:09:27.882 | INFO     | pybuda.compile:pybuda_compile:230 - Usin

[32m2024-03-06 14:09:27.819[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - initialize_child_process called on pid 521367
[32m2024-03-06 14:09:27.836[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2024-03-06 14:09:27.842[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)


2024-03-06 14:09:29.094 | INFO     | pybuda.compile:pybuda_compile:319 - Running post initial graph pass
2024-03-06 14:09:30.127 | INFO     | pybuda.compile:pybuda_compile:391 - Running post autograd graph pass
2024-03-06 14:09:30.485 | INFO     | pybuda.compile:pybuda_compile:424 - Lowering to Buda


[32m2024-03-06 14:09:30.903[0m | [1m[38;2;255;069;000mINFO    [0m | [36mGraphCompiler  [0m - Running with Automatic Mixed Precision Level = 0.
[32m2024-03-06 14:09:31.023[0m | [1m[38;2;255;069;000mINFO    [0m | [36mAlways         [0m - Running Balancer with Policy: PolicyType::NLP
[32m2024-03-06 14:09:32.786[0m | [1m[38;2;255;069;000mINFO    [0m | [36mAlways         [0m - Running Balancer with Policy: PolicyType::NLP
[32m2024-03-06 14:09:36.332[0m | [1m[38;2;255;069;000mINFO    [0m | [36mBalancer       [0m - Based on NLP matmul analysis, target cycle count is set to 125000
[32m2024-03-06 14:09:37.000[0m | [1m[38;2;255;069;000mINFO    [0m | [36mBalancer       [0m - Balancing 3% complete.
[32m2024-03-06 14:09:37.521[0m | [1m[38;2;255;069;000mINFO    [0m | [36mBalancer       [0m - Balancing 6% complete.
[32m2024-03-06 14:09:37.601[0m | [1m[38;2;255;069;000mINFO    [0m | [36mBalancer       [0m - Balancing 7% complete.
[32m2024-03-06 14:09:3

2024-03-06 14:09:53.941 | INFO     | pybuda.compile:pybuda_compile:626 - Generating Netlist
2024-03-06 14:09:54.667 | INFO     | pybuda.ci:create_symlink:85 - Symlink created from /home/jonathan/Desktop/tenstorrent/tt-buda-demos/first_5_steps/pt_bert_question_answering_tt_1_netlist.yaml to /tmp/jonathan/c56c8a3963a1/pt_bert_question_answering_tt_1_netlist.yaml
2024-03-06 14:09:59.395 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - ConstEval graph: input_1_multiply_18
2024-03-06 14:09:59.395 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - ConstEval graph: input_0_subtract_21
2024-03-06 14:09:59.395 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - ConstEval graph: input_1_multiply_22
2024-03-06 14:09:59.395 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - ConstEval graph: input_1_multiply_75
2024-03-06 14:09:59.396 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - ConstEval graph: input_1_multiply_128
2024-03-06 14:09:59.396 | DEBUG    | pybuda.tensor:consteval_tensor:1177 - C

[32m2024-03-06 14:10:00.021[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - Running tt_runtime on host: 'benderv2'
[32m2024-03-06 14:10:00.021[0m | [1m[38;2;100;149;237mINFO    [0m | [36mPerfInfra      [0m - Backend profiler is disabled
[32m2024-03-06 14:10:00.021[0m | [1m[38;2;100;149;237mINFO    [0m | [36mNetlist        [0m - Parsing Netlist from file: /tmp/jonathan/c56c8a3963a1/pt_bert_question_answering_tt_1_netlist.yaml
[32m2024-03-06 14:10:00.549[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2024-03-06 14:10:00.551[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)
[32m2024-03-06 14:10:02.493[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - Using Default BRISC Bin
[32m2024-03-06 14:10:02.493[0m | [1m[38;2;100;149;237mINFO    [0m | [36mCompile

2024-03-06 14:11:23.856 | INFO     | pybuda.backend:feeder_thread_main:120 - Feeder thread on <pybuda.backend.BackendAPI object at 0x7f1100fd7e80> starting
2024-03-06 14:11:23.857 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant lc.input_tensor.layernorm_0.dc.reduce_sum.0.0
2024-03-06 14:11:23.857 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant dc.input_tensor.layernorm_0.1
2024-03-06 14:11:23.858 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant lc.input_tensor.layernorm_0.dc.reduce_sum.5.0
2024-03-06 14:11:23.858 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant dc.input_tensor.layernorm_0.6
2024-03-06 14:11:23.858 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant dc.input_tensor.layernorm_0.8
2024-03-06 14:11:23.858 | DEBUG    | pybuda.backend:push_constants_and_parameters:435 - Pushing to constant input_1_multip

[32m2024-03-06 14:11:24.349[0m | [1m[38;2;100;149;237mINFO    [0m | [36mNetlist        [0m - Parsing Netlist from file: /tmp/jonathan/c56c8a3963a1/pt_bert_question_answering_tt_1_netlist.yaml
[32m2024-03-06 14:11:24.823[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2024-03-06 14:11:24.865[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)
[32m2024-03-06 14:11:24.910[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Disable PCIE DMA
[32m2024-03-06 14:11:24.912[0m | [1m[38;2;100;149;237mINFO    [0m | [36mNetlist        [0m - Parsing Netlist from file: /tmp/jonathan/c56c8a3963a1/pt_bert_question_answering_tt_1_netlist.yaml
[32m2024-03-06 14:11:24.916[0m | [1m[38;2;100;149;237mINFO    [0m | [36mRuntime        [0m - Running program 'run_fwd_0' with params [("$p_loop_count", "1

2024-03-06 14:11:24.915 | DEBUG    | pybuda.device:run_next_command:426 - Received RUN_FORWARD command on TTDevice 'tt_device_0' / 521367
2024-03-06 14:11:24.915 | DEBUG    | pybuda.ttdevice:forward:862 - Starting forward on TTDevice 'tt_device_0'
2024-03-06 14:11:24.916 | DEBUG    | pybuda.backend:feeder_thread_main:142 - Run feeder thread cmd: fwd
2024-03-06 14:11:24.916 | DEBUG    | pybuda.backend:read_queues:316 - Reading output queue pt_bert_question_answering_tt_1.output_reshape_1285


[32m2024-03-06 14:11:25.488[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2024-03-06 14:11:25.520[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)
[32m2024-03-06 14:11:25.537[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Disable PCIE DMA


2024-03-06 14:11:25.537 | DEBUG    | pybuda.device:run_next_command:426 - Received RUN_FORWARD command on CPUDevice 'cpu0_fallback' / 521357
2024-03-06 14:11:25.537 | DEBUG    | pybuda.cpudevice:forward_pt:191 - Starting forward on CPUDevice 'cpu0_fallback'
2024-03-06 14:11:25.540 | DEBUG    | pybuda.cpudevice:forward_pt:265 - Ending forward on CPUDevice 'cpu0_fallback'
2024-03-06 14:11:25.540 | DEBUG    | pybuda.device_connector:pusher_thread_main:159 - Pusher thread pushing tensors
2024-03-06 14:11:25.541 | DEBUG    | pybuda.backend:push_to_queues:407 - Pushing to queue pybuda_6_i0
2024-03-06 14:11:25.542 | DEBUG    | pybuda.backend:push_to_queues:407 - Pushing to queue attention_mask_1
2024-03-06 14:11:25.808 | DEBUG    | pybuda.backend:read_queues:316 - Reading output queue pt_bert_question_answering_tt_1.output_reshape_1292
2024-03-06 14:11:25.808 | DEBUG    | pybuda.backend:read_queues:376 - Done reading queues
2024-03-06 14:11:25.808 | DEBUG    | pybuda.backend:pop_queues:382 - 

## Step 10: Data Postprocessing

Data postprocessing is done to convert the model outputs into a readable / useful format. For NLP models, this usually means receiving the logit outputs from the model, extracting the top matching tokens, and then decoding the tokens into text.

In [10]:
# Data postprocessing
answer_start = output[0].value().argmax().item()
answer_end = output[1].value().argmax().item()
answer = tokenizer.decode(input_tokens["input_ids"][0, answer_start : answer_end + 1])

## Step 11: Print and evaluate outputs

In [11]:
# Print outputs
print(f"Input context:\n{context}")
print(f"\nInput question:\n{question}")
print(f"\nOutput from model running on TTDevice:\n{answer}")

Input context:
Shakespear was a poet who wrote in old english. He used terms like Ye-old lady

Input question:
What sort of terms did Shakespear use?

Output from model running on TTDevice:
terms like Ye - old lady


## Step 12: Shutdown PyBuda

In [12]:
pybuda.shutdown()

2024-03-06 14:11:25.838 | DEBUG    | pybuda.run.impl:_shutdown:1262 - PyBuda shutdown
2024-03-06 14:11:25.839 | DEBUG    | pybuda.device:run_next_command:416 - Received SHUTDOWN command on CPUDevice 'cpu0_fallback'
2024-03-06 14:11:25.839 | DEBUG    | pybuda.device:run_next_command:416 - Received SHUTDOWN command on TTDevice 'tt_device_0'
2024-03-06 14:11:25.839 | DEBUG    | pybuda.device:run_next_command:419 - Waiting for barrier on CPUDevice 'cpu0_fallback'
2024-03-06 14:11:25.839 | DEBUG    | pybuda.device:run_next_command:419 - Waiting for barrier on TTDevice 'tt_device_0'
2024-03-06 14:11:25.840 | DEBUG    | pybuda.run.impl:_shutdown:1278 - Waiting until processes done
2024-03-06 14:11:25.840 | DEBUG    | pybuda.device:run_next_command:421 - Shutting down on CPUDevice 'cpu0_fallback'
2024-03-06 14:11:25.840 | DEBUG    | pybuda.device:run_next_command:421 - Shutting down on TTDevice 'tt_device_0'
2024-03-06 14:11:25.869 | DEBUG    | pybuda.device:atexit_handler:919 - atexit handler

[32m2024-03-06 14:11:25.840[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - finish_child_process called on pid 521357
[32m2024-03-06 14:11:25.870[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - finish_child_process called on pid 521357
[32m2024-03-06 14:11:25.960[0m | [1m[38;2;100;149;237mINFO    [0m | [36mRuntime        [0m - Waiting for cluster completion
[32m2024-03-06 14:11:25.961[0m | [1m[38;2;100;149;237mINFO    [0m | [36mPerfPostProcess[0m - Writing the host postprocess report in /tmp/jonathan/c56c8a3963a1/perf_results//host/device_alignment_th_356742039_proc_521367.json
[32m2024-03-06 14:11:26.057[0m | [1m[38;2;100;149;237mINFO    [0m | [36mRuntime        [0m - Closed all devices successfully
[32m2024-03-06 14:11:26.057[0m | [1m[38;2;100;149;237mINFO    [0m | [36mPerfCheck      [0m - Starting performance check for host events
[32m2024-03-06 14:11:26.057[0m | [1m[38;2;100;149;237mINFO    [0m | [36m

2024-03-06 14:11:26.136 | DEBUG    | pybuda.device:atexit_handler:919 - atexit handler called for (TTDevice 'tt_device_0',)
2024-03-06 14:11:26.143 | DEBUG    | pybuda.device:atexit_handler:923 - atexit handler completed


[32m2024-03-06 14:11:27.028[0m | [1m[38;2;100;149;237mINFO    [0m | [36mAlways         [0m - finish_child_process called on pid 521050
