In [2]:
!pip install -r requirements.txt

Collecting albumentations==1.2.1 (from -r requirements.txt (line 1))
  Downloading albumentations-1.2.1-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.7/116.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting grad-cam==1.4.6 (from -r requirements.txt (line 2))
  Downloading grad-cam-1.4.6.tar.gz (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m56.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting torch_lr_finder==0.2.1 (from -r requirements.txt (line 3))
  Downloading torch_lr_finder-0.2.1-py3-none-any.whl (11 kB)
Collecting gradio (from -r requirements.txt (line 5))
  Downloading gradio-4.37.2-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m57.1 MB/s[0

In [3]:
import gradio as gr
import numpy as np
from PIL import Image
from albumentations import (
    Compose,
    Normalize
)
from albumentations.pytorch.transforms import ToTensorV2
import torch
from models.resnet import Lit_CIFAR10_Resnet18
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from utils.utils import get_misclassified_images, denormalize, show_misclassified_images

In [4]:
# cuda = torch.cuda.is_available()
device=torch.device("cpu")

trained_model = Lit_CIFAR10_Resnet18()
trained_model.load_state_dict(torch.load("Lit_CIFAR10_Resnet18_trained.pt", map_location=torch.device(device)), strict=False)
trained_model.prepare_data()
trained_model.setup()
# test_dataloader = trained_model.test_dataloader()

model = trained_model.model
classes = trained_model.classes

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 34735047.24it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [13]:
def resize_image_pil(image, new_width, new_height):

    # Convert to PIL image
    img = Image.fromarray(np.array(image))

    # Get original size
    width, height = img.size

    # Calculate scale
    width_scale = new_width / width
    height_scale = new_height / height
    scale = min(width_scale, height_scale)

    # Resize
    resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)

    # Crop to exact size
    resized = resized.crop((0, 0, new_width, new_height))

    return resized

def inference(input_img, is_grad_cam=True, transparency = 0.5, target_layer_number = -1,
              top_predictions=3, is_misclassified_images=True, num_misclassified_images=10):
    input_img = resize_image_pil(input_img, 32, 32)  # The output is in PIL

    input_img = np.array(input_img)  # This conversion is important because albumentations accepts numpy array, and not PIL
    org_img = input_img.copy()
    input_img = input_img.reshape((32, 32, 3))
    transforms = Compose(
        # Normalize
        [Normalize([0.49139968, 0.48215841, 0.44653091],
                                  [0.24703223, 0.24348513, 0.26158784]),
        # Convert to tensor
        ToTensorV2()])
    input_img = transforms(image = input_img)['image']

    # input_img = input_img
    input_img = input_img.unsqueeze(0)  # to get batch dimension
    outputs = model(input_img)
    softmax = torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    confidences = {classes[i]: float(o[i]) for i in range(10)}
    _, prediction = torch.max(outputs, 1)

    if is_grad_cam:
        # org_img = denormalize(input_img.cpu().numpy().squeeze())
        # print (org_img)
        target_layers = [model.layer2[target_layer_number]]
        cam = GradCAM(model=model, target_layers=target_layers)
        grayscale_cam = cam(input_tensor=input_img, targets=None)
        grayscale_cam = grayscale_cam[0, :]
        # visualization = show_cam_on_image(org_img, grayscale_cam, use_rgb=True, image_weight=transparency)
        visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)
    else:
        visualization = None

    # Sort the confidences dictionary based on confidence values
    sorted_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True))
    # Pick the top n predictions
    top_n_confidences = dict(list(sorted_confidences.items())[:top_predictions])

    # misclassified_data = get_misclassified_images(model, test_dataloader, device)

    if is_misclassified_images:
        # Plot the misclassified data
        misclassified_data = get_misclassified_images(trained_model, device=device)
        misclassified_images = show_misclassified_images(misclassified_data, classes, num_samples=num_misclassified_images)
    else:
        misclassified_images = None

    # print (classes[prediction[0].item()])

    return classes[prediction[0].item()], visualization, top_n_confidences, misclassified_images



In [14]:
title = "CIFAR10 trained on ResNet18 Model with GradCAM"
description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
# examples = [["cat.jpg", True, 0.5, -1, 3, True, 10],
#             ["dog.jpg", True, 0.5, -1, 3, True, 10],
#             ["bird.jpg", True, 0.5, -1, 3, True, 10],
#             ["car.jpg", True, 0.5, -1, 3, True, 10],
#             ["deer.jpg", True, 0.5, -1, 3, True, 10],
#             ["frog.jpg", True, 0.5, -1, 3, True, 10],
#             ["horse.jpg", True, 0.5, -1, 3, True, 10],
#             ["plane.jpg", True, 0.5, -1, 3, True, 10],
#             ["ship.jpg", True, 0.5, -1, 3, True, 10],
#             ["truck.jpg", True, 0.5, -1, 3, True, 10]]

demo = gr.Interface(
    inference,
    inputs = [
        gr.Image(width=256, height=256, label="Input Image"),
        gr.Checkbox(label="Show GradCAM"),
        gr.Slider(0, 1, value = 0.5, label="Overall Opacity of Image"),
        gr.Slider(-2, -1, value = -2, step=1, label="Which Layer?"),
        gr.Slider(2, 10, value=3, step=1, label="Number of Top Classes"),
        gr.Checkbox(label="Show Misclassified Images"),
        gr.Slider(5, 40, value=10, step=5, label="Number of Misclassified Images")
        ],
    outputs = [
        "text",
        gr.Image(width=256, height=256, label="Output"),
        gr.Label(label="Top Classes"),
        gr.Plot(label="Misclassified Images")
        ],
    title = title,
    description = description
    # examples = examples,
)
demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://1d383e8bcd1a04c2d0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


bird
bird
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://1d383e8bcd1a04c2d0.gradio.live


