<ol>
    <li>瀏覽次數</li>
    <li>不刪除重複的標題</li>
    <li>用逗號隔開，並取每句話前三個詞</li>
</ol>

In [31]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
!cp drive/MyDrive/utils.py .
!cp drive/MyDrive/dataset.py .

In [3]:
!pip install hanlp
!pip install opencc
!pip install pandas
!pip install wikipedia-api
!pip install pandarallel
!pip install wikipedia

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting hanlp
  Downloading hanlp-2.1.0b50-py3-none-any.whl (651 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m651.1/651.1 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting hanlp-common>=0.0.19 (from hanlp)
  Downloading hanlp_common-0.0.19.tar.gz (28 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting hanlp-downloader (from hanlp)
  Downloading hanlp_downloader-0.0.25.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting hanlp-trie>=0.0.4 (from hanlp)
  Downloading hanlp_trie-0.0.5.tar.gz (6.7 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pynvml (from hanlp)
  Downloading pynvml-11.5.0-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece>=0.1.91 (from hanlp)
  Download

In [4]:
from opencc import OpenCC

# built-in libs
import json
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# 3rd party libs
import hanlp
import opencc
import pandas as pd
import wikipedia
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel

# our own libs
from utils import load_json

#from pytorchtools import EarlyStopping
import matplotlib.pyplot as plt

#from sentence_transformers import SentenceTransformer,util
from torch.utils.data import DataLoader

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
wikipedia.set_lang("zh")

device = 'cuda'

In [5]:

@dataclass
class Claim:
    r"""
    陳述句: str
    """
    data: str

@dataclass
class AnnotationID:
    r"""
    註解編號
    """
    id: int

@dataclass
class EvidenceID:
    r"""
    證據句編號
    """
    id: int

@dataclass
class PageTitle:
    r"""
    維基百科文章標題
    """
    title: str

@dataclass
class SentenceID:
    r"""
    維基百科文章句子編號
    """
    id: int

@dataclass
class Evidence:
    r"""
    陳述證據句
    來自維基百科文章中的句子
    包含哪一篇文章、文章句子編號
    """
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

#### 第一階段：資料前處理: 確定要搜尋的wiki文章

In [6]:
save_path = "/content/drive/MyDrive/NCKU-AICUP2023-baseline/"

In [54]:
#TRAIN_DATA = load_json(save_path + "data/public_train_0316.jsonl")
TRAIN_DATA = load_json(save_path + "data/merged_public_train.jsonl")#5/22有給第二批訓練集，加入第一批訓練集=merge檔 claim+evidence+label
#TEST_DATA = load_json(save_path + "data/public_test_data.jsonl")
TEST_DATA = load_json(save_path + "data/merged_public_private_test.jsonl") #只有claim 要用來預測答案上傳用的
CONVERTER_T2S = opencc.OpenCC("t2s")
CONVERTER_S2T = opencc.OpenCC("s2t")

In [55]:
def do_st_corrections(text: str) -> str:
    r"""
    透過OpenCC將繁體中文轉換成簡體中文，再轉換回繁體中文

    Args:
        text (str): 要轉換的文字
    """
    simplified = CONVERTER_T2S.convert(text)
    return CONVERTER_S2T.convert(simplified)

In [56]:
def get_nps_hanlp(
    predictor: Pipeline,
    d: Dict[str, Union[int, Claim, Evidence]],
) -> List[str]:
    r"""
    透過HanLP的parser取得claim的名詞片語

    Args:
        predictor (Pipeline): HanLP的parser
        
        d (Dict[str, Union[int, Claim, Evidence]]): claim的資料

    Returns:
        List[str]: claim的名詞片語

    """
    claim = d["claim"]
    tree = predictor(claim)["con"]
    nps = [
        do_st_corrections("".join(subtree.leaves()))
        #for subtree in tree.subtrees(lambda t: t.label() == "NP" or t.label() == "V"or t.label() == "A")
        for subtree in tree.subtrees(lambda t: t.label() == "NP" or t.label() == "A")
    ]

    return nps

In [57]:
def calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    r"""
    計算精確度

    Args:
        data (List[Dict[str, Union[int, Claim, Evidence]]]): claim的資料

        predictions (pd.Series): 預測的結果

    """
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")


def calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    r"""
    計算召回率

    Args:
        data (List[Dict[str, Union[int, Claim, Evidence]]]): claim的資料

        predictions (pd.Series): 預測的結果

    """
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")

In [58]:
def get_pred_pages(series_data: pd.Series) -> Set[Dict[int, str]]:

    r"""
    傳入敘述句，使用Wikipedia API查詢名詞片語
    過濾出需要查詢的名詞片語

    Args:
        series_data (pd.Series): claim的資料

    Returns:
        Set[Dict[int, str]]: 需要查詢的名詞片語

    """
    results = []
    tmp_muji = []
    # wiki_page: its index showned in claim
    mapping = {}
    claim = series_data["claim"]
    nps = series_data["hanlp_results"]
    first_wiki_term = []

    for i, np in enumerate(nps):
        # Simplified Traditional Chinese Correction
        # 使用Wikipedia API搜尋名詞片語
        # 使用OpenCC將簡體中文轉換成繁體中文
        # 轨道 (力学) -> 軌道 (力學)
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        # Remove the wiki page's description in brackets
        # 移除wiki搜尋結果中，括號內的描述
        # 軌道 (力學) -> 軌道
        wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
        # wiki_set = wiki_search_results
        wiki_df = pd.DataFrame({
            "wiki_set": wiki_set, # 將去除括號後的文章標題作為類別
            "wiki_results": wiki_search_results # 原始文章的標題
        })

        # Elements in wiki_set --> index
        # Extracting only the first element is one way to avoid extracting
        # too many of the similar wiki pages
        # 透過wiki_set欄位合併wiki_df，並取得第一個wiki_set的值
        # 擷取第一個wiki_set的值，是為了避免擷取太多相似的wiki頁面
        grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
        # grouped_df = wiki_df
        candidates = grouped_df["wiki_results"].tolist()
        # muji refers to wiki_set
        muji = grouped_df.index.tolist()

        for prefix, term in zip(muji, candidates):
            # 若prefix(類別)已經出現過，則不再重複擷取
            if prefix not in tmp_muji:
                matched = False

                # Take at least one term from the first noun phrase
                # 如果是第一個名詞片語，至少要取一個名詞片語
                if i == 0:
                    first_wiki_term.append(term)

                # Walrus operator :=
                # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
                # Through these filters, we are trying to figure out if the term
                # is within the claim
                # 透過取代.、空白、-，來判斷term是否在claim中
                if (((new_term := term) in claim) or
                    ((new_term := term.replace("·", "")) in claim) or
                    ((new_term := term.split(" ")[0]) in claim) or
                    ((new_term := term.replace("-", " ")) in claim)):
                    matched = True

                # term未出現在claim中，但term中有.，則以.為分隔符號，判斷term是否在claim中
                elif "·" in term:
                    splitted = term.split("·")
                    for split in splitted:
                        if (new_term := split) in claim:
                            matched = True
                            break

                # 發現term在claim中，進行後處理
                if matched:
                    # post-processing
                    # 將term中的 '-' -> '', ' ' -> '_'
                    term = term.replace(" ", "_")
                    term = term.replace("-", "")
                    results.append(term)
                    # 將文章中與claim相符的部分，以term為key，claim中相符的部分為value(句子中相符合名詞的index)
                    mapping[term] = claim.find(new_term)
                    # 將term加入tmp_muji中，避免重複擷取
                    tmp_muji.append(new_term)

    # 5 is a hyperparameter
    # 若取得的term超過5個，則取mapping中value最小的5個term（最早出現的5個term）
    if len(results) > 6:
        assert -1 not in mapping.values()
        results = sorted(mapping, key=mapping.get)[:6]
    elif len(results) < 1:
        # 至少回傳一個term
        results = first_wiki_term

    return set(results)

In [None]:
'''
s = re.sub(r"\s\(\S+\)", "", "軌道，eqweqw，，www (力學)")
s = re.split(r"[，]", s)
s.remove("")
print(s)
'''

['軌道', 'eqweqw', 'www']


In [59]:
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 6,
) -> None:
    
    r"""
    將預測的結果存成jsonl檔

    Args:
        data (List[Dict[str, Union[int, Claim, Evidence]]]): claim的資料

        predictions (pd.Series): 預測的結果

        mode (str, optional): 資料集的模式. 預設為"train"

        num_pred_doc (int, optional): 預測的數量. 預設值為5
    
    """
    with open(
        save_path + f"data/{mode}_doc{num_pred_doc}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

##### hanlp Constituency Parsing Tree Predictor
用來切割出陳述句中的名詞片語

In [60]:
predictor = (hanlp.pipeline().append(
    hanlp.load("FINE_ELECTRA_SMALL_ZH"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_ELECTRA_SMALL"),
    output_key="con",
    input_key="tok",
))

Downloading https://file.hankcs.com/hanlp/tok/fine_electra_small_20220615_231803.zip to /root/.hanlp/tok/fine_electra_small_20220615_231803.zip
Decompressing /root/.hanlp/tok/fine_electra_small_20220615_231803.zip to /root/.hanlp/tok
Downloading https://file.hankcs.com/hanlp/utils/char_table_20210602_202632.json.zip to /root/.hanlp/utils/char_table_20210602_202632.json.zip
Decompressing /root/.hanlp/utils/char_table_20210602_202632.json.zip to /root/.hanlp/utils
Downloading https://file.hankcs.com/hanlp/transformers/electra_zh_small_20210706_125427.zip to /root/.hanlp/transformers/electra_zh_small_20210706_125427.zip
Decompressing /root/.hanlp/transformers/electra_zh_small_20210706_125427.zip to /root/.hanlp/transformers
Downloading https://file.hankcs.com/hanlp/constituency/ctb9_con_electra_small_20220215_230116.zip to /root/.hanlp/constituency/ctb9_con_electra_small_20220215_230116.zip
Decompressing /root/.hanlp/constituency/ctb9_con_electra_small_20220215_230116.zip to /root/.hanlp/

使用hanlp Constituency Parsing Tree進行句子轉換並找出名詞

In [61]:
hanlp_file = save_path + "data/hanlp_con_results-emily.pkl" #自己的實驗要換檔名 取train claim名詞

# 若hanlp_file存在，則讀取檔案
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    # 若hanlp_file不存在，則重新計算
    # 使用hanlp套件對claim進行斷詞，並取得名詞片語
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

In [62]:
doc_path = save_path + "data/train_doc6.jsonl" #自己的實驗要換檔名 用train claim名詞取得wiki page title
# 若doc_path存在，則讀取檔案
if Path(doc_path).exists():
    with open(doc_path, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    # 若doc_path不存在，則重新計算
    train_df = pd.DataFrame(TRAIN_DATA)
    # 將hanlp_results (斷詞結果)加入train_df
    train_df.loc[:, "hanlp_results"] = hanlp_results
    # 將train_df傳入get_pred_pages函式，取得需要查詢的名詞片語
    predicted_results = train_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TRAIN_DATA, predicted_results, mode="train")

In [63]:
predicted_results


0              {軌道, 天衛三, 仲夏夜_(羅文專輯), 仲夏夜之夢_(消歧義), 天王星, 磁層}
1                            {太平洋, 翼展, 北冰洋, 牠, 南太平洋, 信天翁科}
2             {飛_(消歧義), 主唱, 吉他手, 詹雯婷, F._R._大衛, Faye_Disc}
3                  {2001年, 24小時, 24_(電視劇), 機場, 小時, 香港國際機場}
4        {歷史, 中央部屬高校, 中華人民共和國, 高等學校, 校長, 黨委書記和校長列入中央管理的高校}
                               ...                        
11615                    {過敏, 伯明翰大學, 伯明翰, 研究, 過敏性鼻炎, 諾貝爾獎}
11616                           {2003年, 新界, 翠豐臺, 南豐紗廠, 荃灣}
11617         {大安區_(臺北市), 仁愛路_(臺北市), 信義區_(臺北市), 中正區_(臺北市)}
11618                       {臺南北極殿, 古蹟, 臺南市, 國定古蹟, 文化部, 南}
11619                               {王家驥, 校長, 副校長, 國立臺東大學}
Length: 11620, dtype: object

In [64]:
hanlp_test_file = f"data/hanlp_test_results-emily.pkl"
if Path(hanlp_test_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

KeyboardInterrupt: ignored

In [None]:
test_doc_path = f"data/test_doc6.jsonl"
if Path(test_doc_path).exists():
    with open(test_doc_path, "r", encoding="utf8") as f:
        test_results = pd.Series(
            [set(json.loads(line)["predicted_pages"]) for line in f])
else:
    test_df = pd.DataFrame(TEST_DATA)
    test_df.loc[:, "hanlp_results"] = hanlp_results
    test_results = test_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TEST_DATA, test_results, mode="test")

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=904), Label(value='0 / 904'))), HB…

#### 第二階段：語句檢索: 找出與搜尋關鍵字相關的wiki文章
- 先找出相關文章，並抽取相關的句子作為證據句
- 將陳述句與句子丟入bert模型進行分類：證據句 or 非證據句

In [65]:
# built-in libs
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset, Dataset

# local libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

In [66]:
SEED = 42

#TRAIN_DATA = load_json(save_path + "data/public_train_0316.jsonl")
TRAIN_DATA = load_json(save_path + "data/merged_public_train.jsonl") #合併第二批訓練資料 claim + evidence + label
#TEST_DATA = load_json(save_path + "data/public_test_data.jsonl") #只有claim 沒label
TEST_DATA = load_json(save_path + "data/merged_public_private_test.jsonl") #只有claim 沒label
#DOC_DATA = load_json(save_path + "data/train_doc5.jsonl")#baseline原本的檔案 有萃取出predict page
#DOC_DATA = load_json(save_path + "data/train_doc5sent5_merge.jsonl") #claim + predict page + evidence + label
DOC_DATA = load_json(save_path + "data/train_doc6.jsonl") #claim + predict page + evidence + label

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

#_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
_y = [LABEL2ID[data["label"]] for data in DOC_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=_y,
)

In [67]:
len(TRAIN_GT)

9296

In [None]:
'''
x = "城市規劃是城市建設及管理的依據 ， 位於城市管理之規劃"
sentences = re.split(r"[\s]\n(?=[0-9])", x)
sentences
'''

['城市規劃是城市建設及管理的依據 ， 位於城市管理之規劃']

In [68]:
wiki_pages = jsonl_dir_to_df(save_path + "data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in /content/drive/MyDrive/NCKU-AICUP2023-baseline/data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=118776), Label(value='0 / 118776')…

Transform to id to evidence_map mapping


In [69]:
mapping['亞伯拉罕諸教']

{'0': '亞伯拉罕諸教 ， 又稱亞伯拉罕宗教 、 亞伯拉罕一神諸教 、 天啓宗教 、 天啓諸教 、 閃族一神諸教 、 閃米特一神諸教 、 閃米特諸教等 ， 指世界主要的三個有共同源頭的一神教 ： 基督宗教 （ 包括天主教 、 基督新教與東正教 ） 、 伊斯蘭教與猶太教 。',
 '1': '如此稱呼 ， 皆因這三個宗教均給予聖經舊約中的亞伯拉罕 （ 阿拉伯語譯作易卜拉欣 ） 崇高的地位 ， 且均發源於西亞沙漠地區 ， 來源於閃米特人的原始宗教 。',
 '2': '廣義的沙漠宗教或閃米特宗教還包括當地或其族群中曾經存在的其他多神宗教 ， 儘管現在通常直接用來指這三種一神教 。',
 '3': ''}

In [70]:
def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0

    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0

        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

    return 0.0, 0.0

def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        claim = instance["claim"]

        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

In [71]:
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 6,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    # convert 2d array to 1d array
    probs = probs.reshape(-1)
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]

        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)

    if save_name is not None:
        # write doc7_sent5 file
        with open(save_path + f"data/{save_name}", "w") as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim]["predicted_evidence"].tolist()
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

In [72]:
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())

    return np.array(probs)


In [73]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]

        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten
    

In [74]:
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # 支持
    for i in range(len(df)):
        # 若label為NOT ENOUGH INFO則跳過
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]

        # 一個 '陳述句' 可能有多個 '證據組'
        for evidence_set in evidence_sets:
            sents = []

            # 每個 '證據組'可能有多個 '證據'
            for evidence in evidence_set:
                # evidence[2] 是文章的標題
                page = evidence[2].replace(" ", "_")
                # 唯一例外
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] 文章中的第幾句為證據，通常是一個數字，但mapping需要以str做key
                sent_idx = str(evidence[3])
                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            # 陳述句
            claims.append(claim)
            # 證據句
            sentences.append(whole_evidence)
            # 支持
            labels.append(1)

    # 反對
    for i in range(len(df)):
        # 若label為NOT ENOUGH INFO則跳過
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        # 陳述句
        claim = df["claim"].iloc[i]

        # 證據句
        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in df["evidence"][i]
                            for evidence in evidences])
        
        # 要搜尋的wiki頁面
        predicted_pages = df["predicted_pages"][i]

        # 每個要頁面
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    # sent_idx: wiki頁面中的第幾句
                    # page: wiki頁面
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print("{} is not in our Wiki db.".format(page))
                continue

            # 每個wiki頁面中的每句
            for pair in page_sent_id_pairs:
                # 如果這句已經在證據中，則跳過
                if pair in evidence_set:
                    continue
                text = mapping[page][pair[1]]
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                # 以0.05的機率加入反對的例子，避免太多反對的例子（資料不平衡）
                if text != "" and np.random.rand(1) <= negative_ratio:
                    claims.append(claim)
                    sentences.append(text)
                    labels.append(0)

    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})


def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # 反對
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue

        # 陳述句
        claim = df["claim"].iloc[i]

        # 要搜尋的wiki頁面標題
        predicted_pages = df["predicted_pages"][i]

        # 每個要查詢頁面標題
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                # k: wiki頁面中的每句index與句子
                # page: wiki頁面標題
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print("{} is not in our Wiki db.".format(page))
                continue

            # 每個wiki頁面中的每句
            for page_name, sentence_id in page_sent_id_pairs:
                # 取得wiki頁面中的每句
                text = mapping[page][sentence_id]
                # 若該句不為空，則加入
                if text != "":
                    claims.append(claim)
                    sentences.append(text)

                    # 若為dev set，則加入evidence，若為test set，則不加入
                    if not is_testset:
                        evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

設定超參數

In [75]:
#@title  { display-mode: "form" }

# LERT
#MODEL_NAME = "hfl/chinese-electra-base-generator"  #@param {type:"string"}
#MODEL_NAME = "hfl/chinese-bert-wwm"
MODEL_NAME = "hfl/chinese-lert-base"
# EPOCH
#NUM_EPOCHS = 3  #@param {type:"integer"}
NUM_EPOCHS = 4 #跑10次太久 應該4次就夠
# 學習率
#LR = 2e-4  #@param {type:"number"}
LR = 1e-4
# 訓練 BATCH_SIZE
TRAIN_BATCH_SIZE = 64  #@param {type:"integer"}
# 測試 BATCH_SIZE
TEST_BATCH_SIZE = 256  #@param {type:"integer"}


NEGATIVE_RATIO = 0.03  #@param {type:"number"}


VALIDATION_STEP = 50  #@param {type:"integer"}

# 選擇前幾句作為證據句
TOP_N = 5  #@param {type:"integer"}

實驗設定

In [76]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}"
LOG_DIR = save_path + "logs/" + EXP_DIR
CKPT_DIR = save_path + "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)


In [77]:
# 根據mapping將訓練集的claim和wiki的句子配對
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)

counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    7568
1    7027
Name: label, dtype: int64


訓練模型

In [78]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

Downloading (…)okenizer_config.json:   0%|          | 0.00/19.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

創建模型、優化器等

In [79]:
device = 'cuda'
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# model = SentenceTransformer('distiluse-base-multilingual-cased-v2')
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Downloading pytorch_model.bin:   0%|          | 0.00/412M [00:00<?, ?B/s]

開始訓練

In [None]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0
first_all_train_loss = []
first_all_val_loss = []

for epoch in range(NUM_EPOCHS):
    train_loss = 0
    val_loss = 0
    
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)
        train_loss += loss.item()

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(val_results)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(model, CKPT_DIR, current_steps)
        
    first_all_train_loss.append(train_loss)
    first_all_val_loss.append(val_loss)

print("Finished training!")

# 顯示訓練過程中的loss
plt.plot(first_all_train_loss, label='train_loss')
plt.plot(first_all_val_loss, label='val_loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()
print(first_all_train_loss, first_all_val_loss)

  0%|          | 0/916 [00:00<?, ?it/s]

Start validation


  0%|          | 0/336 [00:00<?, ?it/s]

{'F1 score': 0.36310764207373974, 'Precision': 0.24260930605695474, 'Recall': 0.7214199759326113}
Start validation


  0%|          | 0/336 [00:00<?, ?it/s]

驗證

In [None]:
ckpt_name = "model.50.pt"  #@param {type:"string"}
#model = load_model(model, ckpt_name, CKPT_DIR)
model = load_model(model, ckpt_name, "/content/drive/MyDrive/NCKU-AICUP2023-baseline/checkpoints/sent_retrieval/e4_bs64_0.0002_neg0.03_top5/")
print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
    save_name=f"train_doc5sent{TOP_N}.jsonl",
)
print(f"Training scores => {train_results}")

print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
    save_name=f"dev_doc5sent{TOP_N}.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/1347 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.3824456297781342, 'Precision': 0.2574968648106549, 'Recall': 0.7429646350639578}
Start validation


  0%|          | 0/336 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.3681352284322209, 'Precision': 0.24573806658643718, 'Recall': 0.733453670276775}


In [None]:
dev_evidences

Unnamed: 0,claim,text,evidence,predicted_evidence,prob
0,煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一...,北洋水師是大清帝國建立的現代化海軍 ， 1880年代籌建 ， 1888年12月17日在山東省...,"[[20367, None, None, None]]","[北洋水師, 0]",0.225842
1,煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一...,北洋水師是清朝新式海軍四支艦隊中規模最大 、 投資最巨者 ， 規模一度號稱爲世界第八 、 亞...,"[[20367, None, None, None]]","[北洋水師, 1]",0.084339
2,煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一...,但北洋大臣兼直隸總督李鴻章與北洋水師提督丁汝昌對現代海戰的殘酷性缺乏認識 ， 和李鴻章的政敵...,"[[20367, None, None, None]]","[北洋水師, 2]",0.057702
3,煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一...,學生是在受到國家或當地政府認可之教育機構 （ 如學校 、 學院 ） 學習或進修者 ， 並且該...,"[[20367, None, None, None]]","[學生, 0]",0.051152
4,煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一...,廣義而言 ， 在研究機構或工作單位 （ 如醫院 、 企業等地 ） 學習者有時也會自稱學生 ，...,"[[20367, None, None, None]]","[學生, 1]",0.006466
...,...,...,...,...,...
85939,生物一開始只分爲動物和植物兩界，而後來沒有細胞核的細菌被獨立爲一界，稱作細菌界，真菌也因爲沒...,1874年 ， 恩斯特 · 海克爾將動物界分爲多細胞後生動物 （ 動物的異名 ） 和原生動物...,"[[[4272, 4386, 界_(生物), 1], [4272, 4386, 界_(生物)...","[動物, 19]",0.692265
85940,生物一開始只分爲動物和植物兩界，而後來沒有細胞核的細菌被獨立爲一界，稱作細菌界，真菌也因爲沒...,在現代 ， 動物的生物學分類依賴分子系統發生學等先進分析技術 ， 能夠有效地證明動物分類單元...,"[[[4272, 4386, 界_(生物), 1], [4272, 4386, 界_(生物)...","[動物, 20]",0.035337
85941,生物一開始只分爲動物和植物兩界，而後來沒有細胞核的細菌被獨立爲一界，稱作細菌界，真菌也因爲沒...,在人類發展的過程中 ， 其他動物 （ 肉 、 卵和奶 ） 一直是人類重要的食物來源 ， 動物...,"[[[4272, 4386, 界_(生物), 1], [4272, 4386, 界_(生物)...","[動物, 23]",0.063437
85942,生物一開始只分爲動物和植物兩界，而後來沒有細胞核的細菌被獨立爲一界，稱作細菌界，真菌也因爲沒...,某些動物被馴化爲家禽 、 家畜或寵物 。,"[[[4272, 4386, 界_(生物), 1], [4272, 4386, 界_(生物)...","[動物, 24]",0.326030


預測test資料中的wiki page中的句子哪個是證據句

In [None]:
test_data = load_json(save_path + "data/test_doc6.jsonl") #包含由上面模型預測出的evidence

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer)
test_dataloader = DataLoader(test_set, batch_size=TEST_BATCH_SIZE)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name=f"merge_test_doc5sent{TOP_N}.jsonl",
)

Start predicting the test data


  0%|          | 0/1291 [00:00<?, ?it/s]

#### 第三部分：陳述句驗證
- 將陳述句與證據句丟入bert模型進行分類：支持（Supports）、、

In [7]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm

import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

In [8]:
save_path = "/content/drive/MyDrive/NCKU-AICUP2023-baseline/"

CONVERTER_T2S = opencc.OpenCC("t2s")
CONVERTER_S2T = opencc.OpenCC("s2t")

In [45]:
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json(save_path + "data/train_doc5sent5.jsonl") #上面產出的檔
DEV_DATA = load_json(save_path + "data/dev_doc5sent5.jsonl")
#TRAIN_PKL_FILE = Path(save_path + "data/train_doc5sent5.pkl")
#DEV_PKL_FILE = Path(save_path + "data/dev_doc5sent5.pkl")

#TRAIN_DATA = load_json(save_path + "data/train_doc5sent6.jsonl")
#DEV_DATA = load_json(save_path + "data/dev_doc5sent6.jsonl")
#TRAIN_PKL_FILE = Path(save_path + "data/train_doc5sent6.pkl")
#DEV_PKL_FILE = Path(save_path + "data/dev_doc5sent6.pkl")

TRAIN_PKL_FILE = Path(save_path + "data/train_doc5sent5_0601.pkl")
DEV_PKL_FILE = Path(save_path + "data/dev_doc5sent5_0601.pkl")

In [None]:
DEV_DATA[0]

{'id': 20847,
 'label': 'NOT ENOUGH INFO',
 'claim': '煙臺海軍學校的前身是北洋水師槍炮官所建立負責收要學習駕駛的學生的水師學堂，與八國聯軍時躲過一劫的天津水師學堂並稱。',
 'evidence': [[20367, None, None, None]],
 'predicted_pages': ['北洋水師', '學生', '煙臺海軍學校', '水師', '八國聯軍', '駕駛'],
 'predicted_evidence': [['煙臺海軍學校', 0],
  ['北洋水師', 0],
  ['北洋水師', 1],
  ['煙臺海軍學校', 1],
  ['北洋水師', 2]]}

載入資料集（與第二章相同）

In [21]:
wiki_pages = jsonl_dir_to_df(save_path + "data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
del wiki_pages

Reading and concatenating jsonl files in /content/drive/MyDrive/NCKU-AICUP2023-baseline/data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=296938), Label(value='0 / 296938')…

Transform to id to evidence_map mapping


In [46]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]

        # In case there are less than topk evidence sentences
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad
        concat_claim_evidence = " [SEP] ".join([*claim, *evidence])

        concat = self.tokenizer(
            concat_claim_evidence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        label = LABEL2ID[item["label"]] if "label" in item else -1
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}

        if "label" in item:
            concat_ten["labels"] = torch.tensor(label)

        return concat_ten

驗證函式

In [47]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())

    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

預測函式

In [48]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

In [49]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)

    print(f"Extracting evidence_list for the {mode} mode ...")
    if mode == "eval":
        # extract evidence
        df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
            mapping.get(evi_id, {}).get(str(evi_idx), "")
            for evi_id, evi_idx in x  # for each evidence list
        ][:topk] if isinstance(x, list) else [])
        print(df["evidence_list"][:5])
    else:
        # extract evidence
        df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
            " ".join([  # join evidence
                mapping.get(evi_id, {}).get(str(evi_idx), "")
                for _, _, evi_id, evi_idx in evi_list
            ]) if isinstance(evi_list, list) else ""
            for evi_list in x  # for each evidence list
        ][:1] if isinstance(x, list) else [])

    return df

超參數

In [44]:
#@title  { display-mode: "form" }

MODEL_NAME = "hfl/chinese-electra-180g-base-discriminator"
#MODEL_NAME = "hfl/chinese-electra-base-generator"  #@param {type:"string"}
#MODEL_NAME = "hfl/chinese-bert-wwm"
#MODEL_NAME = "hfl/chinese-lert-base"
#MODEL_NAME = "hfl/chinese-roberta-wwm-ext-large" #colab會爆掉
#TRAIN_BATCH_SIZE = 16  #@param {type:"integer"}
#TEST_BATCH_SIZE = 16  #@param {type:"integer"}
TRAIN_BATCH_SIZE = 32  
TEST_BATCH_SIZE = 32
SEED = 42  #@param {type:"integer"}
#LR = 6e-5  #@param {type:"number"}
#LR = 5e-4 # electra-base-generator
#LR = 5e-4 # chinese bert wwm
#LR = 7e-5 #discriminator
LR = 6e-5

#NUM_EPOCHS = 5  #chinese bert wwm
NUM_EPOCHS = 10
MAX_SEQ_LEN = 256  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
#VALIDATION_STEP = 25
VALIDATION_STEP = 50  #@param {type:"integer"}


In [14]:
OUTPUT_FILENAME = save_path + "submission-1.jsonl"

EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_top{EVIDENCE_TOPK}"
LOG_DIR = save_path + "logs/" + EXP_DIR
CKPT_DIR = save_path + "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

連接陳述句與證據句

In [39]:
if not TRAIN_PKL_FILE.exists():
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=4)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=4)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

In [50]:
train_df

Unnamed: 0,id,label,claim,evidence,predicted_pages,predicted_evidence,evidence_list
0,11393,supports,肖恩·克里斯坦因強姦人類與教唆強姦判刑，但在監獄待了兩年就獲得自由。,"[[[10409, 9397, 肖恩·克里斯坦_(市長), 7], [10409, 9397...","[自由, 肖恩·克里斯坦_(市長), 監獄, 肖恩·梅, 人類, 肖恩·康納利]","[[肖恩·克里斯坦_(市長), 8], [肖恩·克里斯坦_(市長), 7], [肖恩·克里斯...",[他在2004年的也受父親和兄長蘭德爾 （ Randall ） 牽連 ， 被引渡回新西蘭受審...
1,20366,NOT ENOUGH INFO,臺北市立內湖高級中學的水槍大戰是爲了讓畢業生釋放壓力，來迎接接下來的大考挑戰。,"[[[19913, None, None, None]]]","[中學, 畢業生, 臺北, 內湖, 壓力_(醫學), 臺北市立內湖高級中學]","[[壓力_(醫學), 28], [臺北市立內湖高級中學, 45], [臺北市立內湖高級中學,...","[, [PAD], [PAD], [PAD], [PAD]]"
2,5216,refutes,大波士頓是以波士頓領銜的區域，屬於人煙罕見處。,"[[[4375, 4476, 波士頓, 2]]]","[波士頓, 區域, 大波士頓, 處_(佛教)]","[[波士頓, 2], [波士頓, 0], [大波士頓, 0], [波士頓, 1], [波士頓...",[2016年的人口普查結果顯示 ， 以波士頓領銜的大波士頓擁有480萬人口 ， 乃全美第十大...
3,12640,refutes,浙江紹興府書香世家出身的張岱是漢代的歷史人物。,"[[[11568, 10319, 張岱_(明朝), 0]]]","[張岱_(明朝), 紹興府]","[[張岱_(明朝), 0], [張岱_(明朝), 1], [張岱_(明朝), 2], [張岱...",[張岱 ， 字宗子 、 石公 ， 號陶庵 、 蝶庵 ， 浙江紹興府山陰縣人 ， 明末清初作...
4,8699,supports,2010年臺北國際花卉博覽會是臺灣第一個正式獲得國際園藝家協會及國際展覽局認證授權舉辦的A2...,"[[[7918, 7443, 2010年臺北國際花卉博覽會, 0]]]","[國際園藝博覽會, 臺灣, 2010年臺北國際花卉博覽會, 2010年]","[[2010年臺北國際花卉博覽會, 0], [國際園藝博覽會, 0], [2010年臺北國際...",[2010臺北國際花卉博覽會 ， 簡稱臺北花博 、 臺北國際花博 ， 2010年11月6日至...
...,...,...,...,...,...,...,...
9291,5004,supports,2013年6月前Mac OS以真核生物來命名版本名。,"[[[3773, 3923, MacOS, 64], [3773, 3923, 動物, 0]]]","[生物, 真核生物, 2013年6月]","[[真核生物, 3], [生物, 10], [真核生物, 0], [生物, 4], [真核生...",[以前版本的 macOS 以大型貓科動物命名 ， 例如 Mac OS X v 10.8 被稱...
9292,9696,supports,第二次世界大戰過後，中華民國政府接收臺灣並設省份，設置負責南海諸島的接收工作的行政長官公署。,"[[[8745, 8084, 臺灣省行政長官公署, 6]]]","[臺灣, 世界大戰, 政府, 中華民國政府, 中華民國, 第二次世界大戰]","[[第二次世界大戰, 0], [中華民國, 1], [第二次世界大戰, 14], [第二次世...",[第二次世界大戰後 ， 中華民國國民政府從日本手中接管與統治臺灣 、 並設爲省份 ， 但並非...
9293,6114,refutes,粵菜是廣東菜，它不包含南洋香料。,"[[[4865, 4908, 粵菜, 10]]]","[粵菜, 菜]","[[粵菜, 0], [粵菜, 1], [粵菜, 10], [粵菜, 11], [粵菜, 7]]","[近代廣東得益於貿易經濟 ， 還有大量外省移民與南洋香料 ， 進一步提高了食物的豐富度 。,..."
9294,11658,refutes,位於奧地利的格林岑斯總人口1307人，是奧地利東部的一個市鎮。,"[[[10740, 9670, 格林岑斯, 0]]]","[格林岑斯, 奧地利, 東部, 人口, 人, 市鎮]","[[格林岑斯, 1], [格林岑斯, 0], [奧地利, 2], [東部, 6], [奧地利...","[格林岑斯 （ 德語 ： ） 是奧地利蒂羅爾州因斯布魯克蘭縣的一個市鎮 。, [PAD], ..."


In [None]:
#TRAIN_BATCH_SIZE = 45  
#TEST_BATCH_SIZE = 45

In [51]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

In [None]:
#train_df

Unnamed: 0,id,label,claim,evidence,predicted_pages,predicted_evidence,evidence_list
0,661,supports,亞伯拉罕諸教爲基督宗教、伊斯蘭教與猶太教的統稱。,"[[[992, 1028, 亞伯拉罕諸教, 0]]]","[亞伯拉罕, 斯賓塞·亞伯拉罕, F·莫瑞·亞伯拉罕, 亞伯拉罕諸教, 亞伯拉罕·林肯總統圖...","[[亞伯拉罕, 10], [亞伯拉罕諸教, 0], [亞伯拉罕, 0], [斯賓塞·亞伯拉罕...",[亞伯拉罕諸教 ， 又稱亞伯拉罕宗教 、 亞伯拉罕一神諸教 、 天啓宗教 、 天啓諸教 、 ...
1,2368,supports,朱塞佩·威爾第的處女作是《奧貝爾託》，而他寫的第一部喜劇是《一日國王》。,"[[[3147, 3324, 朱塞佩·威爾第, 10]]]","[朱塞佩·威爾第, 處女作, 朱塞佩·馬志尼, 朱塞佩·加里波第, 朱塞佩·西諾波利]","[[朱塞佩·威爾第, 10], [朱塞佩·威爾第, 0], [處女作, 0], [朱塞佩·馬...",[他的處女作是 《 奧貝爾託 》 （ Oberto ） ， 此後他又寫了 《 一日國王 》 ...
2,6866,NOT ENOUGH INFO,中華民國臺南市的主要道路中華東路的年紀是29，屬於臺1線的一部份。,"[[[5724, None, None, None]]]","[臺南市, 中華東路_(臺南市), 道路, 中華民國, 臺1線]","[[中華東路_(臺南市), 0], [中華東路_(臺南市), 2], [臺1線, 0], [...",[]
3,7962,supports,屬於俄羅斯少數民族的楚瓦什人是典型的歐亞混合人種 ，使用楚瓦什語這種以斯拉夫字母爲基礎的拼音...,"[[[7259, 6901, 楚瓦什人, 0], [7259, 6901, 楚瓦什人, 3]...","[人種, 俄羅斯, 楚瓦什人, 人, 少數民族]","[[楚瓦什人, 3], [楚瓦什人, 0], [楚瓦什人, 4], [楚瓦什人, 5], [...",[楚瓦什人 （ чӑваш ） ， 俄羅斯少數民族 。 屬典型的歐亞混合人種 。 使用楚瓦什...
4,8332,refutes,山西人的母親河汾河上游具太原市十分遙遠。,"[[[7481, 7090, 汾河, 3]]]","[河, 上游, 母親, 人, 汾河]","[[汾河, 0], [汾河, 6], [汾河, 1], [汾河, 8], [汾河, 3]]",[汾河在太原境內縱貫北南 ， 全長一百公里 ， 佔到整個汾河的七分之一 。]
...,...,...,...,...,...,...,...
3148,1993,refutes,白話字在由WMF所協作的閩南語維基百科中貢獻率爲零。,"[[[1788, 1928, 白話字, 2]]]","[WMF, 閩南語, 閩南語維基百科, 白話字, 白話]","[[白話字, 2], [閩南語維基百科, 12], [閩南語維基百科, 8], [閩南語維基...",[目前 ， 由維基媒體基金會所協作的閩南語維基百科便是以白話字做爲知識傳遞的文字媒介 ， 爲...
3149,8770,refutes,在聖彼得堡地鐵的一個車站海濱站開通於1700年，位於瓦西里島。,"[[[7731, 7296, 海濱站_(聖彼得堡地鐵), 9]]]","[聖彼得堡地鐵, 地鐵, 海濱站_(聖彼得堡地鐵), 聖彼得堡, 車站]","[[海濱站_(聖彼得堡地鐵), 0], [海濱站_(聖彼得堡地鐵), 4], [聖彼得堡地鐵...",[Category : 1979年啓用的鐵路車站]
3150,7417,refutes,位在俄羅斯的聖彼得堡有超過100萬種的動物，是俄羅斯人口前三大的大城市。,"[[[6705, 6449, 聖彼得堡, 3]]]","[動物, 俄羅斯人口, 俄羅斯, 人口, 聖彼得堡]","[[聖彼得堡, 3], [動物, 2], [聖彼得堡, 0], [聖彼得堡, 19], [俄...",[全市人口約520萬 ， 是俄羅斯人口第二大城市 、 以及世界上的最北端的居民超過100萬人...
3151,164,NOT ENOUGH INFO,颱風屬於一種食物。,"[[[217, None, None, None]]]","[食物, 颱風]","[[颱風, 1], [颱風, 6], [颱風, 0], [颱風, 8], [颱風, 2]]",[]


In [52]:
device = 'cuda'
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABEL2ID),
)
model.to(device)

#要先執行原本的model 找到checkpoint 再執行這四行
#model.load_state_dict(torch.load(save_path + "checkpoints/claim_verification/e10_bs32_7e-05_top5/val_loss_ 3.19_val_acc=0.5228_model.2000.pt"))
#model.to(device)
#LR = 8e-5
#NUM_EPOCHS = 5

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at hfl/chinese-electra-180g-base-discriminator and are newly initial

In [53]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            val_results = run_evaluation(model, eval_dataloader, device)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            save_checkpoint(
                model,
                CKPT_DIR,
                current_steps,
                mark=f"val_loss={val_results['val_loss']:.4f}, val_acc={val_results['val_acc']:.4f}",
            )

print("Finished training!")

  0%|          | 0/2910 [00:00<?, ?it/s]

Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.1360149710145715
val_acc: 0.43029259896729777
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.5049758414699608
val_acc: 0.4294320137693632
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.7350002543566978
val_acc: 0.41910499139414803
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.072584488620497
val_acc: 0.43029259896729777
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.2642198098848945
val_acc: 0.4294320137693632
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.182759745480263
val_acc: 0.4268502581755594
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.4181904253894335
val_acc: 0.43115318416523235
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.4520824053516126
val_acc: 0.4419104991394148
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.4259598549098182
val_acc: 0.4621342512908778
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.504790622894078
val_acc: 0.4629948364888124
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.6944513582203484
val_acc: 0.42986230636833045
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.6383689054071087
val_acc: 0.4625645438898451
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.799901972078297
val_acc: 0.4380378657487091
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7269804804292446
val_acc: 0.4617039586919105
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7055834613434255
val_acc: 0.45051635111876076
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8403383134162588
val_acc: 0.46858864027538727
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7415844332681942
val_acc: 0.4810671256454389
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.0008350055511683
val_acc: 0.4780550774526678
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.034804592393849
val_acc: 0.459552495697074
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.01606529706145
val_acc: 0.4677280550774527
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.135978390092719
val_acc: 0.45266781411359724
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.885288011537839
val_acc: 0.4836488812392427
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.9575106052503193
val_acc: 0.46944922547332185
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.2409800536011995
val_acc: 0.4827882960413081
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.196132086727717
val_acc: 0.4759036144578313
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.2067588975984758
val_acc: 0.47203098106712565
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.1782342963022727
val_acc: 0.47030981067125643
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.355478970971826
val_acc: 0.459552495697074
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.0184973266026747
val_acc: 0.4814974182444062
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.3333196313413853
val_acc: 0.49010327022375216
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.509904002490109
val_acc: 0.46858864027538727
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.4506302304463845
val_acc: 0.47676419965576594
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.4933105298917586
val_acc: 0.45998278829604133
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.3752484043983566
val_acc: 0.4866609294320138
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.347457768165902
val_acc: 0.47332185886402756
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.4138448711943954
val_acc: 0.4823580034423408
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.586080159226509
val_acc: 0.4784853700516351
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.812247681291136
val_acc: 0.4582616179001721
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.6100501132338016
val_acc: 0.4763339070567986
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.546699687226178
val_acc: 0.46858864027538727
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.825158471930517
val_acc: 0.45395869191049915
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.6870799456557184
val_acc: 0.4750430292598967
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.72894522588547
val_acc: 0.47332185886402756
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.5905756754417943
val_acc: 0.4823580034423408


KeyboardInterrupt: ignored

第四部分：製作上傳檔案

In [None]:
TEST_DATA = load_json(save_path + "data/merge_test_doc5sent5.jsonl")
TEST_PKL_FILE = Path(save_path + "data/merge_test_doc5sent5.pkl")

if not TEST_PKL_FILE.exists():
    test_df = join_with_topk_evidence(
        pd.DataFrame(TEST_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    test_df.to_pickle(TEST_PKL_FILE, protocol=4)
else:
    with open(TEST_PKL_FILE, "rb") as f:
        test_df = pickle.load(f)

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

預測

In [None]:
ckpt_name = "val_acc=0.4303_model.300.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
predicted_label = run_predict(model, test_dataloader, device)

Predicting:   0%|          | 0/31 [00:00<?, ?it/s]

寫入檔案

In [None]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    OUTPUT_FILENAME,
    orient="records",
    lines=True,
    force_ascii=False,
)

In [None]:
OUTPUT_FILENAME

'/content/drive/MyDrive/NCKU-AICUP2023-baseline/submission-1.jsonl'