Skip to content

KosukeSumiyasu/MoXI

Repository files navigation

Identifying Important Group of Pixels using Interactions[CVPR'24]

Kosuke Sumiyasu, Kazuhiko Kawamoto, Hiroshi Kera [arxiv]

Overview

heatmap_example MoXI (Model eXplanation by Interactions) is a black box game-theoretic explanation method of image classifiers. Unlike other popular methods (e.g., GradCAM and AttentionRollout), it takes into account the cooperative contributions of two pixels and accurately identifies a group of pixels that have a high impact on prediction confidence.

Installation

Clone this repo:

$ git clone https://github.com/KosukeSumiyasu/MoXI

Demo

You can walk through some examples as follows.

Step 0: download ImageNet dataset. Please specify the ImageNet dataset path in the config_file_path.yaml file for integration.

Step 1: run the following.

$ cd MoXI
$ pip install -r requirements.txt
$ ./online_identify.sh
$ ./evaluate_curve.sh

Step 2: Open Jupyter notebooks in notebook/. 00_plot_insertion_deletion_curve.ipynb --- Quantitive evaluation by insertion and deletion curves. 01_visualize_heatmap.ipynb --- Qualitative evaluation by headmaps.

Try out MoXI on your own model

We offer two implementations of MoXI.

Implentation 1 (Model-agnostic implementation).

If your model is a CNN, use this implementation.

from src.util.load_parser import load_parser
args = load_parser() # set args.interaction_method = 'pixel_zero_values'
...
model = load_your_model(...)

Implementation 2 (ViT-aware implementation).

If you use Vision Transformer models, we highly recommend using this implementation. If your model is based on ViTForImageClassification class of a HuggingFace, it’s very simple.

from src.util.load_parser import load_parser
args = load_parser() # set args.interaction_method = 'vit_embedding'
...
model = replace_vit_embedding_mask(args, model)

For example, refer to the model in "Visualize the heatmap" at https://github.com/KosukeSumiyasu/MoXI/blob/main/notebook/01_visualize_heatmap.ipynb

Otherwise, you need a slight modification in your model.

  • allow forward() functions to recieve embedding_mask keyword argument
  • call select_batch_removing() in the input embedding module. No worries; after this modification, you can still load your pre-trained weights.
from .mask_vit_embedding import select_batch_removing

class YourViTClassifier(...):
  def __init__(...):
    self.ViTModel = YourViTModel(...)
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    output = self.YourViTModel(x, embedding_mask)
    ...

class YourViTModel(...):
  def __init__(...):
    self.ViTEmbedding = YourViTEmbedding(...)
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    embedding = self.ViTEmbedding(x, embedding_mask)
    ...

class YourEmbedding(...):
  def __init(...):
    ...
  def forward(x, ..., embedding_mask=None): # MODIFICATION: new keyword argument embedding_mask
    ...
    embeddings = self.patch_embeddings(x, ...)
    embeddings = embeddings + self.position_embeddings[:, 1:, :]
    
    # MODIFICATION: two lines added.
    if embedding_masking is not None:
        embeddings = select_batch_removing(embeddings, embedding_masking)
    ...

Contact

Citation

If you find this useful, please cite:

@inproceedings{kosuke2024identifying,
  author    = {Kosuke Sumiyasu and Kazuhiko Kawamoto and Hiroshi Kera},
  title     = {Identifying Important Group of Pixels using Interactions},
  journal   = {Conference on Computer Vision and Pattern Recognition (CVPR)},
  year      = {2024}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages