# DEPENDENCY

In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import json
import torch
import pandas as pd

from torch.utils.data import Dataset
from transformers import BertTokenizer

import config

I0709 05:11:23.177581 140153932433216 file_utils.py:39] PyTorch version 1.5.1+cu101 available.


In [4]:
dev_path = config.SSQA_DEV
print(dev_path)

/work/2020-IIS-NLU-internship/SSQA/data/SSQA_se_yes_no_benchmark_V0.8/dev_yes_no.json


In [5]:
PRETRAINED_MODEL_NAME = "bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)

I0709 05:11:24.844888 140153932433216 tokenization_utils_base.py:1254] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /root/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a8671500e518a8c29de8c00


In [7]:
with open(dev_path) as fo:
  dev_set = json.load(fo)

print(list(dev_set[0].keys()))
print(dev_set[0]["qtext"])
print(dev_set[0]["paragraphs"])
print(dev_set[0]["supporting_paragraphs_index"])
print(dev_set[0]["answer"])
len(dev_set)

['qid', 'qtext', 'paragraphs', 'supporting_paragraphs_index', 'answer']
透過參加學校自治組織，可以協助維護校園的秩序與安全。
['身為學校的一份子，我們可以透過參加學校自治組織，協助維護校園的秩序與安全、反映同學的意見與想法、為同學爭取權益等。學校常見的自治組織有：糾察隊、衛生隊、交通隊、學生自治會等，負責推動各項自治活動。', '圖書志工協助整理圖書館的書籍。', '交通隊維持同學上學、放學的行進安全。', '衛生隊檢查校園環境整潔情形。', '學生自治會透過會議，反映同學的意見。', '參加學校自治組織，我們要遵守組織的規範。例如：執行任務時服裝整齊、確實遵守服務時間、保持親切態度、公平對待每位同學，並虛心接受建議，作為全校同學的學習榜樣。', '同學不分性別，都可以加入學校的自治組織，學習做事的方法、增進處理事情的能力、累積服務的經驗，讓我們的學習生活更有意義。', '參加自治組織的同學，付出自己的時間與心力為大家服務，我們要主動配合他們，並適時表達對他們的感謝，有機會，我們也能加入自治組織，為更多同學服務。', '＊小叮嚀：我們可以注意學校自治組織招募成員的訊息，主動報名參加。']
[0]
對


1508

In [16]:
"".join(tokenizer.tokenize(dev_set[0]["paragraphs"][-1]))

'＊小叮嚀：我們可以注意學校自治組織招募成員的訊息，主動報名參加。'

In [17]:
class SSQA_Dataset(Dataset):
    """
        SSQA release 0.8, training set is still in developmemt
        usage :  FGC_Dataset(file_path, mode, tokenizer)
        for tokenizer:
            PRETRAINED_MODEL_NAME = "bert-base-chinese"
            tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
        for file_path:
            refer to config
        for mode:
            ["train", "develop", "test"]
    """
    # read, preprocessing
    def __init__(self, data_file_path, mode, tokenizer=None):
        # load raw json
        assert mode in ["train", "develop", "test"]
        self.mode = mode
        with open(data_file_path) as fo:
            self.raw_data = json.load(fo)
        if tokenizer == None:
            self.tokenizer = BertTokenizer.from_pretrained(config.BERT_EMBEDDING)
        else:
            self.tokenizer = tokenizer 
        self.tokenlized_pair = None
        
        # generate raw pairs of q sent s
        self.raw_pair = list()
        for instance in self.raw_data:
            q = instance["qtext"]
            sentences = instance["paragraphs"]
            for idx, sent in enumerate(sentences):
                # check if is supporting evidence
                lab = idx in instance["supporting_paragraphs_index"]
                self.raw_pair.append((q, sent, lab))
        
        # generate tensors 
        self.dat = list()
        for instance in self.raw_pair:
            q, sent, label = instance
            
            if mode is not "test":
                label_tensor = torch.tensor(label)
            else:
                label_tensor = None
            
            # first sentence, use bert tokenizer to cut subwords
            subwords = ["[CLS]"]
            q_tokens = self.tokenizer.tokenize(q)
            subwords.extend(q_tokens)
            subwords.append("[SEP]")
            len_q = len(subwords)
            
            # second sentence
            sent_tokens = self.tokenizer.tokenize(sent)
            subwords.extend(sent_tokens)
            
            # truncate if > BERT_MAX_INPUT_LEN 
            if(len(subwords) > config.BERT_MAX_INPUT_LEN-1):
                subwords = subwords[:config.BERT_MAX_INPUT_LEN-1]
            
            subwords.append("[SEP]")
            len_sent = len(subwords) -len_q
            
            
            # subwords to ids, ids to torch tensor
            ids = self.tokenizer.convert_tokens_to_ids(subwords)
            tokens_tensor = torch.tensor(ids)
            
            # segments_tensor
            segments_tensor = torch.tensor([0] * len_q + [1] * len_sent, dtype=torch.long)
            self.dat.append((tokens_tensor, segments_tensor, label_tensor))
            
        # id to q
        self.id_to_qid = []
        for qidx, instance in enumerate(self.raw_data):
            cur_qid = instance["qid"]
            sentences = instance["paragraphs"]
            for idx, sent in enumerate(sentences):
                self.id_to_qid.append(cur_qid)
            
        return None
    
    # get one data of index idx
    def __getitem__(self, idx):
        return self.dat[idx]
    
    def __len__(self):
        return len(self.dat)

In [18]:
dev_set = SSQA_Dataset(dev_path, "develop", tokenizer)

In [20]:
"".join(tokenizer.convert_ids_to_tokens(dev_set[0][0]))

'[CLS]透過參加學校自治組織，可以協助維護校園的秩序與安全。[SEP]身為學校的一份子，我們可以透過參加學校自治組織，協助維護校園的秩序與安全、反映同學的意見與想法、為同學爭取權益等。學校常見的自治組織有：糾察隊、衛生隊、交通隊、學生自治會等，負責推動各項自治活動。[SEP]'

In [21]:
dev_set[0][0]

tensor([ 101, 6851, 6882, 1347, 1217, 2119, 3413, 5632, 3780, 5175, 5251, 8024,
        1377,  809, 1295, 1221, 5204, 6362, 3413, 1754, 4638, 4914, 2415, 5645,
        2128, 1059,  511,  102, 6716, 4158, 2119, 3413, 4638,  671,  819, 2094,
        8024, 2769,  947, 1377,  809, 6851, 6882, 1347, 1217, 2119, 3413, 5632,
        3780, 5175, 5251, 8024, 1295, 1221, 5204, 6362, 3413, 1754, 4638, 4914,
        2415, 5645, 2128, 1059,  510, 1353, 3216, 1398, 2119, 4638, 2692, 6210,
        5645, 2682, 3791,  510, 4158, 1398, 2119, 4261, 1357, 3609, 4660, 5023,
         511, 2119, 3413, 2382, 6210, 4638, 5632, 3780, 5175, 5251, 3300, 8038,
        5144, 2175, 7386,  510, 6127, 4495, 7386,  510,  769, 6858, 7386,  510,
        2119, 4495, 5632, 3780, 3298, 5023, 8024, 6511, 6519, 2972, 1240, 1392,
        7517, 5632, 3780, 3833, 1240,  511,  102])

In [22]:
dev_set.id_to_qid

['PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002765',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002766',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002767',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-002768',
 'PubB-G3a-0303-03-0