In [1]:
from detr.explainer import DetrExplainer
from pathlib import Path
from PIL import Image

import yaml
import torch
import wandb
import transformers as tr

In [2]:
def recursive_eval(d: dict):
    for k, v in d.items():
        if isinstance(v, dict):
            recursive_eval(v)
        else:
            try:
                if not v == 'all':
                    d[k] = eval(v)
            except (SyntaxError, NameError, TypeError):
                pass
    return d

In [3]:
def load_cityscapes_detr(path_ckpt: Path):
    ckpt = torch.load(path_ckpt)
    
    model = tr.DetrForObjectDetection.from_pretrained(
            "facebook/detr-resnet-50",
            num_labels=ckpt['hyper_parameters']['max_class_id'] + 1,
            ignore_mismatched_sizes=True)
    
    # load model from state_dict; handle key mismatch
    state_dict = ckpt['state_dict']
    state_dict = {k.replace('model.', '', 1): v for k, v in state_dict.items()} 
    state_dict = {k.replace('detr.', 'model.', 1): v for k, v in state_dict.items()}
    state_dict = {k.replace('model.class_labels_classifier.', 'class_labels_classifier.', 1): v for k, v in state_dict.items()}
    state_dict = {k.replace('model.bbox_predictor.', 'bbox_predictor.', 1): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    
    return model

In [4]:
artifact_name = 'suciucezar07/detr/model-in8k822y:v0'
artifact_dir = wandb.Api()\
    .artifact(artifact_name, type='model')\
    .download(root=Path(r'./resources/artifacts'))

model = load_cityscapes_detr(Path(artifact_dir) / 'model.ckpt')

[34m[1mwandb[0m: Downloading large artifact model-in8k822y:v0, 475.93MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0
Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenc

In [9]:
with open('config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    config = recursive_eval(config)

In [10]:
config['id2label']

{0: 'road',
 1: 'sidewalk',
 2: 'building',
 3: 'wall',
 4: 'fence',
 5: 'pole',
 6: 'traffic light',
 7: 'traffic sign',
 8: 'vegetation',
 9: 'terrain',
 10: 'sky',
 11: 'person',
 12: 'rider',
 13: 'car',
 14: 'truck',
 15: 'bus',
 16: 'train',
 17: 'motorcycle',
 18: 'bicycle'}

In [11]:
detr_explainer = DetrExplainer(
    model=model, 
    processor=tr.DetrImageProcessor.from_pretrained("facebook/detr-resnet-50"), 
    id2label=config['id2label'],
    no_object_id=config['no_object_id'],
    device='cuda')

In [12]:
for img_path in Path(r'./resources/images/cityscapes').glob('*.png'):
    detr_explainer.explain(
        image=Image.open(img_path),
        include_labels=config['include_labels'],
        threshold=config['threshold'],
        output_dir=Path(r'./explanations') / img_path.stem)

                                                                                                                    