In [None]:
# Install all necessary libraries
!pip install -q "transformers>=4.41.2" "datasets" "accelerate" "bitsandbytes>=0.43.2" "peft" "trl"
!pip install -q langchain langchain_community pypdf sentence-transformers chromadb "gradio>=4.0.0"

import os
import torch
from google.colab import drive
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.docstore.document import Document
from datasets import load_dataset
import gradio as gr


# Mount Google Drive and Unzip the Fine-Tuned Model
print("Mounting Google Drive...")
drive.mount('/content/drive')

model_zip_path = '/content/drive/My Drive/phi3_finetuned_model.zip'
model_extract_path = './phi3_finetuned_model'

if os.path.exists(model_zip_path):
    if not os.path.exists(model_extract_path):
        print("Found model zip file in Google Drive. Unzipping...")
        !unzip -o "{model_zip_path}" -d .
        print("Model unzipped successfully.")
    else:
        print("Fine-tuned model folder already exists. Skipping unzip.")
else:
    print("--------------------------------------------------------------------------")
    print("ERROR: 'phi3_finetuned_model.zip' not found in your Google Drive.")
    print("Please upload the zip file to your main Google Drive directory and restart this cell.")
    print("--------------------------------------------------------------------------")


# Define the main function to set up the RAG chain
qa_chain = None # This global variable will hold our chatbot chain so we don't have to reload it.

def setup_chatbot():
    """
    This function performs all the heavy lifting: building the knowledge base
    and loading the fine-tuned model. It returns the ready-to-use QA chain.
    """
    global qa_chain
    if qa_chain is not None:
        print("Chatbot already initialized.")
        return qa_chain

    DATA_PATH = 'medical_data/'
    DB_PATH = 'chroma_db'
    os.makedirs(DATA_PATH, exist_ok=True)

    # Build the Vector Store from all sources
    if os.path.exists(DB_PATH):
        print("\nLoading existing ChromaDB database...")
        embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
        embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
        vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
    else:
        print("\nCreating new ChromaDB database from all sources...")
        all_docs = []
        # Load PDFs
        pdf_files = [f for f in os.listdir(DATA_PATH) if f.endswith('.pdf')]
        if pdf_files:
            loader = PyPDFDirectoryLoader(DATA_PATH)
            pdf_documents = loader.load()
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
            all_docs.extend(text_splitter.split_documents(pdf_documents))
        # Load Drug Dataset
        drug_dataset = load_dataset("MattBastar/Medicine_Details", split="train")
        for row in drug_dataset:
            content = (f"Medicine Name: {row.get('Medicine Name', 'N/A')}\n"
                       f"Composition: {row.get('Composition', 'N/A')}\n"
                       f"Uses: {row.get('Uses', 'N/A')}\n"
                       f"Side Effects: {row.get('Side effects', 'N/A')}\n"
                       f"Manufacturer: {row.get('Manufacturer', 'N/A')}\n"
                       f"Description: {row.get('Description', 'N/A')}")
            all_docs.append(Document(page_content=content, metadata={"source": f"drug_{row.get('Medicine Name', 'Unknown')}"}))

        if not all_docs:
            raise FileNotFoundError("No documents (PDFs or dataset) found to build the knowledge base.")

        embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
        embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
        vector_store = Chroma.from_documents(documents=all_docs, embedding=embeddings, persist_directory=DB_PATH)

    # Load the Fine-Tuned Model
    model_id = "./phi3_finetuned_model"
    if not os.path.exists(model_id):
        raise FileNotFoundError(f"The fine-tuned model directory '{model_id}' was not found.")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id, quantization_config=bnb_config, torch_dtype="auto",
        trust_remote_code=True, attn_implementation="eager", device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    # Assemble the QA Chain
    text_generation_pipeline = pipeline(
        "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512,
        do_sample=True, temperature=0.7, top_p=0.95,
    )
    llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
    prompt_template = '''
    ### INSTRUCTION:
    You are a specialized medical assistant. Your role is to provide clear and accurate answers based ONLY on the provided context.
    If the information is not in the context, state that you cannot answer based on the given documents.
    Do not use any prior knowledge.

    ### CONTEXT:
    {context}

    ### QUESTION:
    {question}

    ### RESPONSE:
    '''
    prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm, chain_type="stuff", retriever=vector_store.as_retriever(search_kwargs={'k': 3}),
        return_source_documents=True, chain_type_kwargs={"prompt": prompt}
    )

    return qa_chain

# Build the Chatbot and Create the Gradio UI
print("\nInitializing chatbot")
try:
    qa_chain = setup_chatbot()
    print(" Chatbot is ready!")

    # --- Create the Gradio UI ---
    def predict(message, history):
        print(f"Received message: {message}")

        result = qa_chain.invoke({"query": message})
        answer = result['result']
        sources = result.get('source_documents', [])

        if sources:
            source_list = [os.path.basename(s.metadata.get('source', 'N/A')) for s in sources]
            response_with_sources = f"{answer}\n\n*Sources: {', '.join(source_list)}*"
        else:
            response_with_sources = answer

        return response_with_sources

    # Launch the Gradio Chat Interface
    gr.ChatInterface(
        predict,
        title=" Medibot",
        description="Ask me medical queires and drug based questions",
        examples=[
            ["What are the symptoms of influenza?"],
            ["What is Paracetamol used for?"],
            ["What are the side effects of metformin?"]
        ],
        theme="soft"
    ).launch(share=True, debug=True) # share=True creates a public link, debug=True helps with errors

except Exception as e:
    print(f"\n An error occurred during setup: {e}")


Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Found model zip file in Google Drive. Unzipping...
Archive:  /content/drive/My Drive/phi3_finetuned_model.zip
   creating: ./phi3_finetuned_model/
  inflating: ./phi3_finetuned_model/tokenizer_config.json  
  inflating: ./phi3_finetuned_model/tokenizer.json  
  inflating: ./phi3_finetuned_model/adapter_model.safetensors  
  inflating: ./phi3_finetuned_model/adapter_config.json  
  inflating: ./phi3_finetuned_model/README.md  
  inflating: ./phi3_finetuned_model/chat_template.jinja  
  inflating: ./phi3_finetuned_model/added_tokens.json  
  inflating: ./phi3_finetuned_model/special_tokens_map.json  
  inflating: ./phi3_finetuned_model/training_args.bin  
  inflating: ./phi3_finetuned_model/tokenizer.model  
Model unzipped successfully.

Initializing chatbot... This may take a few minutes.

Loading existing ChromaDB database...


  vector_store = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)


config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

configuration_phi3.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi3.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

Device set to use cuda:0
  llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


 Chatbot is ready!


  self.chatbot = Chatbot(


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://4b6a7d7229c625a263.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Received message: What is Paracetamol used for?


The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/blocks.py", line 2191, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/blocks.py", line 1700, in call_function
    prediction = await fn(*processed_input)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/utils.py", line 861, in async_wrapper
 