In [34]:
import pandas as pd
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
from qdrant_client.http import models
from tqdm import tqdm
import openai
import instructor
from pydantic import BaseModel, field_validator
from typing import Literal
from tenacity import (
    retry,
    stop_after_attempt,
    wait_fixed,
)

In [2]:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Spawn a Qdrant client connection in memory
qdrant_client = QdrantClient(":memory:")
qdrant_client.create_collection('text-classification',vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE))

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

True

In [11]:
train_csv = pd.read_csv('data/train.csv')

In [12]:
train_csv.head()

Unnamed: 0,id,text,label_text
0,1,I already made a transfer and want to cancel i...,cancel_transfer
1,2,I don't think I made this payment that is show...,card_payment_not_recognised
2,3,"You seem to forget there are so called ""rights...",not toxic
3,4,play the fifty songs i listen to most often,play_music
4,5,please make the smart socket turn off,iot


In [13]:
train_csv.shape

(7994, 3)

In [14]:
pd.set_option('display.max_rows', 500)
pd.DataFrame(train_csv.label_text.value_counts())

Unnamed: 0_level_0,count
label_text,Unnamed: 1_level_1
not toxic,1842
iot,374
play,291
play_music,288
weather,260
weather_query,260
news_query,212
news,212
datetime,191
alarm,167


In [15]:

category_labels = train_csv.label_text.unique()

In [16]:
# create a point valid for qdrant
embeddings = embedding_model.encode(train_csv['text'].tolist(), show_progress_bar=True)

Batches:   0%|          | 0/250 [00:00<?, ?it/s]

In [17]:
# Create and upload points to Qdrant
points = []
for idx, row in train_csv.iterrows():
    point = models.PointStruct(
        id=idx,  # Use the dataframe index as the point ID
        vector=embeddings[idx].tolist(),  # Convert the embedding to a list
        payload={'label_text': row['label_text'] , "text":row['text']}  # Use the label_text as the payload
    )
    points.append(point)
qdrant_client.upload_points(collection_name='text-classification', points=points)

In [18]:
# Lets Move on to the Querying part
test_csv = pd.read_csv('data/test.csv')

In [19]:
test_csv.shape

(4000, 2)

In [20]:
query_text = test_csv.iloc[4]['text']
print(f"Query Text: {query_text}")
print("---"*80)
query_vector = embedding_model.encode(query_text)
qdrant_client.search(collection_name='text-classification', query_vector=query_vector, limit=5,score_threshold=0.3)

Query Text: Can I open an account for a child?
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


[ScoredPoint(id=2868, version=0, score=0.9713864326477051, payload={'label_text': 'age_limit', 'text': 'Can I open up an account for my child?'}, vector=None, shard_key=None),
 ScoredPoint(id=1021, version=0, score=0.9092108607292175, payload={'label_text': 'age_limit', 'text': 'Could I open an account for children?'}, vector=None, shard_key=None),
 ScoredPoint(id=3725, version=0, score=0.8878611326217651, payload={'label_text': 'age_limit', 'text': 'Can my children open an account?'}, vector=None, shard_key=None),
 ScoredPoint(id=7204, version=0, score=0.8877559304237366, payload={'label_text': 'age_limit', 'text': 'Would it be possible to open up an account for children?'}, vector=None, shard_key=None),
 ScoredPoint(id=7760, version=0, score=0.8860688209533691, payload={'label_text': 'age_limit', 'text': 'I want to open an account for my child.'}, vector=None, shard_key=None)]

In [21]:
def qdrant_search(query_text,top_k=5):
    query_vector = embedding_model.encode(query_text)
    search_response = qdrant_client.search(collection_name='text-classification', query_vector=query_vector, limit=top_k)
    return search_response

In [22]:
qdrant_search("do i have any incoming emails")

[ScoredPoint(id=2399, version=0, score=0.5249748229980469, payload={'label_text': 'edit_personal_details', 'text': 'I have a new email.'}, vector=None, shard_key=None),
 ScoredPoint(id=6510, version=0, score=0.43753761053085327, payload={'label_text': 'not toxic', 'text': 'How would any of us know if emails had been doctored unless we had access to them and to an expert?'}, vector=None, shard_key=None),
 ScoredPoint(id=7487, version=0, score=0.42286956310272217, payload={'label_text': 'card_arrival', 'text': "I haven't gotten my credit card in the mail."}, vector=None, shard_key=None),
 ScoredPoint(id=6710, version=0, score=0.4187886714935303, payload={'label_text': 'transfer_not_received_by_recipient', 'text': 'I dont see my reciept'}, vector=None, shard_key=None),
 ScoredPoint(id=710, version=0, score=0.4184560179710388, payload={'label_text': 'news_query', 'text': 'do you have any updates on blank'}, vector=None, shard_key=None)]

In [23]:
# random sample of 1000 rows from the test csv
sample_test_csv = test_csv.sample(1000)

In [24]:
rag_data = []
# Iterate over the test_csv DataFrame
for idx, row in tqdm(sample_test_csv.iterrows(),total=sample_test_csv.shape[0]):
    query_text = row['text']

    # Search in Qdrant for the top match
    # (Note you can also do a batch search since it is more efficient, but for simplicity we are doing a single search here)
    search_results = qdrant_search(query_text)
    # Check if there are any results
    if search_results:
        top_match = search_results[0] # Results are sorted by score, so the top match is the first one
        top_matched_text = top_match.payload['text']
        top_matched_label = top_match.payload['label_text']
        score = top_match.score

        # Append the data to the few_shot_data list
        rag_data.append({
            "Query Text": query_text,
            "Top Matched Text": top_matched_text,
            "Label": top_matched_label,
            "Label Match Score": score,
        })

# Create a DataFrame from the few_shot_data list
few_shot_df = pd.DataFrame(rag_data)

100%|███████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:35<00:00, 10.44it/s]


In [25]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_colwidth', 500)
few_shot_df

Unnamed: 0,Query Text,Top Matched Text,Label,Label Match Score
0,song title,lyrics,music,0.697595
1,convert eight thirty from gtm four to g. m. t. five thirty,convert one thousand and thirty from g. m. t. plus two hundred and thirty to g. m. t. zero hundred,datetime,0.681218
2,is it ten,play one through ten on list,play,0.513409
3,I was really touched by the promo for this movie. Now I’m horrified after watching that video. I will NOT be seeing this movie.,should i watch this movie,recommendation,0.388121
4,turn it up olly,olly lights off,iot,0.696151
...,...,...,...,...
995,hold,stop,audio_volume_mute,0.478837
996,There's a recent charge on my card that I know I didn't make because I've never seen the name before. Can we investigate this?,There was a purchase on my card recently to a name that I don't recognize at all. What can be done about this? I need my money back.,card_payment_not_recognised,0.749043
997,turn off the alarm,set the alarm,alarm,0.854216
998,did i leave the light on in the garage,turn off the garage light,iot_hue_lightoff,0.726741


In [26]:
categories_list = "- " + "\n- ".join(category_labels)
system_prompt = f"""
You are an agent that is specialized in classification tasks.\n
Along with the input text, you are provided with the top 10 documents retrieved from a Retrieval-Augmented Generation (RAG) model. 
Use this information to classify the input text into one of the following categories:
{categories_list}
Note: The documents are included in the user's message for context.
"""

In [27]:
print(system_prompt)


You are an agent that is specialized in classification tasks.

Along with the input text, you are provided with the top 10 documents retrieved from a Retrieval-Augmented Generation (RAG) model. 
Use this information to classify the input text into one of the following categories:
- cancel_transfer
- card_payment_not_recognised
- not toxic
- play_music
- iot
- datetime
- alarm
- alarm_set
- music
- card_payment_fee_charged
- news_query
- iot_hue_lightoff
- card_arrival
- weather_query
- why_verify_identity
- datetime_query
- audio
- alarm_query
- iot_wemo_on
- balance_not_updated_after_cheque_or_cash_deposit
- card_payment_wrong_exchange_rate
- weather
- takeaway_query
- news
- top_up_reverted
- music_query
- toxic
- card_acceptance
- takeaway
- iot_hue_lightdim
- exchange_rate
- pending_cash_withdrawal
- fiat_currency_support
- play
- music_likeness
- top_up_limits
- iot_cleaning
- card_linking
- wrong_amount_of_cash_received
- age_limit
- takeaway_order
- card_not_working
- audio_vol

In [35]:
# Create a Response Model for the classification task 
class CategoryModel(BaseModel):
    category: Literal[tuple(category_labels.tolist())]
    @field_validator('category')
    def check_category(cls, value):
        if value not in category_labels:
            raise ValueError(f"{value} is not a valid category")
        return value
# Specify the OpenAI API Key
openai_api_key = 'sk-proj-mcb0iV62IhxFxgnxi6BHT3BlbkFJ2kb3bWu5jxEcKWfb2AfE'
# Patch the OpenAI client for instructor
openai_client = instructor.patch(openai.OpenAI(api_key=openai_api_key))

In [36]:
@retry(
    stop=stop_after_attempt(2),  # Stop after 2 attempts
    wait=wait_fixed(60),  # Wait 60 second between retries
) #Handle retries for the OpenAI API Rate Limit Calls
def classify_query_text(query_text) -> str:
    # Search the Qdrant For Related Documents
    search_results = qdrant_search(query_text, top_k=10)
    # Prepare the Sample Documents retrieved from RAG
    sample_documents = [{
            "Text": result.payload['text'],
            "Label": result.payload['label_text'],
        } for result in search_results]

    # Prepare the User Message , Append the sample documents to the user message
    user_message = f"""
    Reference Documents from RAG Model: {sample_documents}\n\n
    Input text to classify: {query_text}
    """
    # Prepare the OpenAI Request Body
    openai_request_body = {
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message},
        ],
        # Response Model is the CategoryModel
        "response_model": CategoryModel,
        "model": "gpt-3.5-turbo",
        "temperature": 0.2, # Adjust the temperature for more creative responses
        "max_tokens": 100, # Limit the tokens as we are only classifying the text
        "seed": 42,
    }
    try:
        chat_completion = openai_client.chat.completions.create(**openai_request_body)
    except Exception as e:
        raise e
    # Assuming the chat_completion returns the category directly
    return chat_completion.category

In [37]:
category = classify_query_text("'i want to play the song again'")
print(category)

play_music


In [38]:
test_csv.head()

Unnamed: 0,id,text
0,1,i don't want any alarms
1,2,remove new year from calendar
2,3,"I noticed an extra $1 charge on my statement, can you tell me why that is?"
3,4,Help! I can't find my card.
4,5,Can I open an account for a child?


In [39]:
# let us test the function on the test data
test_csv['predicted_category'] = None
# Iterate over the test DataFrame
for idx,row in tqdm(test_csv.iterrows(),total=test_csv.shape[0]):
    query_text = row['text']
    try:
        category = classify_query_text(query_text)
    except Exception:
        # handle for any exception
        category = "Error"
    test_csv.loc[idx,'predicted_category'] = category

 23%|██████████████████                                                            | 927/4000 [17:33<1:02:20,  1.22s/it]Incomplete output detected, should increase max_tokens
Incomplete output detected, should increase max_tokens
 31%|████████████████████████▍                                                      | 1235/4000 [23:54<59:46,  1.30s/it]Incomplete output detected, should increase max_tokens
 74%|██████████████████████████████████████████████████████████▍                    | 2960/4000 [56:54<19:40,  1.14s/it]Incomplete output detected, should increase max_tokens
100%|█████████████████████████████████████████████████████████████████████████████| 4000/4000 [1:26:10<00:00,  1.29s/it]


In [40]:
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.max_rows', 500)
test_csv.head(20)

Unnamed: 0,id,text,predicted_category
0,1,i don't want any alarms,alarm_remove
1,2,remove new year from calendar,calendar_query
2,3,"I noticed an extra $1 charge on my statement, can you tell me why that is?",extra_charge_on_statement
3,4,Help! I can't find my card.,lost_or_stolen_card
4,5,Can I open an account for a child?,age_limit
5,6,play country radio,play_music
6,7,clean my house,iot_cleaning
7,8,please turn socket off,iot
8,9,is there anything i should be reminded about,alarm_query
9,10,What is the limit to number of transactions I can do with a disposable card?,card_acceptance


In [41]:
# Save the results to a CSV file in required format
test_csv.rename(columns={'predicted_category':'label_text'},inplace=True)
submission = test_csv[['id','label_text']]
submission.to_csv('submission_1.csv',index=False)

In [42]:
submission

Unnamed: 0,id,label_text
0,1,alarm_remove
1,2,calendar_query
2,3,extra_charge_on_statement
3,4,lost_or_stolen_card
4,5,age_limit
...,...,...
3995,3996,general_joke
3996,3997,cancel_transfer
3997,3998,not toxic
3998,3999,top_up_reverted


In [44]:
.to_csv('output.csv', index=False)


NameError: name 'DataFrame' is not defined