# Part #1: How to use TART for Inference

In [1]:
%load_ext autoreload
%reload_ext autoreload

In [2]:
%autoreload 2

In [3]:
import os
import sys

import torch

sys.path.append(f'{os.path.dirname(os.path.dirname(os.getcwd()))}')
import warnings
import yaml

from tart.tart_modules import Tart
from tart.registry import DATASET_REGISTRY

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module='tqdm')

### !! Running this notebook with a pre-trained TART head !!
* Download [pre-trained TART Reasoning module](https://github.com/HazyResearch/TART/releases/download/reasoning_module/tart_heads.zip) (you will need this if you want to run the sample notebooks in `src/notebooks`)

In [4]:
#! wget https://github.com/HazyResearch/TART/releases/download/reasoning_module/tart_heads.zip

In [5]:
#### CUSTOMIZE AS NEEDED ####
path_tart_weights = '/u/scr/nlp/data/ic-fluffy-head-k-2/3e9724ed-5a49-4070-9b7d-4209a30e2392'
cache_dir = '/u/scr/nlp/data/neo/hub'
path_tart_config = 'tart_conf.yaml' # if you are using the pre-trained module above, don't change this!
data_dir_path = None

### Step #1: Set-up TART

In [27]:
BASE_EMBED_MODEL = "EleutherAI/gpt-neo-125m"
PATH_TO_TART_REASONING_HEAD = f"{path_tart_weights}/model_24000.pt"
TART_CONFIG = yaml.load(open(path_tart_config, "r"), Loader=yaml.FullLoader)
CACHE_DIR = cache_dir
DOMAIN = "text"


In [7]:
#### Instantiate TART module ####
t = Tart(
    embed_model_name=BASE_EMBED_MODEL,
    path_to_pretrained_head=PATH_TO_TART_REASONING_HEAD,
    tart_head_config=TART_CONFIG,
    domain=DOMAIN
)

In [8]:
#### Load TartReasoningHead ####
t._load_tart_head(PATH_TO_TART_REASONING_HEAD, TART_CONFIG)

In [9]:
#### Set TART LLM embed model ####
#### Note: we can use any LLM here! TART is LLM-agnostic ####

BASE_EMBED_MODEL = "EleutherAI/gpt-neo-125m"
t.set_embed_model(BASE_EMBED_MODEL)

loading embed model: EleutherAI/gpt-neo-125m ...


### Step #2: Load in sample data...
* For this purposes of this demo, we will use 64 in-context examples, and evaluate on 4 test samples


In [35]:
#### RUN THE FOLLOWING CELLS ####
DATASET_NAME = "sms_spam" 
TOTAL_IN_CONTEXT_EXAMPLES = 64
seed = 42
k_range = [TOTAL_IN_CONTEXT_EXAMPLES] # number of samples  to use as in-context examples
max_eval_samples = 4 # number of samples to evaluate on

In [29]:
dataset = DATASET_REGISTRY[DOMAIN][DATASET_NAME](
    total_train_samples=TOTAL_IN_CONTEXT_EXAMPLES,
    k_range=k_range,
    seed=seed,
    cache_dir=CACHE_DIR,
    max_eval_samples=max_eval_samples,
)

X_ice, y_ice, X_test, y_test = dataset.get_dataset


Found cached dataset sms_spam (/u/scr/nlp/data/neo/hub/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c)


  0%|          | 0/1 [00:00<?, ?it/s]

In [30]:
### RUN THE FOLLOWING CELLS TO INSPECT DATA ####
print(f"In-context example input: {X_ice[-1].strip()}")
print(f"In-context example label: {y_ice[-1]}")

In-context example input: Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
In-context example label: 1


In [31]:
### RUN THE FOLLOWING CELLS TO INSPECT DATA ####

print(f"Test sample input: {X_test[0].strip()}")
print(f"Test sample label: {y_test[0]}")

Test sample input: Ha! I wouldn't say that I just didn't read anything into way u seemed. I don't like 2 be judgemental....i save that for fridays in the pub!
Test sample label: 0


### Step #3: Evaluate!

Step 1: Embed the in-context examples and test samples.


In [32]:
(
    X_ice_embed,
    X_test_embed,
    y_ice_embed,
    y_test_embed
) = t.embed_layer.embed(
    X_test, X_ice, y_ice, y_test, k=k_range[0], seed=seed, text_threshold=1000
)


Step 2: Concatenate embeddings
* Returns a sequence of embeddings which is the concatenation of the embeddings of in-context examples with test sample.

In [33]:
eval_seqs = t._concatenate_inputs(X_ice_embed, y_ice_embed, X_test_embed, y_test_embed)

Step 3: Predict
* Pass the concatenated sequence of embeddings from Step 2 to the TART reasoning module to generate prediction.

In [34]:
print(f"Task: SMS Spam Classification\n\n")
for i, eval_seq in enumerate(eval_seqs):
    pred = t.predict(eval_seq)
    
    print(f"Input: {X_test[i].strip()}")
    print(f"Ground Truth Label: {int(y_test_embed[i])}") 
    print(f"TART Predicted Label: {pred}\n\n")

Task: SMS Spam Classification


Input: Ha! I wouldn't say that I just didn't read anything into way u seemed. I don't like 2 be judgemental....i save that for fridays in the pub!
Ground Truth Label: 0
TART Predicted Label: 0


Input: K go and sleep well. Take rest:-).
Ground Truth Label: 0
TART Predicted Label: 0


Input: Your next amazing xxx PICSFREE1 video will be sent to you enjoy! If one vid is not enough for 2day text back the keyword PICSFREE1 to get the next video.
Ground Truth Label: 1
TART Predicted Label: 1


Input: Had your mobile 11mths ? Update for FREE to Oranges latest colour camera mobiles & unlimited weekend calls. Call Mobile Upd8 on freefone 08000839402 or 2StopTx
Ground Truth Label: 1
TART Predicted Label: 1


