### setting

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install -q transformers

### tokenization

In [3]:
# useful functions for data cleansing

import unicodedata

def decode_bytes(text):

  '''check for bytes within text and decode them'''

  if isinstance(text, str):
    return text
  elif isinstance(text, bytes):
    return text.decode("utf-8", "ignore")
  else:
    raise ValueError(f"unsupported string type: {text}")

def whitespace_tokenize(text):

  '''strip text and split them'''

  text = text.strip()
  if not text:
    return []
  tokens = text.split()
  return tokens

def is_control(char):

  ''' detect control characters except whitespace '''

  if unicodedata.category(char).startswith('C'):
    if char == '\t' or char == '\n' or char == '\r': # category starts with C but is whitespace
      return False
    return True
  else:
    return False

def is_whitespace(char):

  ''' check whether character is whitespace '''

  if char == ' ' or char == '\t' or char == '\n' or char == '\r':
    return True
  elif unicodedata.category(char) == "Zs": # "space separator" category
    return True
  else:
    return False

def is_punc(char):

  ''' check whether character is a punctuation '''
  # ~32 : 각종 공백문자, 48~57 : 숫자, 65~90 : 대문자 알파벳, 97~122 : 소문자 알파벳, 126~ : 이상한 기호들
  if (ord(char) >= 33 and ord(char) <= 47) or (ord(char) >= 58 and ord(char) <= 64):
    return True
  elif (ord(char) >= 91 and ord(char) <= 96) or (ord(char) > 123 and ord(char) <= 126):
    return True
  elif unicodedata.category(char).startswith("P"): # punctuation
    return True
  else:
    return False

In [4]:
import collections

def load_vocab(vocab_file):

  ''' load vocab file as dictionary '''

  vocab_dict = collections.OrderedDict()
  idx=0

  with open(vocab_file, "r") as rf:
    tokens = rf.read().splitlines()

  for token in tokens:
    token = decode_bytes(token)
    token = token.strip()
    vocab_dict[token] = idx
    idx += 1

  return vocab_dict

In [25]:
class BasicTokenizer(object):
  def __init__(self, do_lower_case=True):
    self.do_lower_case = do_lower_case

  def clean_text(self, text):

    ''' skip invalid characters, convert all whitespaces into single space, and return text '''

    output = []
    for char in text:
      # check if char is NULL,�(unrecognizable), or control character
      if ord(char) == 0 or ord(char) == 65533 or is_control(char):
        continue
      if is_whitespace(char):
        output.append(" ") # all whitespace -> " "
      else:
        output.append(char)
    return "".join(output)

  def run_split_on_punc(self, text):

    ''' split text based on punctuations '''

    output = []
    token = ''
    for char in list(text):
      if is_punc(char):
        output.append(token)
        output.append(char)
        token = ''
      else:
        token += char

    if len(token) > 0:
      output.append(token)

    return output

  def tokenize(self, text):
    text = decode_bytes(text)
    text = self.clean_text(text)

    orig_tokens = whitespace_tokenize(text) # 공백문자 단위로 분리
    split_tokens = []

    for token in orig_tokens: # 토큰별로 punctuation 있을 경우 분리
      if self.do_lower_case:
        token = token.lower()
      split_tokens.extend(self.run_split_on_punc(token))

    return split_tokens

In [6]:
class WordpieceTokenizer(object):
  def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
    self.vocab = vocab
    self.unk_token = unk_token
    self.max_input_chars_per_word = max_input_chars_per_word

  def tokenize(self, text):
    ''' input : a single token (via BasicTokenizer), output : wordpiece tokens

    wordpiece merges characters by "score = freq(pair) / freq(first)*freq(second)".
    pairs that frequently appear together are merged, but not if each element also appears frequently

    ex) unable : "un", "##able"
    ex) hugging : "hugging" '''

    text = decode_bytes(text)

    output_tokens = []
    for token in whitespace_tokenize(text):
      chars = list(token)
      # 너무 긴 단어 (>200)는 unk 처리
      if len(chars) > self.max_input_chars_per_word:
        output_tokens.append(self.unk_token)
        continue

      is_bad = False
      start = 0
      sub_tokens = []

      # ex. unfriendly : length 10
      while start < len(chars): # 0~9
        end = len(chars) # 10
        cur_substr = None
        while start < end:
          substr = "".join(chars[start:end]) # unfriendly, unfriendl, unfriend, ... , un
          if start > 0:
            substr = "##" + substr # indicates subword
          if substr in self.vocab: # "unfriendly" not in vocab / ... "un" in vocab
            cur_substr = substr
            break
          end -= 1
        if cur_substr is None:
          is_bad = True
          break
        sub_tokens.append(cur_substr)

        start = end # subword 떨어져나간 경우 (un) 끝자리->시작자리 (f)

      if is_bad: # character 한개도 vocab에 없는 경우
        output_tokens.append(self.unk_token)
      else:
        output_tokens.extend(sub_tokens)

    return output_tokens



In [76]:
class FullTokenizer(object):
  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.vocab_reverse = {v:k for k,v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece = WordpieceTokenizer(vocab=self.vocab)

  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text): # cleanse text, tokenize by punctuation
      for sub_token in self.wordpiece.tokenize(token): # tokenize tokens (subtokens)
        split_tokens.append(sub_token)

    return split_tokens

  def convert_tokens_to_ids(self, tokens):
    output = []
    for token in tokens:
      output.append(self.vocab[token])
    return output

### finetune biobert

In [8]:
from absl import flags

flags.DEFINE_string("bert_config_file", None, "config json file corresponding to the pretrained BERT model")
flags.DEFINE_string("vocab_File", None, "vocab file on which BERT was trained")
flags.DEFINE_string("output_dir", None, "output dir for model checkpoints")
flags.DEFINE_string("train_file", None, "squad-formatted json for training. E.g., train-v1.1.json")
flags.DEFINE_string("predict_file", None, "squad-formatted json for predictions. E.g., dev(test)-v1.1.json")
flags.DEFINE_string("init_checkpoint", None, "initial checkpoint (usually from a pretrained BERT model)")

flags.DEFINE_bool("do_lower_case", True, "whether to lower-case input text. True for lower case.")
flags.DEFINE_bool("do_train", True, "whether to run training")
flags.DEFINE_bool("do_predict", True, "whether to run eval on the dev set")
flags.DEFINE_bool("use_tpu", False, "whether to use TPU or GPU/CPU")
flags.DEFINE_bool("verbose_logging", False, "whether to print all warnings during data processing. Other warnings are printed by default.")

flags.DEFINE_integer("max_seq_len", 384, "maximum input sequence after WordPiece tokenization. Will be padded if shorter, truncated if longer.")
flags.DEFINE_integer("doc_stride", 128, "when splitting up a long document into chunks, how much stride to take between chunks")
flags.DEFINE_integer("max_query_len", 64, "maximum query sequence after WordPiece tokenization. Will be truncated if longer.")
flags.DEFINE_integer("predict_batch_size", 8, "total batch size for predictions")
flags.DEFINE_integer("save_checkpoint_step", 1000, "how often to save model checkpoints")
flags.DEFINE_integer("iterations_per_loop", 1000, "how many steps to make in each estimator call")
flags.DEFINE_integer("n_best_size", 20, "how many n-best predictions to generate in nbest_predictions.json output file")
flags.DEFINE_integer("max_answer_len", 30, "maximum length of generated answer")

flags.DEFINE_float("learning_rate", 5e-5, "initial learning rate for Adam optimizer")
flags.DEFINE_float("num_train_epochs", 3.0, "number of training epochs")
flags.DEFINE_float("warmup_proportion", 0.1, "proportion of training for linear lr warmup. E.g., 0.1 refers to 10% of training")
flags.DEFINE_float("null_score_diff_threshold", 0.0, "to predict null if (null_score - best_non_null) > threshold")

<absl.flags._flagvalues.FlagHolder at 0x7fae0c1a1b40>

In [10]:
class SquadExample(object):
  def __init__(self, qas_id, question_text, doc_tokens, origin_answer_text=None, start_pos=None, end_pos=None):
    self.qas_id = qas_id
    self.question_text = question_text
    self.doc_tokens = doc_tokens
    self.origin_answer_text = origin_answer_text
    self.start_pos = start_pos
    self.end_pos = end_pos

  def __str__(self):
    return self.__repr__()

  def __repr__(self):
    s = f"qas_id: {decode_bytes(self.qas_id)}" # check if text is in str, bytes or unicode and convert
    s += f", question_text: {decode_bytes(self.question_text)}"
    s += f", doc_tokens: {[' '.join(self.doc_tokens)]}"

    if self.start_pos:
      s += f", start_pos: {self.start_pos}, end_pos: {self.end_pos}"

    return s


In [12]:
def read_squad_examples(input_file, is_training):
  """Read a squad json file into a list of SquadExample"""

  with open(input_file, "r") as rf:
    input_data = json.load(rf)["data"][0]["paragraphs"]

  examples = []
  for entry in input_data:
    doc_tokens = []
    char_to_word_offset = []
    prev_is_whitespace = True

    context = entry["context"]
    for c in context:
      if is_whitespace(c):
        prev_is_whitespace=True
      else:
        if prev_is_whitespace: # c comes after whitespace -> add as new token
          doc_tokens.append(c)
        else:
          doc_tokens[-1] += c # c comes after c -> attach to recent token
        prev_is_whitespace=False
      char_to_word_offset.append(len(doc_tokens) -1) # "I went home" -> [0,0,1,1,1,1,1,2,2,2,2]

    qas = entry["qas"]
    for qa in qas:
      qas_id = qa["id"]
      question_text = qa["question"]
      start_pos = None
      end_pos = None
      origin_answer_text = None

      if is_training:
        if len(qa["answers"]) != 1:
          raise ValueError("Each question should have exactly 1 answer.")

        answer = qa["answers"][0]
        origin_answer_text = answer["text"] #"Bazex syndrome"
        answer_offset = answer["answer_start"] #93
        answer_length = len(origin_answer_text)
        start_pos = char_to_word_offset[answer_offset] # 몇번째 단어부터
        end_pos = char_to_word_offset[answer_offset + answer_length-1] # 몇번째 단어까지

        # check if answer text is extractable from context (아닌 경우 건너뜀)
        answer_from_context = " ".join(doc_tokens[start_pos:end_pos+1]) # context에서 주어진 범위로 인덱싱한 정답
        cleaned_answer = " ".join(whitespace_tokenize(origin_answer_text)) # 실제 정답 (strip 및 split)

        if answer_from_context.find(cleaned_answer) == -1: # context에 정답이 들어있지 않은 경우
          tf.logging.warning(f"Could not find answer from context. {answer_from_context} vs {cleaned_answer}")
          continue

      else: # inference용
        start_pos = -1
        end_pos = -1
        origin_answer_text = ""

      example = SquadExample(qas_id=qas_id, question_text=question_text, doc_tokens=doc_tokens,
                              origin_answer_text=origin_answer_text, start_pos=start_pos, end_pos=end_pos)
      examples.append(example)

  return examples

In [13]:
examples = read_squad_examples('/content/drive/MyDrive/cose474_final/trainset_bioasq.json', is_training=True)

In [15]:
len(examples)

11171

In [69]:
def check_is_max_context(doc_spans, cur_span_idx, pos):
  '''check whether a span gives maximum "score" for a token.
  ex) Span A: the man went to the
  Span B: to the store and bought
  Span C: and bought a gallon of

  "bought" will have 2 scores from span B and C.
  we choose span with maximum context, which is defined as min(left, right).
  Span B : min(4,0) = 0
  Span C : min(1,3) = 1

  We choose Span C as the maximum context for the token "bought".'''

  best_score = None
  best_span_idx = None

  for idx, doc_span in enumerate(doc_spans):
    end = doc_span.start + doc_span.length - 1
    if pos < doc_span.start:
      continue
    if pos > end:
      continue
    num_left_context = pos - doc_span.start
    num_right_context = end - pos
    score = min(num_left_context, num_right_context) + 0.01*doc_span.length
    if best_score is None or score > best_score:
      best_score = score
      best_span_idx = idx

  return cur_span_idx == best_span_idx

In [88]:
import json
import tensorflow as tf

class InputFeatures(object):
  def __init__(self, unique_id, example_index, doc_span_index, tokens, token_to_origin_map, token_is_max_context,
             input_ids, input_mask, segment_ids, start_pos=None, end_pos=None):
    self.unique_id = unique_id
    self.example_index = example_index
    self.doc_span_index = doc_span_index
    self.tokens = tokens
    self.token_to_origin_map = token_to_origin_map
    self.token_is_max_context = token_is_max_context
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.start_pos = start_pos
    self.end_pos = end_pos

In [93]:
def convert_examples_to_features(examples, tokenizer, max_seq_len, doc_stride, max_query_len, is_training, output_fn):
  unique_id = 1000000000

  for example_idx, example in enumerate(examples):
    query_tokens = tokenizer.tokenize(example.question_text)

    if len(query_tokens) > max_query_len:
      query_tokens = query_tokens[:max_query_len] # query 최대 길이로 제한

    word_to_subtoken_idx = [] # 각 단어가 subtoken 기준으로 몇번째인지 (ex. 3번째 단어 -> 9번째 subword부터)
    subtoken_to_word_idx = [] # 각 subtoken이 원래 몇번째 단어에 속하는지 (ex. 9번째 subword -> 3번째 단어)
    all_doc_tokens = [] # 모든 subtoken 모음

    for i, token in enumerate(example.doc_tokens): # Psoriasiform, dermatitis, in, a, case, ...
      word_to_subtoken_idx.append(len(all_doc_tokens))
      sub_tokens = tokenizer.tokenize(token) # tokenize into subwords
      for sub_token in sub_tokens:
        subtoken_to_word_idx.append(i)
        all_doc_tokens.append(sub_token)

    tok_start_pos = None
    tok_end_pos = None

    print(f'all_doc_tokens : {all_doc_tokens}')
    print(f'word_to_subtoken_idx : {word_to_subtoken_idx}')
    print(f'subtoken_to_word_idx : {subtoken_to_word_idx}')

    if is_training:
      tok_start_pos = word_to_subtoken_idx[example.start_pos]
      if example.end_pos < len(example.doc_tokens) - 1:
        tok_end_pos = word_to_subtoken_idx[example.end_pos + 1] - 1
      else: # 범위 벗어난 경우
        tok_end_pos = len(all_doc_tokens) - 1

    print(f'tok_start_pos : {tok_start_pos}')
    print(f'tok_end_pos : {tok_end_pos}')
    print(f'answer tokens : {all_doc_tokens[tok_start_pos:tok_end_pos+1]}')

    # max_seq_len = len(query_tokens) + len(doc_tokens) + 3 ([CLS], [SEP], [SEP] : 문장 시작, 문장 구분, 문장 끝)
    max_tokens_for_doc = max_seq_len - len(query_tokens) - 3 # (384-11-3) = 370

    # sliding window for long document training
    Docspan = collections.namedtuple('Docspan', ['start', 'length'])
    doc_spans = []
    start = 0 # 0 -> 128 -> 256 -> 384
    while start < len(all_doc_tokens):
      length = len(all_doc_tokens) - start # 729 -> 601 -> 473 -> 345
      if length > max_tokens_for_doc:
        length = max_tokens_for_doc
      doc_spans.append(Docspan(start=start, length=length))
      if start + length == len(all_doc_tokens): # 384 + 345 = 729
        break
      start += min(length, doc_stride)
    print(doc_spans)

    for doc_span_idx, doc_span in enumerate(doc_spans):
      tokens = []
      token_to_orig_map = {}
      token_is_max_context = {}
      segment_ids = [] # [cls]query[sep] / doc[sep] : 00000/11111 로 구분

      tokens.append("[CLS]")
      segment_ids.append(0)

      for token in query_tokens:
        tokens.append(token)
        segment_ids.append(0)
      tokens.append("[SEP]")
      segment_ids.append(0)

      for i in range(doc_span.length):
        doc_idx = doc_span.start + i
        token_to_orig_map[len(tokens)] = subtoken_to_word_idx[doc_idx]

        is_max_context = check_is_max_context(doc_spans, doc_span_idx, doc_idx)
        token_is_max_context[len(tokens)] = is_max_context
        tokens.append(all_doc_tokens[doc_idx])
        segment_ids.append(1)

      tokens.append("[SEP]")
      segment_ids.append(1)

      # encode tokens to ids (index numbers of vocabs)
      input_ids = tokenizer.convert_tokens_to_ids(tokens)

      # masking (1 for real tokens, 0 for padding tokens)
      input_mask = [1] * len(input_ids)
      while len(input_ids) < max_seq_len:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

      assert len(input_ids) == max_seq_len
      assert len(input_mask) == max_seq_len
      assert len(segment_ids) == max_seq_len

      start_pos = None
      end_pos = None

      if is_training:
        doc_start = doc_span.start
        doc_end = doc_span.start + doc_span.length - 1
        out_of_span = False

        if not (tok_start_pos >= doc_start and tok_end_pos <= doc_end): # tok_start_pos : word_to_subtoken_idx[example.start_pos]
          out_of_span = True
        if out_of_span:
          start_pos = 0
          end_pos = 0
        else:
          doc_offset = len(query_tokens) + 2
          start_pos = tok_start_pos - doc_start + doc_offset
          end_pos = tok_end_pos - doc_start + doc_offset

      ##############################################################################################################

      feature = InputFeatures(
          unique_id=unique_id,
          example_index=example_idx,
          doc_span_index=doc_span_idx,
          tokens=tokens,
          token_to_origin_map=token_to_orig_map,
          token_is_max_context=token_is_max_context,
          input_ids=input_ids,
          input_mask=input_mask,
          segment_ids=segment_ids,
          start_pos=start_pos,
          end_pos=end_pos
      )
      output_fn(feature)

      unique_id += 1


In [52]:
# 10번째~12번째 단어가 정답 범위. subtokenize 이후에는 subtoken 단위로 변경해줘야.
examples[17]

qas_id: 56c1f01def6e394741000045_001, question_text: Orteronel was developed for treatment of which cancer?, doc_tokens: ['Efficacy and safety of second-line agents for treatment of metastatic castration-resistant prostate cancer progressing after docetaxel. A systematic review and meta-analysis. We performed a systematic review of the literature to assess the efficacy and the safety of second-line agents targeting metastatic castration-resistant prostate cancer (mCRPC) that has progressed after docetaxel. Pooled-analysis was also performed, to assess the effectiveness of agents targeting the androgen axis via identical mechanisms of action (abiraterone acetate, orteronel). MATERIALS AND We included phase III randomized controlled trials that enrolled patients with mCRPC progressing during or after first-line docetaxel treatment. Trials were identified by electronic database searching. The primary outcome of the review was overall survival. Secondary outcomes were radiographic progress

In [95]:
# convert_examples_to_features([examples[17]], FullTokenizer('/content/drive/MyDrive/cose474_final/vocab_biobert_large_cased.txt', do_lower_case=False), 384, 128, 64, True, None)

all_doc_tokens : ['Eff', '##ica', '##cy', 'and', 'safety', 'of', 'second', '-', 'line', 'agents', 'for', 'treatment', 'of', 'meta', '##static', 'cast', '##ration', '-', 'resistant', 'prost', '##ate', 'cancer', 'progress', '##ing', 'after', 'do', '##ce', '##tax', '##el', '.', 'A', 'systematic', 'review', 'and', 'meta', '-', 'analysis', '.', 'We', 'performed', 'a', 'systematic', 'review', 'of', 'the', 'literature', 'to', 'assess', 'the', 'efficacy', 'and', 'the', 'safety', 'of', 'second', '-', 'line', 'agents', 'targeting', 'meta', '##static', 'cast', '##ration', '-', 'resistant', 'prost', '##ate', 'cancer', '(', 'm', '##CRP', '##C', ')', 'that', 'has', 'progressed', 'after', 'do', '##ce', '##tax', '##el', '.', 'Poole', '##d', '-', 'analysis', 'was', 'also', 'performed', ',', 'to', 'assess', 'the', 'effectiveness', 'of', 'agents', 'targeting', 'the', 'androgen', 'axis', 'via', 'identical', 'mechanisms', 'of', 'action', '(', 'abi', '##rate', '##rone', 'acetate', ',', 'ort', '##eron', '##e

TypeError: ignored