In [26]:
import os

import torch
from torchvision import models, transforms
from PIL import Image as PilImage

from omnixai.data.image import Image
from omnixai.explainers.vision.specific.gradcam.pytorch.gradcam import GradCAM


In [14]:
os.chdir("..")

In [17]:
from src.model import ImageClassificationModel

In [24]:
model_path='model_assets/model.pt'
img_paths=['data/test/virus/person8_virus_28.jpeg','data/test/bacteria/person1474_bacteria_3837.jpeg']

In [42]:
def grad_pretrained_only(img_paths:list):
    cam_model = models.mobilenet_v2(pretrained=True)
    for img_path in img_paths:
        # Load the test image
        img = Image(PilImage.open(img_path).convert('RGB'))
        # The preprocessing model
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])
        explainer = GradCAM(
            model= cam_model,
            target_layer=cam_model.features[-1][-1],
            preprocess_function=preprocess
        )
        # Explain the top label
        explanations = explainer.explain(img)
        explanations.ipython_plot(index=0, class_names="virus")

    return

In [48]:
def grad_trained(model_path:str, img_paths:list):
    model = ImageClassificationModel(num_classes=3)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    for img_path in img_paths:
        # Load the test image
        img = Image(PilImage.open(img_path).convert('RGB'))
        # The preprocessing model
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])
        explainer = GradCAM(
            model= model,
            target_layer=model.mobilenet.features[-1][-1],
            preprocess_function=preprocess
        )
        # Explain the top label
        explanations = explainer.explain(img)
        explanations.ipython_plot(index=0, class_names="virus")

    return

In [46]:
grad_pretrained_only(img_paths)


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.


Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.



In [49]:
grad_trained(model_path, img_paths)