In [1]:
import logging
from typing import Iterable

import numpy as np
from liga.experiments.experiment import Experiment
from liga.interpret.colors import PerceivableColorsInterpreter
from liga.interpret.common import JoinedInterpreter
from liga.interpret.ground_truth_objects import GroundTruthObjectsInterpreter
from liga.torch_extensions.classifier import TorchImageClassifierSerialization, TorchImageClassifierLoader
from liga.type1.tree import TreeType1Explainer
from liga.type2.attribution import SaliencyType2Explainer, IntegratedGradientsType2Explainer, DeepLiftType2Explainer
from liga.type2.lime import LimeType2Explainer
from simexp.describe.torch_based.common import TorchConfig
from simexp.oiv4.metadata import OIV4MetadataProvider
from simexp.places365.metadata import Places365Task

rng = np.random.default_rng(2372775446)  # seed obtained with np.random.randint(0, 2 ** 32 -1 )

logger = logging.getLogger()
logger.setLevel(logging.INFO)

  assert (len(self.all_classes) == self.classifier.num_classes,


In [2]:
images_url = 'file:///var/ssd/renftlem/openimages-v4-downsampled/' \
             'parquet-stores/test-224x224.parquet'
p365_task = Places365Task()
oiv4_meta = OIV4MetadataProvider()

In [None]:
# from simexp.spark import SparkSessionConfig
# from simexp.convert import ConvertTask, ConvertWriteConfig
#
# spark_cfg = SparkSessionConfig(master='local[*]',
#                                driver_memory='15G',
#                                exec_memory='15G')
# write_cfg = ConvertWriteConfig(output_url='file:///home/renftlem/2020-[article]-evaluating-xai/'
#                                           'downsampled-open-images-v4/'
#                                           'parquet-stores/validation.parquet',
#                                row_size=512)
# convert = ConvertTask(images_dir='/home/renftlem/2020-[article]-evaluating-xai/downsampled-open-images-v4/'
#                                  'raw-datasets/validation',
#                       glob='*.jpg',
#                       spark_cfg=spark_cfg,
#                       write_cfg=write_cfg,
#                       meta=oiv4_meta,
#                       subset='validation',
#                       sample_size=20000)

In [4]:
# convert.run()


In [5]:
def _get_classifier(name: str, torch_cfg: TorchConfig):
    serial = TorchImageClassifierSerialization(name)
    return TorchImageClassifierLoader(serial, torch_cfg).classifier

def get_experiments() -> Iterable[Experiment]:
    torch_cfg = TorchConfig()

    classifiers = ['places365_resnet18.json',
                   'places365_alexnet.json']

    gt_obects_interpreter = GroundTruthObjectsInterpreter(gt_object_provider=oiv4_meta,
                                                          subset='test',
                                                          ignore_images_without_objects=True)
    color_interpreter = PerceivableColorsInterpreter()
    interpreters = [JoinedInterpreter(gt_obects_interpreter, color_interpreter)]

    type2_classes = [SaliencyType2Explainer,
                     IntegratedGradientsType2Explainer,
                     DeepLiftType2Explainer,
                     LimeType2Explainer,]

    type1_instances = [TreeType1Explainer()]

    for classifier_name in classifiers:
        classifier = _get_classifier(classifier_name, torch_cfg)
        for interpreter in interpreters:
            for type2_cls in type2_classes:
                type2 = type2_cls(classifier=classifier,
                                  interpreter=interpreter)
                for type1 in type1_instances:
                    for num_train_obs in [k * 100 for k in [1, 5, 10, 20, 40, 80]]:
                        yield Experiment(rng,
                                         images_url=images_url,
                                         num_train_obs=num_train_obs,
                                         num_test_obs=100,
                                         all_classes=p365_task.class_names,
                                         type1=type1,
                                         type2=type2)

In [6]:
def run():
    return next(get_experiments()).run(k_folds=2, n_jobs=3)
run()

INFO:root:<Running experiment...>
INFO:root:-- <Parameters: {'classifier': 'resnet18', 'images_url': 'file:///home/renftlem/2020-[article]-evaluating-xai/downsampled-open-images-v4/parquet-stores/validation-sample-20000.parquet', 'num_train_obs': 100, 'interpreter': <liga.interpret.ground_truth_objects.GroundTruthObjectsInterpreter object at 0x7f41b1d7e690>, 'type1': <liga.type1.tree.TreeType1Explainer object at 0x7f41b1d7e8d0>, 'type2': saliency, 'num_test_obs': 100}/>
INFO:root:-- <Running LIGA...>
  "required_grads has been set automatically." % index
INFO:root:---- <Status update>
INFO:root:------ <Processed 0 observations/>
INFO:root:------ <1 observations had influential concepts/>
INFO:root:------ <LIGA's augmentation produced 0 additional observations./>
INFO:root:---- <done/>
INFO:root:---- <Status update>
INFO:root:------ <Processed 2 observations/>
INFO:root:------ <2 observations had influential concepts/>
INFO:root:------ <LIGA's augmentation produced 1 additional observat

{'cross_entropy': 31.782605755123797,
 'gini': 0.037763975155279406,
 'acc': 0.09,
 'runtime_s': 56.50141887400241}