In [1]:
import sys
import os

project_path = r"C:\Users\jjmcc\OneDrive\Documents\Thesis Interim\ITGC-interim"
sys.path.append(project_path)

print("Project path added.")

Project path added.


In [2]:
import os
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


  from .autonotebook import tqdm as notebook_tqdm


Running on: cpu


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
MRI_ROOT = r"E:/ProcessedMRI"

image_paths = []
for r, _, files in os.walk(MRI_ROOT):
    for f in files:
        if f.endswith(".png"):
            image_paths.append(os.path.join(r, f))

print("Total PNG:", len(image_paths))

N_IMAGES = 40
image_paths = random.sample(image_paths, N_IMAGES)


Total PNG: 3242


In [4]:
initial_prompts = [
    "short axis cardiac MRI slice",
    "long axis cardiac MRI slice",
    "basal left ventricle MRI slice",
    "mid-ventricular left ventricle MRI",
    "apical LV MRI slice",
    "end-diastolic cardiac MRI frame",
    "end-systolic cardiac MRI frame",
    "left ventricle cavity view",
    "right ventricle cavity view",
    "four chamber cardiac MRI"
]


In [5]:
def prompt_score(prompt, paths):
    imgs = [Image.open(p).convert("RGB") for p in paths]
    inputs = processor(text=[prompt]*len(paths), images=imgs,
                       return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        out = model(**inputs)

    sims = out.logits_per_image.cpu().numpy().squeeze()
    return float(np.mean(sims))

def score_prompt_list(prompts):
    scores = [(p, prompt_score(p, image_paths)) for p in prompts]
    for p, s in scores:
        print(f"{p}  →  {s:.4f}")
    return sorted(scores, key=lambda x: -x[1])


In [6]:
synonyms = {
    "short axis": ["short-axis", "SA", "transverse"],
    "long axis": ["long-axis", "LA", "vertical"],
    "end-diastolic": ["ED phase", "diastolic frame"],
    "end-systolic": ["ES phase", "systolic frame"],
    "left ventricle": ["LV", "left ventricular"],
    "right ventricle": ["RV", "right ventricular"],
    "slice": ["section", "image", "frame"],
    "cardiac MRI": ["MRI of the heart", "cine MRI"]
}

def mutate_local(prompt):
    words = prompt.split()

    # randomly replace synonyms
    new_words = []
    for w in words:
        key = next((k for k in synonyms if k in w), None)
        if key and random.random() < 0.5:
            new_words.append(random.choice(synonyms[key]))
        else:
            new_words.append(w)

    # randomly add structural expansions
    if random.random() < 0.4:
        extras = [
            "high contrast", "cine imaging", "ventricular view",
            "cardiac anatomy focus", "T1-weighted appearance"
        ]
        new_words.append(random.choice(extras))

    return " ".join(new_words)


In [7]:
def evolve_prompts(seed_prompts, rounds=3, top_k=3, mutations_per_prompt=2):
    pool = seed_prompts.copy()

    for r in range(rounds):
        print(f"\n===== ROUND {r+1} =====")
        
        # Score pool
        scored = score_prompt_list(pool)
        top_prompts = [p for p, _ in scored[:top_k]]

        print("Top prompts:", top_prompts)

        # Mutate locally
        new_prompts = []
        for p in top_prompts:
            for _ in range(mutations_per_prompt):
                new_prompts.append(mutate_local(p))

        pool = list(set(pool + new_prompts))

    return score_prompt_list(pool)


In [8]:
optimized = evolve_prompts(initial_prompts, rounds=3)
optimized[:10]



===== ROUND 1 =====
short axis cardiac MRI slice  →  32.0334
long axis cardiac MRI slice  →  31.9187
basal left ventricle MRI slice  →  31.5960
mid-ventricular left ventricle MRI  →  30.7079
apical LV MRI slice  →  31.5406
end-diastolic cardiac MRI frame  →  29.9618
end-systolic cardiac MRI frame  →  29.5835
left ventricle cavity view  →  26.8771
right ventricle cavity view  →  27.0361
four chamber cardiac MRI  →  29.9838
Top prompts: ['short axis cardiac MRI slice', 'long axis cardiac MRI slice', 'basal left ventricle MRI slice']

===== ROUND 2 =====
basal left ventricle MRI slice T1-weighted appearance  →  31.7524
long axis cardiac MRI section  →  29.5713
mid-ventricular left ventricle MRI  →  30.7079
four chamber cardiac MRI  →  29.9838
short axis cardiac MRI slice  →  32.0334
long axis cardiac MRI slice  →  31.9187
basal left ventricle MRI slice  →  31.5960
end-diastolic cardiac MRI frame  →  29.9618
right ventricle cavity view  →  27.0361
basal left ventricle MRI frame  →  30.654

[('short axis cardiac MRI slice T1-weighted appearance', 32.419349670410156),
 ('long axis cardiac MRI slice T1-weighted appearance', 32.34211349487305),
 ('short axis cardiac MRI slice', 32.03339385986328),
 ('short axis cardiac MRI image', 31.937673568725586),
 ('long axis cardiac MRI slice', 31.918739318847656),
 ('basal left ventricle MRI slice T1-weighted appearance', 31.7524356842041),
 ('basal left ventricle MRI slice', 31.59604263305664),
 ('apical LV MRI slice', 31.540599822998047),
 ('long axis cardiac MRI slice ventricular view', 31.259981155395508),
 ('mid-ventricular left ventricle MRI', 30.707895278930664)]

In [9]:
with open("best_prompts.txt", "w") as f:
    for p, s in optimized[:5]:
        f.write(f"{p} | {s}\n")

print("Saved top prompts to best_prompts.txt")


Saved top prompts to best_prompts.txt


In [10]:
print("\n===== OPTIMIZED PROMPTS =====\n")

for i, (prompt, score) in enumerate(optimized[:10], start=1):
    print(f"{i}. {prompt}   (CLIP score: {score:.4f})")



===== OPTIMIZED PROMPTS =====

1. short axis cardiac MRI slice T1-weighted appearance   (CLIP score: 32.4193)
2. long axis cardiac MRI slice T1-weighted appearance   (CLIP score: 32.3421)
3. short axis cardiac MRI slice   (CLIP score: 32.0334)
4. short axis cardiac MRI image   (CLIP score: 31.9377)
5. long axis cardiac MRI slice   (CLIP score: 31.9187)
6. basal left ventricle MRI slice T1-weighted appearance   (CLIP score: 31.7524)
7. basal left ventricle MRI slice   (CLIP score: 31.5960)
8. apical LV MRI slice   (CLIP score: 31.5406)
9. long axis cardiac MRI slice ventricular view   (CLIP score: 31.2600)
10. mid-ventricular left ventricle MRI   (CLIP score: 30.7079)
