In [62]:
import pandas as pd
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
from qdrant_client.http import models
from tqdm import tqdm
import openai
import instructor
from pydantic import BaseModel, validator
from typing import Literal
from tenacity import (
    retry,
    stop_after_attempt,
    wait_fixed,
)

In [2]:
# Embedding model to embed the text data
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Spawn a Qdrant client connection in memory
qdrant_client = QdrantClient(":memory:")
qdrant_client.create_collection('text-classification',vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE))

True

In [3]:
# Training CSV File
train_csv = pd.read_csv('data/text-classification-with-llms/train.csv')

In [4]:
train_csv.head()

Unnamed: 0,id,text,label_text
0,1,what the time difference from here to ottawa,datetime_convert
1,2,start robot cleaner,iot
2,3,"Okay, I found my card, can I put it back in th...",card_linking
3,4,My currency rate is inaccurate.,card_payment_wrong_exchange_rate
4,5,It seems that the rate I got is incorrect.,card_payment_wrong_exchange_rate


In [5]:
train_csv.shape

(3996, 3)

In [6]:
pd.set_option('display.max_rows', 500)
pd.DataFrame(train_csv.label_text.value_counts())

Unnamed: 0_level_0,count
label_text,Unnamed: 1_level_1
not toxic,917
iot,195
play,165
play_music,163
news_query,112
news,112
weather_query,108
weather,108
datetime,103
card_payment_wrong_exchange_rate,92


### Uploading Data to a Vector Store

In [7]:
category_labels= train_csv.label_text.unique()

In [8]:
# create a point valid for qdrant
embeddings = embedding_model.encode(train_csv['text'].tolist(), show_progress_bar=True)

Batches:   0%|          | 0/125 [00:00<?, ?it/s]

In [9]:
# Create and upload points to Qdrant
for idx, row in train_csv.iterrows():
    point = models.PointStruct(
        id=idx,  # Use the dataframe index as the point ID
        vector=embeddings[idx].tolist(),  # Convert the embedding to a list
        payload={'label_text': row['label_text'] , "text":row['text']}  # Use the label_text as the payload
    )
    qdrant_client.upload_points(collection_name='text-classification', points=[point])

### Querying the Data from the Vector Store

In [10]:
# Lets Move on to the Querying part
test_csv = pd.read_csv('data/text-classification-with-llms/test.csv')

In [11]:
query_text = test_csv.iloc[4]['text']
print(f"Query Text: {query_text}")
print("---"*80)
query_vector = embedding_model.encode(query_text)
qdrant_client.search(collection_name='text-classification', query_vector=query_vector, limit=5,score_threshold=0.3)

Query Text: do i have any incoming emails
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


[ScoredPoint(id=2596, version=0, score=0.5096157193183899, payload={'label_text': 'alarm', 'text': 'do i have any alarms'}, vector=None, shard_key=None),
 ScoredPoint(id=862, version=0, score=0.5096157193183899, payload={'label_text': 'alarm_query', 'text': 'do i have any alarms'}, vector=None, shard_key=None),
 ScoredPoint(id=3358, version=0, score=0.4228695034980774, payload={'label_text': 'card_arrival', 'text': "I haven't gotten my credit card in the mail."}, vector=None, shard_key=None),
 ScoredPoint(id=3546, version=0, score=0.3998033404350281, payload={'label_text': 'lost_or_stolen_card', 'text': 'Has there been any activity on my card today?'}, vector=None, shard_key=None),
 ScoredPoint(id=2943, version=0, score=0.39195123314857483, payload={'label_text': 'alarm', 'text': 'do i have any alarms set for today'}, vector=None, shard_key=None)]

In [12]:
def qdrant_search(query_text,top_k=5):
    query_vector = embedding_model.encode(query_text)
    search_response = qdrant_client.search(collection_name='text-classification', query_vector=query_vector, limit=top_k)
    return search_response

In [13]:
qdrant_search("do i have any incoming emails")

[ScoredPoint(id=2596, version=0, score=0.5096157193183899, payload={'label_text': 'alarm', 'text': 'do i have any alarms'}, vector=None, shard_key=None),
 ScoredPoint(id=862, version=0, score=0.5096157193183899, payload={'label_text': 'alarm_query', 'text': 'do i have any alarms'}, vector=None, shard_key=None),
 ScoredPoint(id=3358, version=0, score=0.4228695034980774, payload={'label_text': 'card_arrival', 'text': "I haven't gotten my credit card in the mail."}, vector=None, shard_key=None),
 ScoredPoint(id=3546, version=0, score=0.3998033404350281, payload={'label_text': 'lost_or_stolen_card', 'text': 'Has there been any activity on my card today?'}, vector=None, shard_key=None),
 ScoredPoint(id=2943, version=0, score=0.39195123314857483, payload={'label_text': 'alarm', 'text': 'do i have any alarms set for today'}, vector=None, shard_key=None)]

Let's Try to see what queries are semantically similar to the training set. For the purpose of demonstration let's trim the datasize 

In [None]:
# random sample of 1000 rows from the test csv
sample_test_csv = test_csv.sample(1000)

In [65]:
rag_data = []
# Iterate over the test_csv DataFrame
for idx, row in tqdm(sample_test_csv.iterrows(),total=sample_test_csv.shape[0]):
    query_text = row['text']

    # Search in Qdrant for the top match
    # (Note you can also do a batch search since it is more efficient, but for simplicity we are doing a single search here)
    search_results = qdrant_search(query_text)
    # Check if there are any results
    if search_results:
        top_match = search_results[0] # Results are sorted by score, so the top match is the first one
        top_matched_text = top_match.payload['text']
        top_matched_label = top_match.payload['label_text']
        score = top_match.score

        # Append the data to the few_shot_data list
        rag_data.append({
            "Query Text": query_text,
            "Top Matched Text": top_matched_text,
            "Label": top_matched_label,
            "Label Match Score": score,
        })

# Create a DataFrame from the few_shot_data list
few_shot_df = pd.DataFrame(rag_data)

100%|██████████| 1000/1000 [00:39<00:00, 25.47it/s]


In [72]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_colwidth', 500)
few_shot_df

Unnamed: 0,Query Text,Top Matched Text,Label
0,"The treaties are no more outdated than the Constitution of the US is outdated. \n\nThere is no denying that Native Americans struggle with socially destructive behaviors, but so do non-natives. All of the things you mention occur all over America everyday among non-Natives. It just isn't as obvious as it is in a small village. The rates may not be the same for all those issues, but at in at least two - drug and alcohol abuse - white Americans take the prize. A significantly higher perce...","Thank you Ms. Clarkson for this thoughtful column. As the child of immigrants from Scotland after the war we had the advantage of the language and my father had training in the Royal Navy as an engineer. He upgraded at our cities technical college in the evenings (during the day his job was tearing up trolley car tracks for the railway). Like many if not most Canadians, I have sympathy for the plight of the indigenous people. Driving through the core of our city I sometimes see toddlers of a...",not toxic
1,turn on the playlist i have dedicated to rock music,play my playlist,play
2,play that podcast i was listening to yesterday,which was that coldplay song i listened to yesterday evening can you play it again,music
3,contacts please,I have a contactless that's broken.,contactless_not_working
4,"It will never be ""cleaned up"" until the citizenry are willing to acknowledge the uncomfortable truth of the situation.",initiate cleaning process,iot_cleaning
...,...,...,...
995,"Could you help me reactivate my card? It was previously lost, but I found it this morning in my jacket.","I want to reactivate my card, I thought I had lost it but found it again in my jacket this morning.",card_linking
996,Will you reinstate my PIN?,What happens if I forget my PIN?,pin_blocked
997,could you please play the f. m. station which plays pop songs,play the latest pop song,play_music
998,i want to order a pizza from michael's pizza,i want to order some pizza,takeaway


Now Let's come to the ask of classifying the given `Query Text` using RAG on the documents we just inserted to the Qdrant Store

### System Prompt Setup

We'll Add following things to the Prompt
- All the Labels that we are running this classification for.
- Specific Instruction on what we will be passing on to the Prompt

In [66]:
categories_list = "- " + "\n- ".join(category_labels)
system_prompt = f"""
You are an agent that is specialized in classification tasks.\n
Along with the input text, you are provided with the top 10 documents retrieved from a Retrieval-Augmented Generation (RAG) model. 
Use this information to classify the input text into one of the following categories:
{categories_list}
Note: The documents are included in the user's message for context.
"""

In [67]:
print(system_prompt)


You are an agent that is specialized in classification tasks.

Along with the input text, you are provided with the top 10 documents retrieved from a Retrieval-Augmented Generation (RAG) model. 
Use this information to classify the input text into one of the following categories:
- datetime_convert
- iot
- card_linking
- card_payment_wrong_exchange_rate
- news
- play
- weather
- play_music
- exchange_via_app
- not toxic
- iot_wemo_on
- weather_query
- fiat_currency_support
- card_arrival
- iot_hue_lightup
- takeaway_order
- datetime
- toxic
- news_query
- exchange_rate
- card_delivery_estimate
- takeaway_query
- top_up_by_bank_transfer_charge
- music
- audio
- age_limit
- takeaway
- pending_top_up
- music_likeness
- pending_cash_withdrawal
- automatic_top_up
- audio_volume_up
- card_not_working
- iot_cleaning
- iot_hue_lightchange
- alarm_query
- audio_volume_down
- extra_charge_on_statement
- music_settings
- pin_blocked
- datetime_query
- general
- alarm
- music_query
- audio_volume

### Classification Function

Now the idea is to build a classification function that does the following
- Using the given text searches for the top_k documents 
- Uses those document as a reference to make the model contextually aware (can call it few shot) and then predict
- Ask the model to classify on one of the defined labels from the classification task

In [68]:
@retry(
    stop=stop_after_attempt(2),  # Stop after 2 attempts
    wait=wait_fixed(1),  # Wait 1 second between retries
) #Handle retries for the OpenAI API Rate Limit Calls
def classify_query_text(query_text) -> str:
    # Search the Qdrant For Related Documents
    search_results = qdrant_search(query_text, top_k=10)
    # Prepare the Sample Documents retrieved from RAG
    sample_documents = [{
            "Text": result.payload['text'],
            "Label": result.payload['label_text'],
        } for result in search_results]

    # Prepare the User Message , Append the sample documents to the user message
    user_message = f"""
    Reference Documents from RAG Model: {sample_documents}\n\n
    Input text to classify: {query_text}
    """
    # Prepare the OpenAI Request Body
    openai_request_body = {
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message},
        ],
        # Response Model is the CategoryModel
        "response_model": CategoryModel,
        "model": "gpt-3.5-turbo",
        "temperature": 0.2, # Adjust the temperature for more creative responses
        "max_tokens": 100, # Limit the tokens as we are only classifying the text
        "seed": 42,
    }
    try:
        chat_completion = openai_client.chat.completions.create(**openai_request_body)
    except Exception as e:
        raise e
    # Assuming the chat_completion returns the category directly
    return chat_completion.category

In [46]:
# Create a Response Model for the classification task 
class CategoryModel(BaseModel):
    category: Literal[tuple(category_labels.tolist())]
    @validator('category')
    def check_category(cls, value):
        if value not in category_labels:
            raise ValueError(f"{value} is not a valid category")
        return value
# Specify the OpenAI API Key
openai_api_key = ''
# Patch the OpenAI client for instructor
openai_client = instructor.patch(openai.OpenAI(api_key=openai_api_key))

/var/folders/jg/y770c8z100j5422vmxzgvc_c0000gn/T/ipykernel_11637/751315324.py:3: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.4/migration/
  @validator('category')


In [69]:
category = classify_query_text("'i want to play the song again'")
print(category)

play_music


In [48]:
sample_test_csv

Unnamed: 0,id,text
6021,6022,It costs almost $1mil to train a physician (ab...
659,660,What are the charges for receiving money?
5998,5999,The exchange rate for my transaction last Satu...
10048,10049,Some people use debit because they have poor s...
10901,10902,i want to play the song again
...,...,...
9422,9423,describe the heart stone card game
11925,11926,health
12218,12219,today the following happened to me i had a mee...
13996,13997,please check the weather in kansas


Let us run this function on all of the datapoints in the sample_test_csv

In [52]:
# let us test the function on the sample_test_csv
sample_test_csv['predicted_category'] = None
# Iterate over the sample_test_csv DataFrame
for idx,row in tqdm(sample_test_csv.iterrows(),total=sample_test_csv.shape[0]):
    query_text = row['text']
    try:
        category = classify_query_text(query_text)
    except Exception:
        # handle for rate limit error
        category = "Error"
    sample_test_csv.loc[idx,'predicted_category'] = category

100%|██████████| 1000/1000 [20:42<00:00,  1.24s/it]


In [72]:
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.max_rows', 500)
sample_test_csv.head(20)

Unnamed: 0,id,text,label_text
6021,6022,It costs almost $1mil to train a physician (about 200k comes from the trainee and the rest from ...,not toxic
659,660,What are the charges for receiving money?,top_up_by_bank_transfer_charge
5998,5999,The exchange rate for my transaction last Saturday seems to have been wrong I got charged extra....,card_payment_wrong_exchange_rate
10048,10049,Some people use debit because they have poor self-control with credit.,general
10901,10902,i want to play the song again,play_music
1308,1309,please find all name start with alphabetic of a and create a list,general
6915,6916,tweet a message to suqcom that i am still waiting for my delivery,takeaway_query
11114,11115,It's always someone else's fault.\n\nParty of personal responsibility my backside.,not toxic
7165,7166,"YUGELY sad! Tremendously sad! All the tragic broflakes, mindlessly tuning in every week to read ...",not toxic
500,501,Time for what? He's battling zero right now. He's pursuing his lost causes while the country is ...,not toxic


In [71]:
# Save the results to a CSV file in required format
sample_test_csv.rename(columns={'predicted_category':'label_text'},inplace=True)
submission = sample_test_csv[['id','text','label_text']]
submission.to_csv('submission.csv',index=False)