In [1]:
import sys
import os

project_path = r"C:\Users\jjmcc\OneDrive\Documents\Thesis Interim\ITGC-interim"
sys.path.append(project_path)

print("Project path added.")

Project path added.


In [3]:
import os
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


Running on: cpu


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
MRI_ROOT = r"E:/ProcessedMRI"

image_paths = []
for r, _, files in os.walk(MRI_ROOT):
    for f in files:
        if f.endswith(".png"):
            image_paths.append(os.path.join(r, f))

print("Total PNG:", len(image_paths))

N_IMAGES = 40
image_paths = random.sample(image_paths, N_IMAGES)


Total PNG: 3242


In [5]:
initial_prompts = [
    "short axis cardiac MRI slice",
    "long axis cardiac MRI slice",
    "basal left ventricle MRI slice",
    "mid-ventricular left ventricle MRI",
    "apical LV MRI slice",
    "end-diastolic cardiac MRI frame",
    "end-systolic cardiac MRI frame",
    "left ventricle cavity view",
    "right ventricle cavity view",
    "four chamber cardiac MRI"
]


In [6]:
def prompt_score(prompt, paths):
    imgs = [Image.open(p).convert("RGB") for p in paths]
    inputs = processor(text=[prompt]*len(paths), images=imgs,
                       return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        out = model(**inputs)

    sims = out.logits_per_image.cpu().numpy().squeeze()
    return float(np.mean(sims))

def score_prompt_list(prompts):
    scores = [(p, prompt_score(p, image_paths)) for p in prompts]
    for p, s in scores:
        print(f"{p}  →  {s:.4f}")
    return sorted(scores, key=lambda x: -x[1])


In [7]:
synonyms = {
    "short axis": ["short-axis", "SA", "transverse"],
    "long axis": ["long-axis", "LA", "vertical"],
    "end-diastolic": ["ED phase", "diastolic frame"],
    "end-systolic": ["ES phase", "systolic frame"],
    "left ventricle": ["LV", "left ventricular"],
    "right ventricle": ["RV", "right ventricular"],
    "slice": ["section", "image", "frame"],
    "cardiac MRI": ["MRI of the heart", "cine MRI"]
}

def mutate_local(prompt):
    words = prompt.split()

    # randomly replace synonyms
    new_words = []
    for w in words:
        key = next((k for k in synonyms if k in w), None)
        if key and random.random() < 0.5:
            new_words.append(random.choice(synonyms[key]))
        else:
            new_words.append(w)

    # randomly add structural expansions
    if random.random() < 0.4:
        extras = [
            "high contrast", "cine imaging", "ventricular view",
            "cardiac anatomy focus", "T1-weighted appearance"
        ]
        new_words.append(random.choice(extras))

    return " ".join(new_words)


In [8]:
def evolve_prompts(seed_prompts, rounds=3, top_k=3, mutations_per_prompt=2):
    pool = seed_prompts.copy()

    for r in range(rounds):
        print(f"\n===== ROUND {r+1} =====")
        
        # Score pool
        scored = score_prompt_list(pool)
        top_prompts = [p for p, _ in scored[:top_k]]

        print("Top prompts:", top_prompts)

        # Mutate locally
        new_prompts = []
        for p in top_prompts:
            for _ in range(mutations_per_prompt):
                new_prompts.append(mutate_local(p))

        pool = list(set(pool + new_prompts))

    return score_prompt_list(pool)


In [9]:
optimized = evolve_prompts(initial_prompts, rounds=3)
optimized[:10]



===== ROUND 1 =====
short axis cardiac MRI slice  →  31.7322
long axis cardiac MRI slice  →  31.6762
basal left ventricle MRI slice  →  31.2942
mid-ventricular left ventricle MRI  →  30.3494
apical LV MRI slice  →  31.1012
end-diastolic cardiac MRI frame  →  29.6492
end-systolic cardiac MRI frame  →  29.3468
left ventricle cavity view  →  26.7331
right ventricle cavity view  →  26.9233
four chamber cardiac MRI  →  29.7297
Top prompts: ['short axis cardiac MRI slice', 'long axis cardiac MRI slice', 'basal left ventricle MRI slice']

===== ROUND 2 =====
short axis cardiac MRI image cine imaging  →  29.7086
long axis cardiac MRI slice  →  31.6762
left ventricle cavity view  →  26.7331
basal left ventricle MRI slice T1-weighted appearance  →  31.5163
basal left ventricle MRI slice  →  31.2942
short axis cardiac MRI slice  →  31.7322
four chamber cardiac MRI  →  29.7297
long axis cardiac MRI slice high contrast  →  31.2155
right ventricle cavity view  →  26.9233
basal left ventricle MRI sl

[('short axis cardiac MRI slice', 31.732168197631836),
 ('long axis cardiac MRI slice', 31.6761531829834),
 ('basal left ventricle MRI slice T1-weighted appearance', 31.516305923461914),
 ('basal left ventricle MRI image T1-weighted appearance', 31.29973793029785),
 ('basal left ventricle MRI slice T1-weighted appearance ventricular view',
  31.298656463623047),
 ('basal left ventricle MRI slice', 31.294248580932617),
 ('long axis cardiac MRI slice high contrast', 31.215526580810547),
 ('long axis cardiac MRI image', 31.118478775024414),
 ('apical LV MRI slice', 31.101245880126953),
 ('basal left ventricle MRI slice high contrast', 30.975536346435547)]

In [10]:
with open("best_prompts.txt", "w") as f:
    for p, s in optimized[:5]:
        f.write(f"{p} | {s}\n")

print("Saved top prompts to best_prompts.txt")


Saved top prompts to best_prompts.txt


In [11]:
print("\n===== OPTIMIZED PROMPTS =====\n")

for i, (prompt, score) in enumerate(optimized[:10], start=1):
    print(f"{i}. {prompt}   (CLIP score: {score:.4f})")



===== OPTIMIZED PROMPTS =====

1. short axis cardiac MRI slice   (CLIP score: 31.7322)
2. long axis cardiac MRI slice   (CLIP score: 31.6762)
3. basal left ventricle MRI slice T1-weighted appearance   (CLIP score: 31.5163)
4. basal left ventricle MRI image T1-weighted appearance   (CLIP score: 31.2997)
5. basal left ventricle MRI slice T1-weighted appearance ventricular view   (CLIP score: 31.2987)
6. basal left ventricle MRI slice   (CLIP score: 31.2942)
7. long axis cardiac MRI slice high contrast   (CLIP score: 31.2155)
8. long axis cardiac MRI image   (CLIP score: 31.1185)
9. apical LV MRI slice   (CLIP score: 31.1012)
10. basal left ventricle MRI slice high contrast   (CLIP score: 30.9755)
