In [None]:

import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
import faiss
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Set paths
PREPROCESSED_DIR = Path("./experiments/preprocessed")
MODEL_DIR = Path("./experiments/models/llm")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Synthetic medical documents
documents = [
    "Patient has fever and cough, suggest chest X-ray",
    "MRI shows lesion in frontal lobe",
    "Blood pressure elevated, patient prescribed ACE inhibitor",
    "Patient has shortness of breath and low oxygen saturation",
    "CT scan shows pulmonary embolism"
]

# Corresponding IDs
doc_ids = np.arange(len(documents))

# Synthetic queries
queries = [
    "What does the MRI indicate?",
    "Recommend treatment for high BP",
    "Explain the X-ray findings"
]

print("Synthetic medical documents and queries created.")


In [None]:

# Load sentence transformer for embeddings
embed_model = SentenceTransformer('all-MiniLM-L6-v2')

# Compute embeddings for documents
doc_embeddings = embed_model.encode(documents, convert_to_numpy=True)

# Build FAISS index
embedding_dim = doc_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(doc_embeddings)

print(f"FAISS index built with {index.ntotal} documents.")


In [None]:

def retrieve_documents(query, top_k=2):
    query_emb = embed_model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_emb, top_k)
    results = [documents[i] for i in indices[0]]
    return results

# Test retrieval
for q in queries:
    docs = retrieve_documents(q)
    print(f"\nQuery: {q}")
    for i, d in enumerate(docs):
        print(f"Doc {i+1}: {d}")


In [None]:

# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

# Simple generation example
for q in queries:
    retrieved_docs = retrieve_documents(q)
    context = " ".join(retrieved_docs)
    input_text = f"Question: {q} Context: {context} Answer:"
    
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=50)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\nQuery: {q}")
    print(f"Generated Answer: {answer}")


In [None]:

from peft import LoraConfig, get_peft_model, TaskType

# LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=32,
    target_modules=["q","v"],
    lora_dropout=0.1,
    bias="none"
)

# Wrap Flan-T5 model with LoRA
lora_model = get_peft_model(model, lora_config)
lora_model.train()  # set to training mode

# Dummy fine-tuning loop (for demonstration)
optimizer = torch.optim.Adam(lora_model.parameters(), lr=1e-4)
for epoch in range(1):
    for q in queries:
        retrieved_docs = retrieve_documents(q)
        context = " ".join(retrieved_docs)
        input_text = f"Question: {q} Context: {context} Answer:"
        target_text = "Placeholder answer"
        
        inputs = tokenizer(input_text, return_tensors="pt")
        labels = tokenizer(target_text, return_tensors="pt").input_ids
        outputs = lora_model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"Epoch {epoch}, Query: {q}, Loss: {loss.item():.4f}")

print("LoRA fine-tuning demo completed.")
