# Building Datastore based on KG for MIMIC-NLE Dataset

In [None]:
import json
from tqdm import tqdm
from transformers import AutoTokenizer
from medclip import MedCLIPModel, MedCLIPVisionModelViT
from medclip import MedCLIPProcessor
import torch
import faiss
import os
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def load_data(data_path):
    """We load in all images and only the train captions."""

    annotations = json.load(open(data_path))
    images = []
    captions = []
    for item in annotations:
        if item['split'] == 'train':
            if 'triplets' in item:
                captions.append({'image_id': item['img_id'],  'caption': item['triplets']})
        images.append({'image_id': item['img_id'], 'img_path': item['img_path']})
 
    return images, captions

def encode_captions(captions, model, device, preprocess):
    bs = 256
    encoded_captions = []
    model.to(device)

    for idx in tqdm(range(0, len(captions), bs)):
        batch_captions = captions[idx:idx+bs]
        
        # Tokenize the captions with truncation and padding
        input_ids = preprocess(text=batch_captions, return_tensors='pt', padding=True, truncation=True, max_length=77).to(device)
        
        with torch.no_grad():
            encoded_batch = model.encode_text(input_ids['input_ids']).cpu().numpy()
            encoded_captions.append(encoded_batch)

    encoded_captions = np.concatenate(encoded_captions) if encoded_captions else np.array([])

    return encoded_captions

def encode_images(images, data_dir, model, feature_extractor, device):
    image_ids = [i['image_id'] for i in images]
    
    bs = 64	
    image_features = []
    for idx in tqdm(range(0, len(images), bs)):
        batch_images = images[idx:idx + bs]  # Get a batch of image data
        img_paths = [img['img_path'] for img in batch_images]  # Extract img_path from each image data
        processed_images = [Image.open(img_path).convert("RGB") for img_path in img_paths]
        image_input = feature_extractor(images=processed_images, return_tensors='pt').pixel_values.to(device)
        with torch.no_grad():
            image_embeds = model.encode_image(pixel_values=image_input).cpu().numpy()
            image_features.append(image_embeds)
    image_features = np.concatenate(image_features)

    return image_ids, image_features

def filter_captions(data):

    decoder_name = 'gpt2'
    tokenizer = AutoTokenizer.from_pretrained(decoder_name)
    tokenizer.pad_token = tokenizer.eos_token
    bs = 512

    image_ids = [d['image_id'] for d in data]
    caps = [d['caption'] for d in data]
    encodings = []
    for idx in range(0, len(data), bs):
        encodings += tokenizer.batch_encode_plus(caps[idx:idx+bs], return_tensors='np',padding=True, truncation=True, max_length=100 )['input_ids'].tolist()
    
    filtered_image_ids, filtered_captions = [], []

    assert len(image_ids) == len(caps) and len(caps) == len(encodings)
    for image_id, cap, encoding in zip(image_ids, caps, encodings):
        if len(encoding) <= 100:
            filtered_image_ids.append(image_id)
            filtered_captions.append(cap)

    return filtered_image_ids, filtered_captions


def get_nns(captions, images, k=15):
    xq = images.astype(np.float32)
    xb = captions.astype(np.float32)
    faiss.normalize_L2(xb)
    index = faiss.IndexFlatIP(xb.shape[1])
    index.add(xb)
    faiss.normalize_L2(xq)
    D, I = index.search(xq, k) 

    return index, I

def filter_nns(nns, xb_image_ids, captions, xq_image_ids):
    """We filter out nearest neighbors which are actual captions for the query image, keeping 7 neighbors per image."""
    retrieved_captions = {}
    for nns_list, image_id in zip(nns, xq_image_ids):
        good_nns = []
        for nn in nns_list:  # Iterate directly over nns_list
            if xb_image_ids[nn] == image_id:
                continue
            good_nns.append(captions[nn])
            if len(good_nns) == 7:
                break
        assert len(good_nns) == 7
        retrieved_captions[image_id] = good_nns
    return retrieved_captions

def main(): 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder_name = MedCLIPVisionModelViT
    feature_extractor = MedCLIPProcessor()
    clip_model = MedCLIPModel(vision_cls=encoder_name)
    clip_model.from_pretrained()
    clip_model = clip_model.cuda()

    # load captions from json file
    annotations = json.load(open("/data/mimic_nle_kg_suggestive.json", "r"))
    captions = []
    for item in annotations:
        if item['split'] == 'train':
            if 'triplets' in item:
                captions.append({'image_id': item['img_id'],  'caption': item['triplets']})

    print('Filtering captions')    
    xb_image_ids, captions = filter_captions(captions)

    print('Encoding captions')
    encoded_captions = encode_captions(captions, clip_model, device,preprocess=feature_extractor)

    print('Encoding images')
    xq_image_ids, encoded_images = encode_images(images, image_path, clip_model, feature_extractor, device)

    print('Retrieving neighbors')
    index, nns = get_nns(encoded_captions, encoded_images)
    retrieved_caps = filter_nns(nns, xb_image_ids, captions, xq_image_ids)

    print('Writing files')
    faiss.write_index(index, "/data/datastore/kg_nle_index")
    json.dump(captions, open('/data/datastore/kg_nle_index_captions.json', 'w'))

    json.dump(retrieved_caps, open('/data/retrieved_triplets.json', 'w'))

if __name__ == '__main__':
    main()

# Preparing Dataset for LLaVA Training on MIMIC-NLE dataset with KG-RAG

In [None]:
# To build the dataset in LLaVA format.
import json
import random

# Input JSON file paths
input_json_file = '/data/mimic_nle_dataset.json'
kg_triplets_file = '/data/retrieved_triplets.json'

# Output JSON file path
output_json_file = '/data/mimic-nle-train-.json'

# Number of KG triplets to pick
K = 7

# Seed value for reproducibility
seed_value = 42

# Diagnosis list and certainty list as provided
diagnosis_list = [
    "Atelectasis",
    "Consolidation",
    "Edema",
    "Enlarged Cardiomediastinum",
    "Lung Lesion",
    "Lung Opacity",
    "Pleural Effusion",
    "Pleural Other",
    "Pneumonia",
    "Pneumothorax",
]
certainty_list = ["negative", "uncertain", "positive"]

# Question templates
question_templates = [
    "Which signs show that the patient has {pathologies}?",
    "Explain why these {pathologies} are present in the image.",
    "What evidence in the image indicates {pathologies}?",
    "How can you tell that the patient has {pathologies} from the image?",
    "What features suggest the presence of {pathologies} in this image?"
]

# Function to get pathologies string from record
def get_pathologies(record):
    prompt = ""
    for idx, diagnosis in enumerate(record['img_labels']):
        if diagnosis[1]:
            prompt += certainty_list[1] + " " + diagnosis_list[idx] + ", "
        if diagnosis[2]:
            prompt += certainty_list[2] + " " + diagnosis_list[idx] + ", "
    return prompt.strip(", ")

# Function to load data and KG triplets
def load_data_with_kg_triplets(annot_path, triplets_path, K):
    annotations = json.load(open(annot_path))
    kg_triplets = json.load(open(triplets_path))
    data = {}
    
    for item in annotations:
        triplets = kg_triplets.get(str(item['img_id']), [])
        item['kg_triplets'] = triplets[:K]  # Pick the first K triplets
        data[str(item['img_id'])] = item
    
    return data

# Set the random seed for reproducibility
random.seed(seed_value)

# Load annotations and KG triplets
data = load_data_with_kg_triplets(input_json_file, kg_triplets_file, K)

# Prepare the output data
output_data = []
for record_id, record in data.items():
    if record["split"] == "train" or record["split"] == "val":
        pathologies = get_pathologies(record)
        question_template = random.choice(question_templates)
        question = question_template.format(pathologies=pathologies)
        kg_triplets = "; ".join(record.get('kg_triplets', []))
        
        conversation = {
            "from": "human",
            "value": f"<image>\nThe image-specific triplets from the knowledge graph are: {kg_triplets}. And for the given image, {question}"
            # for test set
            #"value": f"The image-specific triplets from the knowledge graph are: {kg_triplets}. And for the given image, {question}"
        }
        assistant_response = {
            "from": "gpt",
            "value": record["nle"]
        }
        output_record = {
            "id": str(record["img_id"]),
            "split": record["split"],
            "image": record["img_path"],
            "conversations": [conversation, assistant_response]
        }
        output_data.append(output_record)

# Write the output JSON file
with open(output_json_file, 'w') as f:
    json.dump(output_data, f, indent=4)

print(f"Output JSON file has been created successfully: {output_json_file}")
