#**HuatuoGPT-o1 Medical RAG and Reasoning**

##**Notebook Setup**

In [None]:
!pip install transformers datasets sentence-transformers scikit-learn --upgrade -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.2/69.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.9/275.9 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m43.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

##**Load the Dataset**

In [None]:
from datasets import load_dataset

dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/542 [00:00<?, ?B/s]

(…)-00000-of-00001-5e7cb295b9cff0bf.parquet:   0%|          | 0.00/70.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112165 [00:00<?, ? examples/s]

##**Step 3: Initialize the Models**

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer

# Initialize HuatuoGPT-o1
model_name = "FreedomIntelligence/HuatuoGPT-o1-7B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Initialize Sentence Transformer
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

##**Prepare the Knowledge Base**

In [None]:
import pandas as pd
import numpy as np

# Convert dataset to DataFrame
df = pd.DataFrame(dataset["train"])

# Combine question and answer for context
df["combined"] = df["input"] + " " + df["output"]

# Generate embeddings
print("Generating embeddings for the knowledge base...")
embeddings = embed_model.encode(df["combined"].tolist(), show_progress_bar=True, batch_size=128)
print("Embeddings generated!")

##**Implement Retrieval**

In [None]:
from sklearn.metrics.pairwise import cosine_similarity


def retrieve_relevant_contexts(query: str, k: int = 3) -> list:
    """
    Retrieves the k most relevant contexts to a given query.

    Args:
        query (str): The user's medical query.
        k (int): The number of relevant contexts to retrieve.

    Returns:
        list: A list of dictionaries, each containing a relevant context.
    """
    # Generate query embedding
    query_embedding = embed_model.encode([query])[0]

    # Calculate similarities
    similarities = cosine_similarity([query_embedding], embeddings)[0]

    # Get top k similar contexts
    top_k_indices = np.argsort(similarities)[-k:][::-1]

    contexts = []
    for idx in top_k_indices:
        contexts.append(
            {
                "question": df.iloc[idx]["input"],
                "answer": df.iloc[idx]["output"],
                "similarity": similarities[idx],
            }
        )

    return contexts

##**Implement Response Generation**

In [None]:
def generate_structured_response(query: str, contexts: list) -> str:
    """
    Generates a detailed response using the retrieved contexts.

    Args:
        query (str): The user's medical query.
        contexts (list): A list of relevant contexts.

    Returns:
        str: The generated response.
    """
    # Prepare prompt with retrieved contexts
    context_prompt = "\n".join(
        [
            f"Reference {i+1}:" f"\nQuestion: {ctx['question']}" f"\nAnswer: {ctx['answer']}"
            for i, ctx in enumerate(contexts)
        ]
    )

    prompt = f"""Based on the following references and your medical knowledge, provide a detailed response:

References:
{context_prompt}

Question: {query}

By considering:
1. The key medical concepts in the question.
2. How the reference cases relate to this question.
3. What medical principles should be applied.
4. Any potential complications or considerations.

Give the final response:
"""

    # Generate response
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer(
        tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.7,
        num_beams=1,
        do_sample=True,
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the final response portion
    final_response = response.split("Give the final response:\n")[-1]

    return final_response

##**Putting It All Together**

In [None]:
def process_query(query: str, k: int = 3) -> tuple:
    """
    Processes a medical query end-to-end.

    Args:
        query (str): The user's medical query.
        k (int): The number of relevant contexts to retrieve.

    Returns:
        tuple: The generated response and the retrieved contexts.
    """
    contexts = retrieve_relevant_contexts(query, k)
    response = generate_structured_response(query, contexts)
    return response, contexts


# Example query
query = "I've been experiencing persistent headaches and dizziness for the past week. What could be the cause?"

# Process query
response, contexts = process_query(query)

# Print results
print("\nQuery:", query)
print("\nRelevant Contexts:")
for i, ctx in enumerate(contexts, 1):
    print(f"\nReference {i} (Similarity: {ctx['similarity']:.3f}):")
    print(f"Q: {ctx['question']}")
    print(f"A: {ctx['answer']}")

print("\nGenerated Response:")
print(response)