In [2]:
import matplotlib.pyplot as pl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from lime import lime_image
from PIL import Image
from skimage.segmentation import mark_boundaries
from dataclasses import dataclass
from bidict import bidict
from torchvision import transforms

from Models.VGG19_model import VGG19

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
vgg = VGG19(models.vgg19(weights="VGG19_Weights.DEFAULT"))
vgg.load_state_dict(torch.load("../assets/vgg19_transfer.pth", map_location=device))
vgg.eval()

  vgg.load_state_dict(torch.load("../assets/vgg19_transfer.pth", map_location=device))


VGG19(
  (original_vgg19_model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)

In [5]:
classes = bidict({0: "glioma", 1: "meningioma", 2: "notumor", 3: "pituitary"})

In [6]:
@dataclass
class Params:
    image_path: str
    label_to_explain: int
    num_samples: int
    num_features: int
    positive_only: bool

In [7]:
def get_model_prediction(model: nn.Module, image) -> int:
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.1855, 0.1855, 0.1855], std=[0.2003, 0.2003, 0.2004]
            ),
        ]
    )

    input_tensor = transform(image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        output = model(input_batch)

    return torch.argmax(output[0]).item()

In [8]:
image_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.Normalize(
            mean=[0.1855, 0.1855, 0.1855], std=[0.2003, 0.2003, 0.2004]
        ),
    ]
)

preprocess_transform = transforms.ToTensor()

In [9]:
def explain(model: nn.Module, params: Params) -> None:
    image = Image.open(params.image_path)
    image = image.convert("RGB")
    image = image.resize((224, 224))
    print(f"Model prediction = {classes[get_model_prediction(model, image)]}")
    image = np.array(image)

    def batch_predict(images: np.ndarray) -> np.ndarray:
        model.eval()
        batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)

        batch = batch.to(device)
        logits = model(batch)
        probs = F.softmax(logits, dim=1)

        return probs.detach().cpu().numpy()

    # segmenter = SegmentationAlgorithm(
    #     "quickshift",
    #     kernel_size=params.kernel_size,
    #     max_dist=params.max_dist,
    #     ratio=params.ratio,
    # )

    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(
        image=image,
        classifier_fn=batch_predict,
        # segmentation_fn=segmenter,
        num_samples=params.num_samples,
    )

    print(f"Top label = {classes[explanation.top_labels[0]]}")

    temp, mask = explanation.get_image_and_mask(
        label=params.label_to_explain,
        positive_only=params.positive_only,
        num_features=params.num_features,
    )

    marked_image = mark_boundaries(temp / 255.0, mask)

    return marked_image

In [10]:
images = []

image = Image.open("../assets/compression/meningioma_original.jpg")

image_50_compressed = image.copy()
image_50_compressed.save(
    "../assets/compression/meningioma_50.jpg", quality=50, optimize=True
)

image_1_compressed = image.copy()
image_1_compressed.save(
    "../assets/compression/meningioma_1.jpg", quality=1, optimize=True
)

image_20_compressed = image.copy()
image_20_compressed.save(
    "../assets/compression/meningioma_20.jpg", quality=20, optimize=True
)

images.append("../assets/compression/meningioma_original.jpg")
images.append("../assets/compression/meningioma_50.jpg")
images.append("../assets/compression/meningioma_1.jpg")
images.append("../assets/compression/meningioma_20.jpg")

In [11]:
# image = explain(vgg, images[0], classes.inverse["meningioma"])

TypeError: explain() takes 2 positional arguments but 3 were given