In [1]:
import os
import json
import sqlite3
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from groq import Groq
from fuzzywuzzy import fuzz
from nltk.tokenize import sent_tokenize
import nltk

In [2]:
nltk.download('punkt')


[nltk_data] Downloading package punkt to
[nltk_data]     /Users/devayushrout/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
# Symptom synonyms (from your codebase)
symptom_synonyms = {
    "fever": ["bukhar", "tapman", "high temperature", "tez bukhar"],
    "headache": ["sar dard", "sar mein dard", "migraine"],
    "cough": ["khaansi", "khansi", "dry cough"],
    "cold": ["zukaam", "runny nose", "nasal congestion"],
    "vomiting": ["ulti", "throwing up", "nausea"],
    "diarrhea": ["patla mal", "loose motions", "dast"],
    "body pain": ["jodo ka dard", "sareer mein dard", "body ache"],
    "sore throat": ["gale mein dard", "gala kharab"],
}

In [4]:
# Initialize ChatGroq
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
llm = lambda prompt: client.chat.completions.create(
    model="llama3-70b-8192",
    messages=[{"role": "user", "content": prompt}],
    temperature=0.7
).choices[0].message.content

In [5]:
# Load FAISS vectorstore
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.load_local("baymax_vectorstore", embedding_model, allow_dangerous_deserialization=True)

  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# SQLite for conversation state
def init_db():
    conn = sqlite3.connect("conversations.db")
    c = conn.cursor()
    c.execute("""CREATE TABLE IF NOT EXISTS conversations
                (user_id TEXT, session TEXT, symptoms TEXT, duration TEXT, severity INTEGER, history TEXT)""")
    conn.commit()
    return conn

In [7]:
# Normalize symptoms (from your codebase)
def normalize_symptoms(user_input):
    normalized = set()
    for standard, synonyms in symptom_synonyms.items():
        for term in synonyms:
            if fuzz.partial_ratio(term.lower(), user_input.lower()) >= 85:
                normalized.add(standard)
    return list(normalized)

In [8]:
# Retrieve chunks (adapted from get_top_chunk_by_source)
def retrieve_chunks(symptoms, vectorstore, top_k=5):
    query = " ".join(symptoms)
    results = vectorstore.similarity_search_with_score(query, k=top_k)
    selected = {}
    for doc, score in results:
        source = doc.metadata.get("source")
        if source not in selected and score > 0.85:
            selected[source] = {"content": doc.page_content, "metadata": doc.metadata}
    return selected