# LLM Inference  - Text Classification using Zero Shot Prompting

## Setup Environment

In [7]:
import os
from datasets import load_dataset,DatasetDict
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")

from vllm import LLM, SamplingParams

## Instantiate a LLM 

In [3]:
llm = LLM(
        model="meta-llama/Meta-Llama-3-70B-Instruct",
        tensor_parallel_size=4,
        trust_remote_code=True,
        enforce_eager=True,
        gpu_memory_utilization=0.99,
        enable_prefix_caching=True
)

2024-04-29 22:56:32,630	INFO worker.py:1749 -- Started a local Ray instance.


INFO 04-29 22:56:36 llm_engine.py:98] Initializing an LLM engine (v0.4.1) with config: model='meta-llama/Meta-Llama-3-70B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-70B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=4, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 04-29 22:56:45 utils.py:608] Found nccl from library /home/u.ap164907/.config/vllm/nccl/cu12/libnccl.so.2.18.1
[36m(RayWorkerWrapper pid=2577450)[0m INFO 04-29 22:56:45 utils.py:608] Found nccl from library /home/u.ap164907/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 04-29 22:56:49 selector.py:28] Using FlashAttention backend.
[36m(RayWorkerWrapper pid=2577450)[0m INFO 04-29 22:56:49 selector.py:28] Using FlashAttention backend.
[36m(RayWorkerWrapper pid=2577697)[0m INFO 04-29 22:56:45 utils.py:608] Found nccl from library /home/u.ap164907/.config/vllm/nccl/cu12/libnccl.so.2.18.1[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
INFO 04-29 22:56:50 pynccl_utils.py:43] vLLM is using nccl==2.18.1
[36m(RayWorkerWrapper pid=2577450)[0m INFO 04-29 22:56:50 pynccl_utils.py:43] vL

## Common function for all dataset

In [8]:
def zero_shot_classification(dataset_name, prefix):
    # Load the dataset
    dataset = load_dataset(dataset_name)

    # Iterate over the dataset splits (train, test, validation)
    modified_dataset_dict = {}
    for split in ["train", "test", "validation"]:
        # Get the texts and labels from the current split
        texts = dataset[split]["text"]
        labels = dataset[split]["label"]

        # Generate the prompts for each Text
        generating_prompts = [prefix + "Text: " + text + "\nResponse: " for text in texts]

        # Set the sampling parameters
        sampling_params = SamplingParams(temperature=0, max_tokens=1)

        # Generate the sentiment labels for each text
        outputs = llm.generate(generating_prompts, sampling_params)
        predicted_label = []
        for output in outputs:
            try:
                predicted_label.append(int(output.outputs[0].text))
            except ValueError:
                predicted_label.append(-1)

        # Add the predicted labels to the dataset
        modified_dataset = dataset[split].add_column("predicted_label", predicted_label)
        modified_dataset_dict[split] = modified_dataset

    # Create a DatasetDict with the modified datasets
    return DatasetDict(modified_dataset_dict)

## Twitter Dataset

In [4]:
prefix = """
You are an expert in sentiment analysis, with a deep understanding of natural language and human emotions.
Your task is to analyze the sentiment of the given text and classify it as either positive or negative.
When analyzing the sentiment, consider the overall tone, word choice, and emotional connotations within the text.
Positive sentiment typically conveys happiness, joy, approval, or praise, while negative sentiment expresses sadness, anger, criticism, or disappointment.
Provide your analysis in a concise and definitive manner, outputting either the number '1' if positive or '0' if negative based on your assessment of the sentiment expressed in the text.
Do not provide any additional commentary or explanation beyond the sentiment classification itself.
"""

In [9]:
twitter_modified = zero_shot_classification(
    dataset_name="MAdAiLab/twitter_disaster",
    prefix=prefix
)

Processed prompts: 100%|██████████| 8700/8700 [02:22<00:00, 61.06it/s]
Processed prompts: 100%|██████████| 1088/1088 [00:17<00:00, 63.13it/s]
Processed prompts: 100%|██████████| 1088/1088 [00:17<00:00, 61.76it/s]


In [10]:
twitter_modified

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'predicted_label'],
        num_rows: 8700
    })
    test: Dataset({
        features: ['text', 'label', 'predicted_label'],
        num_rows: 1088
    })
    validation: Dataset({
        features: ['text', 'label', 'predicted_label'],
        num_rows: 1088
    })
})

In [11]:
twitter_modified.save_to_disk("./output/twitter_predicted")

Saving the dataset (0/1 shards):   0%|          | 0/8700 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1088 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1088 [00:00<?, ? examples/s]

## Patent Classification Dataset

In [13]:
patent_dataset = load_dataset("MAdAiLab/patent_classification")

Downloading readme:   0%|          | 0.00/937 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.61M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.74M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [29]:
patent_dataset['train'][:5]['label']

[6, 0, 7, 0, 8]

In [32]:
prefix = """
You are an expert in patent classification, with a deep understanding of technical domains and patent categorization.
Your task is to analyze the given patent abstract text and classify it into one of the 9 categories:
'0': Human Necessities
'1': Performing Operations; Transporting
'2': Chemistry; Metallurgy
'3': Textiles; Paper
'4': Fixed Constructions
'5': Mechanical Engineering; Lightning; Heating; Weapons; Blasting
'6': Physics
'7': Electricity
'8': General tagging of new or cross-sectional technology
When analyzing the patent, consider the technical field, invention type, and application area described in the text.
Provide your classification in a concise and definitive manner, outputting the corresponding class label (0-8) based on your assessment of the patent's category.
Do not provide any additional commentary or explanation beyond the classification itself.
"""

In [37]:
prompts = patent_dataset['train'][:5]['text']
labels = patent_dataset['train'][:5]['label']

sampling_params = SamplingParams(temperature=0, max_tokens=1)

generating_prompts = [prefix + "Text: " + prompt + "\nResponse: " for prompt in prompts]

outputs = llm.generate(generating_prompts, sampling_params)

# Print the outputs
for i, output in enumerate(outputs, start=1):
    prompt = output.prompt
    generated_text = output.outputs[0].text

    print(f"Example {i}:")
    # print(f"Prefix: {prefix.strip()}")
    print(f"Text: {prompt.split('Text: ')[-1].strip()} {generated_text.strip()}")
    print(f"Actual label: {labels[i-1]}")
    # print(f"{generated_text.strip()}")
    print("-" * 50 + "\n")

Processed prompts: 100%|██████████| 5/5 [00:00<00:00, 40.07it/s]

Example 1:
Tweet: an apparatus for simultaneously testing multiple integrated circuits includes a sensing circuit associated with each of the tested circuits . each sensing circuit includes a differential amplifier with its positive input connected to the input of the test circuit , and its inversion input connected to the test circuit output . the test circuit input and positive amplifier input are biased to a selected voltage , and the voltage drop across the test circuit is provided to the amplifier inversion input . whenever the test circuit is open , intermittently open or highly resistive , the voltage drop across the test circuit exceeds the threshold voltage of the differential amplifier , causing the amplifier to generate a high level logic output representing an open circuit condition . the outputs of the various sensing circuits together form a digital word representative of the condition of all of the test circuits . the outputs of the differential amplifiers also are provi




In [39]:
patent_modified = zero_shot_classification(
    dataset_name="MAdAiLab/patent_classification",
    prefix=prefix
)

Processed prompts: 100%|██████████| 25000/25000 [23:26<00:00, 17.77it/s]
Processed prompts: 100%|██████████| 5000/5000 [04:41<00:00, 17.75it/s]
Processed prompts: 100%|██████████| 5000/5000 [04:42<00:00, 17.69it/s]


In [40]:
patent_modified.save_to_disk("./output/patent_predicted")

Saving the dataset (0/1 shards):   0%|          | 0/25000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]

## Scotus dataset

In [41]:
scotus_dataset = load_dataset("MAdAiLab/lex_glue_scotus")

Downloading readme:   0%|          | 0.00/813 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/94.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/40.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/39.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1400 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1400 [00:00<?, ? examples/s]

In [47]:
prefix = """
You are an expert in legal issue area classification, with a deep understanding of the US Supreme Court's opinions and the subject matter of controversies.
Your task is to analyze the given court opinion and classify it into one of the 14 relevant issue areas.
When analyzing the opinion, consider the overall content, legal concepts, and subject matter within the text.
The 14 issue areas are: (1) Criminal Procedure, (2) Civil Rights, (3) First Amendment, (4) Due Process, (5) Privacy, (6) Attorneys, (7) Unions, (8) Economic Activity, (9) Judicial Power, (10) Federalism, (11) Interstate Relations, (12) Federal Taxation, (13) Miscellaneous, and (14) Private Action.
Provide your analysis in a concise and definitive manner, outputting the number corresponding to the relevant issue area based on your assessment of the opinion's content.
"""

In [48]:
prompts = scotus_dataset['train'][:5]['text']
labels = scotus_dataset['train'][:5]['label']

sampling_params = SamplingParams(temperature=0, max_tokens=1)

generating_prompts = [prefix + "Text: " + prompt + "\nResponse: " for prompt in prompts]

outputs = llm.generate(generating_prompts, sampling_params,)

# Print the outputs
for i, output in enumerate(outputs, start=1):
    prompt = output.prompt
    generated_text = output.outputs[0].text

    print(f"Example {i}:")
    # print(f"Prefix: {prefix.strip()}")
    # print(f"Text: {prompt.split('Text: ')[-1].strip()}")
    print(f"Actual label: {labels[i-1]}")
    print(f"Predicted label: {generated_text.strip()}")
    print("-" * 50 + "\n")

Processed prompts:  40%|████      | 2/5 [00:03<00:06,  2.03s/it]



Processed prompts: 100%|██████████| 5/5 [00:06<00:00,  1.33s/it]

Example 1:
Actual label: 7
Predicted label: 8
--------------------------------------------------

Example 2:
Actual label: 7
Predicted label: 8
--------------------------------------------------

Example 3:
Actual label: 0
Predicted label: 4
--------------------------------------------------

Example 4:
Actual label: 1
Predicted label: 
--------------------------------------------------

Example 5:
Actual label: 7
Predicted label: 
--------------------------------------------------




