In [18]:
!pip install transformers sentence-transformers faiss-cpu --quiet


In [19]:
import numpy as np
import pandas as pd
import torch
import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import re

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [20]:
# Example medical corpus. Replace or expand with your own data for better coverage.
medical_corpus = [
    "Hypertension is a condition where the force of blood against artery walls is too high, potentially causing health issues like heart disease.",
    "Diabetes affects how your body turns food into energy, often requiring blood sugar monitoring and medication.",
    "Asthma causes airways to narrow and produce extra mucus, leading to breathing difficulties.",
    "Anemia occurs when you lack sufficient healthy red blood cells to transport oxygen to tissues, causing fatigue.",
    "Cardiovascular disease involves narrowed blood vessels that can lead to heart attacks or strokes.",
    "Migraines are intense headaches often accompanied by nausea and sensitivity to light and sound.",
    "Influenza (flu) is a contagious respiratory illness caused by influenza viruses, with symptoms like fever and body aches.",
    "Allergies occur when the immune system reacts to substances in the environment, such as pollen or pet dander.",
    "Osteoarthritis is a joint inflammation characterized by cartilage breakdown, leading to pain and stiffness.",
    "Hyperthyroidism results from an overactive thyroid gland, causing rapid heartbeat and weight loss.",
]

corpus_df = pd.DataFrame({"text": medical_corpus})
corpus_df.head(10)


Unnamed: 0,text
0,Hypertension is a condition where the force of...
1,Diabetes affects how your body turns food into...
2,Asthma causes airways to narrow and produce ex...
3,Anemia occurs when you lack sufficient healthy...
4,Cardiovascular disease involves narrowed blood...
5,Migraines are intense headaches often accompan...
6,Influenza (flu) is a contagious respiratory il...
7,Allergies occur when the immune system reacts ...
8,Osteoarthritis is a joint inflammation charact...
9,Hyperthyroidism results from an overactive thy...


In [21]:
# Load a SentenceTransformer model for embedding the corpus
embedder = SentenceTransformer('all-MiniLM-L6-v2')

# Embed the medical corpus
corpus_embeddings = embedder.encode(corpus_df["text"].tolist(), convert_to_numpy=True)

# Build a FAISS index (L2 distance)
dim = corpus_embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dim)
faiss_index.add(corpus_embeddings)
print("FAISS index size:", faiss_index.ntotal)


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

FAISS index size: 10


In [22]:
def retrieve_medical_info(query, top_k=2):
    """
    Given a user query, retrieve the top_k most relevant documents
    from the medical corpus using FAISS and sentence embeddings.
    """
    query_embedding = embedder.encode([query], convert_to_numpy=True)
    distances, indices = faiss_index.search(query_embedding, top_k)
    results = []
    for idx in indices[0]:
        results.append(corpus_df["text"].iloc[idx])
    return results


In [23]:
conversation_history = []

def generate_prompt(user_input, retrieved_docs, conversation_history, max_history=3):
    """
    Create a multi-turn prompt for the language model.
    - Includes a few lines of previous user/assistant conversation.
    - Adds retrieved medical documents for context.
    - Encourages the model to provide an answer based on that context.
    """
    relevant_history = conversation_history[-max_history:]
    prompt = "You are a highly knowledgeable medical assistant. Provide helpful, accurate, and clear answers.\n"
    
    if relevant_history:
        prompt += "Conversation so far:\n"
        for turn in relevant_history:
            prompt += f"User: {turn['user']}\n"
            prompt += f"Assistant: {turn['assistant']}\n"
    
    prompt += f"User: {user_input}\n"
    prompt += "Relevant medical information:\n"
    for doc in retrieved_docs:
        prompt += f"- {doc}\n"
    prompt += "Assistant:"
    return prompt


In [24]:
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)
lm_model = AutoModelForCausalLM.from_pretrained(model_name)
lm_model.to(device)
print("DialoGPT model loaded successfully.")


DialoGPT model loaded successfully.


In [25]:
def medical_chatbot(num_turns=10):
    """
    Interactively chat with the medical chatbot for 'num_turns' turns.
    Each turn:
    1. Takes user input
    2. Retrieves relevant docs from the corpus
    3. Generates a multi-turn prompt with retrieved docs + conversation history
    4. Model produces a response, appended to conversation_history
    """
    global conversation_history
    print("Medical Chatbot: Hello! I am your advanced medical assistant. How can I help you today?")
    
    for _ in range(num_turns):
        user_input = input("You: ").strip()
        if not user_input:
            print("Medical Chatbot: Please say something or type 'exit' to quit.")
            continue
        if user_input.lower() in ["exit", "quit", "stop"]:
            print("Medical Chatbot: Thank you for chatting. Take care!")
            break
        
        # Retrieve relevant docs
        retrieved_docs = retrieve_medical_info(user_input, top_k=2)
        
        # Generate the multi-turn prompt
        prompt = generate_prompt(user_input, retrieved_docs, conversation_history)
        
        # Encode the prompt for the language model
        inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)
        
        # Generate a response
        output_ids = lm_model.generate(
            inputs,
            max_length=300,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )
        
        # Decode the newly generated tokens
        generated_text = tokenizer.decode(output_ids[0][inputs.shape[-1]:], skip_special_tokens=True)
        response = generated_text.strip()
        
        # Print the chatbot's response
        print("Medical Chatbot:", response)
        
        # Update conversation history
        conversation_history.append({"user": user_input, "assistant": response})


In [None]:
# Start the medical chatbot
medical_chatbot(num_turns=10)
