In [1]:
import torch, numpy, json, os
from torch.utils.data import TensorDataset, DataLoader
from netdissect.progress import verbose_progress, default_progress
from netdissect.nethook import edit_layers
from netdissect.modelconfig import create_instrumented_model
from netdissect.zdataset import standard_z_sample, z_sample_for_model
from netdissect.easydict import EasyDict
from netdissect.aceoptimize import ace_loss
from netdissect.segmenter import UnifiedParsingSegmenter
from netdissect.fullablate import measure_full_ablation
from netdissect.plotutil import plot_tensor_images, plot_max_heatmap
import netdissect.aceoptimize
import netdissect.fullablate

verbose_progress(True)



In [15]:
def evaluate_ablation(scene, classname):
    layer = 'layer4'
    dissectdir = 'dissect/%s' % scene
    with open(os.path.join(dissectdir, 'dissect.json')) as f:
        dissection = EasyDict(json.load(f))

    segmenter = UnifiedParsingSegmenter(segsizes=[256], segdiv='quad')
    model = create_instrumented_model(dissection.settings)

    for category in ['object', 'material', 'part']:
        try:
            classnum = segmenter.get_label_and_category_names()[0].index((classname, category))
            break
        except:
            continue

    edit_layers(model, [layer])

    unit_count = 20

    # then make a batch of 10 images
    big_sample = z_sample_for_model(model, 1000, seed=3)
    big_dataset = TensorDataset(big_sample)
    big_loader = DataLoader(big_dataset, batch_size=10)



    lrec = [l for l in dissection.layers if l.layer == layer][0]
    rrec = [r for r in lrec.rankings if r.name == '%s-iou' % classname][0]
    iou_scores = -torch.tensor(rrec.score)
    iou_values, iou_order = (-iou_scores).sort(0)
    iou_values = -iou_values
    iou_ablation = torch.zeros_like(iou_scores)
    iou_ablation[iou_order[:unit_count]] = 1

    # load ablation from the tree model snapshot
    snapdir = os.path.join(dissectdir, layer, 'ace', classname, 'snapshots')
    data = torch.load(os.path.join(snapdir, 'epoch-9.pth'))
    learned_scores = data['ablation'][0,:,0,0]
    _, learned_order = (-learned_scores - iou_scores.cuda() * 1e-5).sort(0)
    learned_values = learned_scores[learned_order]
    learned_ablation = torch.zeros_like(learned_scores)
    learned_ablation[learned_order[:unit_count]] = learned_scores[learned_order[:unit_count]]

    progress = default_progress()

    # (1) call ace_loss to get baseline tree pixels in batch
    with torch.no_grad():
        baseline_loss = 0
        for [small_sample] in progress(big_loader):
            baseline_loss += ace_loss(segmenter, classnum, model, layer,
                    torch.zeros_like(learned_scores)[None,:,None,None],
                    torch.zeros_like(learned_scores)[None,:,None,None],
                    small_sample, 0, 0, 0, run_backward=False,
                    discrete_pixels=True,
                    discrete_units=0,
                    # mixed_units=True,
                    ablation_only=True,
                    fullimage_measurement=True,
                    fullimage_ablation=True)

    # (2) apply 20 unit iou ablation and call ace_loss to see the difference
    with torch.no_grad():
        iou_loss = 0
        for [small_sample] in progress(big_loader):
            iou_loss += ace_loss(segmenter, classnum, model, layer,
                    torch.zeros_like(iou_scores)[None,:,None,None].cuda(), # high_replacement
                    iou_scores[None,:,None,None].cuda(),  # ablation
                    small_sample, 0, 0, 0, run_backward=False,
                    discrete_pixels=True,
                    discrete_units=20,
                    ablation_only=True,
                    fullimage_measurement=True,
                    fullimage_ablation=True)

    # (3) apply 20 unit learned ablation and call ace_loss to see the difference
    with torch.no_grad():
        learned_loss = 0
        for [small_sample] in progress(big_loader):
            learned_loss += ace_loss(segmenter, classnum, model, layer,
                    torch.zeros_like(learned_scores)[None,:,None,None].cuda(),
                    learned_scores[None,:,None,None].cuda(),
                    small_sample, 0, 0, 0, run_backward=False,
                    discrete_pixels=True,
                    discrete_units=20,
                    mixed_units=True,
                    ablation_only=True,
                    fullimage_measurement=True,
                    fullimage_ablation=True)

    print('%s - %s' % (scene, classname))
    print('baseline %g iou %g learned %g' % (baseline_loss.item(), iou_loss.item(), learned_loss.item()))
    print('iou', 1 - iou_loss / baseline_loss)
    print('learned', 1 - learned_loss / baseline_loss)
    return (1 - learned_loss / baseline_loss).item()


In [21]:
# results = {}
for key in [
    ('conferenceroom', 'window'),
    ('conferenceroom', 'chair'),
    ('conferenceroom', 'curtain'),
    ('conferenceroom', 'person'),
    ('conferenceroom', 'table'),
    ('churchoutdoor', 'window'),
    ('diningroom', 'window'),
    ('restaurant', 'window'),
    ('kitchen', 'window'),
    ('livingroom', 'window'),
    ('bedroom', 'window'),
    ('churchoutdoor', 'window'),
    ('churchoutdoor', 'brick'),
    ('churchoutdoor', 'cloud'),
    ('churchoutdoor', 'dome'),
    ('churchoutdoor', 'door'),
    ('churchoutdoor', 'tree'),
]:
    if key not in results:
        results[key] = evaluate_ablation(*key)




In [22]:
# {('conferenceroom', 'window'): 0.5683056116104126,
#  ('conferenceroom', 'chair'): 0.12282001972198486,
#  ('conferenceroom', 'curtain'): 0.5497019290924072,
#  ('conferenceroom', 'person'): 0.614344596862793,
#  ('conferenceroom', 'table'): 0.2871387004852295,
#  ('churchoutdoor', 'window'): 0.4741958975791931,
#  ('diningroom', 'window'): 0.3416675329208374,
#  ('restaurant', 'window'): 0.46097666025161743,
#  ('livingroom', 'window'): 0.26558423042297363,
#  ('bedroom', 'window'): 0.2773325443267822,
#  ('kitchen', 'window'): 0.5389571785926819,
#  ('churchoutdoor', 'brick'): 0.427112877368927,
#  ('churchoutdoor', 'cloud'): 0.27516114711761475,
#  ('churchoutdoor', 'dome'): 0.7740037441253662,
#  ('churchoutdoor', 'door'): 0.5343429446220398,
#  ('churchoutdoor', 'tree'): 0.5980440378189087}

results

In [4]:
evaluate_ablation('churchoutdoor', 'tree')