In [1]:
# Install required libraries
!pip install unsloth datasets langchain faiss-gpu transformers sentence-transformers langchain-community gradio -q

# Import necessary modules
import torch
import gradio as gr
from datasets import load_dataset
from unsloth import FastLanguageModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from typing import List, Any, Optional
from pydantic import PrivateAttr

# Mount Google Drive for the model
from google.colab import drive
drive.mount('/content/drive')
drive_path = "/content/drive/My Drive/MedQA-Llama3.1-8B_LoRA_Model/lora_model"

# Load LoRA-tuned Llama model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=drive_path,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
    device_map="auto"
)
FastLanguageModel.for_inference(model)

# Load PubMedQA dataset
dataset = load_dataset("bigbio/pubmed_qa", name="pubmed_qa_labeled_fold0_source", split="train", trust_remote_code=True)
docs = [{"id": f"pubmed_qa_train_{i}", "text": ex["LONG_ANSWER"] or ""} for i, ex in enumerate(dataset)]

# Split dataset into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = [{"id": f"{d['id']}_chunk_{idx}", "text": chunk} for d in docs for idx, chunk in enumerate(text_splitter.split_text(d["text"]))]

# Create FAISS vector store
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
texts = [x["text"] for x in chunks]
metadatas = [{"source": x["id"]} for x in chunks]
db = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)

# Custom LLM class
class LoRAMedicalLLM(LLM):
    _model: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    max_new_tokens: int = 512
    device: str = "cuda"

    def __init__(self, model, tokenizer, max_new_tokens=256, device="cuda", **kwargs):
        super().__init__(**kwargs)
        self._model = model
        self._tokenizer = tokenizer
        self.max_new_tokens = max_new_tokens
        self.device = device

    @property
    def _llm_type(self) -> str:
        return "lora-medical-llm"

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        inp = self._tokenizer([prompt], return_tensors="pt", padding=True, truncation=True, max_length=2048).to(self.device)
        with torch.no_grad():
            out_toks = self._model.generate(**inp, max_new_tokens=self.max_new_tokens, use_cache=True)
        return self._tokenizer.decode(out_toks[0], skip_special_tokens=True)

llm = LoRAMedicalLLM(model=model, tokenizer=tokenizer)

# Define prompt and RAG chain
prompt_template = """You are a medical QA system.
Use the following context to answer the question concisely without repeating the context or any lines in the final answer:
{context}

Question: {question}

Answer:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
rag = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, chain_type_kwargs={"prompt": prompt})

# Define the function for Gradio interface
def medical_qa(query):
    response = rag.invoke({"query": query})
    raw_answer = response["result"]
    final_answer = raw_answer.split("Answer:")[-1].strip() if "Answer:" in raw_answer else raw_answer.strip()

    # Retrieve sources
    retrieved_docs = retriever.get_relevant_documents(query)
    sources = "\n\n".join([f"Source {i+1}: {d.page_content}" for i, d in enumerate(retrieved_docs)])

    return f"Answer:\n{final_answer}\n\nSources:\n{sources}"

# Set up Gradio interface
interface = gr.Interface(
    fn=medical_qa,
    inputs=gr.Textbox(label="Enter your medical question"),
    outputs=gr.Textbox(label="Response and Sources"),
    title="Medical QA System",
    description="Ask any medical question, and get a concise answer along with relevant sources."
)

# Launch Gradio app
interface.launch(share=True)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.4/60.4 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.4/177.4 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m65.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.5/57.5 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.6/320.6 kB[0m [31m403.1 kB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

Unsloth 2025.1.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


README.md:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

pubmed_qa.py:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

bigbiohub.py:   0%|          | 0.00/19.3k [00:00<?, ?B/s]

pqal.zip:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://0a8af4eb40e1238229.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


