In [1]:
import gradio as gr
from gradio.components import Dropdown, File, Image, Slider, Textbox, Number, Button
from gradio import Examples
from saliency.saliency_zoo import agi, big, guided_ig, fast_ig, ig, sm, sg, deeplift, saliencymap
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

already_setted = {
    "_TemporaryFileWrapper": False,
    "ndarray": False,
    "str": False,
    # "config": False
}

def check_all_setted(name):
    name = name.__class__.__name__
    already_setted[name] = True
    for key in already_setted.keys():
        if already_setted[key] == False:
            return Button.update(interactive=False)
    return Button.update(interactive=True)

# guided_ig有一个参数,steps, agi有两个参数，topk和steps,big有一个参数steps,ig有一个参数steps,sg有一个参数steps


def show_slider(method):
    if method == "Adversarial Gradient Integration":
        return [Slider.update(visible=True, value=20), Slider.update(visible=True, value=50)]
    elif method == "Boundary Attribution" or method == "Guided-IntegratedGradient" or method == "SmoothGradient" or method == "IntegratedGradient":
        return [Slider.update(visible=False), Slider.update(visible=True, value=50 if method != "Guided-IntegratedGradient" else 15)]
    else:
        return [Slider.update(visible=False), Slider.update(visible=False)]


def load_model(model):
    model = torch.load(model.name).to(device).eval()
    return model


def explaination_pipeline(model, method, image, topk, steps):
    raise NotImplementedError
    return [np.random.randint(0, 255, (224, 224, 3)), np.random.randint(0, 255, (224, 224, 3))]


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            model = File(label="pytorch model", type="file")
            image = Image(label="image", type="numpy")

        with gr.Column():
            explaination_result = Image(
                label="explaination result", interactive=False)
            explaination_insert_deletion = Image(
                label="explaination insert deletion result", interactive=False)

    with gr.Row():
        method = Dropdown(["Adversarial Gradient Integration", "Boundary Attribution", "Guided-IntegratedGradient", "Fast-IntegratedGradient", "IntegratedGradient",
                           "SaliencyGradient", "SmoothGradient", "DeepLift", "SaliencyMap"], label="explaination method")
        topk = Slider(minimum=0, maximum=1000, step=1,
                      label="topk", visible=False, interactive=True)
        steps = Slider(minimum=0, maximum=200, step=1,
                       label="steps", visible=False, interactive=True)
        method.change(show_slider, method, outputs=[topk, steps])

    with gr.Row():
        button = Button(value="Run Explaination!",interactive=False)
        button.click(explaination_pipeline, inputs=[model, method, image, topk, steps], outputs=[
                     explaination_result, explaination_insert_deletion],api_name="explaination")
        
        model.change(check_all_setted,inputs=model, outputs=[button])
        method.change(check_all_setted,inputs=method, outputs=[button])
        image.change(check_all_setted,inputs=image, outputs=[button])

demo.launch(debug=True)


# import torch
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# def load_model(model):
#     model = torch.load(model.name).to(device).eval()
#     return model

# model = File(label="pytorch model", type="file", description="upload your pytorch model")
# method = Dropdown(["Adversarial Gradient Integration", "Boundary Attribution", "Guided-IntegratedGradient", "Fast-IntegratedGradient", "IntegratedGradient",
#                 "SaliencyGradient", "SmoothGradient", "DeepLift", "SaliencyMap"], label="explaination method", description="choose your explaination method")
# image = Image(label="image", type="numpy", description="upload your image")
# config = File(label="config", type="file", description="upload your config")
# # target = Slider(minimum=0, maximum=1000, step=1, default=0, label="target", description="input your target class")
# # mean = Number(default=0.5, label="mean", description="input your mean")
# # std = Number(default=0.5, label="std", description="input your std")

# # 创建Gradio界面
# inputs = [
#     model,
#     method,
#     image,
#     config
# ]
# outputs = Image(label="explaination result", type="numpy", description="explaination result")


# def load_model(model):
#     model = torch.load(model.name)
#     return model

# def explaination_pipeline(model, method, image, target, mean, std):
#     model = load_model(model)
#     image = image / 255.0
#     image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
#     image.requires_grad = True
#     if method == "Adversarial Gradient Integration":
#         attribution_map = agi(model, image, 0, 1)
#     elif method == "Boundary Attribution":
#         attribution_map = big(model, image, 0)
#     elif method == "Guided-IntegratedGradient":
#         attribution_map = guided_ig(model, image, 0)
#     elif method == "Fast-IntegratedGradient":
#         attribution_map = fast_ig(model, image, 0)
#     elif method == "IntegratedGradient":
#         attribution_map = ig(model, image, 0)
#     elif method == "SaliencyGradient":
#         attribution_map = sm(model, image, 0)
#     elif method == "SmoothGradient":
#         attribution_map = sg(model, image, 0)
#     elif method == "DeepLift":
#         attribution_map = deeplift(model, image, 0)
#     elif method == "SaliencyMap":
#         attribution_map = saliencymap(model, image, 0)
#     attribution_map = attribution_map.squeeze().detach().cpu().numpy()
#     attribution_map = (attribution_map - attribution_map.min()) / (attribution_map.max() - attribution_map.min())
#     attribution_map = attribution_map * 255
#     attribution_map = attribution_map.astype("uint8")
#     attribution_map = attribution_map.transpose(1, 2, 0)
#     return attribution_map


# gr.Interface(fn=explaination_pipeline, inputs=inputs,
#              outputs=outputs, title="Model Explaination",interface="english").launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/gradio/routes.py", line 414, in run_predict
    output = await app.get_blocks().process_api(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/gradio/blocks.py", line 1320, in process_api
    result = await self.call_function(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/gradio/blocks.py", line 1048, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/anyio/to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "/Library/Frameworks/Python.framework/Versions/3.10