In [3]:
import os
import base64
from io import BytesIO
from openai import AzureOpenAI
from PIL import Image
import matplotlib.pyplot as plt
import torch
from captum.attr import LayerGradCam
from torchvision import models, transforms
import numpy as np

# -------------------------
# 1. Azure OpenAI Setup
# -------------------------
os.environ["AZURE_OPENAI_API_KEY"] = "c8a777d60b804da0bf534fd324da23f5"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://carqa-recitals-article.openai.azure.com/"
AZURE_OPENAI_API_VERSION = "2024-02-01"
GPT_DEPLOYMENT_NAME = "gpt-4o"

client = AzureOpenAI(
    api_key=os.environ["AZURE_OPENAI_API_KEY"],
    api_version=AZURE_OPENAI_API_VERSION,
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"]
)


# 2. Load ResNet Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 4
classes = ["Benign", "Early", "Pre", "Pro"]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(model.fc.in_features, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(256, num_classes)
)
model.load_state_dict(torch.load(r"C:\Users\94718\Downloads\CW\ANN-based-Acute-Lymphoblastic-Leukemia-Classifier\model\resnet18_cancer_model.pth", map_location=device))
model = model.to(device)
model.eval()


# 3. Grad-CAM Setup

target_layer = model.layer4[1].conv2
gradcam = LayerGradCam(model, target_layer)

def generate_gradcam(image_tensor, label_idx=None):
    image_tensor = image_tensor.unsqueeze(0).to(device)
    output = model(image_tensor)
    pred_class = output.argmax(1).item()
    label_idx = label_idx if label_idx is not None else pred_class

    attr = gradcam.attribute(image_tensor, target=label_idx)
    attr_upsampled = torch.nn.functional.interpolate(attr, size=(224,224), mode='bilinear').squeeze().cpu().detach().numpy()

    # Normalize
    attr_upsampled = (attr_upsampled - attr_upsampled.min()) / (attr_upsampled.max() - attr_upsampled.min())

    # Overlay heatmap
    img_np = image_tensor.squeeze().cpu().permute(1,2,0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = np.clip(img_np * std + mean, 0,1)

    fig, ax = plt.subplots()
    ax.imshow(img_np)
    ax.imshow(attr_upsampled, cmap='jet', alpha=0.5)
    ax.axis('off')

    # Save Grad-CAM figure to bytes
    buf = BytesIO()
    fig.savefig(buf, format='PNG')
    buf.seek(0)
    plt.close(fig)
    return pred_class, buf


# 4. Load Query Image

query_image_path = r"C:\Users\94718\OneDrive\Desktop\NIBM\AI\CourseWork\code\try2\Preprocessed\Early\WBC-Malignant-Early-005.jpg"  # replace with your image
img = Image.open(query_image_path).convert("RGB")
img_tensor = transform(img)

# Generate Grad-CAM and prediction
pred_class_idx, gradcam_buf = generate_gradcam(img_tensor)

# Convert Grad-CAM image to base64 for GPT input
gradcam_base64 = base64.b64encode(gradcam_buf.read()).decode("utf-8")


# 5. GPT Explanation Prompt

prompt = f"""
You are a medical AI assistant. 
A CNN model predicted the class '{classes[pred_class_idx]}' for the provided peripheral blood smear image. 
Here is the Grad-CAM heatmap highlighting important regions (attached). 
Explain in detail why the model predicted this class, which regions were important, 
and provide a concise clinical-style interpretation of the image.
"""

response = client.chat.completions.create(
    model=GPT_DEPLOYMENT_NAME,
    messages=[
        {"role": "system", "content": "You are a helpful AI for cancer image interpretation."},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{gradcam_base64}"}}
            ],
        },
    ],
)


# 6. Display Results

print("Predicted Class:", classes[pred_class_idx])
print("GPT Explanation:\n")
print(response.choices[0].message.content)


Predicted Class: Early
GPT Explanation:

The attached Grad-CAM heatmap provides valuable insights into the regions of the peripheral blood smear image that were key in the model's prediction of the class 'Early'. Grad-CAM, or Gradient-weighted Class Activation Mapping, visualizes which parts of the image contribute most to the model’s decision by highlighting them with warmer colors, such as red and yellow.

### Model Prediction:
The model predicted the class 'Early' for the provided peripheral blood smear image. This class likely corresponds to an early stage of a hematologic condition that is apparent through subtle changes in the morphology and distribution of cells in the blood smear.

### Important Regions Highlighted:
- **Warmer Colors (Red/Yellow Areas):** These areas are the regions that the model focused on most during its decision-making process. In this image, warmer colors are concentrated on certain cells within the smear.
- **Cell Morphology:** The cells highlighted as im

In [4]:
import os
import base64
from io import BytesIO
from openai import AzureOpenAI
from PIL import Image
import matplotlib.pyplot as plt
import torch
from captum.attr import LayerGradCam
from torchvision import models, transforms
import numpy as np


keys_file = "azure_keys.txt"

with open(keys_file, "r") as f:
    lines = f.read().strip().splitlines()
    key_dict = dict(line.split("=", 1) for line in lines if "=" in line)

os.environ["AZURE_OPENAI_API_KEY"] = key_dict.get("AZURE_OPENAI_API_KEY", "")
os.environ["AZURE_OPENAI_ENDPOINT"] = key_dict.get("AZURE_OPENAI_ENDPOINT", "")
AZURE_OPENAI_API_VERSION = key_dict.get("AZURE_OPENAI_API_VERSION", "2024-02-01")
GPT_DEPLOYMENT_NAME = key_dict.get("GPT_DEPLOYMENT_NAME", "gpt-4o")

client = AzureOpenAI(
    api_key=os.environ["AZURE_OPENAI_API_KEY"],
    api_version=AZURE_OPENAI_API_VERSION,
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"]
)

#load model resentes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 4
classes = ["Benign", "Early", "Pre", "Pro"]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Sequential(
    torch.nn.Linear(model.fc.in_features, 256),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(256, num_classes)
)
model.load_state_dict(torch.load(
    r"C:\Users\94718\Downloads\CW\ANN-based-Acute-Lymphoblastic-Leukemia-Classifier\model\resnet18_cancer_model.pth",
    map_location=device
))
model = model.to(device)
model.eval()

#grad cam
target_layer = model.layer4[1].conv2
gradcam = LayerGradCam(model, target_layer)

def generate_gradcam(image_tensor, label_idx=None):
    image_tensor = image_tensor.unsqueeze(0).to(device)
    output = model(image_tensor)
    pred_class = output.argmax(1).item()
    label_idx = label_idx if label_idx is not None else pred_class

    attr = gradcam.attribute(image_tensor, target=label_idx)
    attr_upsampled = torch.nn.functional.interpolate(
        attr, size=(224,224), mode='bilinear'
    ).squeeze().cpu().detach().numpy()

    # Normalize
    attr_upsampled = (attr_upsampled - attr_upsampled.min()) / (attr_upsampled.max() - attr_upsampled.min())

    # Overlay heatmap
    img_np = image_tensor.squeeze().cpu().permute(1,2,0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = np.clip(img_np * std + mean, 0, 1)

    fig, ax = plt.subplots()
    ax.imshow(img_np)
    ax.imshow(attr_upsampled, cmap='jet', alpha=0.5)
    ax.axis('off')

    # Save Grad-CAM figure to bytes
    buf = BytesIO()
    fig.savefig(buf, format='PNG')
    buf.seek(0)
    plt.close(fig)
    return pred_class, buf

#image
query_image_path = r"C:\Users\94718\OneDrive\Desktop\NIBM\AI\CourseWork\code\try2\Preprocessed\Early\WBC-Malignant-Early-005.jpg"
img = Image.open(query_image_path).convert("RGB")
img_tensor = transform(img)

# Generate Grad-CAM and prediction
pred_class_idx, gradcam_buf = generate_gradcam(img_tensor)

# Convert Grad-CAM image to base64 for GPT input
gradcam_base64 = base64.b64encode(gradcam_buf.read()).decode("utf-8")

#gpt prompt
prompt = f"""
You are a medical AI assistant. 
A CNN model predicted the class '{classes[pred_class_idx]}' for the provided peripheral blood smear image. 
Here is the Grad-CAM heatmap highlighting important regions (attached). 
Explain in detail why the model predicted this class, which regions were important, 
and provide a concise clinical-style interpretation of the image.
"""

response = client.chat.completions.create(
    model=GPT_DEPLOYMENT_NAME,
    messages=[
        {"role": "system", "content": "You are a helpful AI for cancer image interpretation."},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{gradcam_base64}"}}
            ],
        },
    ],
)

#results
print("Predicted Class:", classes[pred_class_idx])
print("GPT Explanation:\n")
print(response.choices[0].message.content)


Predicted Class: Early
GPT Explanation:

### Detailed Explanation of Model Prediction:

The provided Grad-CAM (Gradient-weighted Class Activation Mapping) heatmap is used to visually interpret the CNN model's decision-making process. The heatmap highlights regions in the image that were most influential in the model's prediction. 

1. **Model's Prediction:**
   - The CNN model has predicted the class 'Early' for the provided peripheral blood smear image. This suggests that the model has identified features typical of early-stage pathology in the blood smear.

2. **Important Regions Highlighted:**
   - **Red/orange regions:** Indicate areas of high importance in the model's prediction. These areas have more significant weight in determining that the image belongs to the 'Early' class.
   - **Blue regions:** Indicate areas of lesser importance.

From the heatmap:
 - **Central region of the image:** This area contains red/orange hues, indicating the model's substantial focus on cellular f