In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install sentence-transformers faiss-cpu PyPDF2 tqdm

In [2]:
# ============================================================================
# CAG INFERENCE - LOAD EXISTING CACHE AND QUERY
# Use this when your cache folder already exists
# ============================================================================

import torch
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from unsloth import FastLanguageModel
from transformers import TextStreamer
import pickle
from pathlib import Path
from typing import List, Dict, Tuple
import re

print("="*80)
print("CAG INFERENCE MODE - Loading Saved Cache")
print("="*80 + "\n")

# ============================================================================
# STEP 1: CONFIGURE YOUR PATHS (EDIT THESE)
# ============================================================================

# Path to your fine-tuned model
MODEL_PATH = "/content/drive/MyDrive/MED-MCP/TRAIN-LLM/medical_triage_lora"

# Path to your existing cache directory (folder that contains cache files)
CACHE_DIR = "/content/drive/MyDrive/MED-MCP/CAG/medical_cache_large"

# Cache filename (if you used default, keep as is)
CACHE_FILENAME = "medical_cache_large.pkl"

# FAISS index filename (if you used default, keep as is)
FAISS_INDEX_FILENAME = "faiss_index_large.bin"

print(f"Model Path: {MODEL_PATH}")
print(f"Cache Directory: {CACHE_DIR}")
print(f"Cache File: {CACHE_FILENAME}")
print(f"FAISS Index: {FAISS_INDEX_FILENAME}\n")

# ============================================================================
# STEP 2: LOAD MODEL AND EMBEDDING MODEL
# ============================================================================

print("Loading fine-tuned medical triage model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_PATH,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)
print("✓ Model loaded successfully!\n")

print("Loading embedding model...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("✓ Embedding model loaded successfully!\n")

# ============================================================================
# STEP 3: LOAD CACHE FROM DISK
# ============================================================================

cache_dir = Path(CACHE_DIR)
cache_path = cache_dir / CACHE_FILENAME
index_path = cache_dir / FAISS_INDEX_FILENAME

# Check if cache exists
if not cache_path.exists():
    raise FileNotFoundError(f"Cache file not found: {cache_path}")
if not index_path.exists():
    raise FileNotFoundError(f"FAISS index not found: {index_path}")

print(f"Loading cache from {cache_path}...")
with open(cache_path, 'rb') as f:
    cache_data = pickle.load(f)

chunks = cache_data['chunks']
chunk_metadata = cache_data['metadata']
chunk_size = cache_data.get('chunk_size', 500)
chunk_overlap = cache_data.get('chunk_overlap', 50)

print(f"Loading FAISS index from {index_path}...")
index = faiss.read_index(str(index_path))

print("\n" + "="*80)
print("CACHE LOADED SUCCESSFULLY!")
print("="*80)
print(f"✓ Total chunks: {len(chunks)}")
print(f"✓ Total documents: {len(set(m['filename'] for m in chunk_metadata))}")
print(f"✓ FAISS index size: {index.ntotal} vectors")
print(f"✓ Chunk size: {chunk_size} characters")
print(f"✓ Chunk overlap: {chunk_overlap} characters")

# Show sample documents
unique_files = list(set(m['filename'] for m in chunk_metadata))
print(f"\nDocuments in cache (showing first 10):")
for i, filename in enumerate(unique_files[:10], 1):
    print(f"  {i}. {filename}")
if len(unique_files) > 10:
    print(f"  ... and {len(unique_files) - 10} more")

print("="*80 + "\n")

# ============================================================================
# STEP 4: DEFINE INFERENCE FUNCTIONS
# ============================================================================

def retrieve_context(query: str, top_k: int = 5) -> List[Tuple[str, Dict, float]]:
    """
    Retrieve most relevant chunks from cache

    Args:
        query: Query text (patient case description)
        top_k: Number of relevant chunks to retrieve

    Returns:
        List of (chunk_text, metadata, distance_score) tuples
    """
    if index is None or len(chunks) == 0:
        print("Warning: No documents in cache")
        return []

    # Embed the query
    query_embedding = embedding_model.encode([query], convert_to_numpy=True).astype('float32')

    # Search FAISS index
    search_k = min(top_k * 3, len(chunks))
    distances, indices = index.search(query_embedding, search_k)

    # Apply diversity filtering (avoid too many chunks from same document)
    results = []
    seen_sources = set()

    for idx, dist in zip(indices[0], distances[0]):
        if idx < len(chunks):
            chunk = chunks[idx]
            meta = chunk_metadata[idx]
            source = meta['filename']

            # Encourage diversity
            if source in seen_sources and len(results) >= top_k // 2:
                continue

            results.append((chunk, meta, float(dist)))
            seen_sources.add(source)

            if len(results) >= top_k:
                break

    return results


def generate_with_context(
    user_query: str,
    top_k: int = 5,
    max_new_tokens: int = 1200,
    temperature: float = 1.5,
    max_context_length: int = 1800,
    verbose: bool = True
):
    """
    Generate medical triage response with retrieved context

    Args:
        user_query: Natural language patient case description
        top_k: Number of context chunks to retrieve
        max_new_tokens: Maximum tokens to generate
        temperature: Generation temperature (higher = more creative)
        max_context_length: Maximum characters per context chunk
        verbose: Print detailed output

    Returns:
        Generated response text
    """
    # Retrieve relevant context
    if verbose:
        print("Retrieving relevant medical context...")

    context_results = retrieve_context(user_query, top_k=top_k)

    # Format context
    context_text = ""
    if context_results:
        context_text = "\n\nRelevant Medical Guidelines:\n"
        for i, (chunk, meta, score) in enumerate(context_results, 1):
            source = meta.get('filename', 'Unknown')
            chunk_id = meta.get('chunk_id', '')

            # Truncate long chunks
            display_chunk = chunk[:max_context_length]
            if len(chunk) > max_context_length:
                display_chunk += "..."

            context_text += f"\n[Source {i}: {source} (section {chunk_id})]\n{display_chunk}\n"

    # Construct final prompt
    full_prompt = context_text + "\n" + user_query + "\n\nBased on the medical guidelines above, provide specialist referral recommendation with clinical reasoning."

    # Create messages

    # messages = [
    #     {
    #         "role": "system",
    #         "content": "You are a medical triage specialist. Analyze patient symptoms using the provided medical guidelines and recommend appropriate specialist referrals with urgency levels and clinical reasoning."
    #     },


    messages = [
    {
        "role": "system",
        "content": """You are a medical triage specialist.

        Urgency Levels:
        - EMERGENCY: Life-threatening, needs immediate intervention (minutes)
        - URGENT: Serious, needs evaluation within 24 hours
        - ROUTINE: Non-urgent, can wait days-weeks

        Red flags for EMERGENCY:
        - Worst headache of life
        - Sudden neurological deficits
        - Chest pain with cardiac symptoms
        - Active bleeding
        - Loss of consciousness
        - Severe difficulty breathing"""
        },

        {
            "role": "user",
            "content": full_prompt
        }
    ]

    # Tokenize
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to("cuda")

    # Print case info
    if verbose:
        print("="*80)
        print("PATIENT CASE:")
        print("-"*80)
        print(user_query)
        print("-"*80)

        if context_results:
            print("RETRIEVED CONTEXT:")
            print("-"*80)
            for i, (_, meta, score) in enumerate(context_results, 1):
                source = meta.get('filename', 'Unknown')
                chunk_id = meta.get('chunk_id', 0)
                print(f"  {i}. {source} [chunk {chunk_id}] (distance: {score:.4f})")
            print("-"*80)

        print("MODEL RESPONSE:")
        print("-"*80)

    # Generate response
    text_streamer = TextStreamer(tokenizer, skip_prompt=True)

    output = model.generate(
        input_ids=inputs,
        streamer=text_streamer if verbose else None,
        max_new_tokens=max_new_tokens,
        use_cache=True,
        temperature=temperature,
        min_p=0.1
    )

    if verbose:
        print("="*80 + "\n")

    # Return generated text
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response




Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
CAG INFERENCE MODE - Loading Saved Cache

Model Path: /content/drive/MyDrive/MED-MCP/TRAIN-LLM/medical_triage_lora
Cache Directory: /content/drive/MyDrive/MED-MCP/CAG/medical_cache_large
Cache File: medical_cache_large.pkl
FAISS Index: faiss_index_large.bin

Loading fine-tuned medical triage model...
==((====))==  Unsloth 2025.12.8: Fast Llama patching. Transformers: 4.56.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

Unsloth 2025.12.8 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


✓ Model loaded successfully!

Loading embedding model...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✓ Embedding model loaded successfully!

Loading cache from /content/drive/MyDrive/MED-MCP/CAG/medical_cache_large/medical_cache_large.pkl...
Loading FAISS index from /content/drive/MyDrive/MED-MCP/CAG/medical_cache_large/faiss_index_large.bin...

CACHE LOADED SUCCESSFULLY!
✓ Total chunks: 3006
✓ Total documents: 2
✓ FAISS index size: 3006 vectors
✓ Chunk size: 500 characters
✓ Chunk overlap: 50 characters

Documents in cache (showing first 10):
  1. swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf
  2. disease_n_treatments.pdf



In [3]:
# # ============================================================================
# # STEP 5: QUICK TEST QUERY (Optional - Run to verify system works)
# # ============================================================================

# print("Running test query to verify system...\n")

# test_query = "55-year-old male with chest pain radiating to left arm, sweating, and shortness of breath for 3 days. Patient is diabetic and smoker. Pain is 7/10."

# generate_with_context(
#     user_query=test_query,
#     top_k=5,
#     verbose=True
# )

# print("\n" + "="*80)
# print("✓ SYSTEM READY FOR INFERENCE!")
# print("="*80)
# print("\nYou can now use the generate_with_context() function for new queries.")
# print("Example usage shown below...")

Running test query to verify system...

Retrieving relevant medical context...


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


PATIENT CASE:
--------------------------------------------------------------------------------
55-year-old male with chest pain radiating to left arm, sweating, and shortness of breath for 3 days. Patient is diabetic and smoker. Pain is 7/10.
--------------------------------------------------------------------------------
RETRIEVED CONTEXT:
--------------------------------------------------------------------------------
  1. swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf [chunk 871] (distance: 0.7069)
  2. disease_n_treatments.pdf [chunk 496] (distance: 0.8940)
--------------------------------------------------------------------------------
MODEL RESPONSE:
--------------------------------------------------------------------------------
**Specialist Referral Assessment**

**Primary Specialty:** Emergency Medicine
**Urgency Level:** Urgent
**Confidence Score:** 0.85

**Secondary Specialty:** Cardiology

**Critical Questions to Ask:**
• Has patient ever experienced cardiac symp

In [4]:
# ============================================================================
# STEP 6: YOUR CUSTOM QUERIES (EDIT THIS SECTION)
# ============================================================================

print("\n\n" + "="*80)
print("CUSTOM QUERY EXAMPLE")
print("="*80 + "\n")


USER_QUERY = """
I have a 42-year-old female patient presenting with severe headache,
visual disturbances, and nausea. She has a history of hypertension.
The headache started suddenly 2 hours ago and is the worst she's ever had.
She rates it 9/10. She also mentions neck stiffness.
"""

# Generate response
generate_with_context(
    user_query=USER_QUERY.strip(),
    top_k=5,
    max_new_tokens=1024,
    temperature=1.5,
    verbose=True
)



CUSTOM QUERY EXAMPLE

Retrieving relevant medical context...
PATIENT CASE:
--------------------------------------------------------------------------------
I have a 42-year-old female patient presenting with severe headache, 
visual disturbances, and nausea. She has a history of hypertension. 
The headache started suddenly 2 hours ago and is the worst she's ever had. 
She rates it 9/10. She also mentions neck stiffness.
--------------------------------------------------------------------------------
RETRIEVED CONTEXT:
--------------------------------------------------------------------------------
  1. swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf [chunk 140] (distance: 0.8715)
  2. swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf [chunk 146] (distance: 0.9457)
  3. disease_n_treatments.pdf [chunk 1754] (distance: 0.9617)
--------------------------------------------------------------------------------
MODEL RESPONSE:
----------------------------------------------

"system\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\nYou are a medical triage specialist. Analyze patient symptoms using the provided medical guidelines and recommend appropriate specialist referrals with urgency levels and clinical reasoning.user\n\n\n\nRelevant Medical Guidelines:\n\n[Source 1: swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf (section 140)]\nng condition that might cause secondary headaches. Look for red-flag symptoms and signs, and/or conditions, and refer the patient to a hospital for further assessment: ■Focal neurological signs (e.g., motor, sensory, visual disturbances, loss of balance) ■Consciousness ■Fever and chills ■Seizures ■Nuchal rigidity or other meningism ■Papilloedema, pre-retinal or retinal hemorrhage ■History of bleeding diathesis, hypercoagulable state, cancer, HIV/AIDS, autoimmune disorders, illicit drug abusers P\n\n[Source 2: swz-mn-78-01-guideline-2012-eng-swaziland-stg-booklet.pdf (section 146)]\n■Fever and chi