In [None]:
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models import VGG16_Weights

from lime import lime_image
from skimage.segmentation import mark_boundaries
from skimage.color import label2rgb


device = torch.device('cuda')
print(f"Using device: {device}")

weights = VGG16_Weights.DEFAULT
imagenet_classes = weights.meta['categories']

model = models.vgg16(weights=weights).to(device)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def batch_predict(images):
    model.eval()

    batch = torch.stack([
        preprocess(Image.fromarray(img)) for img in images
    ]).to(device)

    with torch.no_grad():
        logits = model(batch)
        probs = torch.nn.functional.softmax(logits, dim=1)
    return probs.cpu().numpy() 

img_path = 'abba5.jpg' 
original_image = np.array(Image.open(img_path).convert('RGB'))

explainer = lime_image.LimeImageExplainer()

explanation = explainer.explain_instance(
    original_image,
    batch_predict,
    top_labels=5,
    num_samples=2000  
)


In [None]:
segments = explanation.segments  # 2D array, same size as image

top_label = explanation.top_labels[0]
weights = dict(explanation.local_exp[top_label])  # { superpixel_idx: weight }

heatmap = np.zeros(segments.shape)
for seg_idx in np.unique(segments):
    heatmap[segments == seg_idx] = weights.get(seg_idx, 0)

if np.max(heatmap) != np.min(heatmap):
    heatmap_norm = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
else:
    heatmap_norm = heatmap

vmax = np.max(np.abs(heatmap)) 
vmin = -vmax 



top_label_name = imagenet_classes[top_label]

plt.figure(figsize=(8, 8))
# plt.imshow(original_image)
plt.imshow(heatmap, cmap='bwr', alpha=0.5,vmin=vmin, vmax=vmax )  # red=positive, blue=negative
plt.title(f"LIME Heatmap (Strength) for {top_label_name}")
plt.axis('off')
plt.colorbar(label='Importance Weight')
plt.show()

In [None]:
top_label = explanation.top_labels[0]
temp, mask = explanation.get_image_and_mask(
    top_label, positive_only=False, num_features=5, hide_rest=False
)
img_boundry = mark_boundaries(temp / 255.0, mask)

plt.imshow(img_boundry)
plt.title(f'LIME Explanation for class {top_label}')
plt.axis('off')
plt.show()