## Settings

In [None]:
# Import necessary modules from the HuggingFace transformers library for NLP tasks and model handling
from transformers import ( 
    AutoTokenizer,                # For tokenizing input text for transformer models
    AutoModelForCausalLM,         # For loading causal language models (e.g., GPT-style)
    BitsAndBytesConfig,           # For quantization and memory-efficient model loading
    TrainingArguments,            # For specifying training hyperparameters
    AutoModel,                    # For loading generic transformer models
    Trainer,                      # For training transformer models
    pipeline,                     # For easy-to-use inference pipelines
    DataCollatorForSeq2Seq,       # For data collation in sequence-to-sequence tasks
    AutoModelForSeq2SeqLM         # For loading sequence-to-sequence language models (e.g., T5, BART)
)

import polars as pl                # High-performance DataFrame library, alternative to pandas
import warnings                    # For controlling warning messages
from tqdm import tqdm              # For progress bars in loops
import glob, os                    # For file and directory operations
from PyPDF2 import PdfReader       # For reading PDF files
from sklearn.model_selection import train_test_split  # For splitting datasets
import pandas as pd                # Data analysis and manipulation tool
import numpy as np                 # Numerical computing library
import torch                       # PyTorch for deep learning
import torch._dynamo               # For PyTorch optimization and compilation
from datasets import Dataset       # HuggingFace datasets library for handling datasets
import json                        # For JSON file operations
import faiss                       # For efficient similarity search and clustering of dense vectors
import re                          # For regular expressions
from peft import (                 # Parameter-Efficient Fine-Tuning (PEFT) utilities
    LoraConfig, 
    get_peft_model, 
    TaskType, 
    PeftModel, 
    prepare_model_for_kbit_training
)
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction  # For BLEU score calculation (NLP evaluation)
from rouge_score import rouge_scorer                                 # For ROUGE score calculation (NLP evaluation)
from pathlib import Path                                             # For object-oriented filesystem paths
from sentence_transformers import CrossEncoder                       # For cross-encoder models (sentence similarity)
from sklearn.metrics.pairwise import cosine_similarity as skl_cosine # For cosine similarity between vectors
from typing import List, Tuple                                       # For type hinting

# Function to split text into overlapping chunks for processing (e.g., for LLM context windows)
def chunk_text(text, max_length=500, overlap=50):
    sentences = text.split(". ")  # Naive sentence splitting by period
    chunks = []
    current_chunk = ""
    for sent in sentences:
        # If adding the next sentence does not exceed max_length, append it
        if len(current_chunk) + len(sent) + 2 <= max_length:
            current_chunk += sent + ". "
        else:
            # Otherwise, save the current chunk and start a new one with overlap
            chunks.append(current_chunk.strip())
            overlap_words = current_chunk.split()[-overlap:]  # Get last 'overlap' words
            current_chunk = " ".join(overlap_words) + " " + sent + ". "
    if current_chunk:
        chunks.append(current_chunk.strip())
    return chunks

# Function to compute cosine similarity between two vectors (e.g., sentence embeddings)
def cosine_similarity(a, b):
    a_norm = a / (np.linalg.norm(a) + 1e-8)  # Normalize vector a
    b_norm = b / (np.linalg.norm(b) + 1e-8)  # Normalize vector b
    return float(np.dot(a_norm, b_norm))      # Return cosine similarity

# Suppress all warnings for cleaner output
warnings.filterwarnings("ignore")
# Suppress PyTorch Dynamo errors (for experimental/unstable features)
torch._dynamo.config.suppress_errors = True

# Model name for a lightweight LLM suitable for medical tasks and Tesla T4 GPU environment
MODEL_NAME = "Intelligent-Internet/II-Medical-8B-1706"

## Finetuning

In [None]:
'''
# (Optional) Example code for filtering document types related to infection
doc_type = pl.read_excel('datasets_folder/SNUHNOTE/1.0.1/document_type_level_mst.xlsx')
doc_type = doc_type.filter(
    (pl.col("doc_type_id") == "D008") & 
    (pl.col("mdfm_name").str.contains("감염"))
)
# mdfm_id = 41813, 41814, 42345 refers to 타과의뢰회신 (consultation requests & replies)
'''

dfs = []

# Loop over 7 Excel files containing consultation data
for i in tqdm(range(1, 8)):
    # Read each file into a Polars DataFrame
    df = pl.read_excel(f'datasets_folder/SNUHNOTE/1.0.1/4_DGNS/DGNS_{i}.xlsx')

    # --- Select only 타과의뢰회신 (consultation replies, not requests) passages ---
    # Filter rows where level_path indicates 타과의뢰회신, but exclude short approval messages and irrelevant content
    df1 = df.filter(
        ((pl.col("level_path").str.contains("41813"))|(pl.col("level_path").str.contains("41814"))|(pl.col("level_path").str.contains("42345"))) &
        ~(pl.col("content")=='승인하였습니다.\n') & ~(pl.col("content")=='승인하였습니다. \n') &  ~(pl.col("content")=='승인하였습니다. \n\n') & ~(pl.col("content")=='승인하였습니다.\n\n') &
        ~(pl.col("content")=='승인하였습니다.\n\n\n') & ~(pl.col("content")=='승인하였습니다. \n\n\n') &
        ~(pl.col("content")=='승인하겠습니다.\n') & ~(pl.col("content")=='승인하겠습니다. \n') &  ~(pl.col("content")=='승인하겠습니다. \n\n') & ~(pl.col("content")=='승인하겠습니다.\n\n') &
        ~(pl.col("content")=='승인하겠습니다.\n\n\n') & ~(pl.col("content")=='승인하겠습니다. \n\n\n') &
        ~(pl.col("content")=='투약 승인하였습니다.\n') & ~(pl.col("content")=='투약 승인하였습니다. \n') & ~(pl.col("content")=='투약 승인하였습니다. \n\n') &~(pl.col("content")=='투약 승인하였습니다.\n\n') &
        ~(pl.col("content").str.contains("문의")) & ~(pl.col("content").str.contains("의뢰")) &~(pl.col("content").str.contains("취소된 처방")) &
        ((pl.col("content").str.contains("추천"))|(pl.col("content").str.contains("권장"))|(pl.col("content").str.contains("승인"))|(pl.col("content").str.contains("고려"))|
        (pl.col("content").str.contains("니다")|(pl.col("content").str.contains("권고"))))
    )
    # Sort by patient ID and record date
    df1 = df1.sort(["nid", "rec_dt_offset"])
    # Group by patient ID, aggregate first record date, first level_path, and concatenate all content for that patient
    df1 = df1.group_by("nid").agg([
        pl.col("rec_dt_offset").first().alias("rec_dt_offset"),
        pl.col("level_path").first().alias("level_path"),
        pl.col("content").implode().alias("content_list")
    ])
    # Join all content into a single string per patient
    df1 = df1.with_columns(
        pl.col("content_list").list.join(" ").alias("content")
    ).select(["nid", "rec_dt_offset", "content"])

    # --- Select only 타과의뢰 (consultation requests, not replies) passages ---
    # Filter rows for replies containing certain keywords, but exclude those with recommendation/approval phrases
    df2 = df.filter(
        ((pl.col("level_path").str.contains("41813"))|(pl.col("level_path").str.contains("41814"))|(pl.col("level_path").str.contains("42345"))) &
        ((pl.col("content").str.contains("문의"))|(pl.col("content").str.contains("상의"))|(pl.col("content").str.contains("의뢰"))) &
        ~((pl.col("content").str.contains("추천"))|(pl.col("content").str.contains("권장"))|(pl.col("content").str.contains("승인"))|(pl.col("content").str.contains("고려"))|
        (pl.col("content").str.contains("바랍니다"))|(pl.col("content").str.contains("의뢰주셔서"))|(pl.col("content").str.contains("의뢰 감사"))|(pl.col("content").str.contains("권고"))|
        (pl.col("content").str.contains("분 이내"))|(pl.col("content").str.contains("해당없음"))|(pl.col("content").str.contains("니다")) )
    )

    # Sort and group by patient ID, aggregate all record dates and content into lists
    df2 = df2.sort(["nid", "rec_dt_offset"])
    df2 = df2.group_by("nid").agg([
        pl.col("rec_dt_offset").alias("rec_dt_offset_list"),
        pl.col("content").alias("content_list")
    ])
    
    # Explode lists so each row is a single (nid, rec_dt_offset, content) pair
    df2_long = (
        df2
        .select(["nid", "rec_dt_offset_list", "content_list"])
        .explode(["rec_dt_offset_list", "content_list"])
        .rename({"rec_dt_offset_list": "df2_rec_dt_offset",
                 "content_list":        "cst"})          
        .sort(["nid", "df2_rec_dt_offset"])
    )

    # --- Join consultation requests and replies by patient and date ---
    # For each request, find the most recent reply (backward join) for the same patient
    final_df = (
        df1.sort(["nid", "rec_dt_offset"])               
           .join_asof(                                   
                df2_long,
                left_on="rec_dt_offset",
                right_on="df2_rec_dt_offset",
                by="nid",
                strategy="backward"                      
           )
           .select(["nid", "rec_dt_offset", "content", "cst"])
    )
    dfs.append(final_df)

# Concatenate all dataframes from each file
x1 = pl.concat(dfs)
# For each patient, keep only the first record (earliest rec_dt_offset)
x1 = x1.sort("rec_dt_offset").group_by("nid").first()
# Remove rows where there is no reply content
x1 = x1.filter(pl.col("cst").is_not_null())

# Function to remove the name of the replier at the end of each 타과의뢰회신
def remove_im_to_baesang(text):
    # Remove text between 'IM' and '배상' (including both), which is a signature pattern
    return re.sub(r'IM((?:(?!배상).)*?)배상', '', text, flags=re.DOTALL)

# Apply the signature removal function to the 'content' column
x1 = x1.with_columns(
    pl.col('content').map_elements(remove_im_to_baesang).alias('content')
)

qa_pairs = []

# Generate QA pairs for fine-tuning: question = prompt + attending message, answer = reply content
for row in tqdm(x1.iter_rows(named=True), total=x1.height, desc="Generating QA pairs"):
    question = (
        "@@@ Task: \n"
        "You are an infectious-disease consultant.  Compose ONE brief consultation note as the [@@@ Answer] that obeys **all** rules 1~4 below:\n"
        "1. Respond including 3 categories below:\n"
        "   • (1) suspected / confirmed pathogen (if any)\n"
        "   • (2) recommended medications (if any)\n"
        "   • (3) additional labs / cultures (if any)\n"
        "2. You must NOT repeat any phrase or idea in each sentence.\n"
        "3. Do not copy ANY text from [@@@ Attending physician’s message].\n"
        "4. Write in complete sentences and ALWAYS finish each sentence with a period.\n\n"
        "@@@ Attending physician’s message\n"
        f"{row['cst']}\n\n"
        "@@@ Answer:"
    )
    answer = row["content"]
    qa_pairs.append({
        "question": question,
        "answer": answer
    })

# Save the generated QA pairs to a JSON file for later use in LLM fine-tuning
with open("qa_pairs.json", "w", encoding="utf-8") as f:
    json.dump(qa_pairs, f, ensure_ascii=False, indent=2)

In [None]:
# Set the model name for NLLB-200, a multilingual translation model
MODEL_NAME = "facebook/nllb-200-3.3B"

# Load the tokenizer for the model, specifying Korean as the source language
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, src_lang="kor_Hang")

# Load the translation model and set it to evaluation mode (no training)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).eval()

# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def translate_ko2en(text: str, max_len: int = 512, num_beams: int = 6) -> str:
    """
    Translate Korean text to English using the NLLB-200 model.
    Args:
        text (str): Input Korean text.
        max_len (int): Maximum length of the output sequence.
        num_beams (int): Number of beams for beam search (not used here, but can be added for better translation).
    Returns:
        str: Translated English text.
    """
    # Tokenize the input text and move tensors to the correct device
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_len,
    ).to(device)

    # Generate translation using the model, forcing the output language to English
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"),  # Set output language to English
            max_length=max_len,
        )

    # Decode the generated tokens to a string and remove special tokens/whitespace
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    
# Define source and destination JSON file paths
SRC_JSON = Path("qa_pairs.json")
DST_JSON = Path("qa_pairs_en.json")

# Load the original QA pairs (in Korean) from the source JSON file
with SRC_JSON.open("r", encoding="utf-8") as f:
    data = json.load(f)

# Translate the 'answer' field of each QA pair from Korean to English
translated = []
for item in tqdm(data, desc="Translating 'question' and 'answer' fields"):
    new_item = item.copy()  # Keep the original item unchanged
    if isinstance(item.get("answer"), str):
        # Translate the answer text
        translated_answer = translate_ko2en(item["answer"])
        # Fix common mistranslation: replace 'investment' with 'administration' if it appears as a word
        translated_answer = re.sub(r'(?<=\S)investment(?=\S)', 'administration', translated_answer, flags=re.IGNORECASE)
        new_item["answer"] = translated_answer
    translated.append(new_item)

# Save the translated QA pairs to the destination JSON file
with DST_JSON.open("w", encoding="utf-8") as f:
    json.dump(translated, f, ensure_ascii=False, indent=2)
print(f"Saved translated data ➜ {DST_JSON.resolve()}")

In [None]:
# 1. Settings for LoRA Fine-tuning

# Configure 4-bit quantization for memory-efficient model loading
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,            
    bnb_4bit_quant_type="nf4",    # Use NormalFloat4 quantization
    bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for computation
    bnb_4bit_use_double_quant=True # Double quantization for further compression
)

# Load the base model with quantization config and your HuggingFace API key
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",            
    quantization_config=bnb_cfg,
    token = "YOUR_API_KEY" # Replace with your own HuggingFace API key
)

# Prepare the model for k-bit (quantized) training
model = prepare_model_for_kbit_training(model)

# Configure LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning
lora_cfg = LoraConfig(
    r=16,  # Rank of the LoRA update matrices
    lora_alpha=32,  # Scaling factor
    target_modules=["q_proj","k_proj","v_proj","o_proj"],  # Target attention projection layers
    lora_dropout=0.05,  # Dropout for LoRA layers
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA to the model
model = get_peft_model(model, lora_cfg)

# Load the tokenizer (again, with API key)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token="YOUR_API_KEY")

# Set device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Print the number of trainable parameters (should be much less than full model)
model.print_trainable_parameters()

# Enable TensorFloat-32 (TF32) for faster training on supported GPUs
if torch.cuda.is_available() and hasattr(torch, "compile"):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# 2. Split English-translated QA pairs into training and test sets
# Training set (90%): for LoRA fine-tuning
# Test set (10%): for evaluation

with open("qa_pairs_en.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)

prompts, labels = [], []
for sample in raw_data:
    # Format each question and answer for supervised fine-tuning
    prompts.append(f"Question: {sample['question']}\n")
    labels.append(" " + str(sample["answer"]))

# Split into train and test sets (90% train, 10% test)
train_p, test_p, train_l, test_l = train_test_split(prompts, labels, test_size=0.1, random_state=42)
train_dataset = Dataset.from_dict({"prompt": train_p, "completion": train_l})

# 3. Initialize Tokenizer and tokenize the sentences

# Load tokenizer again (with auth token for access)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=True)
# Ensure pad token is set (required for some models)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MAX_LEN = 256  # Maximum sequence length

def tokenize_fn(example):
    # Concatenate prompt and completion for full input
    full_text = example["prompt"] + example["completion"]
    # Tokenize the full text
    tok_full = tokenizer(full_text, truncation=True, max_length=MAX_LEN, padding=False)
    input_ids = tok_full["input_ids"]
    labels = input_ids.copy()
    # Mask out the prompt part in the labels so loss is only computed on the answer
    prompt_len = len(tokenizer(example["prompt"], truncation=True, max_length=MAX_LEN)["input_ids"])
    for i in range(min(prompt_len, len(labels))):
        labels[i] = -100  # -100 is ignored by the loss function
    return {"input_ids": input_ids, "attention_mask": tok_full["attention_mask"], "labels": labels}

# Tokenize the training dataset
train_tok = train_dataset.map(tokenize_fn, remove_columns=["prompt", "completion"])

# Data collator for dynamic padding and batch preparation
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")

# 4. Training arguments for LoRA fine-tuning

training_args = TrainingArguments(
    output_dir="lora_rag",  # Directory to save model checkpoints
    num_train_epochs=2,     # Number of training epochs
    per_device_train_batch_size=2,  # Batch size per device
    gradient_accumulation_steps=1,  # Accumulate gradients over this many steps
    learning_rate=1e-4,     # Learning rate
    fp16=torch.cuda.is_available(), # Use mixed precision if on GPU
    logging_steps=50,       # Log every 50 steps
    save_strategy="epoch",  # Save checkpoint at the end of each epoch
    save_total_limit=2,     # Keep only the last 2 checkpoints
    lr_scheduler_type="cosine", # Use cosine learning rate scheduler
    report_to="none",       # Do not report to any tracking system
    label_names=[],         # No custom label names
    remove_unused_columns=False,  # Keep all columns in the dataset
)

# 5. Train the model and save the weights in "lora_rag_weight"

# Initialize the Trainer for supervised fine-tuning
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    data_collator=data_collator
)

# Start training
trainer.train()

# Save the fine-tuned model weights
trainer.save_model("lora_rag_weight")

'''
# Training Loss Results
Step	Training Loss
50	3.748400
100	1.965000
150	2.850100
200	1.830300
250	0.645100
'''

## RAG

In [None]:
# The following code assumes that 100 UpToDate PDF articles (focused on "Infectious Diseases" and relevant topics)
# have already been downloaded and placed in the "rag/" folder.

rag_folder = "rag/"  # Directory containing the PDF files
pdf_files = glob.glob(os.path.join(rag_folder, "*.pdf"))  # List all PDF files in the folder
records = []  # This will store the extracted data from each PDF

# Iterate through each PDF file and extract its text content
for pdf_path in tqdm(pdf_files, desc="Parsing PDFs"):
    try:
        reader = PdfReader(pdf_path)  # Initialize the PDF reader
        text = ""
        # Extract text from each page in the PDF
        for page in reader.pages:
            # Some pages may not have extractable text, so use empty string as fallback
            text += page.extract_text() or ""
        text = text.strip()  # Remove leading/trailing whitespace
        if text:
            # Store the filename and extracted text if text is not empty
            records.append({
                "title": os.path.basename(pdf_path),
                "main_text": text
            })
    except Exception as e:
        # Print an error message if parsing fails for a file
        print(f"Failed to parse {pdf_path}: {e}")

# Create a DataFrame from the extracted records with columns "title" and "main_text"
df = pd.DataFrame(records, columns=["title","main_text"])
# Save the DataFrame to a CSV file for later use
df.to_csv('review_articles.csv', index=False)

In [None]:
df = pd.read_csv('review_articles.csv')  # Load the extracted article texts from CSV

# Split the reference articles into chunks of max length 500 with overlap of 10 words
all_chunks = []
for example in tqdm(df['main_text']):  # Iterate through each article's main text
    for c in chunk_text(example, max_length=500, overlap=10):  # Chunk the text for RAG
        all_chunks.append(c)
print(f"chunk total: {len(all_chunks)}")  # Print total number of chunks created

# Use MedCPT as embedder (dimension: 768) to create FAISS index for RAG system

MODEL_NAME = "ncbi/MedCPT-Query-Encoder"
medcpt_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)  # Load tokenizer
medcpt_model = AutoModel.from_pretrained(MODEL_NAME)          # Load model
medcpt_model.eval()                                           # Set model to evaluation mode
medcpt_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available
medcpt_model = medcpt_model.to(medcpt_device)                 # Move model to device
BATCH_SIZE = 16  # Number of chunks to process at once

# Determine the embedding dimension by running a dummy input through the model
with torch.no_grad():
    dummy = medcpt_tokenizer("dummy", return_tensors="pt", max_length=64, truncation=True)
    dummy = {k: v.to(medcpt_device) for k, v in dummy.items()}
    dummy_emb = medcpt_model(**dummy).last_hidden_state[:, 0, :]  # Use [CLS] token embedding
    embedding_dim = dummy_emb.size(-1)  # Should be 768 for MedCPT

# Initialize a FAISS index for fast similarity search using L2 distance
index = faiss.IndexFlatL2(embedding_dim)
chunk_embeddings_list = []

# Batch process all chunks to generate embeddings and add to FAISS index
for i in tqdm(range(0, len(all_chunks), BATCH_SIZE)):
    batch_chunks = all_chunks[i:i+BATCH_SIZE]
    with torch.no_grad():
        encoded = medcpt_tokenizer(
            batch_chunks,
            truncation=True,
            padding=True,
            return_tensors='pt',
            max_length=512  # Truncate/pad to model's max input length
        )
        encoded = {k: v.to(medcpt_device) for k, v in encoded.items()}
        emb = medcpt_model(**encoded).last_hidden_state[:, 0, :]  # Take [CLS] token embedding
        emb = emb.cpu().numpy().astype("float32")  # Convert to numpy for FAISS
        chunk_embeddings_list.append(emb)
chunk_embeddings = np.concatenate(chunk_embeddings_list, axis=0)  # Combine all embeddings
index.add(chunk_embeddings)  # Add embeddings to FAISS index

# Save the FAISS index to "faiss_index.index" for later retrieval
faiss.write_index(index, "faiss_index.index")

## Evaluation

In [None]:
torch.cuda.empty_cache()  # Clear GPU memory cache to avoid OOM errors

# Load the FAISS index for fast similarity search over document embeddings
index = faiss.read_index("faiss_index.index")

# Load the reference articles (review articles) from CSV
df = pd.read_csv('review_articles.csv')

# Split all articles into overlapping chunks for retrieval
all_chunks = []
for example in tqdm(df['main_text']):
    for c in chunk_text(example, max_length=500, overlap=50):
        all_chunks.append(c)

# Load MedCPT model and tokenizer for embedding queries and chunks
MEDCPT_MODEL_NAME = "ncbi/MedCPT-Query-Encoder"
medcpt_tokenizer = AutoTokenizer.from_pretrained(MEDCPT_MODEL_NAME)
medcpt_model = AutoModel.from_pretrained(MEDCPT_MODEL_NAME)
medcpt_model.eval()  # Set model to evaluation mode (no gradients)
medcpt_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
medcpt_model = medcpt_model.to(medcpt_device)

# Quantize the LLM to 4-bit to reduce memory usage (using bitsandbytes)
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,            
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",            
    quantization_config=bnb_cfg,
    token = "YOUR_API_KEY", # Replace with your own HuggingFaceAPI key
)

# Load LoRA fine-tuned weights for the LLM
model = PeftModel.from_pretrained(model, "lora_rag_weight")

# Load tokenizer for the LLM (with API key if needed)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,
                                         token = "YOUR_API_KEY") # Replace with your own HuggingFaceAPI key

# Move model to device and set to eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Ensure tokenizer has a pad token (required for generation)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def medcpt_embed(texts):
    """
    Embed a list of texts using MedCPT model.
    Returns a numpy array of embeddings.
    """
    embs = []
    with torch.no_grad():
        for t in texts:
            encoded = medcpt_tokenizer(
                t,
                truncation=True,
                padding=True,
                return_tensors='pt',
                max_length=512,
            )
            encoded = {k: v.to(medcpt_device) for k, v in encoded.items()}
            # Use [CLS] token embedding as sentence representation
            emb = medcpt_model(**encoded).last_hidden_state[:, 0, :].cpu().numpy().astype("float32")[0]
            embs.append(emb)
    embs = np.stack(embs)
    return embs

class AdvancedRetriever:
    """
    AdvancedRetriever performs two-stage retrieval:
    1. Dense retrieval using MedCPT + FAISS (fast, approximate)
    2. Cross-encoder re-ranking (pointwise, more accurate)
    Optionally applies MMR (Maximal Marginal Relevance) for diversity.
    """
    def __init__(
        self,
        faiss_index: faiss.Index,
        all_chunks: List[str],
        embed_fn,                         
        cross_encoder_name: str = "cross-encoder/ms-marco-MiniLM-L6-v2",
        device: str | None = None,
    ):
        self.index        = faiss_index
        self.all_chunks   = all_chunks
        self.embed_fn     = embed_fn
        # Load cross-encoder for re-ranking (can be on CPU or GPU)
        self.cross_enc    = CrossEncoder(
            cross_encoder_name,
            device=device or ("cuda" if torch.cuda.is_available() else "cpu")
        )
    def retrieve(
        self,
        query: str,
        recall_k: int = 50,
        rerank_k: int = 8,
        use_mmr: bool = True,
        lambda_mmr: float = 0.6,
    ) -> Tuple[List[str], List[int]]:
        """
        Retrieve top-k relevant chunks for a query.
        Returns:
            texts   – final context strings
            indices – their original indices in `all_chunks`
        """
        # 1️⃣ Dense retrieval: get top recall_k candidates from FAISS
        q_emb = self.embed_fn([query]).astype("float32")
        _, I  = self.index.search(q_emb, recall_k)
        cand_idx   = I[0].tolist()
        cand_texts = [self.all_chunks[i] for i in cand_idx]
        # 2️⃣ Cross-encoder re-ranking: score each candidate with the query
        pairs       = list(zip([query] * len(cand_texts), cand_texts))
        ce_scores   = self.cross_enc.predict(pairs, convert_to_numpy=True)
        ranked      = sorted(
            zip(cand_idx, cand_texts, ce_scores),
            key=lambda x: x[2],
            reverse=True
        )[: rerank_k]
        idxs, texts, _ = zip(*ranked)
        # 3️⃣ Optional: MMR diversification to reduce redundancy
        if use_mmr and rerank_k > 1:
            embs      = self.embed_fn(list(texts))
            order     = self._mmr(embs, q_emb[0], k=rerank_k, lamb=lambda_mmr)
            idxs      = [idxs[i]  for i in order]
            texts     = [texts[i] for i in order]
        return list(texts), list(idxs)
    @staticmethod
    def _mmr(doc_embs: np.ndarray, query_emb: np.ndarray,
             k: int = 20, lamb: float = 0.6) -> List[int]:
        """
        Maximal Marginal Relevance (MMR) for selecting a diverse set of top-k documents.
        Returns indices of selected documents.
        """
        selected, cand = [], list(range(len(doc_embs)))

        # Compute relevance of each doc to the query
        rel = skl_cosine(doc_embs, query_emb.reshape(1, -1)).flatten()

        while len(selected) < min(k, len(doc_embs)):
            if not selected:
                sel = int(np.argmax(rel))
            else:
                # Compute redundancy (similarity to already selected docs)
                redund = skl_cosine(
                    doc_embs[cand],        
                    doc_embs[selected]     
                ).max(axis=1)
                # MMR score: trade-off between relevance and redundancy
                mmr = lamb * rel[cand] - (1 - lamb) * redund
                sel = cand[int(np.argmax(mmr))]
            selected.append(sel)
            cand.remove(sel)
        return selected

class RAGGenerator:
    """
    RAGGenerator generates answers by retrieving relevant context and prompting the LLM.
    """
    def __init__(self, retriever: AdvancedRetriever, llm, tokenizer):
        self.ret   = retriever
        self.llm   = llm
        self.tok   = tokenizer
        # Ensure tokenizer has a pad token for generation
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token
    @torch.inference_mode()
    def __call__(
        self,
        query: str,
        recall_k: int = 50,
        rerank_k: int = 8,
        max_new_tokens: int = 128,
        top_p: float = 0.4,
        temperature: float = 0.1,
    ) -> str:
        # Retrieve relevant context chunks for the query
        ctx_texts, _ = self.ret.retrieve(query, recall_k, rerank_k)
        # Format the evidence for the prompt
        evidence = "\n\n".join(f"[Doc {i+1}]\n{txt}" for i, txt in enumerate(ctx_texts))
        # Construct the prompt for the LLM
        prompt   = (
            "@@@ Evidence\n"
            f"{evidence}\n\n"
            "@@@ Instruction: Base your [@@@ Answer] ONLY on the evidence above. "
            "Do NOT copy phrases verbatim and keep it under 200 characters.\n\n"
            f"Question: {query}\n@@@ Answer"
        )
        # Tokenize the prompt and move to device
        inputs = self.tok(
            prompt, return_tensors="pt", truncation=True, max_length=4096
        ).to(self.llm.device)
        # Generate the answer using the LLM
        output = self.llm.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            eos_token_id=self.tok.eos_token_id,
            pad_token_id=self.tok.eos_token_id,
            no_repeat_ngram_size=3,
            top_p=top_p,
            temperature=temperature,
            repetition_penalty=1.15,
            do_sample=True,
        )
        # Decode and extract the answer from the output
        answer = self.tok.decode(output[0], skip_special_tokens=True)
        return answer.split("@@@ Answer")[-1].strip()

# Instantiate the advanced retriever with FAISS, chunks, and embedding function
adv_retriever = AdvancedRetriever(
    faiss_index=index,
    all_chunks=all_chunks,
    embed_fn=medcpt_embed,
)

# Initialize the RAG generator with retriever, LLM, and tokenizer
rag_gen = RAGGenerator(
    retriever=adv_retriever,
    llm=model,
    tokenizer=tokenizer,
)

In [None]:
torch.cuda.empty_cache()  # Clear CUDA memory cache to prevent OOM errors

# Load the QA pairs from the JSON file
with open("qa_pairs_en.json", "r", encoding="utf-8") as f:
    raw_data = json.load(f)

prompts, labels = [], []
# Prepare prompts and labels for training/testing
for sample in raw_data:
    prompts.append(f"Question: {sample['question']}")
    labels.append(" " + str(sample["answer"]))

# Split the data into training and test sets (10% test)
train_p, test_p, train_l, test_l = train_test_split(prompts, labels, test_size=0.1, random_state=1)
train_dataset = pd.DataFrame({"prompt": train_p, "completion": train_l})

# Test the RAG + Fine-tuned model with the test set QA pairs
# Compare the reference answer from the QA pairs with the generated answer from the RAG + Fine-tuned model

gen_ans, ref_ans, qs = [], [], []
for p, ref in tqdm(list(zip(test_p, test_l)), desc="Testing (RAG)"):
    ans = rag_gen(p)  # Generate answer using the RAG pipeline
    gen_ans.append(ans)
    try:
        # Try to parse the reference answer as a dictionary and extract the last value if possible
        ref_dict = eval(ref.strip())
        if isinstance(ref_dict, dict):
            last_key = list(ref_dict.keys())[-1]
            last_value = ref_dict[last_key]
            ref_ans.append(str(last_value))
        else:
            ref_ans.append(ref.strip())
    except Exception:
        # If parsing fails, just use the raw reference string
        ref_ans.append(ref.strip())
    qs.append(p)

# Function to cut the generated answer to the last period for cleaner output
def cut_to_last_period(text):
    idx = text.rfind('.')
    if idx != -1:
        return text[:idx+1].strip()
    else:
        return text.strip()

# Apply the cut_to_last_period function to all generated answers
gen_ans = [cut_to_last_period(ans) for ans in gen_ans]

# Embed the generated and reference answers using MedCPT embeddings
gen_embs = medcpt_embed(gen_ans)
ref_embs = medcpt_embed(ref_ans)

# Calculate cosine similarity between generated and reference answers
num   = (gen_embs * ref_embs).sum(axis=1)
denom = np.linalg.norm(gen_embs, axis=1) * np.linalg.norm(ref_embs, axis=1)
sim   = num / denom

# Calculate BLEU and ROUGE-LCS F1 score for evaluation

# Prepare data for BLEU calculation: BLEU expects tokenized references and hypotheses
bleu_refs = [[ref.split()] for ref in ref_ans]
bleu_hyps = [gen.split() for gen in gen_ans]
bleu_score = corpus_bleu(bleu_refs, bleu_hyps, smoothing_function=SmoothingFunction().method1)

# Initialize ROUGE scorer for ROUGE-L (Longest Common Subsequence) F1
rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
rouge_l_f1s = []
for ref, gen in zip(ref_ans, gen_ans):
    scores = rouge.score(ref, gen)
    rouge_l_f1s.append(scores['rougeL'].fmeasure)
# Compute mean ROUGE-L F1 score
rouge_l_f1 = sum(rouge_l_f1s) / len(rouge_l_f1s) if rouge_l_f1s else 0.0

# Print evaluation metrics
print(f"\nTest set mean cosine similarity: {np.mean(sim):.4f}")
print(f"Test set BLEU: {bleu_score:.4f}")
print(f"Test set ROUGE-LCS F1: {rouge_l_f1:.4f}")

# Print a few example QAs for inspection
for i in range(min(3, len(qs))):
    print(f"\nQ: {qs[i]}\nReference: {ref_ans[i]}\nGenerated: {gen_ans[i]}\nCosine similarity: {sim[i]:.4f}") 

'''
<Example of Reply>
Reference: (1) Antibiotics are recommended to be maintained as they are, and adjusted according to the results of culture. (2) Esophageal candidasis is also suspected, and fluconazole dosage is recommended to be changed to 400 mg on day 1, then 200 to 400 mg once daily as a normal new functional baseline. Currently, after application of PTGBD, a reduction in inflammatory markers and a fever spike is not observed, but a suppressed RUQ tenderness is observed.
Generated: : #. There is no specific recommendation for empirical antibiotic therapy because there is limited information about the previous medical condition and current status of the patient who has already recovered from the asymptomatic stage of the past infection until now, but only empirically administered vancomycin and meropenem were given up to the time of confirmation of the BAL culture result, so please continue administration of vanco + mero according to the doctor's discretion until the recovery of the lung function is completed. If you want to change the medication, we recommend confirming the susceptibility test of the cultured bacteria before changing the drug.
'''