In [1]:
import copy

from tqdm import tqdm

In [2]:
class Dependency():
    def __init__(self, idx, word, tag, head, rel):
        self.id = idx
        self.word = word
        self.tag = tag
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.id), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.tag}, {self.head}, {self.rel})"

In [3]:
def load_codt(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []
    
    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0:
                yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
                sentence = []
            elif len(toks) == 10:
                dep = Dependency(toks[0], toks[1], toks[3], toks[6], toks[7])
                sentence.append(dep)

In [4]:
file1 = '../aug/diag_codt_sampled/diag_dev_sampled_fixed.conll'
file2 = '../aug/diag_weakcodt_sampled/diag_dev_sampled.conll'

In [5]:
data1 = [d for d in load_codt(file1)]
data2 = [d for d in load_codt(file2)]

print(len(data1), len(data2))

7435 6576


In [6]:
sentences1 = [' '.join([dep.word for dep in deps]) for deps in data1]
sentences2 = [' '.join([dep.word for dep in deps]) for deps in data2]

In [7]:
same_cnt = 0
for i, sentence1 in tqdm(enumerate(sentences1)):
    for j, sentence2 in enumerate(sentences1):
        
        if sentence1 == sentence2:
            same_cnt += 1
            sentences1.pop(j)
            data1.pop(j)
            continue
            
print('----------------')
print(same_cnt)

same_cnt = 0
for i, sentence1 in tqdm(enumerate(sentences2)):
    for j, sentence2 in enumerate(sentences2):
        
        if sentence1 == sentence2:
            same_cnt += 1
            sentences2.pop(j)
            data2.pop(j)
            continue

print(same_cnt)

2894it [00:01, 1770.69it/s]


----------------
4541


2566it [00:01, 2031.66it/s]

4010





In [8]:
sentences1 = [' '.join([dep.word for dep in deps]) for deps in data1]
sentences2 = [' '.join([dep.word for dep in deps]) for deps in data2]

In [9]:
merged = []

same_cnt = 0
for i, sentence1 in tqdm(enumerate(sentences1)):
    
    tmp = []
    for j, sentence2 in enumerate(sentences2):
        
        if sentence1 == sentence2:
            same_cnt += 1
            sentences2.pop(j)
            data2.pop(j)
            continue
    
    merged.append(data1[i])

merged.extend(data2)
print(same_cnt)
print(len(merged))

2894it [00:00, 4143.17it/s]

1195
4265





In [10]:
out_file = '../aug/diag_weakcodt_sampled/diag_dev_merged.conll'
fw = open(out_file, 'w+', encoding='utf-8')

for i, d in enumerate(merged):
    for idx, dep in enumerate(d):
        save_str = f'{dep.id}\t{dep.word}\t_\t_\t_\t_\t{dep.head}\t{dep.rel}\t_\t_\n'
        fw.write(save_str)
        
    fw.write('\n')
        
fw.close()