In [3]:
# Importing necessary libraries for machine learning and deep learning
import transformers  # Hugging Face library for working with pre-trained models
import torch        # PyTorch library for tensor computations and neural networks
import typing       # Library for type hinting and type annotations





In [4]:
class MedicalChatbot:
    # Dictionary mapping medical question categories to tailored system prompts
    category_prompts = {
        "Basic Medical Knowledge": (
            "You are an experienced medical professional specializing in general medical knowledge. "
            "Provide detailed and comprehensible answers to basic medical questions, ensuring accuracy and clarity."
        ),
        "Pharmacological Queries": (
            "You are a pharmacology expert with extensive knowledge of drugs, their mechanisms, and side effects. "
            "Provide precise and evidence-based information about medications and treatments."
        ),
        "Diagnostic Reasoning": (
            "You are a diagnostic expert proficient in identifying medical conditions. "
            "Offer clear and methodical guidance for diagnostic reasoning and differential diagnoses."
        ),
        "Treatment & Management": (
            "You are a specialist in medical treatment and patient management strategies. "
            "Provide comprehensive explanations of treatment guidelines and management protocols."
        ),
        "Specialized Medical Topics": (
            "You are an expert in advanced medical research and innovations. "
            "Discuss specialized medical topics with a focus on the latest scientific advances."
        ),
        "Preventive Medicine": (
            "You are a preventive medicine specialist with expertise in lifestyle modifications and public health. "
            "Explain preventive strategies and their importance in reducing disease risk."
        ),
        "Specialized Medical Scenarios": (
            "You are a clinical specialist adept at handling unique and complex medical scenarios. "
            "Provide nuanced insights into complications and specialized conditions."
        )
    }

    # Fallback system prompt for undefined categories
    default_prompt = (
        "You are an expert and experienced medical professional with extensive medical knowledge. "
        "Provide precise, evidence-based medical explanations that are scientifically accurate and comprehensible to a general audience."
    )

    def __init__(self, model_id="aaditya/OpenBioLLM-Llama3-70B"):
        """
        Initialize the medical chatbot with a specific medical language model
        
        Args:
            model_id (str): Identifier for the pre-trained medical language model
        """

        try:
            # Import logging for better error tracking and debugging
            import logging

            # Configure logging
            logging.basicConfig(level=logging.INFO, 
                                format='%(asctime)s - %(levelname)s - %(message)s')
            self.logger = logging.getLogger(__name__)

            # Log model initialization
            self.logger.info(f"Initializing Medical Chatbot with model: {model_id}")

            # Create a text generation pipeline using the specified model
            self.pipeline = transformers.pipeline(
                "text-generation",            # Specify the task as text generation
                model=model_id,               # Use the specified model from Hugging Face
                model_kwargs={"torch_dtype": torch.bfloat16},  # Use lower precision for memory efficiency
                device="auto",                # Automatically select best available device (CPU/GPU)
            )
        # Log successful model loading
            self.logger.info("Model successfully initialized")

        except Exception as e:
            # More robust error handling
            self.logger.error(f"Failed to initialize model: {e}")
            raise RuntimeError(f"Model initialization failed: {e}")
        

        # Default medical category (can be dynamically updated later)
        self.med_cat = "Basic Medical Knowledge"

        # Fetch the initial system prompt based on the category
        self.system_prompt = self.category_prompts.get(self.med_cat, self.default_prompt)

    def set_category(self, category: str):
        """
        Set the medical category dynamically and update the system prompt.
        
        Args:
            category (str): New medical category to set.
        """
        self.med_cat = category
        self.system_prompt = self.category_prompts.get(self.med_cat, self.default_prompt)

    def generate_response(
        self, 
        user_query: str,           # Type hint for user's input query
        max_tokens: int = 256,     # Default maximum response length
        temperature: float = 0.1,   # Default temperature for response variability
        safety_threshold: float = 0.7  # New parameter for response safety
    ) -> str:                      # Type hint indicating return is a string
        """
        Generate a medical response to a user's query
        
        Args:
            user_query (str): Medical question or prompt
            max_tokens (int): Maximum response length
            temperature (float): Controls response randomness/creativity
        
        Returns:
            str: Generated medical response
        """

        # Input validation
        if not user_query or not isinstance(user_query, str):
            raise ValueError("User query must be a non-empty string")
        
        # Ensure temperature is within a reasonable range
        temperature = max(0.0, min(temperature, 1.0))


        try:
            # Construct message list with system and user roles
            messages = [
                {"role": "system", "content": self.system_prompt},  # System prompt defining AI's role
                {"role": "user", "content": user_query}             # User query for the model
            ]

            # Apply chat template to format messages for the model
            prompt = self.pipeline.tokenizer.apply_chat_template(
                messages,
                tokenize=False,               # Return as string, not tokens
                add_generation_prompt=True    # Add markers for response generation
            )

            # Define token IDs to terminate generation
            terminators = [
                self.pipeline.tokenizer.eos_token_id,                 # Standard end-of-sequence token
                self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")  # Custom end token
            ]

            # Generate response using the language model
            outputs = self.pipeline(
                prompt,
                max_new_tokens=max_tokens,    # Limit response length
                eos_token_id=terminators,     # Use defined termination tokens
                do_sample=True,               # Enable probabilistic sampling
                temperature=temperature,       # Control response randomness
                top_p=0.9,                    # Nucleus sampling parameter
                top_k=50,  # Added top-k sampling for more controlled generation
            )

            # Extract and clean the generated response
            generated_text = outputs[0]["generated_text"][len(prompt):].strip()

            # Optional: Simple safety filtering (very basic, consider more robust solutions)
            if len(generated_text.split()) / max_tokens > safety_threshold:
                self.logger.warning("Generated response might exceed safety threshold")

            return generated_text
        
        except Exception as e:
            # Comprehensive error logging
            self.logger.error(f"Response generation failed: {e}")
            return f"An error occurred while processing your query: {str(e)}"

In [5]:
# List of predefined medical questions for testing the model catogorized with dictionaries 
medical_questions = {
    "Basic Medical Knowledge": [
        "What are the primary symptoms of type 2 diabetes?",
        "Explain the pathophysiology of hypertension.",
        "What are the recommended screening protocols for breast cancer?"
    ],
    "Pharmacological Queries": [
        "What are the potential side effects of statins?",
        "How do ACE inhibitors work to manage blood pressure?",
        "Compare the mechanisms of different antidepressant classes."
    ],
    "Diagnostic Reasoning": [
        "What diagnostic tests would you recommend for suspected rheumatoid arthritis?",
        "Describe the differential diagnosis for chest pain in a 45-year-old male."
    ],
    "Treatment & Management": [
        "What are current guidelines for managing type 1 diabetes in adolescents?",
        "Explain the stages of cancer treatment and potential therapies."
    ],
    "Specialized Medical Topics": [
        "How does CRISPR technology potentially impact genetic disease treatment?",
        "What are the latest advances in immunotherapy for cancer?"
    ],
    "Preventive Medicine": [
        "What lifestyle modifications can reduce the risk of cardiovascular disease?",
        "Discuss the importance of vaccination in preventing infectious diseases."
    ],
    "Specialized Medical Scenarios": [
        "What are the complications of untreated gestational diabetes?",
        "Explain the neurological manifestations of multiple sclerosis."
    ]
}

med_cat = "Basic Medical Knowledge" #select the medical category for the user query
User_questions = medical_questions[med_cat] #set the users questions to the selected category


In [None]:
# Main function to demonstrate the chatbot's functionality
def main():
    try:
         
        # Create an instance of the MedicalChatbot
        chatbot = MedicalChatbot()

        # Support processing multiple categories
        for category, questions in medical_questions.items():
            print(f"\n🩺 Category: {category}")

        # Iterate through a subset of medical questions
            for question in User_questions:  # Test the selected questions
                try:
                        print(f"\n📝 Query: {question}")
                        response = chatbot.generate_response(question)
                        print(f"📝 Response: {response}")
                        print("-" * 50)
                except Exception as question_error:
                    print(f"Error processing question: {question_error}")

    except Exception as main_error:
         print(f"Critical error in main execution: {main_error}")

In [None]:
# Ensure the main function only runs if the script is executed directly
if __name__ == "__main__":
    main()

In [6]:
#optional main function that allows users to input the category
def main():
    try:
        # Create an instance of the MedicalChatbot
        chatbot = MedicalChatbot()

        # Print available categories
        print("Available Medical Categories:")
        for idx, category in enumerate(medical_questions.keys(), 1):
            print(f"{idx}. {category}")

        # Get user input for category selection
        while True:
            try:
                selection = input("\nEnter the number of the category you want to query (or 'q' to quit): ")
                
                # Allow quitting
                if selection.lower() == 'q':
                    break

                # Convert selection to integer and get category
                category_list = list(medical_questions.keys())
                selected_category = category_list[int(selection) - 1]

                print(f"\n🩺 Selected Category: {selected_category}")

                # Get questions for the selected category
                category_questions = medical_questions[selected_category]

                # Process questions in the selected category
                for question in category_questions:
                    try:
                        print(f"\n📝 Query: {question}")
                        
                        # Dynamically set the chatbot's category
                        chatbot.set_category(selected_category)
                        
                        response = chatbot.generate_response(question)
                        print(f"📝 Response: {response}")
                        print("-" * 50)
                    except Exception as question_error:
                        print(f"Error processing question: {question_error}")

            except (ValueError, IndexError):
                print("Invalid selection. Please enter a valid category number.")

    except Exception as main_error:
        print(f"Critical error in main execution: {main_error}")

# Keep the __main__ check
if __name__ == "__main__":
    main()


2024-12-18 20:44:36,775 - INFO - Initializing Medical Chatbot with model: aaditya/OpenBioLLM-Llama3-70B
2024-12-18 20:44:56,885 - INFO - Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-12-18 20:44:56,887 - INFO - NumExpr defaulting to 8 threads.


Downloading shards:   0%|          | 0/30 [00:00<?, ?it/s]

pytorch_model-00001-of-00030.bin:  34%|###4      | 1.56G/4.58G [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


pytorch_model-00002-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00003-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00004-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00005-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00006-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00007-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00008-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00009-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00010-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00011-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00012-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00013-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00014-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00015-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00016-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00017-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00018-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00019-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00020-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00021-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00022-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00023-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00024-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00025-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00026-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00027-of-00030.bin:   0%|          | 0.00/4.66G [00:00<?, ?B/s]

pytorch_model-00028-of-00030.bin:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

pytorch_model-00029-of-00030.bin:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

pytorch_model-00030-of-00030.bin:   0%|          | 0.00/2.10G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs-us-1.hf.co/repos/e6/ef/e6ef22fa4d8cf7165fad5efd5bb7c80f909046e18d6d90669e7a2a38a85dc485/1d3e5e7ac799a39688899311c6a2c926dc6be33e2ef73bb38ddeff2a049b9057?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pytorch_model-00030-of-00030.bin%3B+filename%3D%22pytorch_model-00030-of-00030.bin%22%3B&response-content-type=application%2Foctet-stream&Expires=1734829787&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNDgyOTc4N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2U2L2VmL2U2ZWYyMmZhNGQ4Y2Y3MTY1ZmFkNWVmZDViYjdjODBmOTA5MDQ2ZTE4ZDZkOTA2NjllN2EyYTM4YTg1ZGM0ODUvMWQzZTVlN2FjNzk5YTM5Njg4ODk5MzExYzZhMmM5MjZkYzZiZTMzZTJlZjczYmIzOGRkZWZmMmEwNDliOTA1Nz9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=QAx4patNX9b2-KEYOJagGpouCwoIfPY8auqVpf9hsz4804-wXRg%7EIoh2bDGfFSuipsQtwDMY2rqSQ54ZOZU5gSB4So6X%7ECKf3VNQzhl0MHdHJ3wGoEqPdV95msMUKJtcRAETKohdNU3SxVJ2J

pytorch_model-00030-of-00030.bin:  51%|#####1    | 1.08G/2.10G [00:00<?, ?B/s]

: 