Skip to content

David-Flurin/selfexplainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

764 Commits
 
 
 
 
 
 

Repository files navigation

Self-Explainer

This repository contains the code to train the Self-Explainer and conduct the experiments from the report. Most of the code was directly taken from NN-Explainer and was then adapted to fit the objective.

Directories & Files

  • main.py: Called to start a training
  • plot.py: Some plotting functions used for the report
  • models/: Contains Self-Explainer and Resnet50 modes
  • data/: Defined dataloaders and datasets for the different datasets the Self-Explainer was tested on
  • evaluation/: Various scripts used for the evaluation of the Self-Explainer and baselines
  • utils/: Various helper functions
  • synthetic_dataset/: Contains code to generate samples of the synthetic dataset.
  • color_dataset/: Pixel dataset used in the early stage of the thesis to verify certain loss properties.

Run trainings

Run a training by executing main.py with a configuration file like:

python main.py -c config_files/configuration.cfg

To reproduce the results from the report, use the configuration files in the config_files/ directory for the respective dataset. To change the synthetic dataset from single-label to multi-label mode, set the configuration field multilabel=True.

Evaluate model

To evaluate a model, use the file evaluate_selfexplainer.py in the evaluation/ directory. Change the settings defined in the file (from line 164-192) to fit your environment. The attribution masks are first generated and then evaluated with a bunch of metrics (see file compute_scores.py). To print averaged metrics, run

python print_selfexplainer_mean_scores.py path_to_results_file

With the script evaluate_selfexplainer.py, you can also generate class-specific masks.

Synthetic dataset

To generate a new instance of the synthetic dataset, you have to run following code:

from synthetic_dataset.generator import Generator

g = Generator()
g.create_set(save_path, number_per_shape, proportions, multilabel)

The arguments of the method create_set() are:

  • save_path: Path to location where the new dataset is saved.
  • number_per_shape: How many samples per shape are generated. E.g. number_per_shape=200 leads to 8*200 = 1600 samples in total.
  • proportions: List of floats corresponding to the proportions of training, validation and test set. E.g. if we have a dataset with 10000 samples and proportions=[0.5, 0.2, 0.3], then the training set has 5000 samples, the validation set 2000 samples and the test set 3000 samples.
  • multilabel: Whether the dataset also has samples with 2 target classes.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages