# Chinese Question Answering (QA)

In [None]:
!git clone https://github.com/GitYCC/bert-minimal-tutorial.git

Cloning into 'bert-minimal-tutorial'...
remote: Enumerating objects: 100, done.[K
remote: Counting objects: 100% (100/100), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 100 (delta 52), reused 56 (delta 17), pack-reused 0[K
Receiving objects: 100% (100/100), 38.75 MiB | 10.36 MiB/s, done.
Resolving deltas: 100% (52/52), done.


In [None]:
%cd bert-minimal-tutorial

/content/bert-minimal-tutorial


In [None]:
!pip install -q -r requirements.txt

[?25l[K     |█▍                              | 10kB 26.1MB/s eta 0:00:01[K     |██▉                             | 20kB 29.7MB/s eta 0:00:01[K     |████▎                           | 30kB 25.6MB/s eta 0:00:01[K     |█████▊                          | 40kB 19.2MB/s eta 0:00:01[K     |███████▏                        | 51kB 13.1MB/s eta 0:00:01[K     |████████▋                       | 61kB 12.7MB/s eta 0:00:01[K     |██████████                      | 71kB 5.8MB/s eta 0:00:01[K     |███████████▍                    | 81kB 6.3MB/s eta 0:00:01[K     |████████████▉                   | 92kB 6.7MB/s eta 0:00:01[K     |██████████████▎                 | 102kB 7.2MB/s eta 0:00:01[K     |███████████████▊                | 112kB 7.2MB/s eta 0:00:01[K     |█████████████████▏              | 122kB 7.2MB/s eta 0:00:01[K     |██████████████████▋             | 133kB 7.2MB/s eta 0:00:01[K     |████████████████████            | 143kB 7.2MB/s eta 0:00:01[K     |█████████████████████

In [None]:
import os
import json

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertForQuestionAnswering
from tqdm.notebook import tqdm

from utils import RunningAverage, tokenize_and_map

MODEL_NAME = 'bert-base-chinese'
SEED = 1234

torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

## Dataloader

In [None]:
queries = []
contents = []
indexes_in_content = []

with open('data/DRCD_training.json') as fr:
    rows = json.load(fr)['data']
    for row in rows:
        for paragraph in row['paragraphs']:
            content = paragraph['context']
            for qa in paragraph['qas']:
                query = qa['question']
                for answer in qa['answers']:
                    answer_start = answer['answer_start']
                    answer_text = answer['text']

                    start_index = answer_start
                    end_index = answer_start + len(answer_text) - 1  # end_index落在包含答案的尾巴

                    queries.append(query)
                    contents.append(content)
                    indexes_in_content.append((start_index, end_index))

In [None]:
idx = 700
content = contents[idx]
start_index, end_index = indexes_in_content[idx]

print('query:', queries[idx])
print('content:', content)
print('answer:', content[start_index:end_index+1])

query: 誰成功阻止了蜀漢的北伐？
content: 曹魏主要戰爭都是抗衡蜀漢與孫吳的攻擊，在魏帝曹丕去去世後由曹真、曹休、司馬懿及陳群四人輔佐魏帝曹叡，而張郃和滿寵都是一方大將。這些將領守衛著魏國，其中以司馬懿最為卓越，他成功抵禦蜀漢北伐，並討於遼東之戰攻滅叛變的公孫淵。在曹叡死後，同為託孤大臣的曹爽與司馬懿發生權力鬥爭。最後司馬懿在249年發動政變，史稱高平陵之變，曹爽及其黨羽被滅族，魏國朝政為司馬懿父子掌握。其子司馬師、司馬昭相繼掌權，展開外除方鎮內廢魏帝的行動。當時守衛曹魏東方的重鎮壽春發生三次反抗司馬氏的舉兵，分別是王淩、毌丘儉與文欽、諸葛誕等三次叛亂，史稱壽春三叛。除王凌外的叛軍雖然獲得孫吳的援軍，最後仍被司馬氏擊潰。司馬氏專政期間，支持魏帝的將領與大臣有的反對司馬氏事敗，有的自危，于是或被殺害或逃亡至蜀吳二國，而司馬昭在殺害魏帝曹髦後，因為徹底清除異己，他開始準備篡位稱帝。
answer: 司馬懿


In [None]:
class QADataset(Dataset):
    def __init__(self, tokenizer, queries, contents, indexes_in_content=None, max_len=512, for_train=True):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.for_train = for_train

        self.queries = queries
        self.contents = contents
        self.indexes_in_content = indexes_in_content

    def __getitem__(self, idx):
        query = self.queries[idx].lower()
        content = self.contents[idx].lower()

        query_tokens, query_index_map = tokenize_and_map(self.tokenizer, query)
        content_tokens, content_index_map = tokenize_and_map(self.tokenizer, content)

        cut_index = self.max_len - len(query_tokens) - 3
        if cut_index < len(content_tokens):
            cut_text_index = content_index_map.index(cut_index)
            content_tokens = content_tokens[:cut_index]
            content = content[:cut_text_index]
            content_index_map = content_index_map[:cut_text_index]

        processed_tokens = ['[CLS]'] + query_tokens + ['[SEP]'] + content_tokens + ['[SEP]']

        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(processed_tokens))
        token_type_ids = torch.tensor([0] * (2 + len(query_tokens)) + [1] * (1 + len(content_tokens)))
        attention_mask = torch.tensor([1] * len(processed_tokens))

        outputs = (input_ids, token_type_ids, attention_mask)

        offset = 2 + len(query_tokens)
        if self.for_train:
            start_index_in_content, end_index_in_content = self.indexes_in_content[idx]

            if end_index_in_content >= len(content):
                # end_index is out of max_len => no ans
                start_index_in_content = -1
                end_index_in_content = -1

            start_index = offset + content_index_map[start_index_in_content]
            end_index = offset + content_index_map[end_index_in_content]
            
            start_index, end_index = torch.tensor(start_index), torch.tensor(end_index)
            outputs += (start_index, end_index, )

        content_info = {
            'text': content,
            'tokens': content_tokens,
            'index_map': content_index_map,
            'offset': offset
        }
        outputs += (content_info, )
        return outputs

    def __len__(self):
        return len(self.queries)

    def create_mini_batch(self, samples):
        outputs = list(zip(*samples))

        # zero pad 到同一序列長度
        input_ids = pad_sequence(outputs[0], batch_first=True)
        token_type_ids = pad_sequence(outputs[1], batch_first=True)
        attention_mask = pad_sequence(outputs[2], batch_first=True)

        batch_output = (input_ids, token_type_ids, attention_mask)
    
        if self.for_train:
            start_indexes = torch.stack(outputs[3])
            end_indexes = torch.stack(outputs[4])
            batch_output += (start_indexes, end_indexes, )
        else:
            content_infos = outputs[3]
            batch_output += (content_infos, )

        return batch_output

In [None]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

SKIP_TOKEN_IDS = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]

dataset = QADataset(tokenizer, queries, contents, indexes_in_content)

CUT_RATIO = 0.9
train_size = int(CUT_RATIO * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=109540.0, style=ProgressStyle(descripti…




In [None]:
batch_size = 8

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
    shuffle=True
)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
)

## Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model = BertForQuestionAnswering.from_pretrained(
    MODEL_NAME,
    return_dict=True
)
model.to(device)

device: cuda


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=624.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411577189.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-chinese a

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

## Train

In [None]:
def train_batch(model, data, optimizer, device):
    model.train()
    input_ids, token_type_ids, attention_mask, start_indexes, end_indexes = [d.to(device) for d in data]

    outputs = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        start_positions=start_indexes,
        end_positions=end_indexes
    )
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


def evaluate(model, valid_loader, device):
    model.eval()

    loss_averager = RunningAverage()
    acc_averager = RunningAverage()

    with torch.no_grad():
        for data in tqdm(valid_loader, desc='evaluate'):
            input_ids, token_type_ids, attention_mask, start_indexes, end_indexes = [d.to(device) for d in data]

            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                start_positions=start_indexes,
                end_positions=end_indexes
            )

            loss_averager.add(outputs.loss.item())

            pred_start_indexes = outputs.start_logits.argmax(dim=-1)
            pred_end_indexes = outputs.end_logits.argmax(dim=-1)
            corrects = torch.logical_and(pred_start_indexes == start_indexes,
                                         pred_end_indexes == end_indexes).cpu().tolist()
            acc_averager.add_all(corrects)

    return loss_averager.get(), acc_averager.get()

In [None]:
lr = 0.00001
max_iter = 600
show_per_iter = 10
valid_per_iter = 150
save_per_iter = 300
save_checkpoint_dir = 'models/'
model_prefix = 'cn_qa_'

assert save_per_iter % valid_per_iter == 0

optimizer = optim.Adam(model.parameters(), lr=lr)

i = 1
is_running = True
train_loss_averager = RunningAverage()
model_paths = []
while is_running:
    for train_data in train_loader:
        loss = train_batch(model, train_data, optimizer, device)
        train_loss_averager.add(loss)

        if i % show_per_iter == 0:
            print('train [{}]: loss={}'.format(i, train_loss_averager.get()))
            train_loss_averager.flush()

        if i % valid_per_iter == 0:
            loss, acc = evaluate(model, valid_loader, device)
            print(f'valid: loss={loss} acc={acc}')

        if i % save_per_iter == 0:
            path = os.path.join(save_checkpoint_dir, model_prefix + f'loss{loss:.5}/')
            print(f'save model at {path}')
            model.save_pretrained(path)
            model_paths.append(path)
        
        if i == max_iter:
            is_running = False
            break

        i += 1

train [10]: loss=6.1591521263122555
train [20]: loss=5.79662823677063
train [30]: loss=5.454067325592041
train [40]: loss=5.159710788726807
train [50]: loss=4.822084856033325
train [60]: loss=4.4720964431762695
train [70]: loss=4.38559501171112
train [80]: loss=4.0061897993087765
train [90]: loss=3.5085565567016603
train [100]: loss=3.220983386039734
train [110]: loss=3.039277935028076
train [120]: loss=2.549754571914673
train [130]: loss=2.2049063444137573
train [140]: loss=2.6635491609573365
train [150]: loss=2.223186028003693


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=337.0, style=ProgressStyle(description_wid…


valid: loss=2.1453104891126165 acc=0.3040089086859688
train [160]: loss=2.3869534015655516
train [170]: loss=2.2542189598083495
train [180]: loss=2.182021069526672
train [190]: loss=2.2956191062927247
train [200]: loss=1.8180055260658263
train [210]: loss=2.1757318496704103
train [220]: loss=2.0445661425590513
train [230]: loss=1.71410413980484
train [240]: loss=1.6620292484760284
train [250]: loss=1.9851764559745788
train [260]: loss=1.7982043623924255
train [270]: loss=1.8725574493408204
train [280]: loss=1.499091124534607
train [290]: loss=1.577263706922531
train [300]: loss=1.688332235813141


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=337.0, style=ProgressStyle(description_wid…


valid: loss=1.3913321849677258 acc=0.5174461766889383
save model at models/cn_qa_loss1.3913/
train [310]: loss=1.398196029663086
train [320]: loss=1.4431588232517243
train [330]: loss=1.5344918370246887
train [340]: loss=1.5915170431137085
train [350]: loss=1.2457448959350585
train [360]: loss=1.3680573940277099
train [370]: loss=1.5721929967403412
train [380]: loss=1.6213249266147614
train [390]: loss=1.408910322189331
train [400]: loss=1.4728749632835387
train [410]: loss=1.4926735579967498
train [420]: loss=1.5066946148872375
train [430]: loss=1.5251627445220948
train [440]: loss=1.4519009172916413
train [450]: loss=1.1699353814125062


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=337.0, style=ProgressStyle(description_wid…


valid: loss=1.1175130269205535 acc=0.6043058648849294
train [460]: loss=1.4228537619113921
train [470]: loss=1.2027886033058166
train [480]: loss=1.0095838218927384
train [490]: loss=1.2084775805473327
train [500]: loss=1.3297013640403748
train [510]: loss=1.1536094844341278
train [520]: loss=1.744823545217514
train [530]: loss=1.2151461005210877
train [540]: loss=1.2107483804225923
train [550]: loss=0.8021921932697296
train [560]: loss=1.1398630440235138
train [570]: loss=1.1675307154655457
train [580]: loss=1.2327691614627838
train [590]: loss=1.0360288113355636
train [600]: loss=1.3316008359193803


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=337.0, style=ProgressStyle(description_wid…


valid: loss=0.9974118396418739 acc=0.6273199703043801
save model at models/cn_qa_loss0.99741/


## Predict

In [None]:
reload_checkpoint = model_paths[-1]

queries = [
    '陸特和漢斯雷頓開創了哪一地區對梵語的學術研究？',
    '「北京皇家祭壇—天壇」在哪一年的時候，正式被列為世界文化遺產?'
]
contents = [
    '在歐洲，梵語的學術研究，由德國學者陸特和漢斯雷頓開創。後來威廉·瓊斯發現印歐語系，也要歸功於對梵語的研究。此外，梵語研究，也對西方文字學及歷史語言學的發展，貢獻不少。1786年2月2日，亞洲協會在加爾各答舉行。會中，威廉·瓊斯發表了下面這段著名的言論：「梵語儘管非常古老，構造卻精妙絕倫：比希臘語還完美，比拉丁語還豐富，精緻之處同時勝過此兩者，但在動詞詞根和語法形式上，又跟此兩者無比相似，不可能是巧合的結果。這三種語言太相似了，使任何同時稽考三者的語文學家都不得不相信三者同出一源，出自一種可能已經消逝的語言。基於相似的原因，儘管缺少同樣有力的證據，我們可以推想哥德語和凱爾特語，雖然混入了迥然不同的語彙，也與梵語有著相同的起源；而古波斯語可能也是這一語系的子裔。」',
    '北京天壇位於北京市東城區，是明清兩朝帝王祭天、祈穀和祈雨的場所。是現存中國古代規模最大、倫理等級最高的祭祀建築群。1961年，天壇被中華人民共和國國務院公布為第一批全國重點文物保護單位之一。1998年，「北京皇家祭壇—天壇」被列為世界文化遺產。北京天壇最初為明永樂十八年仿南京城形制而建的天地壇，嘉靖九年實行四郊分祀制度後，在北郊覓地另建地壇，原天地壇則專事祭天、祈穀和祈雨，並改名為天壇。清代基本沿襲明制，在乾隆年間曾進行過大規模的改擴建，但年門和皇乾殿是明代建築而無改建除外。1900年八國聯軍進攻北京時，甚至還把司令部設在這裡，並在圜丘壇上架設大炮，攻擊正陽門和紫禁城，聯軍們將幾乎所有的陳設和祭器都席捲而去。1912年中華民國成立後，除了中華民國大總統袁世凱在1913年冬至祭天外，天壇不再進行任何祭祀活動。1918年起闢為公園，正式對民眾開放。目前園內古柏蔥鬱，是北京城南的一座大型園林。'
]

pred_dataset = QADataset(tokenizer, queries, contents, for_train=False)

pred_loader = DataLoader(
    dataset=pred_dataset,
    batch_size=batch_size,
    collate_fn=pred_dataset.create_mini_batch,
)

model = BertForQuestionAnswering.from_pretrained(reload_checkpoint)
model.to(device)

answers = []
with torch.no_grad():
    for data in tqdm(pred_loader, desc='predict'):
        input_ids, token_type_ids, attention_mask = [d.to(device) for d in data[:3]]
        content_infos = data[3]

        outputs = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )

        for start_logit, end_logit, content_info in zip(outputs.start_logits,
                                                        outputs.end_logits,
                                                        content_infos):
            offset = content_info['offset']
            index_map = content_info['index_map']
            text = content_info['text']
            answer_token_start = start_logit.argmax(dim=-1) - offset
            answer_token_end = end_logit.argmax(dim=-1) - offset
            answer_start = index_map.index(answer_token_start)
            answer_end = index_map.index(answer_token_end) + 1
            if answer_start > answer_end or answer_start <= 0 or answer_end <= 0:
                answer = ''
            else:
                answer = text[answer_start:answer_end]
            answers.append(answer)

print('predict result: ')
for answer in answers:
    print('answer:', answer)

HBox(children=(FloatProgress(value=0.0, description='predict', max=1.0, style=ProgressStyle(description_width=…


predict result: 
answer: 德國
answer: 1998年
