## Importing

In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain
from google.api_core.exceptions import ResourceExhausted
from google.api_core.exceptions import InvalidArgument
import random
import pandas as pd
import time
import re

## Accuracy function

In [2]:
def calculate_accuracy(predictions, true_labels):
    correct = sum(p == t for p, t in zip(predictions, true_labels))
    return correct / len(true_labels) if true_labels else 0


## Classification Model

In [3]:
def classification_model(chat_history, classes):
    api_keys = {"AIzaSyA2Zxvvgy1qbYADGni4QCmC4pA7ZTIIU-c",
                "AIzaSyDSgcHg94NTkjSeIwptOssRV7UWi58HreE",
                "AIzaSyBPifh4rqyEeDiHYwbPqEdLQtowJbVsHlY",
                "AIzaSyA270F6pxFKngmrCag9F0ecFwKiyi5GAN4",
                "AIzaSyCxs-HliBGec0HKZb6AxqTPHDNUpGutbPs",

                "AIzaSyAC72Z3Ctvg1Ku-IgRCPE2Cwbc_HD37ejM",
                "AIzaSyD6AFY37W-CyUByFaXI4wqLWxmq6QH90dk",
                "AIzaSyAHxw57s8PmjFtlphi9SjbnhA71vXWRG9o",
                "AIzaSyBUmDlertCkmbEBnts3d1g2wm0FDPHbWwE",
                "AIzaSyDNDwZW4ldyDTbAxrZlnhg8qHwrjcKoNnA"}
    classification_prompt = PromptTemplate(
    template="""
                Analyze the following conversation and determine which class it belongs to. Choose from the given options.

                Available Classes: {classes}

                Conversation:
                {chat}

                Output the class as one word only.

                example :
                chat:
                Customer: Hi Can i ask something?
                Admin: Hello How can i help you today?
                customer : can you help me i can't login to my account?
                Admin:sure i can help you .

                the output will be:
                login

                make sure to output the class only in one word without any explaination or special chars.
                """ 
                )
    model_name = "gemini-1.5-flash"
    random_api_key = random.choice(list(api_keys))
    llm = ChatGoogleGenerativeAI(api_key=random_api_key, model=model_name, temperature=0.1)
    classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
    classification_result = classification_chain.run(chat=chat_history, classes=classes, verbose=False)
    return classification_result


## Test

In [5]:
test = "Customer: I'm not happy with my order. Admin: I'm sorry to hear that."
classes = ["feedback", "support", "orders", "complaint", "other"]
result = classification_model(test, classes)
print(result)

complaint 



## Evaluation

In [6]:
def classification_evaluation(test_chats, true_labels, classes, model_name):
    api_keys = {"AIzaSyA2Zxvvgy1qbYADGni4QCmC4pA7ZTIIU-c",
                "AIzaSyDSgcHg94NTkjSeIwptOssRV7UWi58HreE",
                "AIzaSyBPifh4rqyEeDiHYwbPqEdLQtowJbVsHlY",
                "AIzaSyA270F6pxFKngmrCag9F0ecFwKiyi5GAN4",
                "AIzaSyCxs-HliBGec0HKZb6AxqTPHDNUpGutbPs",

                "AIzaSyAC72Z3Ctvg1Ku-IgRCPE2Cwbc_HD37ejM",
                "AIzaSyD6AFY37W-CyUByFaXI4wqLWxmq6QH90dk",
                "AIzaSyAHxw57s8PmjFtlphi9SjbnhA71vXWRG9o",
                "AIzaSyBUmDlertCkmbEBnts3d1g2wm0FDPHbWwE",
                "AIzaSyDNDwZW4ldyDTbAxrZlnhg8qHwrjcKoNnA"}
    classification_prompt = PromptTemplate(
    template="""
                Analyze the following conversation and determine which class it belongs to. Choose from the given options.

                Available Classes: {classes}

                Conversation:
                {chat}

                Output the class as one word only.

                example :
                chat:
                Customer: Hi Can i ask something?
                Admin: Hello How can i help you today?
                customer : can you help me i can't login to my account?
                Admin:sure i can help you .

                the output will be:
                login

                make sure to output the class only in one word without any explaination or special chars.
                """ 
                )    
    predictions = []
    wrong_predictions_IDs = []
    len_test_chats = len(test_chats)
    for i, chat_history in enumerate(test_chats):
        retry_attempts = 3
        for attempt in range(retry_attempts):
            try:
                random_api_key = random.choice(list(api_keys))
                llm = ChatGoogleGenerativeAI(api_key=random_api_key, model=model_name, temperature=0.1)
                classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
                classification_result = classification_chain.run(chat=chat_history, classes=classes, verbose=False)
                break
            except ResourceExhausted as e:
                if attempt < retry_attempts - 1:
                    print(f"Resource exhausted. Retrying... ({attempt + 1}/{retry_attempts})")
                    time.sleep(5)  # Wait for 5 seconds before retrying
                else:
                    print(f"Failed after {retry_attempts} attempts. Skipping this chat.")
                    classification_result = "Error"
            except Exception as e:
                if "API_KEY_INVALID" in str(e):
                    print(f"Invalid API key: {random_api_key}. Removing from the list.")
                    api_keys.remove(random_api_key)
                    if not api_keys:
                        raise RuntimeError("All API keys are invalid or expired.")
                else:
                    raise e        
        predictions.append(classification_result.strip())
        if classification_result.strip() != true_labels[i]:
            wrong_predictions_IDs.append(i)
        if i % 10 == 0:
            print(f"Progress: {i}/{len_test_chats}")
        time.sleep(1)
    return predictions, wrong_predictions_IDs

In [100]:
df = pd.read_csv("zero_shot_classification_sentiment_dataset_old.csv")
classes = ["feedback", "support", "orders", "complaint", "other"]
test_chats = df.head(20)["chat_history"].tolist()
true_labels = df.head(20)["target_class"].tolist()
model_name = "gemini-1.5-flash"
predictions, wrong_predictions_IDs = classification_evaluation(test_chats, true_labels, classes, model_name=model_name)

# extract the chars only from predictions using regex
predictions = [re.sub(r'\W+', '', prediction) for prediction in predictions]

accuracy = calculate_accuracy(predictions, true_labels)
print(f"Accuracy: {accuracy}")


Progress: 0/20
Progress: 10/20
Accuracy: 0.85


In [101]:
print("Wrong Predictions:")
for i in wrong_predictions_IDs:
    print(f"Chat ID: {i}")
    print(f"Chat: {test_chats[i]}")
    print(f"Predicted Class: {predictions[i]}")
    print(f"True Class: {true_labels[i]}")
    print("")

Wrong Predictions:
Chat ID: 11
Chat: Customer: My order was delayed. Admin: I apologize for the inconvenience.
Predicted Class: orders
True Class: complaint

Chat ID: 13
Chat: Customer: What are your store hours? Admin: We are open from 9 AM to 9 PM.
Predicted Class: support
True Class: other

Chat ID: 19
Chat: Customer: Could you change the delivery address? Admin: Let me check that for you.
Predicted Class: support
True Class: orders

