In [None]:
from pathlib import Path
import os
import numpy as np
import torch
import time
import sys
from lime import lime_image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image

import myextensions

In [None]:
torch.manual_seed(123)
np.random.seed(123)

model = myextensions.get_vgg()

# Preprocessing transform for VGG
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Function to preprocess image batch for LIME
def batch_predict(images):
    batch = torch.stack([
        preprocess(Image.fromarray(img)) for img in images
    ]).to(myextensions.DEVICE)
    with torch.no_grad():
        logits = model(batch)
        probs = torch.nn.functional.softmax(logits, dim=1)
    return probs.cpu().numpy()  # Move back to CPU for LIME

explainer = lime_image.LimeImageExplainer()

# MODE = "no-attack-lime"
# DATASET_PATH = f"../../inputs/{MODE}"

# SAVE_DIR_PATH = f"/home/cat/uni/bakis/outputs/lime/{MODE}"

# Path(f"{SAVE_DIR_PATH}/pos").mkdir(parents=True, exist_ok=True)
# Path(f"{SAVE_DIR_PATH}/neg").mkdir(parents=True, exist_ok=True)
# Path(f"{SAVE_DIR_PATH}/all").mkdir(parents=True, exist_ok=True)
# Path(f"{SAVE_DIR_PATH}/blended").mkdir(parents=True, exist_ok=True)


In [None]:
def save_3_heatmaps(savepath, filename, explanation, input_image):
    filepath_all = os.path.join(savepath, "all", filename)
    filepath_blended = os.path.join(savepath, "blended", filename)


    # Get superpixel segmentation
    segments = explanation.segments  # 2D array, same size as image

    # Get explanation weights for top label
    top_label = explanation.top_labels[0]
    weights = dict(explanation.local_exp[top_label])  # { superpixel_idx: weight }

    # Build heatmap
    heatmap = np.zeros(segments.shape)
    for seg_idx in np.unique(segments):
        heatmap[segments == seg_idx] = weights.get(seg_idx, 0)

    vmax = np.max(np.abs(heatmap)) # Maximum value (positive of the max absolute value)
    vmin = -vmax  # Minimum value (negative of the max absolute value)

    fig, plt_axis = plt.subplots(figsize=(7, 7))
    plt_axis.xaxis.set_ticks_position("none")
    plt_axis.yaxis.set_ticks_position("none")
    plt_axis.set_yticklabels([])
    plt_axis.set_xticklabels([])
    plt_axis.grid(visible=False)
    plt_axis.imshow(heatmap, cmap='bwr', alpha=0.7, vmin=vmin, vmax=vmax)

    fig.savefig(filepath_all, bbox_inches='tight', pad_inches=0)

    plt_axis.imshow(np.mean(input_image, axis=2), cmap="gray")
    plt_axis.imshow(heatmap, cmap='bwr', alpha=0.7, vmin=vmin, vmax=vmax)

    fig.savefig(filepath_blended, bbox_inches='tight', pad_inches=0)


def print_time(savepath,timings, filenum):
    timings_np = np.array(timings)
    with open(f'{savepath}/time_statistics{filenum}.txt', "w") as f:
        f.write("=== Attribution Time Stats ===\n")
        f.write(f"Total images:      {len(timings)}\n")
        f.write(f"Average time:      {timings_np.mean():.4f} s\n")
        f.write(f"Standard deviation:{timings_np.std():.4f} s\n")
        f.write(f"Minimum time:      {timings_np.min():.4f} s\n")
        f.write(f"Maximum time:      {timings_np.max():.4f} s\n")

In [None]:
timings = []

MODES = \
[
    "adversial-patch/success", "adversial-patch/fail",
    "feature-adversaries/success", "feature-adversaries/fail",
    "fgsm/success", "fgsm/fail",
    "no-attack-lime", 
    "pgd/success", "pgd/fail",
    "shadow-attack-nontargeted/success", "shadow-attack-nontargeted/fail",
    "square-attack-l2/success", "square-attack-l2/fail",
    "square-attack-linf/success", "square-attack-linf/fail",
]

for mode in MODES:
    dATASET_PATH = f"../../inputs/{mode}"
    sAVE_DIR_PATH = f"/home/cat/uni/bakis/outputs/lime/{mode}"

    Path(f"{sAVE_DIR_PATH}/pos").mkdir(parents=True, exist_ok=True)
    Path(f"{sAVE_DIR_PATH}/neg").mkdir(parents=True, exist_ok=True)
    Path(f"{sAVE_DIR_PATH}/all").mkdir(parents=True, exist_ok=True)
    Path(f"{sAVE_DIR_PATH}/blended").mkdir(parents=True, exist_ok=True)

    sorted_filenames = os.listdir(dATASET_PATH)
    sorted_filenames.sort()

    for filename in sorted_filenames:
        full_path = os.path.join(dATASET_PATH, filename)

        input_image = np.array(Image.open(full_path).convert('RGB'))

        start_time = time.perf_counter()

        explanation = explainer.explain_instance(
            input_image,
            batch_predict,
            top_labels=5,
            num_samples=500  # You can increase for better results (slower)
        )

        elapsed = time.perf_counter() - start_time
        
        try:
            save_3_heatmaps(sAVE_DIR_PATH, filename, explanation, input_image)
        except Exception:
            print(f"Exception with: {full_path}")
            pass


    if (timings != []):
        print_time(sAVE_DIR_PATH ,timings, 0)
        timings = []