In [1]:
# !module unload anaconda3_gpu/4.13.0

import os
import sys
import json
import argparse
import collections
import numpy as np
import random
from datetime import datetime

sys.path.append('./Situation3D') 
from lib.sepdataset_bert import ScannetQADataset, ScannetQADatasetConfig
from lib.config import CONF
from models.sqa_module_bert import ScanQA
# from scripts.train import get_sqa

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
# from tensorboardX import SummaryWriter
# import wandb

def get_sqa(sqa_train, sqa_val, train_num_scenes, val_num_scenes):
    train_scene_list = sorted(list(set([data["scene_id"] for data in sqa_train])))
    val_scene_list = sorted(list(set([data["scene_id"] for data in sqa_val])))
    # set train_num_scenes
    if train_num_scenes <= -1:
        train_num_scenes = len(train_scene_list)
    else:
        assert len(train_scene_list) >= train_num_scenes

    # slice train_scene_list
    train_scene_list = train_scene_list[:train_num_scenes]

    # filter data in chosen scenes
    new_sqa_train = []
    for data in sqa_train:
        if data["scene_id"] in train_scene_list:
            new_sqa_train.append(data)

    # set val_num_scenes
    if val_num_scenes <= -1:
        val_num_scenes = len(val_scene_list)
    else:
        assert len(val_scene_list) >= val_num_scenes

    # slice val_scene_list
    val_scene_list = val_scene_list[:val_num_scenes]

    new_sqa_val = []
    for data in sqa_val:
        if data["scene_id"] in val_scene_list:
            new_sqa_val.append(data)

    # all sqa scene
    all_scene_list = train_scene_list + val_scene_list
    return new_sqa_train, new_sqa_val, all_scene_list


def get_answer_cands(answer_counter_list):
    answer_counter = answer_counter_list
    answer_counter = collections.Counter(sorted(answer_counter))
    num_all_answers = len(answer_counter)
    answer_max_size = -1
    if answer_max_size < 0:
        answer_max_size = len(answer_counter)
    answer_counter = dict([x for x in answer_counter.most_common()[:answer_max_size] if x[1] >= 1])
    print("using {} answers out of {} ones".format(len(answer_counter), num_all_answers))
    answer_cands = sorted(answer_counter.keys())
    return answer_cands, answer_counter


def get_dataloader(sqa, all_scene_list, split, config, augment, answer_counter_list):
    answer_cands, answer_counter = get_answer_cands(answer_counter_list)
    config.num_answers = len(answer_cands)
    tokenizer = None

    dataset = ScannetQADataset(
        sqa=sqa[split],
        sqa_all_scene=all_scene_list,
        answer_cands=answer_cands,
        answer_counter=answer_counter,
        answer_cls_loss='bce',
        split=split,
        num_points=40000,
        use_height=True,
        use_color=True,
        use_normal=False,
        use_multiview=False,
        tokenizer=tokenizer,
        augment=augment,
        debug=False,
        wos=False,
        use_bert=True,
        no_mirror=True,
        no_rotx=True,
        no_roty=True
    )
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
    return dataset, dataloader


DC = ScannetQADatasetConfig()
project_name = "SQA"
SQA_TRAIN = json.load(open(os.path.join(CONF.PATH.SQA, project_name + "_train.json")))
SQA_VAL = json.load(open(os.path.join(CONF.PATH.SQA, project_name + "_test.json")))
answer_counter_list = json.load(open(os.path.join(CONF.PATH.SQA, "answer_counter.json")))
random.seed(0)
np.random.seed(0)

# init training dataset
print("preparing data...")
sqa_train, sqa_val, all_scene_list = get_sqa(SQA_TRAIN, SQA_VAL, -1, -1)
sqa = {
    "train": sqa_train,
    "val": sqa_val,
}

train_dataset, train_dataloader = get_dataloader(sqa, all_scene_list, "train", DC, True, answer_counter_list)
val_dataset, val_dataloader = get_dataloader(sqa, all_scene_list, "val", DC, False, answer_counter_list)
print("train on {} samples and val on {} samples".format(len(train_dataset), len(val_dataset)))

dataloader = {
    "train": train_dataloader,
    "val": val_dataloader
}

print('')
print('======================')
print('This cell is finished.')
print('======================')

preparing data...
using 707 answers out of 707 ones
all train: 26623
answerable train 26623
Tokenizing questions and situations using BERT Tokenizer...
This may take a while...
Finished tokenizing questions and situations using BERT Tokenizer
loading data...
using 707 answers out of 707 ones
all val: 3519
answerable val 3519
Tokenizing questions and situations using BERT Tokenizer...
This may take a while...
Finished tokenizing questions and situations using BERT Tokenizer
loading data...
train on 26623 samples and val on 3519 samples

This cell is finished.


In [2]:
first_batch = next(iter(dataloader['train']))

In [22]:
# print(sqa_train[0].get('question'))
max_length = 0
for x, record in enumerate(sqa_train[:]):
    # show the record with the longest question list
    if len(record.get('question')) > max_length:
        max_length = len(record.get('question'))
        print(x, record.get('question'), record.get('situation'))
print(max_length)

max_length = 0
for x, record in enumerate(sqa_train[:]):
    # show the record with the longest situation list
    if len(record.get('situation')) > max_length:
        max_length = len(record.get('situation'))
        print(x, record.get('situation'), record.get('question'))
print(x, max_length)



0 ['What', 'color', 'is', 'the', 'desk', 'to', 'my', 'right', '?'] ['I', 'am', 'facing', 'a', 'window', 'and', 'there', 'is', 'a', 'desk', 'on', 'my', 'right', 'and', 'a', 'chair', 'behind', 'me', '.']
1 ['What', 'is', 'on', 'the', '12', "o'clock", 'of', 'the', 'coffee', 'table', 'that', 'is', 'on', 'my', '1', "o'clock", '?'] ['I', 'am', 'sitting', 'on', 'the', 'edge', 'of', 'the', 'couch', 'with', 'a', 'curtain', 'right', 'next', 'to', 'me', 'on', 'the', 'left', '.']
5 ['What', 'object', 'to', 'my', '4', "o'clock", 'will', 'help', 'me', 'know', 'whether', 'I', 'am', 'running', 'late', 'to', 'class', '?'] ['I', 'am', 'facing', 'the', 'door', 'and', 'there', 'is', 'a', 'file', 'cabinet', 'on', 'my', 'left', '.']
14 ['Which', 'object', 'is', 'better', 'to', 'sit', 'on', 'while', 'working', '-', 'the', 'couch', 'or', 'the', 'chair', 'on', 'my', 'left', '?'] ['I', 'am', 'sitting', 'on', 'the', 'left', 'cushion', 'of', 'a', 'couch', 'facing', 'a', 'cabinet', '.']
106 ['Am', 'I', 'closer', '

In [4]:
train_scene_list = sorted(list(set([data["scene_id"] for data in sqa_train])))
val_scene_list = sorted(list(set([data["scene_id"] for data in sqa_val])))
scene_number_to_id = dataloader['val'].dataset.scene_number_to_id
# print(val_scene_list)
# print(first_batch['scene_id'][0])
# print(scene_number_to_id[first_batch['scene_id'][0].item()])

# print(len(train_scene_list), len(val_scene_list), len(all_scene_list)) # 518 67 585
# print('Size of the dataset: ', len(sqa_train), '. Number of Iterations: ', len(dataloader['train'])) # 26623 832

# print(sqa_train[0].keys()) # answers, object_ids, object_names, quesiton, situation, question_id, scene_id, position, original_question, original_situation
# print(first_batch.keys()) # 's_feat', 'q_feat', 'point_clouds', 's_len', 'q_len', 'center_label', 'heading_class_label', 'heading_residual_label', 'size_class_label', 'size_residual_label', 'num_bbox', 'sem_cls_label', 'box_label_mask', 'vote_label', 'vote_label_mask', 'scan_idx', 'pcl_color', 'auxiliary_task', 'scene_id', 'question_id', 'load_time', 'answer_cat', 'answer_cats', 'answer_cat_scores', 'question_type', 'situation', 'question', 'answers'
# print(first_batch['answers'][0][0])
# print(first_batch['question_id'][0])
# print(str(first_batch['question_id'][0].item()))
# print(first_batch['auxiliary_task'][0].tolist())
print(first_batch['auxiliary_task'][0].tolist())
print(first_batch['auxiliary_task'][1].tolist())
print(first_batch['auxiliary_task'][2].tolist())
print(first_batch['auxiliary_task'][3].tolist())
print(first_batch['auxiliary_task'][4].tolist())

# # 's_feat', 'q_feat', 'point_clouds', 's_len', 'q_len', 'center_label', 'heading_class_label', 'heading_residual_label', 'size_class_label', 'size_residual_label', 'num_bbox', 'sem_cls_label', 'box_label_mask', 'vote_label', 'vote_label_mask', 'scan_idx', 'pcl_color', 'auxiliary_task', 'scene_id', 'question_id', 'load_time', 'answer_cat', 'answer_cats', 'answer_cat_scores'
# print(first_batch['point_clouds'].shape) # torch.Size([32, 40000, 7])
# print(first_batch['s_len'].shape, first_batch['q_len'].shape) # torch.Size([32])
# print(first_batch['s_len']) 
# print(first_batch['center_label'].shape, first_batch['size_residual_label'].shape) # torch.Size([32, 128, 3])
# print(first_batch['heading_class_label'].shape, first_batch['heading_residual_label'].shape, first_batch['size_class_label'].shape) # torch.Size([32, 128])
# print(first_batch['num_bbox'].shape) # torch.Size([32])
# print(first_batch['sem_cls_label'].shape, first_batch['box_label_mask'].shape) # torch.Size([32, 128])
# print(first_batch['vote_label'].shape, first_batch['vote_label_mask'].shape) # torch.Size([32, 40000, 9]) torch.Size([32, 40000])
# print(first_batch['scan_idx'].shape, first_batch['pcl_color'].shape) # torch.Size([32]) torch.Size([32, 40000, 3])
# print(first_batch['auxiliary_task'].shape) # torch.Size([32, 7]) First 3 are position, next 4 are quaternion
# print(first_batch['scene_id'].shape, first_batch['question_id'].shape) # torch.Size([32])
# print(first_batch['load_time'].shape) # torch.Size([32])
# print(first_batch['answer_cat'].shape, first_batch['answer_cats'].shape, first_batch['answer_cat_scores'].shape) # torch.Size([32]) torch.Size([32, 707]) torch.Size([32, 707])
# print(first_batch['answer_cat'])

print('')
print('======================')
print('This cell is finished.')
print('======================')

[-1.1695787906646729, 0.03118199296295643, 1.5532602071762085, 0.0, 0.0, -0.03951023891568184, 0.9992191791534424]
[0.009532594121992588, -2.462644100189209, 1.868485927581787, 0.0, 0.0, 0.999764621257782, -0.02169468253850937]
[0.030481506139039993, -0.6087040305137634, 0.9354695677757263, 0.0, 0.0, -0.041241951286792755, 0.9991492033004761]
[0.14939717948436737, -0.8001063466072083, 0.9635432362556458, 0.0, 0.0, 0.05638403072953224, 0.9984091520309448]
[0.6156883835792542, -1.4739642143249512, 1.2038722038269043, 0.0, 0.0, -0.6923461556434631, 0.7215655446052551]

This cell is finished.


In [3]:
answer_counter = collections.Counter(sorted(answer_counter_list))
num_all_answers = len(answer_counter)
print(num_all_answers) # 707
most_common = answer_counter.most_common()
print(most_common[-20:])

counter = 0
for i in range(len(most_common))[14:]:
    counter += most_common[i][1]
print(counter) # 1000



print('')
print('======================')
print('This cell is finished.')
print('======================')

707
[('under desk', 2), ('under table', 2), ('underneath', 2), ('untidy', 2), ('upstairs', 2), ('vacuum', 2), ('vertical', 2), ('walk backward', 2), ('walk to left', 2), ('warm', 2), ('water', 2), ('whiteboards', 2), ('whitec', 2), ('window sill', 2), ('windshield', 2), ('wine', 2), ('wood beam', 2), ('wood paneling', 2), ('wooden chairs', 2), ('yellow and orange', 2)]
14715

This cell is finished.


In [52]:
from collections import defaultdict
from collections import Counter

# SQA_TRAIN = json.load(open(os.path.join(CONF.PATH.SQA, project_name + "_train.json")))
# print(SQA_TRAIN[0])
# SQA_TRAIN[0]['original_question'] = 'Example question'
# print(SQA_TRAIN[0]['question'])
# SQA_TRAIN[0].update({'answers': [1]})
# print(SQA_TRAIN[0])

grouped = defaultdict(list)

for record in SQA_TRAIN:
    first_word = record['original_question'].split()[0]  # split the sentence into words and take the first one
    grouped[first_word].append(record)

# Count the number of sentences in each group
counts = Counter({word: len(group) for word, group in grouped.items()}).most_common()
print(counts)
print(len(counts))

question_type = counts[6][0]
for i in range(len(grouped[question_type]))[:20]:
    print(grouped[question_type][i]['original_question'], grouped[question_type][i]['answers'])

counter = 0
for i in range(len(counts))[9:]:
    print(counts[i][0])
    counter += counts[i][1]
print(counter) # 1000



[('What', 8171), ('Is', 5055), ('How', 4014), ('Can', 2520), ('Which', 2269), ('Are', 891), ('If', 801), ('Where', 645), ('Am', 415), ('I', 325), ('Does', 256), ('The', 187), ('To', 140), ('In', 120), ('Do', 80), ("What's", 74), ('On', 69), ('Would', 67), ('From', 63), ('When', 51), ('Turning', 47), ('Will', 42), ('Could', 34), ('There', 32), ('My', 19), ('Should', 18), ('Need', 17), ('Gotta', 16), ('Did', 15), ('Looking', 14), ('After', 11), ('Between', 9), ('Besides', 8), ('It', 7), ('Have', 7), ('Behind', 6), ('Has', 5), ('Wanna', 5), ('Someone', 5), ('Towards', 4), ('A', 3), ('Going', 3), ('Across', 3), ("I've", 3), ('Feeling', 3), ('While', 2), ('Cam', 2), ('Wat', 2), ('Im', 2), ('Apart', 2), ('Excluding', 2), ('Wanting', 2), ('Walking', 2), ('Of', 2), ('All', 2), ('Without', 2), ('As', 2), ('Other', 2), ('Suppose', 2), ('One', 2), ('Who', 1), ('Whats', 1), ('Now', 1), ('Towhat', 1), ('Total', 1), ('So', 1), ('Directly', 1), ('Why', 1), ('Want', 1), ('Whacolor', 1), ('Above', 1), 

In [42]:
import json, sys
import numpy as np
from scipy.spatial.transform import Rotation as R
import argparse
sys.path.append(os.path.join(os.getcwd(), 'lib')) # HACK add the lib folder
from lib.config import CONF

all_annotations = json.load(open(os.path.join('./dataset/sqa3d/SQA3D/assets/data/sqa_task/balanced', 'v1_balanced_sqa_annotations_train_scannetv2.json'), 'r'))['annotations'] # + \
    # json.load(open(os.path.join('./dataset/sqa3d/SQA3D/assets/data/sqa_task/balanced', 'v1_balanced_sqa_annotations_val_scannetv2.json'), 'r'))['annotations'] + \
    # json.load(open(os.path.join('./dataset/sqa3d/SQA3D/assets/data/sqa_task/balanced', 'v1_balanced_sqa_annotations_test_scannetv2.json'), 'r'))['annotations']

all_data = \
    json.load(open(os.path.join('./dataset/sqa3d/SQA3D/ScanQA/data/qa', 'SQA_train.json'), 'r')) # + \
    # json.load(open(os.path.join('./dataset/sqa3d/SQA3D/ScanQA/data/qa', 'SQA_val.json'), 'r')) + \
    # json.load(open(os.path.join('./dataset/sqa3d/SQA3D/ScanQA/data/qa', 'SQA_test.json'), 'r'))


qid2annoid = {}
for i in range(len(all_annotations)):
    qid2annoid[all_annotations[i]["question_id"]] = i

print('all annotations: ', sorted(all_annotations[0].keys()))
print('all data: ', sorted(all_data[0].keys()))
print(len(all_annotations), len(all_data))
print(all_annotations[0]['position'], all_annotations[0]['rotation'])
print(all_data[0]['position'])

# compare all_annotations and all_data to see if they are the same

for i in range(len(all_data)):
    question_id = all_data[i]["question_id"]
    if all_data[i]["answers"][0] != all_annotations[qid2annoid[question_id]]["answers"][0]['answer'] and all_data[i]["answers"][0] != 'unknown':
        print('answers mismatch: ', i, all_data[i]["question"], all_data[i]["answers"][0], all_annotations[qid2annoid[question_id]]["answers"][0]['answer'])
    if all_data[i]["scene_id"] != all_annotations[qid2annoid[question_id]]["scene_id"]:
        print('scene_id mismatch: ', i, all_data[i]["scene_id"], all_annotations[qid2annoid[question_id]]["scene_id"])


all annotations:  ['answer_type', 'answers', 'position', 'question_id', 'question_type', 'rotation', 'scene_id']
all data:  ['answers', 'object_ids', 'object_names', 'position', 'question', 'question_id', 'scene_id', 'situation']
26623 26623
{'x': -0.9651003385573296, 'y': -1.2417634435553606, 'z': 0} {'_x': 0, '_y': 0, '_z': 0.09983341664682724, '_w': 0.9950041652780182}
[-0.9651003385573296, -1.2417634435553606, 0, 0, 0, 0.09983341664682724, 0.9950041652780182]
