This notebook compares the incremental setup to the non-incremental setup to see how often each model correctly identified the objects.

In [None]:
import pickle
import torch
from collections import Counter
import pandas as pd

In [2]:
def load_data(model,mode, dataset, split):
#load predicted and gold bounding boxes

    try:

        #the predicted bounding box
        with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
            pred_bbox_list=list(pickle.load(f))

        if mode=="non_inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_noninc_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))

        elif mode == "inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_target_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))
        else:
            print("The mode can only be non_inc or inc")
            return -1,-1,-1,-1,-1

        #the number of one sentence split up incrementally ("the left zebra" would have length 3)
        with open(r"/home/users/fschreiber/project/incremental_pickles/length_incremental_units/"+dataset+"_"+split+"_length_unit.p","rb") as f:
            inc_len=pickle.load(f)

        #the original model data split up incrementally
        data_model=torch.load("/home/users/fschreiber/project/ready_inc_data/"+dataset+"/"+dataset+"_"+split+".pth")

        with open(r"/home/users/fschreiber/project/binary_grouped/"+model+"/"+mode+"/"+dataset+split+".p","rb") as f:
            binary_grouped=pickle.load(f)

        
        if mode=="non_inc":
            target_bbox_list=[x for x,y in zip(target_bbox_list,inc_len) for _ in range(y)]
            
        if model=="TVG":
            pred_bbox_list,target_bbox_list=TVG_prep(pred_bbox_list,target_bbox_list)
                    

        return pred_bbox_list,target_bbox_list,inc_len,data_model,binary_grouped
    
    except FileNotFoundError as e:
        #print(e)
        
        return  -1,-1,-1,-1,-1
    


#TVG needs some extra adjustments to fit the same data format as Resc
def TVG_prep(pred_bbox_list,target_bbox_list):
    #print("TVG")
    for ind,(pred,targ) in enumerate (zip (pred_bbox_list,target_bbox_list)):

        pred=pred.view(1,-1)

        pred=xywh2xyxy(pred)
        pred=torch.clamp(pred,0,1)

        pred_bbox_list[ind]=pred

        targ=targ.view(1,-1)
        targ=xywh2xyxy(targ)

        target_bbox_list[ind]=targ
    return pred_bbox_list,target_bbox_list

#copied from TransVG needed to transform the bounding box vectors
def xywh2xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


In [3]:
pred_bbox_list_inc,target_bbox_list_inc,inc_len_inc,model_inc,binary_grouped_inc=load_data("ReSc","inc","gref","val")

pred_bbox_list_non_inc,target_bbox_list_non_inc,inc_len_non_inc,model_non_inc,binary_grouped_non_inc=load_data("ReSc","non_inc","gref","val")


In [6]:
#calculate a confusion matrix comparing the incremental to the non-incremental data.
def confusion_matrix(binary_grouped_inc,binary_grouped_non_inc):
    tp=0
    fp=0
    fn=0
    tn=0
    for entry_inc, entry_non_inc in zip( binary_grouped_inc, binary_grouped_non_inc):
        
        for i,j in zip(entry_inc,entry_non_inc):
        
            if i==1 and j==1:
                tp=tp+1
            elif i==1 and j==0:
                fp=fp+1
            elif i==0 and j==1:
                fn=fn+1
            elif i==0 and j==0:
                tn=tn+1
            else:
                print("something is very wrong")

    return tp,fp,fn,tn



In [9]:
found_sets_list,tp_list,fp_list,fn_list,tn_list=([] for i in range(5))

model_input="TVG"
split_list=["testB","testA","val","test"]
dataset_list=["unc","unc+","gref_umd","gref"]
#dataset_list=["unc"]

for file in dataset_list:
    for split in split_list:

            #load incremental and non-incremental data
            pred_bbox_list,target_bbox_list,inc_len,model,binary_grouped_inc= load_data(model_input,"inc",file,split)
            pred_bbox_list,target_bbox_list,inc_len,model,binary_grouped_non_inc= load_data(model_input,"non_inc",file,split)

            #if the file is not found pass
            if pred_bbox_list==-1 or target_bbox_list==-1 or inc_len==-1:
                pass
        

            else:

                print(file+split)

                found_sets_list.append(file+" "+split)

                tp,fp,fn,tn=confusion_matrix_end(binary_grouped_inc,binary_grouped_non_inc)

                print("Both correct",tp)
                print("Only inc correct:",fp)
                print("Only non inc correct",fn)
                print("Both incorrect",tn)
                print()

                tp_list.append(tp)
                fp_list.append(fp)
                fn_list.append(fn)
                tn_list.append(tn)

data={}
data["Dataset "+model_input]=found_sets_list
data["Both correct"]=tp_list
data["Only inc correct"]=fp_list
data["Only non inc correct"]=fn_list
data["Both incorrect"]=tn_list


df = pd.DataFrame(data)

df.set_index("Dataset "+model_input,inplace=True)

df

unctestB
Both correct 3920
Only inc correct: 0
Only non inc correct 1175
Both incorrect 0

unctestA
Both correct 4717
Only inc correct: 0
Only non inc correct 940
Both incorrect 0

uncval
Both correct 8757
Only inc correct: 0
Only non inc correct 2077
Both incorrect 0

unc+testB
Both correct 2896
Only inc correct: 0
Only non inc correct 1993
Both incorrect 0

unc+testA
Both correct 4149
Only inc correct: 0
Only non inc correct 1577
Both incorrect 0

unc+val
Both correct 7315
Only inc correct: 0
Only non inc correct 3443
Both incorrect 0

gref_umdval
Both correct 3353
Only inc correct: 12
Only non inc correct 1509
Both incorrect 22

gref_umdtest
Both correct 6504
Only inc correct: 8
Only non inc correct 3048
Both incorrect 42

grefval
Both correct 6475
Only inc correct: 11
Only non inc correct 3023
Both incorrect 27



Unnamed: 0_level_0,Both correct,Only inc correct,Only non inc correct,Both incorrect
Dataset TVG,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
unc testB,3920,0,1175,0
unc testA,4717,0,940,0
unc val,8757,0,2077,0
unc+ testB,2896,0,1993,0
unc+ testA,4149,0,1577,0
unc+ val,7315,0,3443,0
gref_umd val,3353,12,1509,22
gref_umd test,6504,8,3048,42
gref val,6475,11,3023,27
