In [1]:
import os
import numpy as np
from lime.lime_image import LimeImageExplainer
from PIL import Image
import torch
import matplotlib.pyplot as plt
from ultralytics import YOLO  # Assuming you're using the Ultralytics YOLOv8 implementation

# Load the YOLOv8 model
model = YOLO('yolov8m.pt')  # Replace with your actual model path
model.to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()

# Define a function to preprocess the image
def preprocess(image):
    # Implement preprocessing specific to YOLOv8 if necessary
    image = np.array(image) / 255.0  # Normalize to [0, 1]
    return torch.tensor(image, dtype=torch.float32)

# Define a function to predict with the model
def predict(input_image):
    input_image = torch.tensor(input_image, dtype=torch.float32)
    if input_image.dim() == 4:
        input_image = input_image.permute(0, 3, 1, 2)  # Permute the dimensions
    input_image = input_image.to('cuda' if torch.cuda.is_available() else 'cpu')  # Move to the appropriate device
    with torch.no_grad():
        output = model(input_image)
    return output[0].numpy()  # Assuming the model returns a tensor, convert to numpy array

def generate_lime(image_path=None, save_path=None):
    if image_path is None:
        test_data_path = "data/test/Task 1/"
        for image_file in os.listdir(test_data_path):
            print("Processing", image_file)
            image_path = os.path.join(test_data_path, image_file)
            image_name = os.path.splitext(image_file)[0]
            image = Image.open(image_path).convert("RGB")
            width, height = image.size
            image = preprocess(image)
            image = image.unsqueeze(0)  # Add batch dimension

            # Create the LIME explainer
            explainer = LimeImageExplainer()

            # Explain the model's predictions for the image
            explanation = explainer.explain_instance(
                image[0].permute(1, 2, 0).numpy(),
                predict,
                top_labels=5,
                num_samples=1000,
            )

            # Get the image and mask for the explanation
            lime_image, mask = explanation.get_image_and_mask(
                explanation.top_labels[0],
                positive_only=False,
                num_features=10,
                hide_rest=False,
            )

            # Normalize the image to the [0, 1] range
            lime_image = (lime_image - np.min(lime_image)) / (np.max(lime_image) - np.min(lime_image))

            # Save the LIME image
            os.makedirs("docs/evaluation/lime/", exist_ok=True)
            lime_image_path = f"docs/evaluation/lime/{image_name}.jpg"
            plt.imsave(lime_image_path, lime_image)

            # Resize the image to the original size
            lime_image = Image.open(lime_image_path)
            lime_image = lime_image.resize((width, height))
            lime_image.save(lime_image_path)

    else:
        print("Processing", image_path)
        image = Image.open(image_path).convert("RGB")
        width, height = image.size
        image = preprocess(image)
        image = image.unsqueeze(0)  # Add batch dimension

        # Create the LIME explainer
        explainer = LimeImageExplainer()

        # Explain the model's predictions for the image
        explanation = explainer.explain_instance(
            image[0].permute(1, 2, 0).numpy(), predict, top_labels=5, num_samples=1000
        )

        # Get the image and mask for the explanation
        lime_image, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10,
            hide_rest=False,
        )

        # Normalize the image to the [0, 1] range
        lime_image = (lime_image - np.min(lime_image)) / (np.max(lime_image) - np.min(lime_image))

        # Save the LIME image
        plt.imsave(save_path, lime_image)

        # Resize the image to the original size
        lime_image = Image.open(save_path)
        lime_image = lime_image.resize((width, height))
        lime_image.save(save_path)

# Example usage
generate_lime(image_path='path/to/your/image.jpg', save_path='path/to/save/lime_image.jpg')


KeyboardInterrupt: 