## Classification Model

In [1]:
VECTOR_DB_DIR = './vector_db'
NUM_RELEVANT_DOCS = 1
EMBEDDING_MODEL = 'paraphrase-MiniLM-L3-v2'
CLASSIFICATION_MODEL = "facebook/bart-large-mnli"

query = "Cannot find my payment receipt for subscription"


Importing the libraries

In [2]:
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain.schema import Document
import json
from transformers import pipeline



  from .autonotebook import tqdm as notebook_tqdm


Importing Data

In [3]:
def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

support_tickets = load_json('data/support_tickets.json')
knowledge_base = load_json('data/knowledge_base.json')

In [4]:
# Creating Document from json files

docs = []

for ticket, entry in zip(support_tickets, knowledge_base):
    ticket_text = ticket['text']
    relevant_knowledge = entry['content']
    
    combined_content = f"Support Ticket: {ticket_text}\nRelevant Knowledge: {relevant_knowledge}"
    
    docs.append(combined_content)

# To print the document
for doc in docs:
    print(doc)

Support Ticket: My account login is not working. I've tried resetting my password twice.
Relevant Knowledge: Category 1 - Login Issues - Login issues often occur due to incorrect passwords or account lockouts.
Support Ticket: The app crashes every time I try to upload a photo.
Relevant Knowledge: Category 2 - App Functionality - App crashes can be caused by outdated software or device incompatibility.
Support Ticket: I was charged twice for my last subscription payment.
Relevant Knowledge: Category 3 - Billing - Billing discrepancies may result from processing errors or duplicate transactions.
Support Ticket: I can't find the option to change my profile picture.
Relevant Knowledge: Category 4 - Account Management - Account management includes tasks such as changing profile information, linking social media accounts, and managing privacy settings.
Support Ticket: The video playback is very laggy on my device.
Relevant Knowledge: Category 5 - Performance Issues - Performance issues can b

In [5]:
docs = [Document(page_content=combined_content) for combined_content in docs]

Embedding Model

In [6]:
# Initializing the Embedding Model

embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)



In [8]:
# creating the vector database and retriever

os.makedirs(VECTOR_DB_DIR, exist_ok=True)
vector_db_path = os.path.join(VECTOR_DB_DIR, EMBEDDING_MODEL.split('/')[-1])

# To check if the vector database exists
if not os.path.exists(vector_db_path):

    print(f'Creating Vector Database at {vector_db_path}')
    Chroma.from_documents(
        documents=docs, 
        embedding=embedding_model, 
        persist_directory=vector_db_path
    )

db = Chroma(
    persist_directory=vector_db_path,
    embedding_function=embedding_model
)

retriever = db.as_retriever(
    search_type="similarity",
    search_kwargs={"k":NUM_RELEVANT_DOCS}
)

In [9]:
# retrieving the relevant docs

retrieved_docs = retriever.get_relevant_documents(query)


  warn_deprecated(


Classification Model

In [10]:
# creating the labels for the classification categories

knowledge_categories = [
    "Category 1 - Login Issues",
    "Category 2 - App Functionality",
    "Category 3 - Billing",
    "Category 4 - Account Management",
    "Category 5 - Performance Issues"
]

# Initializing the model

# classification_model = pipeline("text-classification", model="distilbert/distilbert-base-uncased-finetuned-sst-2-english")

classification_model = pipeline("zero-shot-classification", model=CLASSIFICATION_MODEL)



In [11]:
# Invoking the classification_model to generate output

def classify_ticket(query, retrieved_docs):
    context = " ".join([doc.page_content for doc in retrieved_docs])

    # return context
    
    prediction = classification_model(query, candidate_labels=knowledge_categories, hypothesis_template=f"This text is about {{}}. {context}")
    return prediction['labels'][0] 

Result

In [12]:
# Calling the classification function and printing the result

predicted_category = classify_ticket(query, retrieved_docs)

print(f"Predicted Category: {predicted_category}")

Predicted Category: Category 3 - Billing
