In [2]:
import json
import time
from torch.utils.data import Dataset
import ipdb
import config
import os
import pickle

In [3]:
class VQADataset_custom(Dataset):
    """VQA dataset"""

    def __init__(self, coco_pkl_file, ques_ann_path, ans_ann_path, mode):
        """
        Args:
            ques_ann (string): Path to the json file with ques_annotations.
            ans_ann (string): Path to the json file with ans_annotations.
        """
        
        
        self.coco_details = pickle.load(open(coco_pkl_file, 'rb'))['area_and_intersection']        
        self.questions = json.load(open(ques_ann_path, 'r'))['questions']  ###or self.questions = load_vocab(ques_ann_path)
        self.answers = json.load(open(ans_ann_path, 'r'))['answers']
        self.mode = mode

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        """Returns ONE data pair-image,match_coco_objects"""
        
        q = self.questions
        area_inter = self.coco_details
        ans = self.answers
        
        assert q[idx]['image_id'] == area_inter[idx]['image_id'] == ans[idx]['image_id']
        
        
        classes_img = area_inter[idx]['classes_img']
        percent_area_per_catId_all_inst = area_inter[idx]['percent_area_per_catId_all_inst']
        percent_area_per_catId_max_inst = area_inter[idx]['percent_area_per_catId_max_inst']
        if_intersect_overlap_sq5 = area_inter[idx]['if_intersect_overlap_sq5'] #"if_intersect_overlap_default"
        #if_intersect_overlap_default = area_inter[idx]['if_intersect_overlap_default']
        
        # print('Reading image data')
        img_id = self.questions[idx]['image_id']
        # print(img_id)

        question = self.questions[idx]['question']
        question_id = self.questions[idx]['question_id']
        # nouns_q = questions[idx]['nouns_q']
        # nouns_q_coco_stuff = questions[idx]['nouns_q_COCO_stuff']

        # print('Reading nouns data')
        nouns_q_coco = self.questions[idx]['nouns_q_COCO']
        nouns_ans = self.answers[idx]['ans_match_COCO']
        # print(img_id, nouns_img, nouns_q_coco, nouns_ans )
        
        
        answers = [i['answer'] for i in self.answers[idx]['answers']]


        return answers, classes_img, nouns_q_coco, nouns_ans, img_id, question_id, question, \
                percent_area_per_catId_all_inst,if_intersect_overlap_sq5, percent_area_per_catId_max_inst

In [4]:
coco_val_pkl = './coco_areas_and_intersection/coco_vqa_val2014.json'
coco_train_pkl = './coco_areas_and_intersection/coco_vqa_train2014.json'

In [5]:
def prep_q_json(dataset, filename, area_thresh, overlap_thresh,  all_area_thresh=None):
    start = time.time()
    abcd = []
    file_data = {}

    for i in range(len(dataset)):
        answers, classes_img, nouns_q, nouns_ans, img_id, ques_id, question,\
        percent_area_per_catId_all_inst,if_intersect_overlap, percent_area_per_catId_max_inst = dataset[i]
        classes_img_set = sorted(list(set(classes_img)))
        

        if len(set(answers))==1:  #uniform answers
            final_target_list = sorted(list(set(classes_img_set) - set(nouns_q) - set(nouns_ans)))
            for cat_id in classes_img_set:
                if all_area_thresh is not None:
                    if (percent_area_per_catId_max_inst[cat_id] > area_thresh and \
                        percent_area_per_catId_all_inst[cat_id] > all_area_thresh ):
                        if cat_id in final_target_list:
                            final_target_list.remove(cat_id)
                else:
                    if (percent_area_per_catId_max_inst[cat_id] > area_thresh):
                        if cat_id in final_target_list:
                            final_target_list.remove(cat_id)
                    #ipdb.set_trace()

            for cat_id_q_a in sorted(list(set(nouns_q)|set(nouns_ans))):
                for cat_id in classes_img_set:
                    if (cat_id_q_a,cat_id) in if_intersect_overlap.keys():
                        if if_intersect_overlap[(cat_id_q_a,cat_id)]> overlap_thresh:
                            if cat_id in final_target_list:
                                final_target_list.remove(cat_id)  
                                
            for obj_class in final_target_list :
                new_i_id = str(img_id).zfill(12) + '_' + str(obj_class).zfill(12)
                abcd.append({"image_id": new_i_id, 
                         "question": question, 
                         "question_id": ques_id})
            
                        
    file_data['questions'] = abcd 
    with open(filename, 'w') as outfile_val1:
        json.dump(file_data, outfile_val1)

    print(time.time()-start)

In [6]:
all_ans_same = 0

In [12]:
def final_prep(mode, coco_pkl, area_thresh, overlap_thresh, all_area_thresh=None):  # mode='val2014'  string
    question_path = 'tagged_' + mode + '_questions.json'        ## corresponds to question.json file
    answer_path ='tagged_' + mode + '_answers.json'    
    
    # in case you want to play with thresholds- good idea to store in different folders
    #root_dir = os.path.join('mini_datasets_qa', str(area_thresh)+ '_'+str(overlap_thresh))
    root_dir = config.iv_q_dir
    os.makedirs(root_dir,exist_ok=True)
    res_file = 'v2_OpenEnded_mscoco_'+ mode + '_questions.json'  
    
    start = time.time()
    dataset_mode = VQADataset_custom(coco_pkl, question_path, answer_path, mode)
    print(time.time()-start)
    prep_q_json(dataset_mode, os.path.join(root_dir, res_file) ,area_thresh= area_thresh, \
                overlap_thresh=overlap_thresh, all_area_thresh=all_area_thresh)
    #sample check to make sure
    print(dataset_mode[0])
    with open( os.path.join(root_dir, res_file)) as f:
        edited_questions = json.load(f)['questions']
    print(len(edited_questions))
    print(edited_questions[0])


In [13]:
final_prep('val2014', coco_val_pkl, 0.1, 0.0) 

9.443247318267822
3.5640757083892822
(['down', 'down', 'at table', 'skateboard', 'down', 'table', 'down', 'down', 'down', 'down'], [5, 15, 41, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 8, 27, 31, 67, 1, 15, 1], [1], [41, 41, 42, 36, 41, 41, 67], 262148, 262148000, 'Where is he looking?', {1: 0.099, 5: 0.0, 8: 0.011, 15: 0.077, 27: 0.0, 31: 0.001, 41: 0.004, 67: 0.014}, {(1, 8): 0.046, (1, 15): 0.002, (1, 27): 0.011, (1, 31): 0.018, (1, 41): 0.027, (1, 67): 0.001, (8, 1): 0.384, (8, 27): 0.039, (15, 1): 0.003, (15, 67): 0.21, (27, 1): 0.809, (27, 8): 0.351, (31, 1): 1.0, (41, 1): 0.486, (67, 1): 0.004, (67, 15): 0.981}, {1: 0.043, 5: 0.0, 8: 0.009, 15: 0.075, 27: 0.0, 31: 0.001, 41: 0.004, 67: 0.014})
> [0;32m<ipython-input-12-3c9bb66e501e>[0m(22)[0;36mfinal_prep[0;34m()[0m
[0;32m     21 [0;31m    [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m    [0;32mwith[0m [0mopen[0m[0;34m([0m [0mos[0m[0;34m.[0m[0

In [14]:
final_prep('train2014', coco_train_pkl, 0.1, 0.0)  

20.333593606948853
6.156264781951904
(['net', 'net', 'net', 'netting', 'net', 'net', 'mesh', 'net', 'net', 'net'], [37, 1, 40], [58], [], 458752, 458752000, 'What is this photo taken looking through?', {1: 0.072, 37: 0.0, 40: 0.005}, {(1, 37): 0.011, (1, 40): 0.084, (37, 1): 0.993, (40, 1): 0.991}, {1: 0.072, 37: 0.0, 40: 0.005})
> [0;32m<ipython-input-12-3c9bb66e501e>[0m(22)[0;36mfinal_prep[0;34m()[0m
[0;32m     21 [0;31m    [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m    [0;32mwith[0m [0mopen[0m[0;34m([0m [0mos[0m[0;34m.[0m[0mpath[0m[0;34m.[0m[0mjoin[0m[0;34m([0m[0mroot_dir[0m[0;34m,[0m [0mres_file[0m[0;34m)[0m[0;34m)[0m [0;32mas[0m [0mf[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m        [0medited_questions[0m [0;34m=[0m [0mjson[0m[0;34m.[0m[0mload[0m[0;34m([0m[0mf[0m[0;34m)[0m[0;34m[[0m[0;34m'questions'[0m[0;34m][0m[0;34m[0m