In [1]:
import os, sys
import torch
import numpy as np
from nerv.utils import load_obj, VideoReader, strip_suffix, read_img, dump_obj


In [2]:
# load the json files
gt_file = "/drive/yuwu3/CLEVRER/questions/validation.json"
json_gt = load_obj(gt_file)

pred_file = "/home/yuwu3/SpaceQA/Code/MultimodalBaseline/SlotFormer/pretrained/aloe_clevrer_params-rollout/CLEVRER_pred.json"
json_pred = load_obj(pred_file)



In [33]:
# standardize the json pred file
json_new = []

# print("length json_gt, json_pred: ", len(json_gt), len(json_pred))
for scene_id in range(len(json_gt)):
    sc_gt = json_gt[scene_id]['questions']
    sc_pred = json_pred[scene_id]['questions']
    # print("length sc_gt, sc_pred: ", len(sc_gt), len(sc_pred))
    for qid in range(len(sc_gt)):
        q_gt = sc_gt[qid]
        q_pred = sc_pred[qid]
        q_type = q_gt['question_type']
        if q_type == 'descriptive':
            entry = {
                "sceneid" : str(scene_id),
                "qid" : str(qid),
                "question" : q_gt['question'],
                "question_type" : q_type,
                "answer" : q_pred['answer'],
                "gt" : q_gt['answer'],
                "question_subtype" : q_gt['question_subtype'],
                "choices" : None,
            }
            json_new.append(entry)
        else:
            # multiple choice    
            choice_list = []
            answer_list = []
            gt_list = []    
            for cid in range(len(q_gt['choices'])):
                c_gt = q_gt['choices'][cid]
                c_pred = q_pred['choices'][cid]
                choice_list.append(c_gt['choice'])
                gt_list.append(c_gt['answer'])
                answer_list.append(c_pred['answer'])

            entry = {
                "sceneid" : str(scene_id),
                "qid" : str(qid),
                "question" : q_gt['question'],
                "question_type" : q_type,
                "answer" : answer_list,
                "gt" : gt_list,
                "question_subtype" : None,
                "choices" : choice_list
            }
            json_new.append(entry)

dump_obj(json_new, "/home/yuwu3/SpaceQA/Code/MultimodalBaseline/SlotFormer/pretrained/aloe_clevrer_params-rollout/CLEVRER_pred.json")

In [16]:
# compute overall accuracy
total, correct = 0, 0
total_per_q, correct_per_q = 0, 0
total_expl, correct_expl = 0, 0
total_expl_per_q, correct_expl_per_q = 0, 0
total_pred, correct_pred = 0, 0
total_pred_per_q, correct_pred_per_q = 0, 0
total_coun, correct_coun = 0, 0
total_coun_per_q, correct_coun_per_q = 0, 0
total_desc, correct_desc = 0, 0

error_case = []

# print("length json_gt, json_pred: ", len(json_gt), len(json_pred))
for scene_id in range(len(json_gt)):
    sc_gt = json_gt[scene_id]['questions']
    sc_pred = json_pred[scene_id]['questions']
    # print("length sc_gt, sc_pred: ", len(sc_gt), len(sc_pred))
    for qid in range(len(sc_gt)):
        q_gt = sc_gt[qid]
        q_pred = sc_pred[qid]
        q_type = q_gt['question_type']
        if q_type == 'descriptive':
            if q_gt['answer'] == q_pred['answer']:
                correct_desc += 1
                correct += 1
                correct_per_q += 1
            else:
                pass
            total_desc += 1
            total += 1
            total_per_q += 1
            continue

        # multiple choice
        correct_question = True
        # print("length of choices: ", len(q_gt['choices']), len(q_pred['choices']))
        
        for cid in range(len(q_gt['choices'])):
            c_gt = q_gt['choices'][cid]
            c_pred = q_pred['choices'][cid]
            ans = c_gt['answer']
            pred = c_pred['answer']

            if ans == pred:
                correct += 1
            else:
                correct_question = False
                error_case.append({'question':q_gt['question'], 
                                   'type':q_gt['question_type'],
                                   'choice' : c_gt['choice'],
                                   'answer':c_gt['answer'],
                                   'sid':scene_id,
                                   'qid': qid})
                
            total += 1
            if q_type.startswith('explanatory'):
                if ans == pred:
                    correct_expl += 1
                total_expl += 1

            if q_type.startswith('predictive'):
                # print(pred, ans)
                if ans == pred:
                    correct_pred += 1
                total_pred += 1

            if q_type.startswith('counterfactual'):
                if ans == pred:
                    correct_coun += 1
                total_coun += 1
        if correct_question:
            correct_per_q += 1
        total_per_q += 1

        if q_type.startswith('explanatory'):
            if correct_question:
                correct_expl_per_q += 1
            total_expl_per_q += 1

        if q_type.startswith('predictive'):
            if correct_question:
                correct_pred_per_q += 1
            total_pred_per_q += 1

        if q_type.startswith('counterfactual'):
            if correct_question:
                correct_coun_per_q += 1
            total_coun_per_q += 1

print('============ results ============')
print('overall accuracy per option: %f %%' % (float(correct) * 100.0 / total))
print('overall accuracy per question: %f %%' % (float(correct_per_q) * 100.0 / total_per_q))
print('descriptive accuracy per question: %f %%' % (float(correct_desc) * 100.0 / total_desc))
print('explanatory accuracy per option: %f %%' % (float(correct_expl) * 100.0 / total_expl))
print('explanatory accuracy per question: %f %%' % (float(correct_expl_per_q) * 100.0 / total_expl_per_q))
print('predictive accuracy per option: %f %%' % (float(correct_pred) * 100.0 / total_pred))
print('predictive accuracy per question: %f %%' % (float(correct_pred_per_q) * 100.0 / total_pred_per_q))
print('counterfactual accuracy per option: %f %%' % (float(correct_coun) * 100.0 / total_coun))
print('counterfactual accuracy per question: %f %%' % (float(correct_coun_per_q) * 100.0 / total_coun_per_q))
print('============ results ============')
print(total, total_per_q, total_desc, total_expl, total_expl_per_q, total_pred, total_pred_per_q, total_coun, total_coun_per_q)
   

overall accuracy per option: 94.666751 %
overall accuracy per question: 92.245443 %
descriptive accuracy per question: 95.079105 %
explanatory accuracy per option: 98.032381 %
explanatory accuracy per question: 94.521678 %
predictive accuracy per option: 96.007872 %
predictive accuracy per question: 92.324993 %
counterfactual accuracy per option: 90.566095 %
counterfactual accuracy per question: 73.449052 %
125852 76368 54990 30697 8488 7114 3557 33051 9333


In [24]:
# do error analysis
correct, num_obj, num_collision = [], [], []
TP, TN, FP, FN = [], [], [], []
qtype_list = []
last_sid = -1
for idx, q_entry in enumerate(json_pred):
    sid = int(q_entry['sceneid'])
    seg = (sid // 1000) * 1000
    seg1 = ((sid // 1000) + 1) * 1000
    if sid != last_sid:
        ano_file = f"/drive/yuwu3/CLEVRER/annotation_{seg:0>5}-{seg1:0>5}/annotation_{sid:0>5}.json"
        ano_json = load_obj(ano_file)
        n_obj = len(ano_json['object_property'])
        n_collision = len(ano_json['collision'])

    
    if q_entry['question_type'] == 'descriptive':
        pred = q_entry['answer']
        gt = q_entry['gt']
        qtype_list.append(q_entry['question_type'])
        correct.append(gt==pred)
        TP.append(-1)
        TN.append(-1)
        FP.append(-1)
        FN.append(-1)
        num_obj.append(n_obj)
        num_collision.append(n_collision)
    else:
        for c in range(len(q_entry['gt'])):
            pred = q_entry['answer'][c]
            gt = q_entry['gt'][c]
            qtype_list.append(q_entry['question_type'])
            correct.append(q_entry['gt']==q_entry['answer'])
            TP.append(gt=='correct' and pred=='correct')
            TN.append(gt=='wrong' and pred=='wrong')
            FP.append(gt=='wrong' and pred=='correct')
            FN.append(gt=='correct' and pred=='wrong')
            num_obj.append(n_obj)
            num_collision.append(n_collision)
            
            
correct = np.array(correct)
num_obj = np.array(num_obj)
num_collision = np.array(num_collision)
TP = np.array(TP)
TN = np.array(TN)
FP = np.array(FP)
FN = np.array(FN)
qtype_list = np.array(qtype_list)

In [6]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


# confusion_mat = confusion_matrix(correct, num_obj,)


In [22]:
gt, pred

('wrong', 'wrong')

In [20]:
print(TP.sum())

-54990


In [27]:
type_mask = qtype_list == "counterfactual"
for value in np.unique(num_obj):
    mask = num_obj == value
    mask = (type_mask & mask)
    subset = correct[mask]
    TP_subset = TP[mask]
    TN_subset = TN[mask]
    FP_subset = FP[mask]
    FN_subset = FN[mask]
    accuracy = subset.sum() / len(subset)
    TP_rate = TP_subset.sum() / len(TP_subset)
    FP_rate = FP_subset.sum() / len(FP_subset)
    TN_rate = TN_subset.sum() / len(TN_subset)
    FN_rate = FN_subset.sum() / len(FN_subset)
    print(f"num_obj = {value}, accuracy = {accuracy}, TP={TP_rate}, TN={TN_rate}, FP={FP_rate}, FN={FN_rate}")

for value in np.unique(num_collision):
    mask = num_collision == value
    mask = (type_mask & mask)
    subset = correct[mask]
    TP_subset = TP[mask]
    TN_subset = TN[mask]
    FP_subset = FP[mask]
    FN_subset = FN[mask]
    accuracy = subset.sum() / len(subset)
    TP_rate = TP_subset.sum() / len(TP_subset)
    FP_rate = FP_subset.sum() / len(FP_subset)
    TN_rate = TN_subset.sum() / len(TN_subset)
    FN_rate = FN_subset.sum() / len(FN_subset)
    print(f"num_collistion = {value}, accuracy = {accuracy}, TP={TP_rate}, TN={TN_rate}, FP={FP_rate}, FN={FN_rate}")

num_obj = 3, accuracy = 0.9615384615384616, TP=0.34615384615384615, TN=0.6153846153846154, FP=0.038461538461538464, FN=0.0
num_obj = 4, accuracy = 0.7181739453173327, TP=0.43548187554125944, TN=0.43238896449338116, FP=0.06420883335395274, FN=0.06792032661140665
num_obj = 5, accuracy = 0.7279861874117093, TP=0.45581541359284256, TN=0.4575419871291791, FP=0.04426306702244546, FN=0.042379532255532884
num_obj = 6, accuracy = 0.740983606557377, TP=0.4630327868852459, TN=0.4595081967213115, FP=0.03836065573770492, FN=0.039098360655737706
num_collistion = 1, accuracy = 0.7352614015572859, TP=0.4532814238042269, TN=0.45272525027808674, FP=0.047274749721913235, FN=0.04671857619577308
num_collistion = 2, accuracy = 0.7409561118932485, TP=0.4565097631144029, TN=0.45406133317010466, FP=0.04486747872926486, FN=0.04456142498622758
num_collistion = 3, accuracy = 0.7383354828920415, TP=0.45451786662620436, TN=0.4532281314012594, FP=0.045520066762764586, FN=0.04673393520977164
num_collistion = 4, accur

In [36]:
np.unique(num_obj)

array([1, 2, 3, 4, 5])

In [22]:
ano_file = "/drive/yuwu3/CLEVRER/annotation_10000-11000/annotation_10190.json"
ano_json = load_obj(ano_file)
len(ano_json['object_property'])
ano_json['collision']

[{'object_ids': [0, 1], 'frame_id': 34, 'location': [-1.971, 1.5359, 0.351]},
 {'object_ids': [1, 2], 'frame_id': 60, 'location': [-2.9715, 1.0351, 0.3509]},
 {'object_ids': [0, 2], 'frame_id': 74, 'location': [-2.6608, 0.6759, 0.38]},
 {'object_ids': [2, 3],
  'frame_id': 123,
  'location': [-2.9103, -0.4021, 0.3491]}]

In [16]:
for entry in json_pred:
    sid = int(entry["sceneid"])
    seg = (sid // 1000) * 1000
    print(f"{seg:0>4}")

0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000
0000


In [5]:
len(correct)

125852

In [17]:
for case in error_case:
    question = case['question']
    choice = case['choice']
    answer = case['answer']
    q_type = case['type']
    sid = case['sid']
    qid = case['qid']
    if q_type == 'counterfactual':
        print(sid, qid, question, choice, answer)

1 15 What will happen without the metal sphere? The yellow sphere collides with the gray cube correct
2 14 If the metal cube is removed, which event will happen? The metal sphere collides with the cyan cube correct
3 15 If the metal sphere is removed, which of the following will not happen? The cube collides with the rubber object correct
3 15 If the metal sphere is removed, which of the following will not happen? The rubber sphere collides with the cylinder correct
11 14 What will not happen without the cylinder? The brown object collides with the cube correct
12 13 If the cyan sphere is removed, which of the following will not happen? The blue object collides with the green sphere correct
13 14 If the blue cylinder is removed, which of the following will happen? The cube collides with the gray object wrong
14 13 Without the red object, what will not happen? The purple object collides with the green cube correct
14 13 Without the red object, what will not happen? The rubber cylinder a