# LLM Inference with Weave

## Weave Official Example

In [None]:
import json
import asyncio
import weave
from weave.flow.scorer import MultiTaskBinaryClassificationF1
import openai

# We create a model class with one predict function.
# All inputs, predictions and parameters are automatically captured for easy inspection.

class ExtractFruitsModel(weave.Model):
    model_name: str
    prompt_template: str

    @weave.op()
    async def predict(self, sentence: str) -> dict:
        client = openai.AsyncClient()

        response = await client.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "user", "content": self.prompt_template.format(sentence=sentence)}
            ],
            response_format={ "type": "json_object" }
        )
        result = response.choices[0].message.content
        if result is None:
            raise ValueError("No response from model")
        parsed = json.loads(result)
        return parsed

# We call init to begin capturing data in the project, intro-example.
weave.init('intro-example')

# We create our model with our system prompt.
model = ExtractFruitsModel(name='gpt4',
                           model_name='gpt-4-0125-preview',
                           prompt_template='Extract fields ("fruit": <str>, "color": <str>, "flavor") from the following text, as json: {sentence}')
sentences = ["There are many fruits that were found on the recently discovered planet Goocrux. There are neoskizzles that grow there, which are purple and taste like candy.",
"Pounits are a bright green color and are more savory than sweet.",
"Finally, there are fruits called glowls, which have a very sour and bitter taste which is acidic and caustic, and a pale orange tinge to them."]
labels = [
    {'fruit': 'neoskizzles', 'color': 'purple', 'flavor': 'candy'},
    {'fruit': 'pounits', 'color': 'bright green', 'flavor': 'savory'},
    {'fruit': 'glowls', 'color': 'pale orange', 'flavor': 'sour and bitter'}
]
examples = [
    {'id': '0', 'sentence': sentences[0], 'target': labels[0]},
    {'id': '1', 'sentence': sentences[1], 'target': labels[1]},
    {'id': '2', 'sentence': sentences[2], 'target': labels[2]}
]
# If you have already published the Dataset, you can run:
# dataset = weave.ref('example_labels').get()

# We define a scoring functions to compare our model predictions with a ground truth label.
@weave.op()
def fruit_name_score(target: dict, model_output: dict) -> dict:
    return {'correct': target['fruit'] == model_output['fruit']}

# Finally, we run an evaluation of this model.
# This will generate a prediction for each input example, and then score it with each scoring function.
evaluation = weave.Evaluation(
    name='fruit_eval',
    dataset=examples, scorers=[MultiTaskBinaryClassificationF1(class_names=["fruit", "color", "flavor"]), fruit_name_score],
)
print(asyncio.run(evaluation.evaluate(model)))
# if you're in a Jupyter Notebook, run:
# await evaluation.evaluate(model)

## Twitter Dataset

In [3]:
twitter_dataset = load_dataset("MAdAiLab/twitter_disaster")
twitter_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8700
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1088
    })
})

In [4]:
examples = twitter_dataset['train'].to_list()[:5]
examples

[{'text': '@sabcnewsroom sabotage!I rule out structural failure', 'label': 0},
 {'text': 'Two giant cranes holding a bridge collapse into nearby homes http://t.co/UmANaaHwMI',
  'label': 1},
 {'text': '@yeetrpan I asked if they were hiring and they said not you I was devastated.',
  'label': 0},
 {'text': 'Watch This Airport Get Swallowed Up By A Sandstorm In Under A Minute http://t.co/7IJlZ6BcSP',
  'label': 1},
 {'text': 'Survived my first #tubestrike thanks to @Citymapper', 'label': 0}]

## Intergrating Weave

In [1]:
import os
from datasets import load_dataset, DatasetDict
from dotenv import load_dotenv, find_dotenv
import weave
# import asyncio
from sklearn.metrics import accuracy_score

load_dotenv(find_dotenv())

HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

from vllm import LLM, SamplingParams
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

weave.init('seq-clf-vllm-inference')

Logged in as Weights & Biases user: akshat_patil.
View Weave data at https://wandb.ai/madailab/seq-clf-vllm-inference/weave




In [2]:
twitter_dataset = load_dataset("MAdAiLab/twitter_disaster")
twitter_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8700
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1088
    })
})

In [3]:
llm = LLM(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    tensor_parallel_size=2,
    trust_remote_code=True,
    enforce_eager=True,
    gpu_memory_utilization=0.99,
    enable_prefix_caching=True
)


2024-05-13 00:07:56,225	INFO worker.py:1749 -- Started a local Ray instance.


INFO 05-13 00:07:56 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=meta-llama/Meta-Llama-3-8B-Instruct)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 05-13 00:07:59 utils.py:660] Found nccl from library /home/harpreet_guest2/.config/vllm/nccl/cu12/libnccl.so.2.18.1
[36m(RayWorkerWrapper pid=425395)[0m INFO 05-13 00:07:59 utils.py:660] Found nccl from library /home/harpreet_guest2/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-13 00:08:00 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
INFO 05-13 00:08:00 selector.py:32] Using XFormers backend.
[36m(RayWorkerWrapper pid=425395)[0m INFO 05-13 00:08:00 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
[36m(RayWorkerWrapper pid=425395)[0m INFO 05-13 00:08:00 selector.py:32] Using XFormers backend.
INFO 05-13 00:08:01 pynccl_utils.py:43] vLLM is using nccl==2.18.1
[36m(RayWorkerWrapper pid=425395)[0m INFO 05-13 00:08:01 pynccl_utils.py:43] vLLM is using nccl==2.18.1
INFO 05-13 00:08:02 utils.py:1

In [4]:
examples = twitter_dataset['train'].to_list()[:5]
# example_prompt = PromptTemplate(
#     input_variables=["text", "label"], template="Text: {text} \nClassification Label: {label}"
# )
# prompt = FewShotPromptTemplate(
#     examples=examples,
#     example_prompt=example_prompt,
#     suffix="Text: {text} \nClassification Label: ",
#     input_variables=["text"],
# )


In [5]:
examples[0]

{'text': '@sabcnewsroom sabotage!I rule out structural failure', 'label': 0}

In [6]:
prompt="""Text: @sabcnewsroom sabotage!I rule out structural failure 
Classification Label: 0

Text: Two giant cranes holding a bridge collapse into nearby homes http://t.co/UmANaaHwMI 
Classification Label: 1

Text: @yeetrpan I asked if they were hiring and they said not you I was devastated. 
Classification Label: 0

Text: Watch This Airport Get Swallowed Up By A Sandstorm In Under A Minute http://t.co/7IJlZ6BcSP 
Classification Label: 1

Text: Survived my first #tubestrike thanks to @Citymapper 
Classification Label: 0

Text: {text} 
Classification Label: 
"""

In [14]:
prompt="""
Given the following tweet:

"{text}"

0: negative
1: positive

What is your answer? Please respond with 0 or 1.

Answer: 
"""

In [7]:
class SequenceClassificationModel(weave.Model):
    model_name: str
    prompt_template: str
    llm: LLM

    @weave.op()
    def predict(self, texts: list[str]) -> list[int]:
        sampling_params = SamplingParams(temperature=0, max_tokens=1)
        prompt_texts = [self.prompt_template.format(text=text) for text in texts]
        outputs = self.llm.generate(prompt_texts, sampling_params)
        predicted_labels = []
        for output in outputs:
            try:
                predicted_labels.append(int(output.outputs[0].text))
            except ValueError:
                predicted_labels.append(-1)
        return predicted_labels

In [15]:
model = SequenceClassificationModel(
    name='twitter-zero-shot-classification',
    model_name='meta-llama/Meta-Llama-3-70B-Instruct',
    prompt_template=prompt,
    llm=llm
)

In [9]:
test_inputs = twitter_dataset['test']
test_examples = [
    {'id': str(i), 'text': text, 'label': label}
    for i, (text, label) in enumerate(zip(test_inputs['text'], test_inputs['label']))
]


In [10]:
test_examples[5]['text']

'Providence Health &amp; Services: Emergency Services Supervisor - Emergency Department... (#Kodiak AK) http://t.co/AQcSUSqbDy #Healthcare #Job'

In [11]:
test_examples[:5]

[{'id': '0',
  'text': 'Heavy Rainfall and Flooding in Northern #VietNam | Situation Report No.2 http://t.co/hVxu1Zcvau http://t.co/iJmCCMHh5G',
  'label': 1},
 {'id': '1',
  'text': 'Bolshevik government monopolized food supply to seize power over hunhry population. Artificial famine was the result https://t.co/0xOUv7DHWz',
  'label': 1},
 {'id': '2',
  'text': 'WHELEN MODEL 295SS-100 SIREN AMPLIFIER POLICE EMERGENCY VEHICLE - Full read by eBay http://t.co/UGR6REFZpT http://t.co/eYyUqX4Tbt',
  'label': 0},
 {'id': '3',
  'text': '#Autoinsurance industry clueless on driverless cars : #healthinsurance http://t.co/YdEtWgRibk',
  'label': 1},
 {'id': '4',
  'text': 'Gunmen kill four in El Salvador bus attack: Suspected Salvadoran gang members killed four people and wounded s... http://t.co/r8k6rXw6D6',
  'label': 1}]

In [16]:
@weave.op()
def evaluate_model(model: SequenceClassificationModel, test_examples: list) -> dict:
    texts = [ex['text'] for ex in test_examples]
    y_true = [ex['label'] for ex in test_examples]
    y_pred = model.predict(texts)
    valid_indices = [i for i, pred in enumerate(y_pred) if pred != -1]
    accuracy = accuracy_score([y_true[i] for i in valid_indices], [y_pred[i] for i in valid_indices])
    return {'accuracy': accuracy}

results = evaluate_model(model, test_examples)
print("Accuracy:", results['accuracy'])

Processed prompts: 100%|██████████| 1088/1088 [00:05<00:00, 206.86it/s]

🍩 https://wandb.ai/madailab/seq-clf-vllm-inference/r/call/6bfa769e-3f55-4ccd-aa38-227d80741c8a
Accuracy: 0.4071691176470588





In [17]:
results

{'accuracy': 0.4071691176470588}