In [33]:
import torch
import json
import numpy as np
import pandas as pd
from tqdm import tqdm

from utils import generate_detect_sg
from sgg_benchmark.config.paths_catalog import DatasetCatalog

# Run Evaluation for test or train split

In [2]:
BASE_PATH="/home/maelic/Documents/Scene-Graph-Benchmark-Cuda11.7/"
SGDET_DIR=BASE_PATH+"checkpoints/VG150/curated/sgdet_motifs_causal_tde"
FASTER_RCNN_DIR=BASE_PATH+"checkpoints/VG150/curated/pretrained_faster_rcnn/model_final.pth"
GLOVE_DIR="/home/maelic/glove"
EVAL_TYPE="test"
CONFIG="configs/VG150/curated/e2e_relation_X_101_32_8_FPN_1x.yaml"
DATASET="VG150_curated_filtered"
TASK="sgdet"

In [3]:
#!cd .. && pwd && CUDA_VISIBLE_DEVICES=0 python tools/relation_test_net.py --config-file $CONFIG --task  --verbose MODEL.ROI_RELATION_HEAD.PREDICTOR CausalAnalysisPredictor MODEL.ROI_RELATION_HEAD.CAUSAL.EFFECT_TYPE TDE MODEL.ROI_RELATION_HEAD.CAUSAL.FUSION_TYPE sum MODEL.ROI_RELATION_HEAD.CAUSAL.CONTEXT_LAYER motifs  SOLVER.IMS_PER_BATCH 32 TEST.IMS_PER_BATCH 1 DTYPE "float16" SOLVER.MAX_ITER 30000 SOLVER.VAL_PERIOD 2000 SOLVER.CHECKPOINT_PERIOD 2000 SOLVER.PRE_VAL False GLOVE_DIR $GLOVE_DIR MODEL.PRETRAINED_DETECTOR_CKPT $FASTER_RCNN_DIR OUTPUT_DIR $SGDET_DIR DATASETS.TO_TEST $EVAL_TYPE

In [4]:
result_path = SGDET_DIR + '/inference/' + DATASET +'_' + EVAL_TYPE + '/' 
info_dict = json.load(open(result_path + 'visual_info.json'))
detected_result = torch.load(result_path+'eval_results.pytorch')
result_dict = torch.load(result_path + 'result_dict.pytorch')
vg_dict = json.load(open(DatasetCatalog.DATASETS[DATASET]['dict_file']))

In [5]:
#print(compute_metrics(detected_result, TASK))

In [6]:
output_scene_graphs = generate_detect_sg(detected_result, info_dict, vg_dict)

Generating scene graphs : 100%|██████████| 26980/26980 [03:25<00:00, 131.25it/s]


## Compute average triplets score

In [31]:
print(output_scene_graphs['2343148'][0]['entities'])
for r in output_scene_graphs['2343148'][0]['relations']:
    print(r)
    rel = output_scene_graphs['2343148'][0]['entities'][r[0]]+'-'+r[2]+'-'+output_scene_graphs['2343148'][0]['entities'][r[1]]
    print(rel)

['sky', 'plate', 'bus', 'road', 'tree', 'window', 'building', 'window', 'truck', 'window', 'window', 'car', 'sign', 'line', 'street', 'door', 'tree', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'bus', 'letter']
[0, 1, 'above', 0.6246896982192993]
sky-above-plate
[0, 2, 'above', 0.6532915234565735]
sky-above-bus
[0, 3, 'over', 0.49638083577156067]
sky-over-road
[0, 4, 'above', 0.4133039116859436]
sky-above-tree
[0, 5, 'has', 0.43939805030822754]
sky-has-window
[0, 6, 'above', 0.4709216356277466]
sky-above-building
[0, 7, 'above', 0.3530330955982208]
sky-above-window
[0, 8, 'above', 0.7329128980636597]
sky-above-truck
[0, 11, 'above', 0.6173756718635559]
sky-above-car
[0, 12, 'above', 0.7371037006378174]
sky-above-sign
[0, 13, 'above', 0.4656640291213989]
sky-above-line
[0, 14, 'above', 0.5301374197006226]
sky-above-street
[0, 15, 'above', 0.3871307671070099]
sky-above-door
[0, 16, 'above', 0.4400191605091095]
sky-above-tree
[0, 17, 'above', 0.6672959923744202]
sky-above-

In [36]:
def triplets_ranking(output_scene_graphs):
    triplets_score = []
    for graph in tqdm(output_scene_graphs.values()):
        for r in graph[0]['relations']:
            rel = graph[0]['entities'][r[0]]+'-'+r[2]+'-'+graph[0]['entities'][r[1]]
            triplets_score.append([rel, r[3]])
    df = pd.DataFrame(triplets_score, columns =['Triplet', 'Score'])
    # average the dataframe by triplet
    df = df.groupby('Triplet').mean().reset_index()
    # sort the dataframe by score
    df = df.sort_values(by=['Score'], ascending=False)
    return df
print(triplets_ranking(output_scene_graphs)[50:100])


100%|██████████| 26980/26980 [00:03<00:00, 7600.57it/s] 


                         Triplet     Score
31122         child-wearing-jean  0.875166
71113           kid-wearing-jean  0.874428
88471        person-wearing-shoe  0.874031
20544         boy-wearing-jacket  0.874002
36035                cow-has-top  0.872811
118486      surfer-wearing-short  0.872768
36479            cow-wearing-eye  0.869921
112594       skier-wearing-glove  0.869231
112583        skier-wearing-coat  0.867818
31096          child-wearing-cap  0.866976
71110         kid-wearing-helmet  0.866470
112603      skier-wearing-jacket  0.866360
95641        player-wearing-pant  0.865502
31120       child-wearing-helmet  0.864698
112593       skier-wearing-glass  0.864486
54988           girl-wearing-hat  0.862856
31118          child-wearing-hat  0.861821
118479       surfer-wearing-pant  0.860773
39803          dish-behind-child  0.860765
112620       skier-wearing-short  0.858962
118484      surfer-wearing-shirt  0.857578
54982         girl-wearing-glove  0.856086
95631      