This notebook calculates the average length of sentences for each dataset.

In [None]:
import pickle
import torch
import pandas as pd

In [2]:
def load_data(model,mode, dataset, split):
#load predicted and gold bounding boxes

    try:

        #the predicted bounding box
        with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
            pred_bbox_list=list(pickle.load(f))

        if mode=="non_inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_noninc_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))

        elif mode == "inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_target_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))
        else:
            print("The mode can only be non_inc or inc")
            return -1,-1,-1,-1,-1

        #the number of one sentence split up incrementally ("the left zebra" would have length 3)
        with open(r"/home/users/fschreiber/project/incremental_pickles/length_incremental_units/"+dataset+"_"+split+"_length_unit.p","rb") as f:
            inc_len=pickle.load(f)

        #the original model data split up incrementally
        data_model=torch.load("/home/users/fschreiber/project/ready_inc_data/"+dataset+"/"+dataset+"_"+split+".pth")

        with open(r"/home/users/fschreiber/project/binary_grouped/"+model+"/"+mode+"/"+dataset+split+".p","rb") as f:
            binary_grouped=pickle.load(f)

        
        if mode=="non_inc":
            target_bbox_list=[x for x,y in zip(target_bbox_list,inc_len) for _ in range(y)]
            
        if model=="TVG":
            pred_bbox_list,target_bbox_list=TVG_prep(pred_bbox_list,target_bbox_list)
                    

        return pred_bbox_list,target_bbox_list,inc_len,data_model,binary_grouped
    
    except FileNotFoundError as e:
        #print(e)
        
        return  -1,-1,-1,-1,-1
    


#TVG needs some extra adjustments to fit the same data format as Resc
def TVG_prep(pred_bbox_list,target_bbox_list):
    #print("TVG")
    for ind,(pred,targ) in enumerate (zip (pred_bbox_list,target_bbox_list)):

        pred=pred.view(1,-1)

        pred=xywh2xyxy(pred)
        pred=torch.clamp(pred,0,1)

        pred_bbox_list[ind]=pred

        targ=targ.view(1,-1)
        targ=xywh2xyxy(targ)

        target_bbox_list[ind]=targ
    return pred_bbox_list,target_bbox_list

#copied from TransVG needed to transform the bounding box vectors
def xywh2xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


In [3]:
pred_bbox_list,target_bbox_list,inc_len,model,binary_grouped=load_data("ReSc","inc","unc","testB")


In [4]:
#group sentences that belong to one incremental unit
def group_by_increment(bbox_list,inc_len):
    counter=0
    group_list=[]
    for i in inc_len:
        
        group_list.append(bbox_list[counter:counter+i])
        counter=counter+i
    return group_list



In [5]:
model_group=group_by_increment(model,inc_len)

In [6]:
binary_grouped[0][-1]

1

In [7]:
"""

right_only = []
wrong_only = []
mixed_only = []

for index, entry in enumerate(binary_grouped):
    
    if all(p == 1 for p in entry):
        right_only.append(index)
    elif all(p == 0 for p in entry):
        wrong_only.append(index)
    else:
        mixed_only.append(index)

"""        

'\n\nright_only = []\nwrong_only = []\nmixed_only = []\n\nfor index, entry in enumerate(binary_grouped):\n    \n    if all(p == 1 for p in entry):\n        right_only.append(index)\n    elif all(p == 0 for p in entry):\n        wrong_only.append(index)\n    else:\n        mixed_only.append(index)\n\n'

In [8]:
#Calculates sentence length split by condition (completly right, completly wrong, mixed)
def length_right_wrong(model_group, binary_grouped, mode):
    
    length_right = []   
    length_wrong = []   
    length_mixed = []  

    # iterate through sentences and their clasification
    for sent, pred in zip(model_group, binary_grouped):
        
        #get the sentence length
        length = len(sent[-1][3].split())

       #final only takes into account if the object was found in the end
        if mode == "final":
        
            if pred[-1] == 1:
                length_right.append(length)
            elif pred[-1] == 0:
                length_wrong.append(length)
            else:
                print("Value should be 0 or 1, value is:", pred[-1])
        
        #all works over the complete increment
        elif mode == "all":
           
            if all(p == 1 for p in pred):
                length_right.append(length)
            elif all(p == 0 for p in pred):
                length_wrong.append(length)
            else:
                length_mixed.append(length)
        else:
            
            print("Mode can only be final or all")

    try:
        # Calculate the average length of mixed classified sentences (if any)
        av_mix = sum(length_mixed) / len(length_mixed)
    except ZeroDivisionError:
        # Handle the case where there are no mixed classified sentences
        av_mix = 0

    # Return the average length of correctly classified sentences, 
    # the average length of incorrectly classified sentences, 
    # and the average length of mixed classified sentences (if any)
    return sum(length_right) / len(length_right), sum(length_wrong) / len(length_wrong), av_mix


In [9]:
all_len_right=[]
all_len_wrong=[]
all_len_mixed=[]

found_sets_list=[]

mode="inc"

length_mode="all"

model_input="ReSc"
split_list=["testB","testA","val","test"]
dataset_list=["unc","unc+","gref_umd","gref"]
#dataset_list=["unc"]

for file in dataset_list:
    for split in split_list:
    
        pred_bbox_list,target_bbox_list,inc_len,model,binary_grouped= load_data(model_input,mode,file,split)

        #if the file is not found pass
        if pred_bbox_list==-1 or target_bbox_list==-1 or inc_len==-1:
             pass
    

        else:

            print(file+split)
            found_sets_list.append(file+" "+split)

            model_group=group_by_increment(model,inc_len)

            length_right,length_wrong,length_mixed=length_right_wrong(model_group,binary_grouped,mode=length_mode)

            all_len_right.append(length_right)
            all_len_wrong.append(length_wrong)
            all_len_mixed.append(length_mixed)


data={}
data["Dataset "+model_input+" "+length_mode]=found_sets_list
data["Length Right"]=all_len_right
data["Length Wrong"]=all_len_wrong

if length_mode=="all":
    data["Length Mixed"]=all_len_mixed
            

# Create the DataFrame
df = pd.DataFrame(data)
df= df.round(2)
# Print the DataFrame
df
            
            

unctestB
unctestA
uncval
unc+testB
unc+testA
unc+val
gref_umdval
gref_umdtest
grefval


Unnamed: 0,Dataset ReSc all,Length Right,Length Wrong,Length Mixed
0,unc testB,2.86,3.91,4.47
1,unc testA,2.69,3.59,4.49
2,unc val,2.76,3.98,4.49
3,unc+ testB,2.86,3.98,4.62
4,unc+ testA,2.53,3.5,4.46
5,unc+ val,2.71,3.8,4.58
6,gref_umd val,7.69,7.93,8.86
7,gref_umd test,7.62,8.03,8.67
8,gref val,7.2,8.22,8.84
