In [11]:
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms
from captum.attr import Occlusion
from captum.attr import visualization as viz

## Get Dataset Statistics

In [12]:
from utils.preprocessing import get_mean_std, get_label_map

mean, std = get_mean_std('dataset_1500')
label_map = get_label_map('dataset_1500')

Calculating mean and std: 100%|██████████| 141/141 [00:14<00:00,  9.52it/s]


## Image Preprocessing

In [13]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

def process_image(image):
    return transform(image).unsqueeze(0)

## Load Pre-Trained Model

In [14]:
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"device: {device}")

model = models.efficientnet_v2_s()
print(model.classifier)

in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    torch.nn.Linear(in_features, 128),
    torch.nn.BatchNorm1d(128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(128, 3)
)
print(model.classifier)

model.load_state_dict(torch.load('ckpts/effv2s_bn_0.001_10_0.5/best_val_loss.pth', map_location=device))
model = model.to(device)

device: mps
Sequential(
  (0): Dropout(p=0.2, inplace=True)
  (1): Linear(in_features=1280, out_features=1000, bias=True)
)
Sequential(
  (0): Linear(in_features=1280, out_features=128, bias=True)
  (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=128, out_features=3, bias=True)
)


## Predict Function

In [15]:
def predict(image):
    transformed_image = process_image(image)
    
    model.eval()
    with torch.no_grad():
        outputs = model(transformed_image.to(device))
        probabilities = F.softmax(outputs, dim=1).cpu().numpy()[0]
        return {label_map[i]: float(probabilities[i]) for i in range(len(label_map))}

## Get Occlusion Image Function

In [16]:
def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    img = Image.open(buf)
    return img

In [17]:
def get_occlusion_image(image):

    transformed_image = process_image(image)
    
    model.eval()
    output = model(transformed_image.to(device))
    output = F.softmax(output, dim=1)
    _, pred = torch.max(output, 1)
    pred.squeeze_()

    model.to(device)
    model.eval()
    occlusion = Occlusion(model)

    attributions_occ = occlusion.attribute(transformed_image.to(device),
                                        strides = (3, 8, 8),
                                        target=pred,
                                        sliding_window_shapes=(3, 15, 15),
                                        baselines=0)

    figure, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                                  np.array(image.resize((256, 256))),
                                                  ["original_image", "blended_heat_map", "heat_map"],
                                                  ["all", "positive", "positive"],
                                                  fig_size=(15, 5),
                                                  alpha_overlay=0.7,
                                                  show_colorbar=True,
                                                  outlier_perc=2)

    captum_image = fig2img(figure)
    return captum_image

## Clear Image Function

In [18]:
def clear_image():
    return None, None, None

## Gradio App

In [19]:
title=(
    """
    <center>
        <h1> Blue Magpie Recognizer 🦜 </h1>
        <b> This model recognizes the species of a blue magpie from an image. <b>
    </center>
    """
)

In [20]:
import gradio as gr

with gr.Blocks(theme="soft") as demo:
    gr.Markdown(title)
    with gr.Row(equal_height=True):
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            with gr.Row():
                clear_button = gr.Button("Clear")
                upload_button = gr.Button("Submit")
        with gr.Column():
            predict_output = gr.Label(num_top_classes=3, label="Prediction")
    
    with gr.Row():
        captum_output = gr.Image(type="pil", label="Captum heatmap")

    
    upload_button.click(predict, inputs=image_input, outputs=predict_output)
    upload_button.click(get_occlusion_image, inputs=image_input, outputs=captum_output)

    clear_button.click(clear_image, inputs=[], outputs=[image_input, predict_output, captum_output])

demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://7bc71d70e302cf48a9.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




  plt.show()
  plt.show()
