<a href="https://colab.research.google.com/github/VincentK1991/BERT_summarization_1/blob/master/notebook/pre_processing_raw_text_for_GPT2_summarizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Input Preprocessing

This notebook illustrates step-by-step how to prepare the training data for GPT2 model from raw text. 

This notebook consists of 4 main sections
1. import packages and cleaning rawtext dataset
2. keyword tagging and extraction
3. partition dataset to training and validation
4. tokenization and collating

At the end of this session, we will get a tensor file that will be used for training our GPT2 summarizer model. We will save the tensor file to .pt file so that we can use it later.

# 1. Import packages and cleaning raw text data

We will use Huggingface implementation of GPT2 and BERT. Huggingface also provides tokenization tools for the GPT2 and BERT as well. This is ver helpful. However, for our GPT2 summarization task, we need to add special tokens to label the start of text, the end, and whether text is summary or not.

## 1.1 Load Packages

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import timeit

import torch
print(torch.__version__,' pytorch version')
from torch.utils.data import DataLoader, TensorDataset, RandomSampler

!pip install transformers==2.6.0

1.8.0+cu101  pytorch version
Collecting transformers==2.6.0
[?25l  Downloading https://files.pythonhosted.org/packages/4c/a0/32e3a4501ef480f7ea01aac329a716132f32f7911ef1c2fac228acc57ca7/transformers-2.6.0-py3-none-any.whl (540kB)
[K     |████████████████████████████████| 542kB 7.9MB/s 
[?25hCollecting boto3
[?25l  Downloading https://files.pythonhosted.org/packages/0b/3e/1649fa2b98a71635f721b00fd45477f7e2ecb4d5416d768abfa992ba771c/boto3-1.17.29-py2.py3-none-any.whl (131kB)
[K     |████████████████████████████████| 133kB 16.8MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 15.6MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/f5/99/e0808cb947ba10f575839c43e8fafc9cc44e4a7a2c8f79c60db48220a577/sentencepiece-0.1.95-cp37-cp37m-manylinux2014_x86_64.whl (1.2MB)


In [None]:
import transformers
from transformers import GPT2Tokenizer

In [None]:
assert transformers.__version__ == "2.6.0",' make sure transformers version is 2.6.0'
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
special_tokens = {'bos_token':'<|startoftext|>','eos_token':'<|endoftext|>','pad_token':'<pad>','additional_special_tokens':['<|keyword|>','<|summarize|>']}
tokenizer.add_special_tokens(special_tokens)
assert len(tokenizer) == 50261, 'total length of vocabis 50261'
assert tokenizer.bos_token_id == 50257, 'beginning of sentence  token is 50257'
assert tokenizer.eos_token_id == 50256, 'end of sentence token is 50256'
assert tokenizer.pad_token_id == 50258, '<pad> token is 50258'  #token for <pad>, len of all tokens in the tokenizer
assert tokenizer.additional_special_tokens_ids[0] == 50259,'keyword_token is 50259' #token for <|keyword|>
assert tokenizer.additional_special_tokens_ids[1] == 50260 ,'summary_token is 50260 ' #token for <|summarize|>

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




## 1.2 clean up the the dataset (raw text)

- Perform initial cleaning, remove NaN, etc.
- tag keywords using either NLTK or BERT (pre-trained on token extraction task)
- extract keywords

In [None]:
df = pd.read_csv("input/metadata.csv")
df = df[["title","sha","abstract"]]

  interactivity=interactivity, compiler=compiler, result=result)


In [None]:
df = df.dropna(subset=["abstract"]).reset_index(drop=True)
list_POS = ["FW","JJ","NN","NNS","NNP","VB","VBD","VBG","VBN","VBZ","VBP"]
df = df[:3000]

In [None]:
df

Unnamed: 0,title,sha,abstract
0,Clinical features of culture-proven Mycoplasma...,d1aafb70c066a2068b02786f8929fd9c900897fb,OBJECTIVE: This retrospective chart review des...
1,Nitric oxide: a pro-inflammatory mediator in l...,6b0567729c2143a66d737eb0a2f63f2dce2e5a7d,Inflammatory diseases of the respiratory tract...
2,Surfactant protein-D and pulmonary host defense,06ced00a5fc04215949aa72528f2eeaae1d58927,Surfactant protein-D (SP-D) participates in th...
3,Role of endothelin-1 in lung disease,348055649b6b8cf2b9a376498df9bf41f7123605,Endothelin-1 (ET-1) is a 21 amino acid peptide...
4,Gene expression in epithelial cells in respons...,5f48792a5fa08bed9f56016f4981ae2ca6031b32,Respiratory syncytial virus (RSV) and pneumoni...
...,...,...,...
2995,Risk factors for hematemesis in Hoima and Buli...,820acf55c4e52411482f6eb44360ffa35288b89a,"INTRODUCTION: On 17 September 2015, Buliisa Di..."
2996,Proteomic fingerprinting in HIV/HCV co-infecti...,0a01f5cf1c5cdc2711bcef74315dc54a6e143df0,BACKGROUND: Hepatic complications of hepatitis...
2997,Immune regulation of the unfolded protein resp...,6a80b22e84d2692545c6f11d7cb4c96602a25c39,Protein folding in the endoplasmic reticulum (...
2998,Preprints: An underutilized mechanism to accel...,6eb282b0887ed1a7ab59123919bbadbf9ce6ed55,"In an Essay, Michael Johansson and colleagues ..."


# 2.keyword extraction

there are many approaches to do keyword extractions. In this case, I will show 2 approaches
- NLTK part of speech tagging
- using BERT pre-trained on token classification task

## 2.1 NLTK tagging and extraction

In [None]:
import nltk
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
def tag_pull_abstract(df, list_POS):
    """ return list of keyword list
    input: pandas dataframe
                    list of part of speech tag (in order to generate keyword)
    ourput: List(List(keyword string))"""
    list_tokenized = df['abstract'].apply(
        lambda x: nltk.pos_tag(nltk.word_tokenize(x))).values
    list_answer = [[item[0] for item in row if item[1] in list_POS]
                   for row in list_tokenized]
    #list_answer = list(map(lambda x: ' '.join(x), list_answer))
    return list_answer

In [None]:
df['keyword_POS'] = tag_pull_abstract(df, list_POS)
df = df.dropna(subset = ["keyword_POS"]).reset_index(drop=True)
df['keyword_POS_str'] = df['keyword_POS'].apply( lambda x: ' '.join(x))

In [None]:
df["keyword_POS_str"].tail(5)

2995    INTRODUCTION September Buliisa District Health...
2996    BACKGROUND Hepatic complications hepatitis C v...
2997    Protein folding endoplasmic reticulum ER is su...
2998    Essay Michael Johansson colleagues advocate po...
2999    Industry-driven voluntary disease control prog...
Name: keyword_POS_str, dtype: object

## 2.2 pre-trained BERT for token extraction

In this case, we are using BERT token classification where the classification task is to classify the part of speech of tokens. To fine-tune BERT for this task, please refer to a separate notebook on BERT POS tagging. 

In this section we will use BERT previously fine-tuned on this task. The weight and tokenizers is included in the resources folder.

You are invited to train this task on your own. It is pretty easy and not time-consuimg. It took me about 15-30 minutes on Colab GPU.

In [None]:
from transformers import BertForTokenClassification, AdamW,BertTokenizer
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [None]:
BERT_model = BertForTokenClassification.from_pretrained("resources/POS_tagging")
tokenizer2 = BertTokenizer.from_pretrained("resources/POS_tagging")

In [None]:
with open('resources/POS_tagging/POS2idx.json', 'r') as fp:
    POS2idx = json.load(fp)

In [None]:
list_POS = ["FW","JJ","NN","NNS","NNP","VB","VBD","VBG","VBN","VBZ","VBP"]
POS_values = list(POS2idx.keys())

In [None]:
list_abs_len = []
for i in df['abstract']:
  if type(i) == str:
    j = tokenizer.encode(i)
    list_abs_len.append(len(j))

Token indices sequence length is longer than the specified maximum sequence length for this model (9964 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (24769 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
BERT_model = BERT_model.to(device)

In [None]:
def find_keywords(test_sentence):

  if type(test_sentence) != str:
    return ['']

  tokenized_sentence = tokenizer2.encode(test_sentence)
  input_ids = torch.tensor([tokenized_sentence[:510]]).to(device)

  with torch.no_grad():
      output = BERT_model(input_ids)
  label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)

  list_keywords = []

  tokens = tokenizer2.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])
  new_tokens, new_labels = [], []
  for token, label_idx in zip(tokens, label_indices[0]):
      if token.startswith("##"):
          new_tokens[-1] = new_tokens[-1] + token[2:]
      else:
          new_labels.append(POS_values[label_idx])
          new_tokens.append(token)
  for token, label in zip(new_tokens, new_labels):
      if label in list_POS:
        list_keywords.append(token)
  return list_keywords

In [None]:
list_all_keywords = []
counter = 0
print_every = 1000
start = timeit.default_timer()
for i in df['abstract']:
  list_all_keywords.append(find_keywords(i))
  counter += 1
  if counter % print_every == 0:
    stop = timeit.default_timer()
    print('1000 iterations takes {:.3f}'.format(stop - start),' sec')
    start = timeit.default_timer()

In [None]:
df['keyword_POS'] = list_all_keywords
df = df.dropna(subset = ["keyword_POS"]).reset_index(drop=True)
df['keyword_POS_str'] = df['keyword_POS'].apply( lambda x: ' '.join(x))

In [None]:
df.tail(5)

Unnamed: 0,title,sha,abstract,keyword_POS,keyword_POS_str
2995,Risk factors for hematemesis in Hoima and Buli...,820acf55c4e52411482f6eb44360ffa35288b89a,"INTRODUCTION: On 17 September 2015, Buliisa Di...","[[CLS], introduction, september, buliisa, dist...",[CLS] introduction september buliisa district ...
2996,Proteomic fingerprinting in HIV/HCV co-infecti...,0a01f5cf1c5cdc2711bcef74315dc54a6e143df0,BACKGROUND: Hepatic complications of hepatitis...,"[background, hepatic, complications, hepatitis...",background hepatic complications hepatitis c v...
2997,Immune regulation of the unfolded protein resp...,6a80b22e84d2692545c6f11d7cb4c96602a25c39,Protein folding in the endoplasmic reticulum (...,"[[CLS], protein, folding, endoplasmic, reticul...",[CLS] protein folding endoplasmic reticulum er...
2998,Preprints: An underutilized mechanism to accel...,6eb282b0887ed1a7ab59123919bbadbf9ce6ed55,"In an Essay, Michael Johansson and colleagues ...","[[CLS], essay, michael, johansson, colleagues,...",[CLS] essay michael johansson colleagues advoc...
2999,Time-series analysis for porcine reproductive ...,c0e284c9dd2ca500ae5f10d04d0639d0f8779a16,Industry-driven voluntary disease control prog...,"[industry, -, driven, voluntary, disease, cont...",industry - driven voluntary disease control pr...


# 3. partition dataset to training and validation

I will use sklearn train_test_split function to do partitioning

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
frame_train, frame_dev, _, _ = train_test_split(df, df,test_size =0.1, random_state = 2021)

# 4. Tokeninzing and collating 

- perform collation tasks which are
    - create pair of keyword and labels (summary)
    - create decoy pairs (as part of the sentence selection training)
    - create segment tokens
    - add start and end tokens
    - add padding tokens up to 1052 tokens
    - shuffling
- save the tensor file

In [None]:
frame_dev = frame_dev.reset_index()
frame_dev.head(5)

Unnamed: 0,index,title,sha,abstract,keyword_POS,keyword_POS_str
0,1271,The Natural History of Influenza Infection in ...,5f88b4d6e65a19b84991526dab51391ff4acf08f,Introduction. Medical advances have led to an ...,"[Introduction, Medical, advances, have, led, i...",Introduction Medical advances have led increas...
1,2262,Active Targeted Drug Delivery for Microbes Usi...,2891c813ce4dd156d5f56736db5a10a2065c9167,Although vaccines and antibiotics could kill o...,"[vaccines, antibiotics, kill, inhibit, microbe...",vaccines antibiotics kill inhibit microbes man...
2,2769,Screening of FDA-Approved Drugs for Inhibitors...,1bd2f6497996fc0fccd8dffd7f84846d3d36f964,"Japanese encephalitis virus (JEV), an arthropo...","[Japanese, encephalitis, virus, JEV, arthropod...",Japanese encephalitis virus JEV arthropod-born...
3,2597,Multiple Immunosuppressive Effects of CpG-c41 ...,212b5e2ca9c78f0864f5a8540b36b3c18ad06d27,A growing body of literature suggests that mos...,"[growing, body, literature, suggests, chronic,...",growing body literature suggests chronic autoi...
4,1993,Deletion of Dystrophin In-Frame Exon 5 Leads t...,593333395be3bd94387b4e273cc8ed13b398d5c0,Duchenne and Becker muscular dystrophy severit...,"[Duchenne, Becker, muscular, dystrophy, severi...",Duchenne Becker muscular dystrophy severity de...


## 4.1 Helper Function for collation

These 8 functions are
1. read the dataframe, 

  - Note here that the 0th element in list always the correct pairs 

2. tokenize the input and return a list of tokens
3. create a segment tokens (either keyword, summary, or padding segment)
4. write the label for the lm head (only on the correct pair; i.e. the 0th element of the list; else where is masked with [-100] token to prevent the model from computing cross entropy loss).
5. return the last token before padding (i.e. the <|endoftext|>) (this is recognized by the mc head for the multiple choice loss)
6. return [1,0,0,0] because the correct pair is always the 0th element of the list

7. shuffle the tuple. After this the correct pair can be any element. return numpy array

8. create a tensor object from the numpy array

In [None]:
def load_words(df, num,with_title = False):
  """import dataframe with number of what sample to choose,
  return a keyword (together with title or not) as strings
  and abstract (gold label for summarization).
  and 3 distractors. all as a tuple of 5 strings"""
  arr_distract = np.random.randint(len(df), size=3)
  keyword = df['keyword_POS_str'][num]
  if with_title:
    title = df['title'][num]
    keyword = title + keyword
  abstract = df['abstract'][num]
  distract1 = df['abstract'][arr_distract[0]]
  distract2 = df['abstract'][arr_distract[1]]
  distract3 = df['abstract'][arr_distract[2]]

  return (keyword,abstract,distract1,distract2,distract3)

In [None]:
key_batch = load_words(frame_dev,30)

In [None]:
key_batch

('TP53 gene is known “ guardian genome ” plays vital role regulating cell cycle cell proliferation DNA damage repair initiation programmed cell death suppressing tumor growth Non uniform usage synonymous codons specific amino acid translation protein known codon usage bias CUB is unique property genome shows species specific deviation Analysis codon usage bias compositional dynamics coding sequences has contributed understanding molecular mechanism evolution particular gene study complete nucleotide coding sequences TP53 gene different mammalian species were used CUB analysis results showed codon usage patterns TP53 gene different mammalian species has been influenced GC bias GC moderate bias exists codon usage TP53 gene observed nature has favored represented codon CTG leucine amino acid selected ATA codon isoleucine TP53 gene mammalian species course evolution',
 'TP53 gene is known as the “guardian of the genome” as it plays a vital role in regulating cell cycle, cell proliferation,

In [None]:
def write_input_ids(word_batch,max_len=1024):
  """return list of input tokens"""
  key, abstract, dis1,dis2,dis3 = word_batch

  input_true = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ abstract + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis1 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis1 + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis2 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis2 + ' <|endoftext|>',max_length = tokenizer.max_len)
  input_dis3 = tokenizer.encode('<|startoftext|> ' + key + ' <|summarize|> '+ dis3 + ' <|endoftext|>',max_length = tokenizer.max_len)
  
  if max_len == None:
    max_len = max(len(input_true),len(input_dis1),len(input_dis2),len(input_dis3))
  for i in [input_true,input_dis1,input_dis2,input_dis3]:
    while len(i) < max_len:
      i.append(tokenizer.pad_token_id)
  list_input_token = [input_true,input_dis1,input_dis2,input_dis3]
  return list_input_token

In [None]:
def write_token_type_labels(list_input_ids,max_len=1024):
  list_segment = []
  for item in list_input_ids:
    try:
      item.index(tokenizer.eos_token_id)
    except:
      item[-1] = tokenizer.eos_token_id
    num_seg_a = item.index(tokenizer.additional_special_tokens_ids[1]) + 1
    end_index = item.index(tokenizer.eos_token_id)
    num_seg_b = end_index - num_seg_a + 1
    num_pad = max_len - end_index - 1
    segment_ids = [tokenizer.additional_special_tokens_ids[0]]*num_seg_a + [tokenizer.additional_special_tokens_ids[1]]*num_seg_b + [tokenizer.pad_token_id]*num_pad
    list_segment.append(segment_ids)
  return list_segment

In [None]:
def write_lm_labels(list_input_ids,list_type_labels):
  list_lm_label = []
  is_true_label = True
  for input_tokens,segments in zip(list_input_ids,list_type_labels):
    if is_true_label:
      is_true_label = False
      temp_list = []
      for token,segment in zip(input_tokens,segments):
        if segment == tokenizer.additional_special_tokens_ids[1]:
          temp_list.append(token)
        else:
          temp_list.append(-100)
      list_lm_label.append(temp_list)
    else:
      temp_list = [-100]*len(input_tokens)
      list_lm_label.append(temp_list)
  return list_lm_label

In [None]:
def write_last_token(list_input_ids):
  list_mc_token = []
  for item in list_input_ids:
    list_mc_token.append(item.index(tokenizer.eos_token_id))
  return list_mc_token

In [None]:
def write_mc_label():
  return [1,0,0,0]

In [None]:
def shuffle_batch(list_input_ids,list_type_labels,list_last_tokens,list_lm_labels,list_mc_labels):
  array_input_token = np.array(list_input_ids)
  array_segment = np.array(list_type_labels)
  array_mc_token = np.array(list_last_tokens)
  array_lm_label = np.array(list_lm_labels)
  array_mc_label = np.array(list_mc_labels)

  randomize = np.arange(4)
  np.random.shuffle(randomize)

  array_input_token = array_input_token[randomize]
  array_segment = array_segment[randomize]
  array_mc_token = array_mc_token[randomize]
  array_lm_label = array_lm_label[randomize]
  array_mc_label = array_mc_label[randomize]

  return (array_input_token,array_segment,array_mc_token,array_lm_label,array_mc_label)

In [None]:
def write_torch_tensor(np_batch):
  torch_input_token = torch.tensor(np_batch[0], dtype=torch.long).unsqueeze(0)
  torch_segment = torch.tensor(np_batch[1],dtype=torch.long).unsqueeze(0)
  torch_mc_token = torch.tensor(np_batch[2],dtype=torch.long).unsqueeze(0)
  torch_lm_label = torch.tensor(np_batch[3],dtype=torch.long).unsqueeze(0)
  torch_mc_label = torch.tensor([np.argmax(np_batch[4])],dtype=torch.long).unsqueeze(0)
  return (torch_input_token,torch_segment,torch_mc_token,torch_lm_label,torch_mc_label)

## 4.2 write a wrapper function to execute the helper functions in order

execute all 8 helper functions

concatenate the temporary tensor object every 1000 items.

This is done for the sake of time efficiency. concat tensor on a very long tensor takes a bit of time.

In [None]:
def execute_all_function(df):
  exist_temp_tensor = False
  exist_big_tensor = False
  start = timeit.default_timer()
  for num in range(len(df)):
    #print(num)
    word_tuple = load_words(df, num)
    if type(word_tuple[0]) != str or type(word_tuple[1]) != str:
      continue
    
    list_input_ids = write_input_ids(word_tuple)
    list_type_labels = write_token_type_labels(list_input_ids)
    list_lm_labels = write_lm_labels(list_input_ids,list_type_labels)
    list_last_tokens = write_last_token(list_input_ids)
    list_mc_labels = write_mc_label()

    np_tuple = shuffle_batch(list_input_ids,list_type_labels,list_last_tokens,list_lm_labels,list_mc_labels)
    tensor_tuple = write_torch_tensor(np_tuple)
    
    if not exist_temp_tensor:
      temp_0 = tensor_tuple[0]
      temp_1 = tensor_tuple[1]
      temp_2 = tensor_tuple[2]
      temp_3 = tensor_tuple[3]
      temp_4 = tensor_tuple[4]
      exist_temp_tensor = True
    elif exist_temp_tensor:
      temp_0 = torch.cat((temp_0,tensor_tuple[0]),0)
      temp_1 = torch.cat((temp_1,tensor_tuple[1]),0)
      temp_2 = torch.cat((temp_2,tensor_tuple[2]),0)
      temp_3 = torch.cat((temp_3,tensor_tuple[3]),0)
      temp_4 = torch.cat((temp_4,tensor_tuple[4]),0)

    if num % 1000 == 0:
      if not exist_big_tensor:
        big_first_tensor = temp_0
        big_second_tensor = temp_1
        big_third_tensor = temp_2
        big_fourth_tensor = temp_3
        big_fifth_tensor = temp_4
        exist_temp_tensor = False
        exist_big_tensor = True
        del temp_0,temp_1,temp_2,temp_3,temp_4
      else:
        big_first_tensor = torch.cat((big_first_tensor,temp_0),0)
        big_second_tensor = torch.cat((big_second_tensor,temp_1),0)
        big_third_tensor = torch.cat((big_third_tensor,temp_2),0)
        big_fourth_tensor = torch.cat((big_fourth_tensor,temp_3),0)
        big_fifth_tensor = torch.cat((big_fifth_tensor,temp_4),0)
        exist_temp_tensor = False
        del temp_0,temp_1,temp_2,temp_3,temp_4
      
      stop = timeit.default_timer()
      print('iterations ',num,' takes ', stop - start,' sec')
      start = timeit.default_timer()
  
  big_first_tensor = torch.cat((big_first_tensor,temp_0),0)
  big_second_tensor = torch.cat((big_second_tensor,temp_1),0)
  big_third_tensor = torch.cat((big_third_tensor,temp_2),0)
  big_fourth_tensor = torch.cat((big_fourth_tensor,temp_3),0)
  big_fifth_tensor = torch.cat((big_fifth_tensor,temp_4),0)
  return big_first_tensor, big_second_tensor, big_third_tensor,big_fourth_tensor,big_fifth_tensor

In [None]:
tensor_1,tensor_2,tensor_3,tensor_4,tensor_5 = execute_all_function(frame_dev)

iterations  0  takes  0.0665676399999029  sec


1000 iterations take ~25 second
that means 32k training set would take about 14 mins !



In [None]:
# create a tensor dataset object
tensor_dataset = TensorDataset(tensor_1,tensor_2,tensor_3,tensor_4,tensor_5)

## 4.3 save the tensor file for later use

In [None]:
# save the tensor object to load later when training
torch.save(tensor_dataset, 'resources/tensor/torch_devFile_1_May07_2020.pt')

## check your result by printing statement

Make sure the labels are all correct and lined up

In [None]:
tensor_5.numpy()

In [None]:
item = 151
print(tensor_1[item])
print(tensor_2[item])
print(tensor_3[item])
print(tensor_4[item])
print(tensor_5[item])

tensor([[50257,  3486,  2767,  ..., 50258, 50258, 50258],
        [50257,  3486,  2767,  ..., 50258, 50258, 50258],
        [50257,  3486,  2767,  ..., 50258, 50258, 50258],
        [50257,  3486,  2767,  ..., 50258, 50258, 50258]])
tensor([[50259, 50259, 50259,  ..., 50258, 50258, 50258],
        [50259, 50259, 50259,  ..., 50258, 50258, 50258],
        [50259, 50259, 50259,  ..., 50258, 50258, 50258],
        [50259, 50259, 50259,  ..., 50258, 50258, 50258]])
tensor([602, 508, 644, 617])
tensor([[-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]])
tensor([2])


In [None]:
print('{:>2}{:>10}{:>10}{:>10}{:>10}{:>20}{:>10}{:>20}{:>10}'.format('count','input','decoded input','input','decoded input','input','decoded input','input','decoded input'))
count = 0
for i,j,k,m in zip(tensor_1[item][1],tensor_1[item][2],tensor_2[item][2],tensor_4[item][2]):
  i = int(i)
  j = int(j)
  k = int(k)
  m = int(m)
  if i == -100:
    decode_i = 'masked'
  else:
    decode_i = tokenizer.decode(i)
  if j == -100:
    decode_j = 'masked'
  else:
    decode_j = tokenizer.decode(j)
  if k == -100:
    decode_k = 'masked'
  else:
    decode_k = tokenizer.decode(k)
  if m == -100:
    decode_m = 'masked'
  else:
    decode_m = tokenizer.decode(m)
  #print(i,j)
  print('{:>2}{:>10}{:>10}{:>10}{:>10}{:>20}{:>10}{:>20}{:>10}'.format(count,i,decode_i,j,decode_j,k,decode_k,m,decode_m))
  count += 1

count     inputdecoded input     inputdecoded input               inputdecoded input               inputdecoded input
 0     50257<|startoftext|>     50257<|startoftext|>               50259<|keyword|>                -100    masked
 1      3486        AP      3486        AP               50259<|keyword|>                -100    masked
 2      2767        ET      2767        ET               50259<|keyword|>                -100    masked
 3      1847        AL      1847        AL               50259<|keyword|>                -100    masked
 4        32         A        32         A               50259<|keyword|>                -100    masked
 5        17         2        17         2               50259<|keyword|>                -100    masked
 6      3486        AP      3486        AP               50259<|keyword|>                -100    masked
 7        17         2        17         2               50259<|keyword|>                -100    masked
 8     10812     genes     10812     gen