This notebook creates the "binary_grouped" which holds the information if an object was found for each sentence for each increment.

In [None]:
import pickle
import torch

In [2]:
def load_data(model,mode, dataset, split):
#load predicted and gold bounding boxes

    try:

        #the predicted bounding box
        with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
            pred_bbox_list=list(pickle.load(f))

        if mode=="non_inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_noninc_"+model+"/"+dataset+"/"+split+"_pred_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))

        elif mode == "inc":
            #the target bounding box
            with open(r"/home/users/fschreiber/project/bboxes_"+model+"/"+dataset+"/"+split+"_target_bbox_list.p","rb") as f:
                target_bbox_list=list(pickle.load(f))
        else:
            print("The mode can only be non_inc or inc")
            return -1,-1,-1,-1

        #the number of one sentence split up incrementally ("the left zebra" would have length 3)
        with open(r"/home/users/fschreiber/project/incremental_pickles/length_incremental_units/"+dataset+"_"+split+"_length_unit.p","rb") as f:
            inc_len=pickle.load(f)

        #the original model data split up incrementally
        data_model=torch.load("/home/users/fschreiber/project/ready_inc_data/"+dataset+"/"+dataset+"_"+split+".pth")

        return pred_bbox_list,target_bbox_list,inc_len,data_model
    
    except FileNotFoundError as e:
        #print(e)
        
        return  -1,-1,-1,-1
    


#TVG needs some extra adjustments to fit the same data format as Resc
def TVG_prep(pred_bbox_list,target_bbox_list):
    #print("TVG")
    for ind,(pred,targ) in enumerate (zip (pred_bbox_list,target_bbox_list)):

        pred=pred.view(1,-1)

        pred=xywh2xyxy(pred)
        pred=torch.clamp(pred,0,1)

        pred_bbox_list[ind]=pred

        targ=targ.view(1,-1)
        targ=xywh2xyxy(targ)

        target_bbox_list[ind]=targ
    return pred_bbox_list,target_bbox_list

#copied from TransVG needed to transform the bounding box vectors
def xywh2xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)


In [5]:
#copied from ReSC
def bbox_iou(box1, box2, x1y1x2y2=True):
    """
    Returns the IoU of two bounding boxes
    """

    if x1y1x2y2:
        # Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
    else:
        # Transform from center and width to exact coordinates
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2

    # get the coordinates of the intersection rectangle
    inter_rect_x1 = torch.max(b1_x1, b2_x1)
    inter_rect_y1 = torch.max(b1_y1, b2_y1)
    inter_rect_x2 = torch.min(b1_x2, b2_x2)
    inter_rect_y2 = torch.min(b1_y2, b2_y2)
    # Intersection area
    inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1, 0) * torch.clamp(inter_rect_y2 - inter_rect_y1, 0)
    # Union Area
    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)

    # print(box1, box1.shape)
    # print(box2, box2.shape)
    return inter_area / (b1_area + b2_area - inter_area + 1e-16)

#copied from TransVG needed to transform the bounding box vectors
def xywh2xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)



In [6]:
#group sentences that belong to one incremental unit
def group_by_increment(bbox_list,inc_len):
    counter=0
    group_list=[]
    for i in inc_len:
        
        group_list.append(bbox_list[counter:counter+i])
        counter=counter+i
    return group_list

In [10]:
mode="non_inc"
model_input="ReSc"
split_list=["testB","testA","val","test"]
dataset_list=["unc","unc+","gref_umd","gref"]

print("The config is:",mode,model_input)
for file in dataset_list:
    for split in split_list:
        
        pred_bbox_list,target_bbox_list,inc_len,model= load_data(model_input,mode,file,split)
        if pred_bbox_list==-1 or target_bbox_list==-1 or inc_len==-1:
             pass
        
        else:

            if mode=="non_inc":
                target_bbox_list=[x for x,y in zip(target_bbox_list,inc_len) for _ in range(y)]

            if model_input=="TVG":
                pred_bbox_list,target_bbox_list=TVG_prep(pred_bbox_list,target_bbox_list)
                
            
            print(file+split)

            
            #gives the overall accuracy
    
            acc_list=[]
            for i,j in zip(pred_bbox_list,target_bbox_list):
            
                acc_list.append(bbox_iou(i,j,True))

            percentage = sum(1 for item in acc_list if item > 0.5) / len(acc_list)
            print("Accuracy:",percentage)

            #change the percentages into true or false
            binary_list=[1 if entry > 0.5 else 0 for entry in acc_list]

            binary_grouped=group_by_increment(binary_list,inc_len)


            try:
                with open(r"/home/users/fschreiber/project/binary_grouped/"+model_input+"/"+mode+"/"+file+split+".p", "wb") as output_file:
                    pickle.dump(binary_grouped, output_file)
            except Exception as e:
                print("something went wrong in saving")
                print(e)

The config is: non_inc ReSc
unctestB
Accuracy: 0.7455637307942004
unctestA
Accuracy: 0.7713018361873324
uncval
Accuracy: 0.7536696299359107
unc+testB
Accuracy: 0.7647570388085662
unc+testA
Accuracy: 0.7786550930767622
unc+val
Accuracy: 0.7699984447094199
gref_umdval
Accuracy: 0.7792182426328768
gref_umdtest
Accuracy: 0.7713480601821567
grefval
Accuracy: 0.7572466491294
