In [1]:
import os
import json
import faiss
import torch
import xml.etree.ElementTree as ET
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from PIL import Image
from pathlib import Path
import requests
from io import BytesIO
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Step 1: Load MedQuAD XML data
chunks = []
data_root = Path("/Users/mahed/VsProjects/MedQuAD")  # Adjust if needed

for subfolder in data_root.glob("*_QA"):
    for file in subfolder.glob("*.xml"):
        try:
            tree = ET.parse(file)
            root = tree.getroot()

            q_el = root.find(".//Question")
            a_el = root.find(".//Answer")

            if q_el is not None and a_el is not None:
                question = q_el.text.strip() if q_el.text else ""
                answer = a_el.text.strip() if a_el.text else ""

                if question and answer:
                    chunks.append({
                        "text": f"Q: {question}\nA: {answer}",
                        "meta": {
                            "source": subfolder.name,
                            "file": file.name
                        }
                    })
        except Exception as e:
            print(f"Error processing {file}: {e}")

In [3]:
print(len(chunks))
print(chunks[0])

5394
{'text': 'Q: What is (are) keratoderma with woolly hair ?\nA: Keratoderma with woolly hair is a group of related conditions that affect the skin and hair and in many cases increase the risk of potentially life-threatening heart problems. People with these conditions have hair that is unusually coarse, dry, fine, and tightly curled. In some cases, the hair is also sparse. The woolly hair texture typically affects only scalp hair and is present from birth. Starting early in life, affected individuals also develop palmoplantar keratoderma, a condition that causes skin on the palms of the hands and the soles of the feet to become thick, scaly, and calloused.  Cardiomyopathy, which is a disease of the heart muscle, is a life-threatening health problem that can develop in people with keratoderma with woolly hair. Unlike the other features of this condition, signs and symptoms of cardiomyopathy may not appear until adolescence or later. Complications of cardiomyopathy can include an abno

In [4]:
# Step 2: Load embedding models
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# Step 3: Encode MedQuAD text with CLIP
clip_text_embeddings = []
for chunk in chunks:
    inputs = clip_processor(text=[chunk["text"]], return_tensors="pt", padding=True, truncation=True)
    embedding = clip_model.get_text_features(**inputs).detach().cpu().numpy()[0]
    clip_text_embeddings.append(embedding)

In [None]:
clip_text_embeddings = np.array(clip_text_embeddings)
text_index = faiss.IndexFlatL2(clip_text_embeddings.shape[1])
text_index.add(clip_text_embeddings)

In [None]:
# Step 4: Load DermNet dataset and embed sample images
image_chunks = []
image_embeddings = []
dermnet = load_dataset("dermnet", split="train")

for i in range(100):  # Limit for speed/test
    item = dermnet[i]
    image = item["image"]
    label = item["label"]
    inputs = clip_processor(images=image, return_tensors="pt")
    embedding = clip_model.get_image_features(**inputs).detach().cpu().numpy()[0]
    image_chunks.append({"image": image, "label": label})
    image_embeddings.append(embedding)

image_embeddings = np.array(image_embeddings)
image_index = faiss.IndexFlatL2(image_embeddings.shape[1])
image_index.add(image_embeddings)