In [None]:
!pip install timm==0.9.12 torch torchvision --quiet
!pip install captum lime scikit-image matplotlib --quiet
!pip install --quiet scikit-image

In [None]:
import os, numpy as np, matplotlib.pyplot as plt
from matplotlib import gridspec
from tqdm import tqdm
import torch, torch.nn.functional as F
import timm
from timm.data import create_transform
from collections import Counter
from PIL import Image
from skimage.segmentation import slic, mark_boundaries
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score


In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
DEVICE   = 'cuda' if torch.cuda.is_available() else 'cpu' # assuming use gpu on colab
IMG_SIZE = 300
TRANSFORM = create_transform(
    IMG_SIZE,
    is_training=False,
    interpolation='bicubic',
    mean=(0.485, 0.456, 0.406),
    std =(0.229, 0.224, 0.225)
)


In [None]:
CHECKPOINT = '/content/drive/MyDrive/1430project/best_model.pth'
MODEL_NAME = 'efficientnet_b3'
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=2).to(DEVICE)
model.load_state_dict(torch.load(CHECKPOINT, map_location=DEVICE))
model.eval()
class_names = ['fake', 'real']


In [None]:
# Ty to Allison for standardizing this explainer in advance so it was easy to choose all params and all that
class LimeExplainer:
    def __init__(self, model, transform, device,
                 num_samples=1000, num_superpixels=50,
                 compactness=10, sigma=1, output_dir=None):
        self.model           = model
        self.transform       = transform
        self.device          = device
        self.num_samples     = num_samples
        self.num_superpixels = num_superpixels
        self.compactness     = compactness
        self.sigma           = sigma
        self.output_dir      = output_dir

  # set up all core params

    def segment_image(self, image):
      # use SLIC to split into superpixesl
        img_array = np.array(image)
        segments  = slic(
            img_array,
            n_segments=self.num_superpixels,
            compactness=self.compactness,
            sigma=self.sigma,
            start_label=1
        )
        return segments

    def perturb_image(self, image, segments, perturb_mask):
      # replace "off" segs w gray so we can test effect
        img_array = np.array(image).copy()
        gray = np.mean(img_array, axis=2, keepdims=True).repeat(3, axis=2)
        for seg_id in range(1, np.max(segments)+1):
            if not perturb_mask[seg_id-1]:
                img_array[segments == seg_id] = gray[segments == seg_id]

        return Image.fromarray(img_array.astype(np.uint8))

    def get_model_prediction(self, image):
        tensor = self.transform(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            logits = self.model(tensor)
            probs  = F.softmax(logits, dim=1).cpu().numpy()[0]
        return probs

    def explain(self, image, target_class=None):

        initial_probs = self.get_model_prediction(image)
        pred_cls      = int(np.argmax(initial_probs))

        if target_class is None:
            target_class = pred_cls


        segments     = self.segment_image(image)
        num_segments = int(np.max(segments))

        perturbed_data, predictions = [], []
        for i in tqdm(range(self.num_samples),
                      desc=f"Sampling for class {target_class}"):
        # randomly turn segments on/off and collect how it effects prediction

            mask = np.random.randint(0, 2, num_segments, dtype=bool)
            pert_img = self.perturb_image(image, segments, mask)
            prob     = self.get_model_prediction(pert_img)[target_class]

            perturbed_data.append(mask)
            predictions.append(prob)

        perturbed_data = np.array(perturbed_data)
        predictions    = np.array(predictions)

        explainer = Ridge(alpha=1.0)
        explainer.fit(perturbed_data, predictions)
        feat_imp = explainer.coef_

        seg_imp = np.zeros(image.size[::-1], dtype=np.float32)
        for seg_id in range(1, num_segments+1):
            seg_imp[segments == seg_id] = feat_imp[seg_id-1]

        r2 = r2_score(predictions, explainer.predict(perturbed_data))

        return segments, seg_imp, feat_imp, r2, pred_cls, initial_probs


    def visualize_explanation(self, image, segments, seg_imp,
                              class_label, probability, r2,
                              save_path=None):

        img_array = np.array(image)

        if np.max(np.abs(seg_imp)) > 0:
            seg_imp = seg_imp / np.max(np.abs(seg_imp))

        fig = plt.figure(figsize=(20, 15))
        gs  = gridspec.GridSpec(2, 3, height_ratios=[1, 0.05])

        # original
        ax = plt.subplot(gs[0, 0])
        ax.imshow(img_array)
        ax.set_title('Original Image', fontsize=14)
        ax.axis('off')

        # segmentation
        ax = plt.subplot(gs[0, 1])
        ax.imshow(mark_boundaries(img_array, segments))
        ax.set_title(f'Segmentation ({np.max(segments)} superpixels)',
                     fontsize=14)
        ax.axis('off')
        # heat-map
        ax = plt.subplot(gs[0, 2])
        cmap     = plt.cm.RdYlGn
        heat_img = ax.imshow(seg_imp, cmap=cmap, vmin=-1, vmax=1)
        ax.set_title(f'Importance Heat-map\nClass: {class_label}, '
            f'Prob: {probability:.4f}, R^2: {r2:.4f}',
            fontsize=14
        )
        ax.axis('off')

        # colour-bar
        cax = plt.subplot(gs[1, :])
        plt.colorbar(heat_img, cax=cax, orientation='horizontal')
        cax.set_xlabel('Feature Importance (Red: Negative  Green: Positive)',
            fontsize=12)

        plt.tight_layout()

        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        return fig

    def create_overlay_visualization(self, image, segments, seg_imp,
                                     class_label, probability,
                                     threshold=0.0, save_path=None):

        img     = np.array(image).astype(float) / 255.0
        if np.max(np.abs(seg_imp)) > 0:
            seg_imp = seg_imp / np.max(np.abs(seg_imp))

        pos = seg_imp > threshold
        neg = seg_imp < -threshold
        alpha = 0.5
        overlay = img.copy()
        if np.any(pos):
          # green overlay means positively influential space
            green = np.zeros_like(img); green[..., 1] = np.abs(seg_imp) * pos
            overlay = overlay * (1 - alpha * pos[..., None]) + green * alpha
        if np.any(neg):
          # red overlay means negatively influential space
            red = np.zeros_like(img); red[..., 0] = np.abs(seg_imp) * neg
            overlay = overlay * (1 - alpha * neg[..., None]) + red * alpha

        overlay = np.clip(overlay, 0, 1)

        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(overlay)
        ax.set_title(f'LIME Overlay {class_label} ({probability:.4f})',
                     fontsize=14)
        ax.axis('off')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

        return fig

    def find_most_influential_regions(self, segments, feat_imp, top_n=5):
        abs_imp = np.abs(feat_imp)
        top_idx = np.argsort(abs_imp)[::-1][:top_n]

        regions = []
        for rank, idx in enumerate(top_idx, 1):
            regions.append({
                'rank'       : rank,
                'segment_id' : idx + 1,
                'importance' : feat_imp[idx],
                'influence'  : 'Positive' if feat_imp[idx] > 0 else 'Negative'
            })
        return regions

    def visualize_top_regions(self, image, segments, feat_imp,
                              top_n=5, save_path=None):
        img_array = np.array(image)
        abs_imp   = np.abs(feat_imp)
        top_idx   = np.argsort(abs_imp)[::-1][:top_n]

        fig, ax = plt.subplots(1, top_n + 1, figsize=(20, 5))

        ax[0].imshow(mark_boundaries(img_array, segments))
        ax[0].set_title('All Superpixels', fontsize=12)
        ax[0].axis('off')

        alpha = 0.5
        for i, idx in enumerate(top_idx):
            seg_id     = idx + 1
            importance = feat_imp[idx]
            mask       = segments == seg_id

            highlighted = img_array.copy().astype(float)
            overlay = np.zeros_like(highlighted)
            overlay[mask] = [0, 255, 0] if importance > 0 else [255, 0, 0]
            highlighted = highlighted * (1 - alpha) + overlay * alpha
            highlighted = highlighted.astype(np.uint8)
            highlighted = mark_boundaries(highlighted, mask.astype(int),
                                          color=(1, 1, 1))
            sign = '+' if importance > 0 else '−'
            ax[i+1].imshow(highlighted)
            ax[i+1].set_title(
                f'Region {seg_id}\n{sign}{abs_imp[idx]:.4f}',
                fontsize=10
            )
            ax[i+1].axis('off')
        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

        return fig


In [None]:

ROOT_DIR   = '/content/drive/MyDrive/1430project'
OUTPUT_DIR = '/content/drive/MyDrive/1430project/explanations_batch_lg'

class_names = ['fake', 'real']

img_paths = []
for cls in ['0', '1']:
    cls_dir = os.path.join(ROOT_DIR, cls)
    img_paths += [
        os.path.join(cls_dir, f)
        for f in os.listdir(cls_dir)
        if f.lower().endswith(('.png'))
    ]



In [None]:
!pip install --quiet captum


In [None]:
import matplotlib.pyplot as plt
import torch, numpy as np
from captum.attr import LayerGradCam, LayerAttribution
from skimage.transform import resize
from skimage.util import img_as_float32
from PIL import Image
import os

class GradCamHelper:
    def __init__(self, model, device, target_layer=None,
                 alpha: float = 0.5, cmap: str = 'jet'):
        self.model   = model
        self.device  = device
        self.layer   = target_layer or model.conv_head   # last conv if none given
        self.gradcam = LayerGradCam(self.model, self.layer)
        self.alpha   = alpha
        self.cmap    = cmap

    def get_heatmap(self, image_tensor, class_idx):
        attributions = self.gradcam.attribute(image_tensor, target=class_idx)
        up = LayerAttribution.interpolate(attributions,
                                          image_tensor.shape[-2:])
        # Resize CAM to match original input res
        heat = up.squeeze().cpu().detach().numpy()
        heat = np.maximum(heat, 0)          # ReLU
        if heat.max() > 0:
            heat /= heat.max()
        return heat
    def save_overlay(self, heatmap, pil_image, out_path,
                     alpha=None, cmap=None):
        alpha = self.alpha if alpha is None else alpha
        cmap  = self.cmap  if cmap  is None else cmap

        img = img_as_float32(np.array(pil_image))

        if heatmap.shape != img.shape[:2]: # make sure heatmap matches image size for overlay
            heatmap = resize(
                heatmap, img.shape[:2], order=1,
                preserve_range=True, anti_aliasing=True
            )

        hm_rgb = plt.colormaps.get_cmap(cmap)(heatmap)[..., :3] # blend heatmap and image
        overlay = (1 - alpha) * img + alpha * hm_rgb
        overlay = np.clip(overlay, 0, 1)

        plt.figure(figsize=(6, 6))
        plt.imshow(overlay)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(out_path, dpi=150, bbox_inches='tight')
        plt.close()


In [None]:

lime = LimeExplainer(model, transform=TRANSFORM, device=DEVICE, output_dir=OUTPUT_DIR)
grad_helper = GradCamHelper(model, DEVICE)
# set up both LIME and gcam for same model

for k, img_path in enumerate(sorted(img_paths), start=1):
      image     = Image.open(img_path).convert('RGB')
      probs     = lime.get_model_prediction(image)
      pred_idx  = int(np.argmax(probs))
      # get model prediction and index of top class
      pred_name = class_names[pred_idx]
      pred_prob = probs[pred_idx]
      img_tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE)
      heat = grad_helper.get_heatmap(img_tensor, pred_idx)
      base = os.path.splitext(os.path.basename(img_path))[0]
      # gen and save gcam overlay
      grad_path = os.path.join(OUTPUT_DIR,
                                f"{base}_{pred_name}_gradcam.png")
      grad_helper.save_overlay(heat, image, grad_path, alpha=0.5)
      segments, seg_img, feat_imp, r2, pred_cls, _ = lime.explain(image, pred_idx) # run lime to get feature importance
      out_prefix = f"{base}_{pred_name}"
      lime.visualize_explanation(
          image, segments, seg_imp, pred_name, pred_prob, r2,
          save_path=os.path.join(OUTPUT_DIR, f'{out_prefix}_explanation.png')
      )
      lime.create_overlay_visualization(
          image, segments, seg_imp,
          pred_name, pred_prob, threshold=0.0,
          save_path=os.path.join(OUTPUT_DIR, f'{out_prefix}_overlay.png')
      )

      lime.visualize_top_regions(
          image, segments, feat_imp, top_n=5,
          save_path=os.path.join(OUTPUT_DIR, f'{out_prefix}_top_regions.png')
      )
      np.savez( # saev for later analysis if want
          os.path.join(OUTPUT_DIR, f'{out_prefix}_data.npz'),
          segments=segments,
          segments_importance=seg_imp,
          feature_importance=feat_imp,
          r2=r2,
          prediction=probs,
          gradcam=heat
      )





