In [1]:
import os
import json
import numpy as np

In [2]:
import pickle

In [3]:
!pwd

/usr/itetnas04/data-scratch-01/fencai/data/diora


In [4]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if word not in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

In [5]:
with open('./pytorch/data/partit_data/partnet.dict.pkl', 'rb') as r:
    x = pickle.load(r)

In [6]:
x

<__main__.Vocabulary at 0x7faa243cbb50>

## Validation

In [90]:
# 0.chair aaff8d73 model.step_329900.pt 1380e247 [all: ab1811e0 / 06db957b / 14126ca7]
# 1.table aea84162 model.step_317900.pt b3d65fcf [all: 1cc2528f / d42498e3 / 22f0f3e5]
# 2.bed e81dd9de model.step_57900.pt d43523b6 [all: 408181da / 2e1de6e1 / ea6511bd]
# 3.bag 9149db2e model.step_25900.pt 4bbd9841 [all: 4bb6f870 / 12b06d81 / 56c1faa4]

# 0.chair 037dee7e
# 1.table 1229d735
# 2.bed d66314ad
# 3.bag 87bf56b8

path = './log/87bf56b8/parse.jsonl'
type_ = '3.bag'

In [91]:
with open(path, 'r') as f:
    lines = [l.strip() for l in f.readlines()]

In [92]:
lines_res = [
    json.loads(l) for l in lines
]

In [93]:
lines_res[0]['tree'][0]

[['this', [['is', 'a'], ['bag', 'with']]], ['a', 'body']]

In [94]:
lines_res[0]['tree'][1]

[['and', ['a', ['single', 'handle']]], [['at', 'the'], 'top']]

In [95]:
def get_len(tree):
    if isinstance(tree, str):
        return 1

    return sum([get_len(x) for x in tree])

In [96]:
get_len(lines_res[0]['tree'])

14

In [97]:
# bfs
def get_spans(tree):
    queue = [(tree, 0)]
    spans = []

    while queue:
        current_node = queue.pop(0)

        tree = current_node[0]
        offset = current_node[1]

        spans.append((offset, offset + get_len(tree)-1))

        if not isinstance(tree[0], str):
            queue.append((tree[0], offset))

        if not isinstance(tree[1], str):
            queue.append((tree[1], offset + get_len(tree[0])))
    return set(spans)

In [98]:
def get_stats(span1, span2):
    tp = 0
    fp = 0
    fn = 0
    for span in span1:
        if span in span2:
            tp += 1
        else:
            fp += 1

    for span in span2:
        if span not in span1:
            fn += 1
#     print('tp: {}; fp: {}; fn: {}'.format(tp, fp, fn))
    return tp, fp, fn

In [99]:
sent_f1_txt, corpus_f1_txt = [], [0., 0., 0.]

for idx, line in enumerate(lines_res):
    pred_txt = get_spans(line['tree'])
    example_id = line['example_id']
    with open(os.path.join(f'pytorch/data/partit_data/{type_}/test/', example_id, 'lan_spans.txt'), 'r') as w:
        gold_txt = json.loads(w.read())
    gold_txt = set([(a, b) for a, b in gold_txt])
    
    tp_txt, fp_txt, fn_txt = get_stats(pred_txt, gold_txt) 
    corpus_f1_txt[0] += tp_txt
    corpus_f1_txt[1] += fp_txt
    corpus_f1_txt[2] += fn_txt

    overlap_txt = pred_txt.intersection(gold_txt)
    prec_txt = float(len(overlap_txt)) / (len(pred_txt) + 1e-8)
    reca_txt = float(len(overlap_txt)) / (len(gold_txt) + 1e-8)

    if len(gold_txt) == 0:
        reca_txt = 1. 
        if len(pred_txt) == 0:
            prec_txt = 1.
    f1_txt = 2 * prec_txt * reca_txt / (prec_txt + reca_txt + 1e-8)
    sent_f1_txt.append(f1_txt)

tp_txt, fp_txt, fn_txt = corpus_f1_txt  
prec_txt = tp_txt / (tp_txt + fp_txt)
recall_txt = tp_txt / (tp_txt + fn_txt)
corpus_f1_txt = 2 * prec_txt * recall_txt / (prec_txt + recall_txt) if prec_txt + recall_txt > 0 else 0.
sent_f1_txt = np.mean(np.array(sent_f1_txt))
print('prec_txt: ', prec_txt)
print('recall_txt: ', recall_txt)
print('corpus_f1_txt: ', corpus_f1_txt)
print('sent_f1_txt: ', sent_f1_txt)

prec_txt:  0.42033898305084744
recall_txt:  0.42033898305084744
corpus_f1_txt:  0.42033898305084744
sent_f1_txt:  0.42260239227647994


In [100]:
variable = 42
user_input = "The answer is {variable}"
user_input_formatted = user_input.format(variable=variable)
print(user_input_formatted)

The answer is 42


## Check the coverage between train and test

In [92]:
file_dir = f'./pytorch/data/partit_data//train/'.format(type_)
dir_list = [x for x in os.listdir(file_dir) if '.' not in x]
textfile_list = [
    os.path.join(file_dir, dir_name, 'utterance.txt') for dir_name in dir_list
]

sentences = []

for textfile in textfile_list:
    with open(textfile, 'r') as r:
        sentences.append(r.read().replace(".", "").replace(",", " , ").replace(":", " : ").replace(";", " ; ").replace("/", " ").replace("\'", " \'").replace("\"", " \"").strip().split())

In [97]:
word_train_set = set([tk for tks in sentences for tk in tks])

In [99]:
len(word_train_set)

1480

In [100]:
file_dir = './pytorch/data/partit_data/0.chair/test/'
dir_list = [x for x in os.listdir(file_dir) if '.' not in x]
textfile_list = [
    os.path.join(file_dir, dir_name, 'utterance.txt') for dir_name in dir_list
]

sentences = []

for textfile in textfile_list:
    with open(textfile, 'r') as r:
        sentences.append(r.read().replace(".", "").replace(",", " , ").replace(":", " : ").replace(";", " ; ").replace("/", " ").replace("\'", " \'").replace("\"", " \"").strip().split())

In [101]:
word_test_set = set([tk for tks in sentences for tk in tks])

In [102]:
len(word_test_set)

696

In [104]:
len(word_train_set.intersection(word_test_set))

544