<a href="https://colab.research.google.com/github/VincentK1991/BERT_summarization_1/blob/master/notebook/pre-processing-text-for-GPT2-fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load Packages

In [0]:
%cd '/content/drive/My Drive/Colab Notebooks/GPT-2/summarization_preprocessing'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import timeit

import torch
print(torch.__version__,' pytorch version')
from torch.utils.data import DataLoader, TensorDataset, RandomSampler

!pip install transformers==2.6.0

In [0]:
import transformers
print(transformers.__version__,' make sure transformers version is 2.6.0')
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
special_tokens = {'bos_token':'<|startoftext|>','eos_token':'<|endoftext|>','pad_token':'<pad>','additional_special_tokens':['<|keyword|>','<|summarize|>']}
tokenizer.add_special_tokens(special_tokens)
print(len(tokenizer), 'total length of vocab')
print(tokenizer.bos_token_id, 'bos_token')
print(tokenizer.eos_token_id, 'eos_token')
print(tokenizer.pad_token_id, 'pad_token')  #token for <pad>, len of all tokens in the tokenizer
print(tokenizer.additional_special_tokens_ids[0], 'keyword_token') #token for <|keyword|>
print(tokenizer.additional_special_tokens_ids[1], 'summary_token') #token for <|summarize|>

2.6.0  make sure transformers version is 2.6.0


HBox(children=(IntProgress(value=0, description='Downloading', max=1042301, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Downloading', max=456318, style=ProgressStyle(description_wid…


50261 total length of vocab
50257 bos_token
50256 eos_token
50258 pad_token
50259 keyword_token
50260 summary_token


# Load the dataset (before processing)

reset index before running pre-processing so the columns will match up.

In [0]:
frame_dev = pd.read_csv('Copy_Copy_updated_COVID19_devdata_Apr25_2020.csv',index_col=0)

In [0]:
frame_dev = frame_dev.reset_index()
frame_dev.head(5)

Unnamed: 0,index,sha,title,abstract,keyword_POS,keyword_NER,full_text_file,keyword_POS_str
0,18799,d3f172114638130e814f81f946359556da900e0c,Fieber nach Tropenaufenthalt,Bei Fieber nach einem Tropenaufenthalt kommen ...,"['Bei', 'Fieber', 'einem', 'Tropenaufenthalt',...","['Bei', 'nach', 'einem', 'kommen', 'viele', 'm...",custom_license,Bei Fieber einem Tropenaufenthalt kommen Erkra...
1,30112,7ac70544937971da0d91ef466686e9dc3fdd461b,Pathogenesis of Virus-Induced Demyelination,Publisher Summary Demyelination is a component...,"['Publisher', 'Summary', 'Demyelination', 'is'...","['Publisher', 'is', 'a', 'component', 'of', 's...",custom_license,Publisher Summary Demyelination is component d...
2,5504,b2d2050a4b62e13a5c78e23190f61a7895ca33e2,Analysis of synonymous codon usage patterns in...,Synonymous codon usage bias (CUB) is a defined...,"['codon', 'usage', 'bias', 'CUB', 'is', 'usage...","['Synonymous', 'codon', 'usage', 'bias', 'CUB'...",comm_use_subset,codon usage bias CUB is usage codons amino aci...
3,13902,e376a4c9ab21f569e4c61dacc4e98cf04c05e848,BrainStem Encephalitis Associated with Chandip...,Clinical data of 104 hospitalized children dur...,"['data', 'children', 'epidemic', 'encephalitis...","['Clinical', 'data', 'of', '104', 'hospitalize...",custom_license,data children epidemic encephalitis Andhra Pra...
4,23896,483e7cca05d9d7f861d2b6eddca54b686858122d,Evaluation of echinacea for the prevention and...,Summary Echinacea is one of the most commonly ...,"['Summary', 'Echinacea', 'is', 'herbal', 'prod...","['Summary', 'is', 'one', 'of', 'the', 'most', ...",custom_license,Summary Echinacea is herbal products controver...


# Helper Function

These 8 functions are
1. read the dataframe, 

  - Note here that the 0th element in list always the correct pairs 

2. tokenize the input and return a list of tokens
3. create a segment tokens (either keyword, summary, or padding segment)
4. write the label for the lm head (only on the correct pair; i.e. the 0th element of the list; else where is masked with [-100] token to prevent the model from computing cross entropy loss).
5. return the last token before padding (i.e. the <|endoftext|>) (this is recognized by the mc head for the multiple choice loss)
6. return [1,0,0,0] because the correct pair is always the 0th element of the list

7. shuffle the tuple. After this the correct pair can be any element. return numpy array

8. create a tensor object from the numpy array

In [0]:
def load_words(df, num,with_title = False):
  """import dataframe with number of what sample to choose,
  return a keyword (together with title or not) as strings
  and abstract (gold label for summarization).
  and 3 distractors. all as a tuple of 5 strings"""
  arr_distract = np.random.randint(len(df), size=3)
  keyword = df['keyword_POS_str'][num]
  if with_title:
    title = df['title'][num]
    keyword = title + keyword
  abstract = df['abstract'][num]
  distract1 = df['abstract'][arr_distract[0]]
  distract2 = df['abstract'][arr_distract[1]]
  distract3 = df['abstract'][arr_distract[2]]

  return (keyword,abstract,distract1,distract2,distract3)

In [0]:
key_batch = load_words(frame_dev,30)

In [0]:
key_batch

(nan,
 'Unknown',
 'The intracellular expression of antibodies or antibody fragments (intrabodies) in different compartments of mammalian cells allows to block or modulate the function of endogenous molecules. Intrabodies can alter protein folding, protein-protein, protein-DNA, protein-RNA interactions and protein modification. They can induce a phenotypic knockout and work as neutralizing agents by direct binding to the target antigen, by diverting its intracellular traffic or by inhibiting its association with binding partners. They have been largely employed as research tools and are emerging as therapeutic molecules for the treatment of human diseases as viral pathologies, cancer and misfolding diseases. The fast growing bio-market of recombinant antibodies provides intrabodies with enhanced binding specificity, stability and solubility, together with lower immunogenicity, for their use in therapy. This chapter describes the crucial aspects required to express intrabodies in differ

In [0]:
def write_input_ids(word_batch,max_len=1024):
  """return list of input tokens"""
  key, abstract, dis1,dis2,dis3 = word_batch

  input_true = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ abstract + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis1 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis1 + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis2 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis2 + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis3 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis3 + ' <|endoftext|>',max_length = tokenizer.max_len)
  
  if max_len == None:
    max_len = max(len(input_true),len(input_dis1),len(input_dis2),len(input_dis3))
  for i in [input_true,input_dis1,input_dis2,input_dis3]:
    while len(i) < max_len:
      i.append(tokenizer.pad_token_id)
  list_input_token = [input_true,input_dis1,input_dis2,input_dis3]
  return list_input_token

In [0]:
def write_token_type_labels(list_input_ids,max_len=1024):
  list_segment = []
  for item in list_input_ids:
    try:
      item.index(tokenizer.eos_token_id)
    except:
      item[-1] = tokenizer.eos_token_id
    num_seg_a = item.index(tokenizer.additional_special_tokens_ids[1]) + 1
    end_index = item.index(tokenizer.eos_token_id)
    num_seg_b = end_index - num_seg_a + 1
    num_pad = max_len - end_index - 1
    segment_ids = [tokenizer.additional_special_tokens_ids[0]]*num_seg_a + [tokenizer.additional_special_tokens_ids[1]]*num_seg_b + [tokenizer.pad_token_id]*num_pad
    list_segment.append(segment_ids)
  return list_segment

In [0]:
def write_lm_labels(list_input_ids,list_type_labels):
  list_lm_label = []
  is_true_label = True
  for input_tokens,segments in zip(list_input_ids,list_type_labels):
    if is_true_label:
      is_true_label = False
      temp_list = []
      for token,segment in zip(input_tokens,segments):
        if segment == tokenizer.additional_special_tokens_ids[1]:
          temp_list.append(token)
        else:
          temp_list.append(-100)
      list_lm_label.append(temp_list)
    else:
      temp_list = [-100]*len(input_tokens)
      list_lm_label.append(temp_list)
  return list_lm_label

In [0]:
def write_last_token(list_input_ids):
  list_mc_token = []
  for item in list_input_ids:
    list_mc_token.append(item.index(tokenizer.eos_token_id))
  return list_mc_token

In [0]:
def write_mc_label():
  return [1,0,0,0]

In [0]:
def shuffle_batch(list_input_ids,list_type_labels,list_last_tokens,list_lm_labels,list_mc_labels):
  array_input_token = np.array(list_input_ids)
  array_segment = np.array(list_type_labels)
  array_mc_token = np.array(list_last_tokens)
  array_lm_label = np.array(list_lm_labels)
  array_mc_label = np.array(list_mc_labels)

  randomize = np.arange(4)
  np.random.shuffle(randomize)

  array_input_token = array_input_token[randomize]
  array_segment = array_segment[randomize]
  array_mc_token = array_mc_token[randomize]
  array_lm_label = array_lm_label[randomize]
  array_mc_label = array_mc_label[randomize]

  return (array_input_token,array_segment,array_mc_token,array_lm_label,array_mc_label)

In [0]:
def write_torch_tensor(np_batch):
  torch_input_token = torch.tensor(np_batch[0], dtype=torch.long).unsqueeze(0)
  torch_segment = torch.tensor(np_batch[1],dtype=torch.long).unsqueeze(0)
  torch_mc_token = torch.tensor(np_batch[2],dtype=torch.long).unsqueeze(0)
  torch_lm_label = torch.tensor(np_batch[3],dtype=torch.long).unsqueeze(0)
  torch_mc_label = torch.tensor([np.argmax(np_batch[4])],dtype=torch.long).unsqueeze(0)
  return (torch_input_token,torch_segment,torch_mc_token,torch_lm_label,torch_mc_label)

# write a big function

execute all 8 helper functions

concatenate the temporary tensor object every 1000 items.

This is done for the sake of time efficiency. concat tensor on a very long tensor takes a bit of time.

In [0]:
def execute_all_function(df):
  exist_temp_tensor = False
  exist_big_tensor = False
  start = timeit.default_timer()
  for num in range(len(df)):
    #print(num)
    word_tuple = load_words(df, num)
    if type(word_tuple[0]) != str or type(word_tuple[1]) != str:
      continue
    
    list_input_ids = write_input_ids(word_tuple)
    list_type_labels = write_token_type_labels(list_input_ids)
    list_lm_labels = write_lm_labels(list_input_ids,list_type_labels)
    list_last_tokens = write_last_token(list_input_ids)
    list_mc_labels = write_mc_label()

    np_tuple = shuffle_batch(list_input_ids,list_type_labels,list_last_tokens,list_lm_labels,list_mc_labels)
    tensor_tuple = write_torch_tensor(np_tuple)
    
    if not exist_temp_tensor:
      temp_0 = tensor_tuple[0]
      temp_1 = tensor_tuple[1]
      temp_2 = tensor_tuple[2]
      temp_3 = tensor_tuple[3]
      temp_4 = tensor_tuple[4]
      exist_temp_tensor = True
    elif exist_temp_tensor:
      temp_0 = torch.cat((temp_0,tensor_tuple[0]),0)
      temp_1 = torch.cat((temp_1,tensor_tuple[1]),0)
      temp_2 = torch.cat((temp_2,tensor_tuple[2]),0)
      temp_3 = torch.cat((temp_3,tensor_tuple[3]),0)
      temp_4 = torch.cat((temp_4,tensor_tuple[4]),0)

    if num % 1000 == 0:
      if not exist_big_tensor:
        big_first_tensor = temp_0
        big_second_tensor = temp_1
        big_third_tensor = temp_2
        big_fourth_tensor = temp_3
        big_fifth_tensor = temp_4
        exist_temp_tensor = False
        exist_big_tensor = True
        del temp_0,temp_1,temp_2,temp_3,temp_4
      else:
        big_first_tensor = torch.cat((big_first_tensor,temp_0),0)
        big_second_tensor = torch.cat((big_second_tensor,temp_1),0)
        big_third_tensor = torch.cat((big_third_tensor,temp_2),0)
        big_fourth_tensor = torch.cat((big_fourth_tensor,temp_3),0)
        big_fifth_tensor = torch.cat((big_fifth_tensor,temp_4),0)
        exist_temp_tensor = False
        del temp_0,temp_1,temp_2,temp_3,temp_4
      
      stop = timeit.default_timer()
      print('iterations ',num,' takes ', stop - start,' sec')
      start = timeit.default_timer()
  
  big_first_tensor = torch.cat((big_first_tensor,temp_0),0)
  big_second_tensor = torch.cat((big_second_tensor,temp_1),0)
  big_third_tensor = torch.cat((big_third_tensor,temp_2),0)
  big_fourth_tensor = torch.cat((big_fourth_tensor,temp_3),0)
  big_fifth_tensor = torch.cat((big_fifth_tensor,temp_4),0)
  return big_first_tensor, big_second_tensor, big_third_tensor,big_fourth_tensor,big_fifth_tensor

In [0]:
tensor_1,tensor_2,tensor_3,tensor_4,tensor_5 = execute_all_function(frame_dev)

iterations  0  takes  0.021691910999834363  sec
iterations  1000  takes  25.79339128399988  sec
iterations  2000  takes  25.48767054900054  sec
iterations  3000  takes  24.083790039000633  sec


1000 iterations take ~25 second

that means 32k training set would take about 14 mins !

In [0]:
# create a tensor dataset object
tensor_dataset = TensorDataset(tensor_1,tensor_2,tensor_3,tensor_4,tensor_5)

In [0]:
# save the tensor object to load later when training
torch.save(tensor_dataset, 'torch_devFile_1_May07_2020.pt')

# check your result by printing statement

Make sure the labels are all correct and lined up

In [0]:
tensor_5.numpy()

In [0]:
item = 1515
print(tensor_1[item])
print(tensor_2[item])
print(tensor_3[item])
print(tensor_4[item])
print(tensor_5[item])

In [0]:
print('{:>2}{:>10}{:>10}{:>10}{:>10}{:>20}{:>10}{:>20}{:>10}'.format('count','input','decoded input','input','decoded input','input','decoded input','input','decoded input'))
count = 0
for i,j,k,m in zip(tensor_1[item][1],tensor_1[item][2],tensor_2[item][2],tensor_4[item][2]):
  i = int(i)
  j = int(j)
  k = int(k)
  m = int(m)
  if i == -100:
    decode_i = 'masked'
  else:
    decode_i = tokenizer.decode(i)
  if j == -100:
    decode_j = 'masked'
  else:
    decode_j = tokenizer.decode(j)
  if k == -100:
    decode_k = 'masked'
  else:
    decode_k = tokenizer.decode(k)
  if m == -100:
    decode_m = 'masked'
  else:
    decode_m = tokenizer.decode(m)
  #print(i,j)
  print('{:>2}{:>10}{:>10}{:>10}{:>10}{:>20}{:>10}{:>20}{:>10}'.format(count,i,decode_i,j,decode_j,k,decode_k,m,decode_m))
  count += 1