In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import re
import os
import sys
import json
import tempfile
import subprocess
import collections

# import util
import conll
from bert import tokenization

In [3]:
def flatten(l):
    return [item for sublist in l for item in sublist]

In [83]:
# convert the medical data to the spanbert format
raw_train = []
raw_test = []

with open("./data/medical_train.jsonlines", encoding="utf-8") as f:
    for line in f.readlines():
        raw_train.append(json.loads(line))
        
with open("./data/medical_test.jsonlines", encoding="utf-8") as f:
    for line in f.readlines():
        raw_test.append(json.loads(line))

In [5]:
raw_train[0].keys()

dict_keys(['sentences', 'clusters', 'doc_key', 'speakers'])

In [6]:
vocab_file = "./cased_config_vocab/vocab.txt"
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=False)

In [79]:
seg_lens =  [128, 256, 384, 512]

In [88]:
def process_medic(raw, tokenizer, max_seg_len):
    new = dict()
    new["doc_key"] = raw["doc_key"]
    
    new["clusters"] = [] 
    new["sentences"] = [] #
    new["speakers"] = [] #
    new["sentence_map"] = [] #
    new["subtoken_map"] = [] #
    
    # new["segments"] = []
    
    new_tokens = []
    speaker_id = raw["speakers"][0][0]
    
    sentence_tokens = flatten(raw["sentences"])
    
#     cluster_map = []
#     for token in sentence_tokens:
#         cluster_map.append(0)
        
#     for i, cluster in enumerate(clusters):
#         new["clusters"].append([]) # cluster prototype
#         for const in clusters:
#             for j in range(const[0],(const[1]+1)):
#                 cluster_map[j] = i+1
                
    new_cluster_map = []
    word_idx = 0
    seg_count = 0
    subtoken_count = 0
    
    
    for i,sentence in enumerate(raw["sentences"]):

        if seg_count == 0:
            temp_segs=["[CLS]"]
            new["subtoken_map"].append(word_idx)
            new["sentence_map"].append(i)# for cls
            seg_count += 1
            subtoken_count += 1
        
        # temp_speakers.append("[SPL]")
        
        for j,token in enumerate(sentence):
            # visit each token            
            if token == "/." or token == "/?" :
                token = token[1:]
        
#             if word_idx in [82,83,84]:
#                 print(word_idx)
#                 print(token)
            
            subtokens = tokenizer.tokenize(token)
            
            if token == "":
                subtokens = ["[UNK]"]
            
#             if word_idx in [82,83,84]:
#                 print(subtokens)
            
            for sidx, subtoken in enumerate(subtokens):
                # temp_tokens.append(subtokens)
                temp_segs.append(subtoken)
                new["subtoken_map"].append(word_idx)
                new["sentence_map"].append(i)
                seg_count += 1
                subtoken_count += 1
                
                final_end_condition = (i==(len(raw["sentences"])-1) and j==(len(sentence)-1) and sidx==(len(subtoken)-1))
                
                if seg_count == (max_seg_len-1) or final_end_condition:
                    # temp_tokens.append("[SEP]")
                    # temp_speakers.append("[SPL]")
                    temp_segs.append("[SEP]")
                    new["sentences"].append(temp_segs)
                    new["subtoken_map"].append(word_idx)
                    new["sentence_map"].append(i)
                    
                    tmp_speaker = ["[SPL]"]
                    for tsg in range(len(temp_segs)-2):
                        tmp_speaker.append(speaker_id)
                    tmp_speaker.append("[SPL]")
                    new["speakers"].append(tmp_speaker)
                    
                    subtoken_count += 1
                    
                    if not final_end_condition:
                        temp_segs = ["[CLS]"]
                        new["subtoken_map"].append(word_idx)
                        new["sentence_map"].append(i)
                        subtoken_count += 1
                        seg_count = 1

            word_idx += 1
    
    all_tokens = flatten(new["sentences"])
            
    # dealing with the clusters:
    for cluster in raw["clusters"]:
        temp_cluster = []
        for consts in cluster:
            new_start = new["subtoken_map"].index(consts[0])
            new_end = len(new["subtoken_map"]) - new["subtoken_map"][::-1].index(consts[1]) - 1
            temp_cluster.append([new_start, new_end])
            # print(all_tokens[new_start:new_end+1])
        new["clusters"].append(temp_cluster)
        
    
    # checking
    # merge clusters
#     merged_clusters = []
#     for c1 in new["clusters"]:
#         existing = None
#         for m in c1:
#             for c2 in merged_clusters:
#                 if m in c2:
#                     existing = c2
#                     break
#             if existing is not None:
#                 break
#         if existing is not None:
#             print("Merging clusters (shouldn't happen very often.)")
#             existing.update(c1)
#         else:
#             merged_clusters.append(set(c1))
#     merged_clusters = [list(c) for c in merged_clusters]
    
    return new

In [89]:
tokens = flatten(raw_train[55]["sentences"])
# print(len(tokens))
# print(tokens[80:86])
# print(tokens[83])

new_piece = process_medic(raw_train[55], tokenizer, 128)

In [90]:
# process_medic(raw_train[0], tokenizer, max_seg_len=128)
for seg_len in seg_lens:
    with open("./data/train.medical."+str(seg_len)+".jsonlines","w") as f:
        for i,raw_piece in enumerate(raw_train):
            # print(i)
            new_piece = process_medic(raw_piece, tokenizer, seg_len)
            f.write(json.dumps(new_piece))
            f.write("\n")
            
    with open("./data/test.medical."+str(seg_len)+".jsonlines","w") as f:
        for raw_piece in raw_test:
            new_piece = process_medic(raw_piece, tokenizer, seg_len)
            f.write(json.dumps(new_piece))
            f.write("\n")