<a href="https://colab.research.google.com/github/LingYiJun0009/RAG_LLM_playmaInterview/blob/main/My_First_RAG_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Replace huggingface Access Tokens to your own huggingface token
### Then, remember to request access on the "mistralai/Mistral-7B-Instruct-v0.2", as it is gated


In [None]:
huggingface_token = "place token here"

# Then, Press Run All to run everything
### might take a couple minutes to download things and load model

In [None]:
!pip install transformers torch sentence-transformers chromadb langchain streamlit pyngrok datasets accelerate bitsandbytes

In [None]:
from huggingface_hub import login

login(huggingface_token)


In [None]:
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
import pandas as pd
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

In [None]:

class PatientCareRAGAgent:
    def __init__(self):
        self.setup_embedding_model()
        self.setup_llm()
        self.setup_vector_database()
        self.patient_data = {}

    def setup_embedding_model(self):
        """Initialize the embedding model for RAG"""
        print("Loading embedding model...")
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        print("✅ Embedding model loaded successfully!")

    def setup_llm(self):
        """Initialize the open-source LLM with efficient settings for Colab"""
        print("Loading LLM model...")

        # Use quantization for better memory efficiency in Colab
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )

        model_name = "mistralai/Mistral-7B-Instruct-v0.2"

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="cuda:0",
            trust_remote_code=True
        )
        # self.model = self.model.cuda()
        # Add padding token if not present
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("✅ LLM model loaded successfully!")

    def setup_vector_database(self):
        """Initialize ChromaDB for vector storage"""
        print("Setting up vector database...")
        self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))

        collection_name = "patient_care_data"

        try:
            # Try to get existing collection first
            self.collection = self.chroma_client.get_collection(name=collection_name)
            print(f"✅ Found existing collection: {collection_name}")
        except Exception:
            # Collection doesn't exist, create it
            print(f"Collection '{collection_name}' not found. Creating new collection...")
            self.collection = self.chroma_client.create_collection(
                name=collection_name,
                metadata={"hnsw:space": "cosine"}
            )
            print(f"✅ Created new collection: {collection_name}")

        print("✅ Vector database setup complete!")
    def get_ids(self):
        # print out existing patient
        all_ids = self.collection.get(include=[])['ids']

        return(all_ids)

    def add_patient_data(self, patient_id, data_entries):
        """Add patient data to the RAG system"""
        print(f"Adding data for patient {patient_id}...")

        documents = []
        metadatas = []
        ids = []

        for i, entry in enumerate(data_entries):
            # Create structured document
            doc_text = f"""
            Patient ID: {patient_id}
            Category: {entry['category']}
            Date: {entry['date']}
            Information: {entry['content']}
            Caregiver Notes: {entry.get('notes', 'N/A')}
            """

            documents.append(doc_text)
            metadatas.append({
                'patient_id': patient_id,
                'category': entry['category'],
                'date': entry['date'],
                'source': entry.get('source', 'manual_entry')
            })
            ids.append(f"{patient_id}_{entry['category']}_{i}")

        # Generate embeddings and store
        embeddings = self.embedding_model.encode(documents).tolist()

        self.collection.add(
            embeddings=embeddings,
            documents=documents,
            metadatas=metadatas,
            ids=ids
        )

        print(f"✅ Added {len(documents)} entries for patient {patient_id}")

    def query_patient_info(self, patient_id, query, max_results=2):
        """Query patient information using RAG"""
        print(f"Querying information for patient {patient_id}: {query}")

        # Generate query embedding
        query_embedding = self.embedding_model.encode([query]).tolist()

        # Search vector database
        results = self.collection.query(
            query_embeddings=query_embedding,
            n_results=max_results,
            where={"patient_id": patient_id}
        )

        # Extract relevant context
        context_docs = results['documents'][0] if results['documents'] else []

        return self.generate_response(query, context_docs, patient_id)

    def generate_response(self, query, context_docs, patient_id):
        """Generate response using LLM with retrieved context"""
        device = "cuda:0"
        # Prepare context
        context = "\n\n".join(context_docs[:1])  # Use most relevant docs

        prompt = f"""<s>[INST] You are a medical information assistant. Extract information from patient data to answer questions.

        Rules:
        - Answer ONLY based on the provided patient information
        - If information is not available, respond with "Information not available"
        - Be concise and factual
        - Do not make assumptions or add information not explicitly stated

        Patient Data:
        {context}

        Question: {query}

        Provide a direct answer based only on the information above. [/INST]"""

        print(f"here is how the prompt look like: \n {prompt}")
        # Tokenize and generate
        inputs = self.tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
        inputs = inputs.to(device) # Move inputs to the same device as the model

        print(f"inputs device: {inputs.device}")
        with torch.no_grad():
            outputs = self.model.generate(
                inputs,
                max_length=inputs.shape[1] + 150,
                num_return_sequences=1,
                temperature=0.7,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        # print(f"raw output: {outputs}")
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract just the response part
        response = response.split("Response (be specific, helpful, and include relevant details from the patient history):")[-1].strip()

        return {
            'response': response,
            'context_used': len(context_docs),
            'retrieved_info': context_docs
        }

In [None]:
# del agent  # Delete your large variables or dataframes
# import gc
# gc.collect()
# import torch
# torch.cuda.empty_cache()
agent = PatientCareRAGAgent()

# Adding patient data
### Can add more data according to the format


In [None]:
sample_patient_data = [
    {
        'category': 'medical_history',
        'date': '2025-05-15',
        'content': 'Patient has diabetes type 2, diagnosed in 2020. Takes metformin 500mg twice daily. Last HbA1c was 7.2%.',
        'notes': 'Patient sometimes forgets evening medication'
    },
    {
        'category': 'daily_habits',
        'date': '2025-06-01',
        'content': 'Prefers morning showers around 8 AM. Likes to have breakfast at 8:30 AM - usually oatmeal or toast.',
        'notes': 'Gets agitated if routine is changed'
    },
    {
        'category': 'preferences',
        'date': '2025-06-10',
        'content': 'Enjoys classical music, especially Mozart. Likes to read mystery novels in the afternoon.',
        'notes': 'Music helps calm anxiety episodes'
    },
    {
        'category': 'safety_concerns',
        'date': '2025-06-15',
        'content': 'Has history of falls. Uses walker for mobility. Needs assistance with stairs.',
        'notes': 'Never leave unattended near stairs'
    },
    {
        'category': 'social_info',
        'date': '2025-06-12',
        'content': 'Has two children who visit weekly. Daughter Sarah calls every Tuesday evening.',
        'notes': 'Gets excited about family visits, may need help preparing'
    }
]

# Add sample data
patient_id = "PATIENT_001"
agent.add_patient_data(patient_id, sample_patient_data)

In [None]:
sample_patient_data2 = [
    {
        'category': 'medical_history',
        'date': '2025-05-20',
        'content': 'Patient has hypertension, diagnosed in 2018. Takes amlodipine 10mg daily. Blood pressure averages 130/80 mmHg.',
        'notes': 'Occasionally misses doses when feeling well'
    },
    {
        'category': 'daily_habits',
        'date': '2025-06-03',
        'content': 'Enjoys taking a walk in the garden every evening at 6 PM. Has lunch at noon; prefers light meals like salads or soups.',
        'notes': 'Walk is important for mood, avoid interruptions'
    },
    {
        'category': 'preferences',
        'date': '2025-06-11',
        'content': 'Likes watching nature documentaries after dinner. Prefers herbal tea over coffee.',
        'notes': 'Tea helps with relaxation before bedtime'
    },
    {
        'category': 'safety_concerns',
        'date': '2025-06-18',
        'content': 'Has mild visual impairment. Needs good lighting in all rooms. Avoids using sharp kitchen tools.',
        'notes': 'Ensure nightlights are on in hallways'
    },
    {
        'category': 'social_info',
        'date': '2025-06-14',
        'content': 'Lives with spouse. Grandchildren visit on weekends. Enjoys video calls with old friends every Friday.',
        'notes': 'May need help setting up video calls'
    }
]

# Add sample data
patient_id2 = "PATIENT_002"
agent.add_patient_data(patient_id2, sample_patient_data2)

# Run query
### Enter patient id first, then ask information about that patient
#### Sample questions:
##### - what's the patient's favourite food
##### - Who are the patient's family
##### - What illness does this patient have?
etc...

In [None]:
def query_patient2(question, patient_id="PATIENT_001"):
    """Helper function for easy querying"""

    result = agent.query_patient_info(patient_id, question)

    # print(f"\n❓ Question: {question}")
    # result = text.split(marker, 1)[1]
    print(f"🤖 Response: {result['response'].split('[/INST]')[1]}")
    # print(f"📊 Used {result['context_used']} relevant documents")

    return result


In [None]:
query = ""
patient_ID = ""
while query != "exit":
    unique_patients = sorted(list(set('_'.join(item.split('_')[:2]) for item in agent.get_ids())))

    if patient_ID == "":
        print(f"sample patient ID: \n{unique_patients}")
        patient_ID = input("please enter patient ID you'd like to query:")
        unique_patients = sorted(list(set('_'.join(item.split('_')[:2]) for item in agent.get_ids())))
        if patient_ID not in unique_patients:
            patient_ID = ""
            print(f"patient ID not found, here are the available choices: {unique_patients}")
            continue
    print("enter 'exit' if you'd like to end session, or query another patient")
    query = input("Question:" )
    if query == "exit":
        print("Session stopped: ")
        continue
    query_patient2(question = query, patient_id = patient_ID)