In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from captum.attr import Occlusion
from captum.attr import visualization as viz

## Get Dataset Statistics

In [2]:
from utils.preprocessing import get_mean_std, get_label_map

split_ratio = [0.7, 0.15, 0.15]
mean, std = get_mean_std('dataset_1500', split_ratio=split_ratio, random_seed=2024)

label_map = get_label_map('dataset_1500')

Calculating mean and std: 100%|██████████| 99/99 [00:10<00:00,  9.57it/s]


## Image Preprocessing

In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

def process_image(image):
    return transform(image).unsqueeze(0)

## Load Pre-Trained Model

In [None]:
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"device: {device}")

model = models.efficientnet_v2_s()
print(model.classifier)

in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    torch.nn.Linear(in_features, 128),
    torch.nn.BatchNorm1d(128),
    torch.nn.SiLU(),
    torch.nn.Dropout(0.3),
    torch.nn.Linear(128, 3)
)
print(model.classifier)

model.load_state_dict(torch.load('ckpts/effv2s_bn_si_0.001_10_0.5/best_val_acc.pth', map_location=device))
model.to(device)

## Predict Function

In [5]:
def get_prediction(image):
    transformed_image = process_image(image)
    
    model.eval()
    with torch.no_grad():
        outputs = model(transformed_image.to(device))
        probs = F.softmax(outputs, dim=1).cpu().numpy()[0]
        return {label_map[i]: prob for i, prob in enumerate(probs)}

## Get Occlusion Image Function

In [6]:
def get_occlusion_image(image):

    transformed_image = process_image(image)
    
    model.eval()
    output = model(transformed_image.to(device))
    output = F.softmax(output, dim=1)
    _, pred = torch.max(output, 1)
    pred.squeeze_()

    model.to(device)
    model.eval()
    occlusion = Occlusion(model)

    attributions_occ = occlusion.attribute(transformed_image.to(device),
                                           strides = (3, 8, 8),
                                           target=pred,
                                           sliding_window_shapes=(3, 15, 15),
                                           baselines=0)

    fig, ax = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
                                                np.array(image.resize((256, 256))),
                                                ["original_image", "blended_heat_map", "heat_map"],
                                                ["all", "positive", "positive"],
                                                fig_size=(15, 5),
                                                alpha_overlay=0.7,
                                                show_colorbar=True,
                                                outlier_perc=2)

    return fig

## Clear Image Function

In [7]:
def clear_image():
    return None, None, None

## Gradio App

In [8]:
title=(
    """
    <center>
        <h1> Blue Magpie Recognizer 🦜 </h1>
        <b> Upload an image and get the prediction! </b>
    </center>
    """
)

In [9]:
import gradio as gr

with gr.Blocks(theme="soft") as app:
    gr.Markdown(title)

    with gr.Row(equal_height=True):
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            
            with gr.Row():
                clear_button = gr.Button("Clear")
                predict_button = gr.Button("Predict")
        
        with gr.Column():
            predict_output = gr.Label(num_top_classes=3, label="Prediction")
    
    with gr.Row():
        captum_output = gr.Plot(label="Occlusion Heat Map")

    
    predict_button.click(get_prediction, inputs=image_input, outputs=predict_output)
    predict_button.click(get_occlusion_image, inputs=image_input, outputs=captum_output)

    clear_button.click(clear_image, inputs=[], outputs=[image_input, predict_output, captum_output])

app.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://0618e88c156a931009.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




  plt.show()
