## Private Healthcare AI Assistant for Clinics Using Qdrant Hybrid Cloud (JWT-RBAC), DSPy and Groq — Llama3

Building a Private AI Assistant for clinics and hospitals which fetches patient data and answers questions on top of that data.

Flow Diagram :

![image](https://miro.medium.com/v2/resize:fit:4800/format:webp/0*CIjgbRNz4iYP6zwK)



1. Dataset: We will be working on a healthcare dataset that contains the patient’s data, including details about name, illness, medication, bills, hospital name, etc. One thing to be noted is, datasets like these are rarely available online; this dataset also is not real and is generated digitally: originally, it’s a multi-label classification dataset and can be downloaded from Kaggle here: [🩺Healthcare Dataset 🧪](https://www.kaggle.com/datasets/prasad22/healthcare-dataset)

2. DSPy: DSPy (or Declarative Sequencing Python framework) is a game-changing framework for algorithmically optimizing LM prompts instead of manual prompting.

3. Qdrant Managed Cloud: Qdrant is a lightweight vector database that recently started their managed cloud services, which let you use a free cluster for trial and the option to upgrade as you use more features. We will use it to store our dataset in the form of vectors.

4. Groq: Groq is building an AI accelerator application-specific integrated circuit (ASIC) which they call the Language Processing Unit (LPU) and related hardware to accelerate the inference performance of AI workloads. They provide access to latest models like Llama3 free of cost (it’s limited), but it’s enough for our use case.

1. Get API Key from [Qdrant](https://cloud.qdrant.io/login)
2. Get API Key from [Groq](https://groq.com)

In [1]:
%%capture
!pip install qdrant-client groq sentence-transformers dspy-ai fastembed gradio --upgrade

#### Google Colab

In [2]:
!kaggle datasets download -d prasad22/healthcare-dataset

Dataset URL: https://www.kaggle.com/datasets/prasad22/healthcare-dataset
License(s): CC0-1.0
Downloading healthcare-dataset.zip to /content
  0% 0.00/2.91M [00:00<?, ?B/s]
100% 2.91M/2.91M [00:00<00:00, 151MB/s]


In [3]:
!unzip /content/healthcare-dataset.zip -d /content/healthcare-dataset/

Archive:  /content/healthcare-dataset.zip
  inflating: /content/healthcare-dataset/healthcare_dataset.csv  


#### Kaggle

In [None]:
!kaggle datasets download -d prasad22/healthcare-dataset

Dataset URL: https://www.kaggle.com/datasets/prasad22/healthcare-dataset
License(s): CC0-1.0
Downloading healthcare-dataset.zip to /kaggle/working
100%|██████████████████████████████████████| 2.91M/2.91M [00:00<00:00, 5.43MB/s]
100%|██████████████████████████████████████| 2.91M/2.91M [00:00<00:00, 4.61MB/s]


In [None]:
!unzip /kaggle/working/healthcare-dataset.zip -d /kaggle/working/healthcare-dataset/

Archive:  /kaggle/working/healthcare-dataset.zip
  inflating: /kaggle/working/healthcare-dataset/healthcare_dataset.csv  


### Alternate method to download

In [None]:
!pip install opendatasets

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import opendatasets as od

od.download("https://www.kaggle.com/datasets/prasad22/healthcare-dataset")

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username:Your Kaggle Key:Downloading healthcare-dataset.zip to ./healthcare-dataset


100%|██████████| 2.91M/2.91M [00:01<00:00, 1.80MB/s]







In [None]:
%pwd

'/Users/kanishkhajaisankar/Downloads'

In [None]:
%ls

### 1. Data Exploration

In [4]:
import pandas as pd

# df = pd.read_csv("/kaggle/working/healthcare-dataset/healthcare_dataset.csv")
df = pd.read_csv("/content/healthcare-dataset/healthcare_dataset.csv")
# df.head()
df[:5]

Unnamed: 0,Name,Age,Gender,Blood Type,Medical Condition,Date of Admission,Doctor,Hospital,Insurance Provider,Billing Amount,Room Number,Admission Type,Discharge Date,Medication,Test Results
0,Bobby JacksOn,30,Male,B-,Cancer,2024-01-31,Matthew Smith,Sons and Miller,Blue Cross,18856.281306,328,Urgent,2024-02-02,Paracetamol,Normal
1,LesLie TErRy,62,Male,A+,Obesity,2019-08-20,Samantha Davies,Kim Inc,Medicare,33643.327287,265,Emergency,2019-08-26,Ibuprofen,Inconclusive
2,DaNnY sMitH,76,Female,A-,Obesity,2022-09-22,Tiffany Mitchell,Cook PLC,Aetna,27955.096079,205,Emergency,2022-10-07,Aspirin,Normal
3,andrEw waTtS,28,Female,O+,Diabetes,2020-11-18,Kevin Wells,"Hernandez Rogers and Vang,",Medicare,37909.78241,450,Elective,2020-12-18,Ibuprofen,Abnormal
4,adrIENNE bEll,43,Female,AB+,Cancer,2022-09-19,Kathleen Hanna,White-White,Aetna,14238.317814,458,Urgent,2022-10-09,Penicillin,Abnormal


### 2. Data Pre-process

In [5]:
# Function to format each row into a single string
def format_row(row):
    return (
        f"Name: {row['Name']}, Age: {row['Age']}, Gender: {row['Gender']}, "
        f"Blood Type: {row['Blood Type']}, Medical Condition: {row['Medical Condition']}, "
        f"Date of Admission: {row['Date of Admission']}, Doctor: {row['Doctor']}, "
        f"Hospital: {row['Hospital']}, Insurance Provider: {row['Insurance Provider']}, "
        f"Billing Amount: {row['Billing Amount']}, Room Number: {row['Room Number']}, "
        f"Admission Type: {row['Admission Type']}, Discharge Date: {row['Discharge Date']}, "
        f"Medication: {row['Medication']}, Test Results: {row['Test Results']}"
        "\n\n".lower()
    )


# Apply the function to each row and create a new column with the formatted text
df['formatted_text'] = df.apply(format_row, axis=1)


# Convert the formatted text into a list (or any other format you need)
text_data = df['formatted_text'].tolist()

In [6]:
df.shape

(55500, 16)

Since free-tier quadrant better to use lesser data

In [7]:
from random import shuffle
sampled_dataset = text_data[:200]
shuffle(sampled_dataset)

In [8]:
sampled_dataset[:5]

['name: peter fitzgerald, age: 73, gender: male, blood type: ab+, medical condition: obesity, date of admission: 2020-05-15, doctor: angela contreras, hospital: garner-bowman, insurance provider: medicare, billing amount: 19746.83200760437, room number: 162, admission type: urgent, discharge date: 2020-05-20, medication: aspirin, test results: abnormal\n\n',
 'name: william hill, age: 38, gender: female, blood type: a+, medical condition: cancer, date of admission: 2023-05-16, doctor: matthew walker, hospital: lindsey inc, insurance provider: cigna, billing amount: 39476.94751437997, room number: 428, admission type: elective, discharge date: 2023-06-01, medication: aspirin, test results: abnormal\n\n',
 'name: nicholas hall, age: 31, gender: female, blood type: a-, medical condition: diabetes, date of admission: 2020-12-15, doctor: marissa stevenson, hospital: alexander and jensen andrews,, insurance provider: unitedhealthcare, billing amount: 3730.002190896625, room number: 218, admi

### 3. Generate Embeddings for the sentences to store them in vectorDB

In [9]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("BAAI/bge-large-en-v1.5", device='cuda')
vectors = model.encode(sampled_dataset)

  from tqdm.autonotebook import tqdm, trange


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/94.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/191 [00:00<?, ?B/s]

In [10]:
vectors[0].shape

(1024,)

### 4. Initialize Qdrant Client

In [11]:
import os
from google.colab import userdata

# os.environ['QDRANT__SERVICE__API_KEY']=<your api key>
os.environ['QDRANT__SERVICE__API_KEY']= userdata.get('QDRANT_API_KEY')


from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams


# Initialize the client


client = QdrantClient(
    # url=<your cluster’s url>,
    url = "https://06d76ff3-ca6b-4a52-aa00-2773ae9154c9.us-east4-0.gcp.cloud.qdrant.io:6333",
    # url='http://localhost:6333',
    api_key=os.environ['QDRANT__SERVICE__API_KEY'],
)

The API key is internal to organization or services only. So not everyone can access the cluster and the data is safe!

### 4. Create a collection named : phi_data

PHI - Protected Health Information

In [12]:
client.recreate_collection(
    collection_name="phi_data",
    vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)

  client.recreate_collection(


True

### 5. Upload the Collection to the cloud cluster

In [13]:
client.upload_collection(
    collection_name="phi_data",
    ids=[i for i in range(len(sampled_dataset))],
    vectors=vectors,
    parallel=4,
    max_retries=3,
)

### 6. Qdrant JWT - Role Based Access Control

Qdrant also provides the option of access control via JWT

In [None]:
!docker pull qdrant/qdrant

In [None]:
!docker run -p 6333:6333 -p 6334:6334 -e QDRANT__SERVICE__API_KEY=QDRANT__SERVICE__API_KEY -e QDRANT__SERVICE__JWT_RBAC=true qdrant/qdrant

Create a dummy collection with the original API key

In [None]:
# create a dummy collection with the original API key


root_client = QdrantClient(
    url = "https://06d76ff3-ca6b-4a52-aa00-2773ae9154c9.us-east4-0.gcp.cloud.qdrant.io:6333",
    api_key = userdata.get('QDRANT_API_KEY'),
)


root_client.recreate_collection(
    collection_name="dummy",
    vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)


root_client.upload_collection(
    collection_name="dummy",
    ids=[i for i in range(len(sampled_dataset))],
    vectors=vectors,
    parallel=4,
    max_retries=3,
)

Create a user and limit their access to read only mode using JWT, it creates a temporary key that is linked to the original key.

In [None]:
import jwt
import time


# API key used as the secret to sign the token
api_key = userdata.get('QDRANT_API_KEY')


# Current time in seconds since the Unix epoch
current_time = int(time.time())


# JWT payload
payload = {
    'exp': current_time + 3600,  # Token expires in 1 hour
    'value_exists': {
        'collection': 'demo_collection',
        'matches': [
            {'key': 'user', 'value': 'John'}
        ]
    },
    "access": [
    {
        "collection": "demo_collection",
        "access": "r",
        "payload": {
            "user": "John"
      }
    }
  ]  # Read-only global access
}


# Encode the JWT token
encoded_jwt = jwt.encode(payload, api_key, algorithm='HS256')


# Print the JWT token
print(encoded_jwt)

Test uploading new data point to the collection :-> 403 (Forbidden) Response

In [None]:
# but what if you try to upload points to dummy instead of reading it, you will get forbidden error!


from qdrant_client import QdrantClient, models
import numpy as np


client = QdrantClient(
    url="http://localhost:6333",
    api_key=your_role_key,
)


data = np.array(list([0.1]*1024))
print(data.shape)


client.upload_points(
    collection_name="dummy",
    points=[
        models.PointStruct(
            id="5c56c793-69f3-4fbf-87e6-c4bf54c28c26",
            vector=data,
        )])

### 7. Set up DSPy for prompting

In [None]:
# Incase of error : AttributeError: module 'google._upb._message' has no attribute 'MessageMapContainer'

#!pip install proto-plus==1.24.0.dev1

In [16]:
from dspy.retrieve.qdrant_rm import QdrantRM
qdrant_retriever_model = QdrantRM("phi_data", client, k=3)

### 8. Initialize DSPy - Groq's integration using Groq's API Key

In [17]:
import dspy
llama3 = dspy.GROQ(model='llama3-8b-8192', api_key = userdata.get('GROQ_API_KEY') )

Use Qdrant as the retriever model and Groq as the LLM

In [18]:
dspy.settings.configure(rm=qdrant_retriever_model, lm=llama3)

### 9. Set up COT (Chain of Thought) Modules and Signatures using DSPy

In [19]:
class GenerateAnswer(dspy.Signature):
    """Answer questions with logical factoid answers."""

    context = dspy.InputField(desc="will contain phi medical data of patients matched with the query")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="an answer between 10 to 20 words")

Function to get 3-5 best matching data points for the query

In [20]:
def get_context(text):
    query_vector = model.encode(text)


    hits = client.search(
        collection_name="phi_data",
        query_vector=query_vector,
        limit=3  # Return 5 closest points
    )
    s=''
    for x in [sampled_dataset[i.id] for i in hits]:
        s = s + x
    return s

### 10. Main class that handles RAG Pipeline

In [21]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()


        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)


    def forward(self, question):
        context = get_context(question)
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)

In [22]:
rag = RAG()
def respond(query):
    response = rag(query)
    return response.answer

### 11. Gradio for Visually pleasing UI

In [23]:
import gradio as gr

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])

    def respond(query, chat_history):
        #response = uncompiled_rag(query)
        response = rag(query)
        chat_history.append((query, response.answer))
        return "", chat_history


    msg.submit(respond, [msg, chatbot], [msg, chatbot])

In [24]:
#demo.launch()
demo.launch(share=True) #if using colab

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://bb8e4318f1eff3162f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


