In [1]:
import transformers
import json
from Levenshtein import distance
import string
import random

In [2]:
def get_json_value_by_key(in_json, target_key, results=[]):
    if isinstance(in_json, dict):  # 如果输入数据的格式为dict
        for key in in_json.keys():  # 循环获取key
            data = in_json[key]
            get_json_value_by_key(data, target_key, results=results)  # 回归当前key对于的value

            if key == target_key:  # 如果当前key与目标key相同就将当前key的value添加到输出列表
                results.append(data)

    elif isinstance(in_json, list) or isinstance(in_json, tuple):  # 如果输入数据格式为list或者tuple
        for data in in_json:  # 循环当前列表
            get_json_value_by_key(data, target_key, results=results)  # 回归列表的当前的元素

    return results

In [3]:
with open('data_final_with_images_filtered_include_madeup_corrected.json','r') as f:
    data = f.read()
    data = json.loads(data)
sentences = get_json_value_by_key(data, 'transcript',results=[])

In [4]:
salient_words = get_json_value_by_key(data, 'salient_words',results=[])

In [5]:
def correct_label(ref, sentences,salient_words):
    with open(ref, 'r') as f:
        correct_data = f.readlines()
        for i in range(0,len(correct_data),2):
            try:
                sent = correct_data[i].strip('\n')
                correct_spans = correct_data[i+1].strip('\n').split('\t')
                index = sentences.index(sent)
                salient_words[index] = correct_spans
            except:
                print(i)




In [9]:
correct_label('error_labeled_sentence1.txt',sentences,salient_words)
for i in [7012,10612,17537,18142,21794,22930]:  
    salient_words[i] = []
salient_words[31382] = ['sambal kangkong','spicy food','XO rice','penang rendang','nonya curry']

74


'Their sambal kangkong is a must try if you are a fan of spicy food! Other signature dishes include their XO rice, penang rendang and their nonya curry!'

In [10]:
tokenizer = transformers.AutoTokenizer.from_pretrained('vblagoje/bert-english-uncased-finetuned-pos',use_fast=False)
salient_tokens_tag = []
for item in salient_words:
    if len(item) == 0:
        salient_tokens_tag.append([])
    else:
        token_list_per_sentence = []
        for span in item:
            tokens = []
            temperory_tokens = tokenizer.tokenize(span)   ## 作为之后查询列表的子集
            for token in temperory_tokens:
                if token not in '.?!,:;' and token != ' ':
                    tokens.append(token)
            tags = ['B']
            tags.extend(['I' for i in range((len(tokens)-1))])
            token_list_per_sentence.append((tokens,tags))
        salient_tokens_tag.append(token_list_per_sentence)
            

In [11]:
## 使用滑动窗口列表子集查询的方式，将span words的token赋值给整个句子(e.g. ['salmon', 'wrap'])
def exact_match(target, list):
    for i in range(len(list)-len(target)+1):
        if target == list[i:i+len(target)]:
            return i
            break

def stripping(list):
    str1 = ''
    for item in list:
        if item.startswith('##'):
            str1 += item[2:]
        else:
            if str1 == '':
                str1 = item
            else:
                str1 += item
    for i in '.?!,:;':
        str1 = str1.replace(i, '')
    str1.strip()
    return str1
    
def approximate_match(target, list):
    str1 = stripping(target)
    for i in range(len(list)):
        for j in range(i+1,len(list)):
            str2 = stripping(list[i:j])
            if distance(str1,str2) <= 1:
                return i,j  #返回错误的span对应句子中正确的位置
                break

final_output = []
error_sentence_index = []
assert len(sentences) == len(salient_tokens_tag)
for i in range(len(sentences)):
    salient_reference = salient_tokens_tag[i]
    tokenized_sentence = tokenizer.tokenize(sentences[i])
    token_labels = ['O' for i in range(len(tokenized_sentence))] 
    if len(salient_reference) > 0:
        for (span, span_label) in salient_reference:
            position = exact_match(span, tokenized_sentence)
            if position == None:
                try:
                    position, end = approximate_match(span, tokenized_sentence)
                    span = tokenized_sentence[position:end]  ## 获得正确的span
                    span_label = ['B']
                    span_label.extend(['I' for i in range((len(span)-1))])
                except:
                    pass
            if position != None:
                token_labels[position:position+len(span)] = span_label
            else:
                error_sentence_index.append(i)
    final_output.append((tokenized_sentence, token_labels))


In [12]:
# double check coherence between label and sentence
for sentence, label in final_output:
    if len(sentence) != len(label):
        print(final_output.index((sentence, label)))
        # here should be no output printed

In [20]:
# combine adjacent two sentences into one turn, with addition of [CLS] and [SEP] tokens
assert len(final_output) % 2 == 0
output_per_turn = []
for i in range(0,len(final_output),2):
    sentence_turn = ['[CLS]']+final_output[i][0]+['[SEP]']+final_output[i+1][0]+['[SEP]']
    label_turn = ['O']+final_output[i][1]+['O']+final_output[i+1][1]+['O']
    output_per_turn.append((sentence_turn, label_turn))

In [23]:
## split data into train, dev and test
def split_data(full_list, ratio = [0.8,0.1,0.1], seed = 0):
    random.seed(seed)
    assert sum(ratio) == 1
    cutpoint1 = round(len(full_list) * ratio[0])
    cutpoint2 = round(len(full_list) * sum(ratio[:2]))
    random.shuffle(full_list)
    train_set = full_list[:cutpoint1]
    dev_set = full_list[cutpoint1:cutpoint2]
    test_set = full_list[cutpoint2:]
    with open(f'dataset/train_set.json','w') as g:
        json.dump(train_set, g)
    with open(f'dataset/dev_set.json','w') as g:
        json.dump(dev_set, g)
    with open(f'dataset/test_set.json','w') as g:
        json.dump(test_set, g)

split_data(output_per_turn)

In [48]:
## check the saved data
import json
with open('input_data.json','r') as f:
    data = f.read()
    data = json.loads(data)
for i in range(len(data)):
    if len(data[i][0]) != len(data[i][1]):
        print(i)

In [24]:
def get_max_length(file):
    with open(file, 'r') as f:
        data = json.loads(f.read())
        max_len = 0
        for sentence, label in data:
            if len(sentence) > max_len:
                max_len = len(sentence)
    return max_len
    

In [80]:
171  213  153

10

In [1]:
output_per_turn[1]

NameError: name 'output_per_turn' is not defined