### Necessary Imports

In [None]:
from sklearn.model_selection import train_test_split
import os
from PIL import Image
import torchvision.models as models
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import re
import torchxrayvision as xrv
import torchvision
from tqdm import tqdm

In [None]:
#Using standard dataset
def create_dataloaders(jhu_image_dir=None,ucsf_image_dir=None,washu_image_dir=None,jhu_label=None,ucsf_label=None,washu_label=None,replicate=2):
    
    jhu_odi=np.load(jhu_label)
    jhu_odi=np.array([item>40 for item in jhu_odi],dtype=np.int32)
    jhu_odi=list(np.repeat(jhu_odi,replicate))
    jhu_images=[os.path.join(jhu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(jhu_image_dir)) for img in os.listdir(os.path.join(jhu_image_dir,img_folder)) ]
    
    ucsf_odi=np.load(ucsf_label)
    ucsf_odi=np.array([item>40 for item in ucsf_odi],dtype=np.int32)
    ucsf_odi=list(np.repeat(ucsf_odi,replicate))
    ucsf_images=[os.path.join(ucsf_image_dir,img_folder,img) for img_folder in sorted(os.listdir(ucsf_image_dir)) for img in os.listdir(os.path.join(ucsf_image_dir,img_folder)) ]
    
    washu_odi=np.load(washu_label)
    washu_odi=np.array([item>40 for item in washu_odi],dtype=np.int32)
    washu_odi=list(np.repeat(washu_odi,replicate))
    washu_images=[os.path.join(washu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(washu_image_dir)) for img in os.listdir(os.path.join(washu_image_dir,img_folder)) ]
    
    images=jhu_images+ucsf_images+washu_images
    scores=jhu_odi+ucsf_odi+washu_odi
    
    print(len(images))
    print(len(scores))
    
    return images,scores

images,scores=create_dataloaders(jhu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/JPG',ucsf_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF',washu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/final_images',jhu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/jhu.npy',ucsf_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/ucsf.npy',washu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/washu.npy')

In [None]:
#Using augmented dataset
def create_aug_dataloaders(jhu_image_dir=None,ucsf_image_dir=None,washu_image_dir=None,jhu_label=None,ucsf_label=None,washu_label=None,replicate=10):
    
    jhu_odi=np.load(jhu_label)
    jhu_odi=np.array([item>40 for item in jhu_odi],dtype=np.int32)
    jhu_odi=list(np.repeat(jhu_odi,replicate))
    jhu_images=[os.path.join(jhu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(jhu_image_dir)) for img in os.listdir(os.path.join(jhu_image_dir,img_folder)) ]
    
    ucsf_odi=np.load(ucsf_label)
    ucsf_odi=np.array([item>40 for item in ucsf_odi],dtype=np.int32)
    ucsf_odi=list(np.repeat(ucsf_odi,replicate))
    ucsf_images=[os.path.join(ucsf_image_dir,img_folder,img) for img_folder in sorted(os.listdir(ucsf_image_dir)) for img in os.listdir(os.path.join(ucsf_image_dir,img_folder)) ]
    
    washu_odi=np.load(washu_label)
    washu_odi=np.array([item>40 for item in washu_odi],dtype=np.int32)
    washu_odi=list(np.repeat(washu_odi,replicate))
    washu_images=[os.path.join(washu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(washu_image_dir)) for img in os.listdir(os.path.join(washu_image_dir,img_folder)) ]
    
    images=jhu_images+ucsf_images+washu_images
    scores=jhu_odi+ucsf_odi+washu_odi
    
    print(len(images))
    print(len(scores))
    
    return images,scores

images,scores=create_aug_dataloaders(jhu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/augmented_images',ucsf_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/augmented_images',washu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/augmented_images',jhu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/jhu.npy',ucsf_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/ucsf.npy',washu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/washu.npy')

In [None]:
#Using augmented dataset
def create_aug_dataloaders_resnet(jhu_image_dir=None,ucsf_image_dir=None,washu_image_dir=None,jhu_label=None,ucsf_label=None,washu_label=None,replicate=10):
    
    jhu_odi=np.load(jhu_label)
    jhu_odi=np.array([item>40 for item in jhu_odi],dtype=np.int32)
    jhu_odi=list(np.repeat(jhu_odi,replicate))
    jhu_images=[os.path.join(jhu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(jhu_image_dir)) for img in os.listdir(os.path.join(jhu_image_dir,img_folder)) ]
    
    ucsf_odi=np.load(ucsf_label)
    ucsf_odi=np.array([item>40 for item in ucsf_odi],dtype=np.int32)
    ucsf_odi=list(np.repeat(ucsf_odi,replicate))
    ucsf_images=[os.path.join(ucsf_image_dir,img_folder,img) for img_folder in sorted(os.listdir(ucsf_image_dir)) for img in os.listdir(os.path.join(ucsf_image_dir,img_folder)) ]
    
    washu_odi=np.load(washu_label)
    washu_odi=np.array([item>40 for item in washu_odi],dtype=np.int32)
    washu_odi=list(np.repeat(washu_odi,replicate))
    washu_images=[os.path.join(washu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(washu_image_dir)) for img in os.listdir(os.path.join(washu_image_dir,img_folder)) ]
    
    images=jhu_images+ucsf_images+washu_images
    scores=jhu_odi+ucsf_odi+washu_odi
    
    print(len(images))
    print(len(scores))
    
    return images,scores

images,scores=create_aug_dataloaders_resnet(jhu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/augmented_images_resnet',ucsf_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/augmented_images_resnet',washu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/augmented_images_resnet',jhu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/jhu.npy',ucsf_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/ucsf.npy',washu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/washu.npy')

## Create seperate image folders for jhu, washu and ucsf and embeddings

In [None]:
#Using augmented dataset
def create_aug_dataloaders_resnet_10(jhu_image_dir=None,ucsf_image_dir=None,washu_image_dir=None,jhu_label=None,ucsf_label=None,washu_label=None,replicate=22):
    
    jhu_odi=np.load(jhu_label)
    jhu_odi=np.array([item>40 for item in jhu_odi],dtype=np.int32)
    jhu_odi=list(np.repeat(jhu_odi,replicate))
    jhu_images=[os.path.join(jhu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(jhu_image_dir)) for img in os.listdir(os.path.join(jhu_image_dir,img_folder)) ]
    
    ucsf_odi=np.load(ucsf_label)
    ucsf_odi=np.array([item>40 for item in ucsf_odi],dtype=np.int32)
    ucsf_odi=list(np.repeat(ucsf_odi,replicate))
    ucsf_images=[os.path.join(ucsf_image_dir,img_folder,img) for img_folder in sorted(os.listdir(ucsf_image_dir)) for img in os.listdir(os.path.join(ucsf_image_dir,img_folder)) ]
    
    washu_odi=np.load(washu_label)
    washu_odi=np.array([item>40 for item in washu_odi],dtype=np.int32)
    washu_odi=list(np.repeat(washu_odi,replicate))
    washu_images=[os.path.join(washu_image_dir,img_folder,img) for img_folder in sorted(os.listdir(washu_image_dir)) for img in os.listdir(os.path.join(washu_image_dir,img_folder)) ]
    
    images=jhu_images+ucsf_images+washu_images
    scores=jhu_odi+ucsf_odi+washu_odi
    
    print(len(images))
    print(len(scores))
    
    return jhu_images,jhu_odi,ucsf_images,ucsf_odi,washu_images,washu_odi

jhu_images,jhu_scores,ucsf_images,ucsf_scores,washu_images,washu_scores=create_aug_dataloaders_resnet_10(jhu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/augmented_images_resnet_10',ucsf_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/augmented_images_resnet_10',washu_image_dir='/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/augmented_images_resnet_10',jhu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/jhu.npy',ucsf_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/ucsf.npy',washu_label='/home/blu/ai/xray_score/JHU_UCSF_WASHU/washu.npy')

In [None]:
print(len(jhu_scores))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
##Resnet
#Load the resnet model
model=models.resnet50('IMAGENET1K_V2')
model = torch.nn.Sequential(*list(model.children())[:-1])  # Remove final classification layer
model.eval()

In [None]:
#Get the embeddings for each image- store it in np array
final_transform = transforms.Compose([
            transforms.Resize((224, 224), interpolation=Image.LANCZOS),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])  # Single channel stats
        ])

def image_embedding_generator(image_paths):
    index=0
    for img_path in image_paths:
        img=Image.open(img_path).convert('RGB')
        img=final_transform(img).unsqueeze(0)
        with torch.no_grad():
            embedding=model(img)
            index+=1
            print(index,embedding.shape)
            embedding=embedding.squeeze().numpy()
            yield embedding

embeddings=[embedding for embedding in image_embedding_generator(washu_images)]
embeddings=np.array(embeddings)
print(embeddings.shape)
np.save('washu_image_embeddings_aug_resnet_10.npy',embeddings)

In [None]:
#Load torchxray vision
#Load the model
model=xrv.models.ResNet(weights="resnet50-res512-all")
model = model.to(device)
model.eval()
# Transformation (Resizing to 512x512)
transform = torchvision.transforms.Compose([
    xrv.datasets.XRayResizer(512)
])
def load_and_preprocess_image(img_path):
    img = np.array(Image.open(img_path).convert('L'))
    img = xrv.datasets.normalize(img, 255)
    img = img[None, :, :]  # Add channel dimension
    img = transform(img)  # Resize
    return torch.from_numpy(img)

def image_embeddings_in_batches(image_paths, batch_size=16):
    embeddings = []
    batch = []

    for img_path in tqdm(image_paths):
        img_tensor = load_and_preprocess_image(img_path)
        batch.append(img_tensor)

        if len(batch) == batch_size:
            batch_tensor = torch.stack(batch).to(device)  # (B, 1, 512, 512)
            with torch.no_grad():
                batch_embedding = model.features(batch_tensor)  # (B, D)
            embeddings.append(batch_embedding.cpu().numpy())
            batch = []

    # Process leftover images
    if batch:
        batch_tensor = torch.stack(batch).to(device)
        with torch.no_grad():
            batch_embedding = model.features(batch_tensor)
        embeddings.append(batch_embedding.cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)
    return embeddings

embeddings = image_embeddings_in_batches(images, batch_size=16)
print(embeddings.shape)
np.save('image_embeddings_aug_torchxray.npy', embeddings)

In [None]:
#Train test split
#Loading the embeddings into a list and performing split
#image_embeddings=np.load('image_embeddings_aug_torchxray.npy').tolist()
jhu_embeddings=np.load('jhu_image_embeddings_aug_resnet_10.npy')
ucsf_embeddings=np.load('ucsf_image_embeddings_aug_resnet_10.npy')
washu_embeddings=np.load('washu_image_embeddings_aug_resnet_10.npy')
image_embeddings=np.vstack((jhu_embeddings,ucsf_embeddings,washu_embeddings)).tolist()
image_labels=jhu_scores+ucsf_scores+washu_scores
# image_embeddings=np.load('image_embeddings_aug_resnet.npy').tolist()
# image_labels=scores
# image_embeddings=image_embeddings[:-940]
# image_labels=image_labels[:-940]
print(len(image_embeddings))
print(len(image_labels))
#Split the data
train_embeddings,test_embeddings,train_labels,test_labels=train_test_split(image_embeddings,image_labels,test_size=0.2,random_state=42)

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier,plot_tree
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

rf_clf = RandomForestClassifier(n_estimators=200, max_depth=10,class_weight='balanced', random_state=40)
rf_clf.fit(train_embeddings, train_labels)
rf_predictions = rf_clf.predict(test_embeddings)
print(classification_report(test_labels, rf_predictions))
print(confusion_matrix(test_labels, rf_predictions))


In [None]:
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

tree = rf_clf.estimators_[0]

plt.figure(figsize=(20, 10))
plot_tree(tree, filled=True, feature_names=None, class_names=True)
plt.show()

### Heatmap generation

In [1]:
#Find the 5 highest ODI score images an 5 lowest ones
# Find their embeddings
# Create the heatmap
import numpy as np
import os

def find_target_images():
    """
    Loads original scores and image paths, then finds the 5 images with
    the highest scores and 5 with the lowest scores.
    """
    # Define paths to original data and score files
    base_dirs = {
        'jhu': '/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/JPG',
        'ucsf': '/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF',
        'washu': '/mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/final_images'
    }
    label_files = {
        'jhu': '/home/blu/ai/xray_score/JHU_UCSF_WASHU/jhu.npy',
        'ucsf': '/home/blu/ai/xray_score/JHU_UCSF_WASHU/ucsf.npy',
        'washu': '/home/blu/ai/xray_score/JHU_UCSF_WASHU/washu.npy'
    }

    all_images_with_scores = []

    for key in base_dirs:
        # Load the original continuous scores
        scores = np.load(label_files[key])
        
        # Get paths to the original images, skipping the augmented directory
        # This assumes original images are in subfolders within the base_dir
        image_paths = [
            os.path.join(base_dirs[key], img_folder)
            for img_folder in sorted(os.listdir(base_dirs[key]))
        ]
        
        # Important: Ensure the number of images matches the number of scores
        if len(scores) == len(image_paths):
            for i in range(len(scores)):
                all_images_with_scores.append((scores[i], image_paths[i]))
        else:
            print(f"Warning: Mismatch in {key} dataset. Found {len(image_paths)} images and {len(scores)} scores.")


    # Sort the list based on the score (the first element in the tuple)
    all_images_with_scores.sort(key=lambda x: x[0])

    # Get the 5 lowest and 5 highest
    lowest_5 = all_images_with_scores[:5]
    highest_5 = all_images_with_scores[-5:]

    print("--- 5 Images with Lowest ODI Scores ---")
    for score, path in lowest_5:
        print(f"Score: {score:.2f}, Path: {path}")

    print("\n--- 5 Images with Highest ODI Scores ---")
    for score, path in highest_5:
        print(f"Score: {score:.2f}, Path: {path}")
        
    #return lowest_5 + highest_5
    return all_images_with_scores


# Execute the function to find our 10 target images
target_images = find_target_images()

--- 5 Images with Lowest ODI Scores ---
Score: 0.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/final_images/WUNCPXY0072-TC
Score: 4.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/JPG/JHUCV0050_CK
Score: 6.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/JPG/JHUCV0029_CL
Score: 6.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_WashU/final_images/WUNCPXY0010-CJ
Score: 10.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_JHU/JPG/JHUCV0051_CS

--- 5 Images with Highest ODI Scores ---
Score: 84.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF/USFCPXY0090-ML
Score: 84.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF/USFCPXY0138-JC
Score: 86.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF/USFCPXY0163-RK
Score: 86.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_UCSF/UCSF/USFCPXY0166-BS
Score: 94.00, Path: /mnt/c/Users/swapnil/Downloads/XRAY_Dataset/XRAY_U

In [None]:
! pip install grad-cam

In [2]:
import torch
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import matplotlib.pyplot as plt
import cv2

def generate_and_plot_heatmap(image_path, score, cam_model, target_layer):
    """
    Generates and displays a heatmap for a single image.
    """
    try:
        original_img = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        print(f"Error: Could not find image at {image_path}")
        return

    # Standard preprocessing for ResNet
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(original_img).unsqueeze(0)
    
    # Generate the heatmap. We pass `targets=None` so Grad-CAM automatically
    # picks the class with the highest score from the ImageNet model.
    grayscale_cam = cam_model(input_tensor=input_tensor, targets=None)
    
    # Get the first (and only) heatmap
    grayscale_cam = grayscale_cam[0, :]
    
    # Overlay heatmap on a normalized version of the original image
    #visualization_img = np.array(original_img.resize((224, 224))) / 255.0
    visualization_img = np.array(original_img) / 255.0
    resized_heatmap = cv2.resize(grayscale_cam, original_img.size)
    cam_image = show_cam_on_image(visualization_img, resized_heatmap, use_rgb=True)
    #cam_image = show_cam_on_image(visualization_img, grayscale_cam, use_rgb=True)

    # # Plot the result
    # plt.imshow(cam_image)
    # plt.title(f'Score: {score:.2f}\n{os.path.basename(image_path)}')
    # plt.axis('off')
    # #plt.show()
    # plt.imsave(f'./heatmaps/{score}_{os.path.basename(image_path)}',cam_image)
    
    fig, ax = plt.subplots()
    ax.imshow(cam_image)
    #ax.set_title(f'Score: {score:.2f}\n{os.path.basename(image_path)}')
    ax.axis('off')
    # 2. Define save path and save the figure
    # We use a sanitized filename to avoid issues with paths
    filename = f'{score}_{os.path.basename(image_path)}'
    save_path = os.path.join('heatmaps', filename)
    plt.savefig(save_path,bbox_inches='tight', pad_inches=0)
    
    # 3. CRUCIAL STEP: Close the figure to release memory
    plt.close(fig)

# --- Main Execution for Heatmap Generation ---

# 1. Load pre-trained ResNet50 and define the target layer
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.eval()
target_layer = [model.layer4[-1]] # The last convolutional block

# 2. Initialize GradCAM
cam = GradCAM(model=model, target_layers=target_layer)
os.makedirs('heatmaps',exist_ok=True)
# 3. Loop through the 10 selected images and generate a heatmap for each
print("\n--- Generating Heatmaps ---")
idx=0
if 'target_images' in locals() and target_images:
    for score, folder_path in target_images:
        idx+=1
        print(f'idx: {idx}')
        images=os.listdir(folder_path)
        #img_path=os.path.join(folder_path,images[1])
        for image in images:
            img_path=os.path.join(folder_path,image)
            generate_and_plot_heatmap(img_path, score, cam, target_layer)
else:
    print("Could not find target images. Please run the first part of the script successfully.")


--- Generating Heatmaps ---
idx: 1
idx: 2
idx: 3
idx: 4
idx: 5
idx: 6
idx: 7
idx: 8
idx: 9
idx: 10
idx: 11
idx: 12
idx: 13
idx: 14
idx: 15
idx: 16
idx: 17
idx: 18
idx: 19
idx: 20
idx: 21
idx: 22
idx: 23
idx: 24
idx: 25
idx: 26
idx: 27
idx: 28
idx: 29
idx: 30
idx: 31
idx: 32
idx: 33
idx: 34
idx: 35
idx: 36
idx: 37
idx: 38
idx: 39
idx: 40
idx: 41
idx: 42
idx: 43
idx: 44
idx: 45
idx: 46
idx: 47
idx: 48
idx: 49
idx: 50
idx: 51
idx: 52
idx: 53


: 