In [111]:
import pickle
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import re
from datetime import datetime, timedelta
import json
import uuid
from dateutil import parser
import pandas as pd

## Load the Model

In [112]:
# Load the saved model
with open('chatbot_model.pkl', 'rb') as f:
    model_data = pickle.load(f)

tokenizer = model_data['tokenizer']
label_encoder = model_data['label_encoder']
reverse_label_encoder = model_data['reverse_label_encoder']

num_labels = len(label_encoder)
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased',
    num_labels=num_labels
)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [113]:
class InferenceTool:
    def __init__(self, model, tokenizer, reverse_label_encoder):
        self.model = model
        self.tokenizer = tokenizer
        self.reverse_label_encoder = reverse_label_encoder
        self.conversation_state = {}  # Track state if needed
    
    def predict_and_respond(self, text, user_id='default'):
        intent, confidence = self.predict_intent(text)
        
        responses = {
            'greeting': "Hello! How can I help with your booking?",
            'reschedule_booking': "Sure, let's reschedule. Provide the new date and time.",
            'cancel_booking': "Got it. Confirm if you want to cancel.",
            'pricing_inquiry': "Prices start at $80. More details available.",
            'book_service': "I'd be happy to book.",
            'booking_status': "Please provide your booking reference.",
            'thanks': "You're welcome!",
            'confirm': "Confirmed!",
            'deny': "No problem.",
            'provide_datetime': "Noted the time. Proceeding.",
        }
        
        response = responses.get(intent, "Sorry, I didn't understand.")
        return {
            'response': response,
            'intent': intent,
            'confidence': confidence
        }
    
    def predict_intent(self, text):
        inputs = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=128,
            return_tensors='pt'
        )
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_label = torch.argmax(predictions, dim=-1).item()
            confidence = predictions[0][predicted_label].item()
        
        intent = self.reverse_label_encoder[predicted_label]
        return intent, confidence
    
    def extract_datetime(self, text):
        try:
            # Parse the date from text
            parsed_date = parser.parse(text, fuzzy=True)
            # Format to 'YYYY-MM-DD HH:MM'
            standardized = parsed_date.strftime('%Y-%m-%d %H:%M')
            return standardized
        except ValueError:
            return None

# Initialize the tool
tool = InferenceTool(model, tokenizer, reverse_label_encoder)

# Example usage
query = "Can I reschedule my booking?"
result = tool.predict_and_respond(query)
print(f"Query: {query}")
print(f"Response: {result['response']}")
print(f"Intent: {result['intent']}")
print(f"Confidence Score: {result['confidence']}")

Query: Can I reschedule my booking?
Response: Sure, let's reschedule. Provide the new date and time.
Intent: reschedule_booking
Confidence Score: 0.9668230414390564


In [114]:
import sqlite3

class AppointmentTool:
    def __init__(self, db_path='appointments.db'):
        self.db_path = db_path
        self.init_db()
    
    def init_db(self):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS appointments (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                user_id TEXT,
                service TEXT,
                date_time TEXT,
                status TEXT DEFAULT 'pending'
            )
        ''')
        conn.commit()
        conn.close()
    
    def add_appointment(self, user_id, service, date_time):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
            INSERT INTO appointments (user_id, service, date_time)
            VALUES (?, ?, ?)
        ''', (user_id, service, date_time))
        conn.commit()
        conn.close()
        return "Appointment added successfully."
    
    def get_appointments(self, user_id=None):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        if user_id:
            cursor.execute('SELECT * FROM appointments WHERE user_id = ?', (user_id,))
        else:
            cursor.execute('SELECT * FROM appointments')
        results = cursor.fetchall()
        conn.close()
        return results
    
    def cancel_appointment(self, appointment_id):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
            UPDATE appointments SET status = 'cancelled' WHERE id = ?
        ''', (appointment_id,))
        conn.commit()
        conn.close()
        return "Appointment cancelled successfully." if cursor.rowcount > 0 else "Appointment not found."
    
    def cancel_appointments_by_user(self, user_id):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
            UPDATE appointments SET status = 'cancelled' WHERE user_id = ?
        ''', (user_id,))
        conn.commit()
        conn.close()
        return f"{cursor.rowcount} appointments cancelled for user {user_id}."
    
    def reschedule_appointment(self, appointment_id, new_date_time):
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute('''
            UPDATE appointments SET date_time = ? WHERE id = ?
        ''', (new_date_time, appointment_id))
        conn.commit()
        conn.close()
        return "Appointment rescheduled successfully." if cursor.rowcount > 0 else "Appointment not found."

# Initialize the tool
appt_tool = AppointmentTool()

# Example usage
# result = appt_tool.add_appointment('user123', 'Thai Massage', '2025-03-30 10:00')
# print(result)

# # Get appointments
# appointments = appt_tool.get_appointments('user123')
# print(appointments)

# result = appt_tool.cancel_appointment(1)  # Use actual appointment ID from get_appointments
# print(result)

# appointments = appt_tool.get_appointments('user123')
# print(appointments)

# result = appt_tool.cancel_appointments_by_user('user123')
# print(result)

# appointments = appt_tool.get_appointments('user123')
# print(appointments)

# result = appt_tool.reschedule_appointment(1, '2025-03-31 11:00')  # Replace 1 with actual ID
# print(result)

# appointments = appt_tool.get_appointments('user123')
# print(appointments)

In [115]:
query = "Book me for March 24 2 PM"
extracted = tool.extract_datetime(query)
print(f"Extracted datetime: {extracted}")

Extracted datetime: 2025-03-24 14:00


In [116]:
from langgraph.graph import StateGraph, START, END
from typing import TypedDict

#Define the states
class ChatState(TypedDict):
    query: str
    intent: str
    confidence: float
    response: str
    appointment_action: str
    datetime: str
    conversation_state: dict  # Add for multi-turn

In [117]:
# Define nodes
def intent_analysis(state: ChatState):
    result = tool.predict_and_respond(state['query'])
    state['intent'] = result['intent']
    state['confidence'] = result['confidence']
    state['response'] = result['response']
    # Check conversation state for pending actions
    if state.get('conversation_state', {}).get('pending') == 'reschedule':
        state['intent'] = 'provide_datetime'
    return state

def appointment_trigger(state: ChatState):
    if state['intent'] in ['book_service', 'reschedule_booking', 'cancel_booking']:
        state['appointment_action'] = state['intent']
        state['datetime'] = tool.extract_datetime(state['query']) or 'Not extracted'
        user_id = 'user123'  # Default user_id; extract from context if available
        if state['intent'] == 'book_service':
            service = 'Massage'  # Extract service from query if possible
            result = appt_tool.add_appointment(user_id, service, state['datetime'])
            state['response'] += f" {result}"
        elif state['intent'] == 'reschedule_booking':
            appointment_id = 1  # Extract or assume ID from query/context
            result = appt_tool.reschedule_appointment(appointment_id, state['datetime'])
            state['response'] += f" {result}"
        elif state['intent'] == 'cancel_booking':
            appointment_id = 1  # Extract or assume ID from query/context
            result = appt_tool.cancel_appointment(appointment_id)
            state['response'] += f" {result}"
        elif state['intent'] == 'confirm':
            if state.get('conversation_state', {}).get('pending') == 'reschedule':
                # Perform reschedule
                result = appt_tool.reschedule_appointment(1, state['datetime'])
                state['response'] = f"Sent reschedule information to pro, you will get notified once it's confirmed. {result}"
                state['conversation_state'] = {}
        elif state['intent'] == 'reschedule_booking':
            state['response'] = "Sure, let's reschedule. Provide the new date and time."
            state['conversation_state'] = {'pending': 'reschedule'}
    return state

In [118]:
import pandas as pd

class DataTool:
    def __init__(self, csv_path='simple_dataset.csv'):
        self.data = pd.read_csv(csv_path)
    
    def retrieve_and_generate(self, query):
        # Simple retrieval: search for keywords in Massage_Type
        keywords = query.lower().split()
        relevant_rows = self.data[self.data['Massage_Type'].str.lower().str.contains('|'.join(keywords))]
        
        if relevant_rows.empty:
            return "Sorry, I couldn't find information on that massage type."
        
        # Rank by number of matching keywords
        def count_matches(row):
            massage_type = row['Massage_Type'].lower()
            return sum(1 for keyword in keywords if keyword in massage_type)
        
        relevant_rows['match_count'] = relevant_rows.apply(count_matches, axis=1)
        relevant_rows = relevant_rows.sort_values(by='match_count', ascending=False)
        
        # Limit to top 1 result
        top_row = relevant_rows.iloc[0]
        response = f"The {top_row['Massage_Type']} average cost is ${top_row['Avg_Spending']} and duration of massage is {top_row['Duration_Minutes']} minutes."        
        return response

# Initialize the tool
rag_tool = DataTool()

# Example usage
query = "How much does the Hot massage cost?"
result = rag_tool.retrieve_and_generate(query)
print(result)

The Hot Stone Massage average cost is $125.0 and duration of massage is 75 minutes.


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  relevant_rows['match_count'] = relevant_rows.apply(count_matches, axis=1)


In [119]:
def data_retrieval(state: ChatState):
    if state['intent'] == 'pricing_inquiry':
        rag_result = rag_tool.retrieve_and_generate(state['query'])
        state['response'] = rag_result
    return state

def route_after_intent(state: ChatState):
    if state['intent'] == 'pricing_inquiry':
        return "data_retrieval"
    else:
        return "appointment_trigger"

# Build graph
graph = StateGraph(ChatState)
graph.add_node("intent_analysis", intent_analysis)
graph.add_node("data_retrieval", data_retrieval)
graph.add_node("appointment_trigger", appointment_trigger)
graph.add_edge(START, "intent_analysis")
graph.add_conditional_edges("intent_analysis", route_after_intent)
graph.add_edge("data_retrieval", "appointment_trigger")
graph.add_edge("appointment_trigger", END)

# Compile and run
compiled_graph = graph.compile()

In [120]:
# Initial state
state = {"query": "Can I reschedule my booking?", "conversation_state": {}}
result = compiled_graph.invoke(state)
print(result['response'])  # "Sure, let's reschedule. Provide the new date and time."

# Next turn
state = {"query": "Yes", "conversation_state": result.get('conversation_state', {})}
result = compiled_graph.invoke(state)
print(result['response'])  # "Please provide the new date that you would like to reschedule your booking at"

# Final turn
state = {"query": "30 Mar 2025 10 am", "conversation_state": result.get('conversation_state', {})}
result = compiled_graph.invoke(state)
print(result['response']) 

Sure, let's reschedule. Provide the new date and time. Appointment not found.
Confirmed!
Noted the time. Proceeding.
