In [1]:
%pip install -r requirements.txt

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Note: you may need to restart the kernel to use updated packages.


In [7]:
# built-in libs
import json
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# 3rd party libs
import hanlp
import opencc
import pandas as pd
import wikipedia
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel

# our own libs
from utils import load_json

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
wikipedia.set_lang("zh")

In [8]:
# PART 1. Document retrieval

# Step 1:使用Constituency Parser 找出 claim 中的的 Noun Phrases(NPs)
# Step 2:從Ｗikipedia API 中取出中取出 NP相符合的頁面名稱
# Step 3:保留出現在句子index最靠前的五篇文章作為相關文章
# Prepare the environment and import all library we need

In [9]:
# # 讀檔案
# import jsonlines

# file_path = 'data/merged.jsonl'

# with jsonlines.open(file_path, 'r') as file:
#     for i, item in enumerate(file.iter()):
#         print(item)
#         if i == 4: 
#             break

In [10]:
TRAIN_DATA = load_json("data/merge_train.jsonl")
TEST_DATA = load_json("data/merge_test_data.jsonl")
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")

In [11]:
@dataclass
class Claim:
    data: str

@dataclass
class AnnotationID:
    id: int

@dataclass
class EvidenceID:
    id: int

@dataclass
class PageTitle:
    title: str

@dataclass
class SentenceID:
    id: int

@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

In [12]:
# 將繁體中文轉成簡體中文
def do_st_corrections(text: str) -> str:
    simplified = CONVERTER_T2S.convert(text)
    return CONVERTER_S2T.convert(simplified)

In [13]:
!pip install hanlp_restful



In [14]:
# 以此提取句子中的名詞
def get_nps_hanlp(
    #pedictor 為HanLP的套件預測器，用來預測給定文本的語法樹，排序句子中詞的優先程度
    predictor: Pipeline,
    d: Dict[str, Union[int, Claim, Evidence]],
) -> List[str]:
    claim = d["claim"]
    tree = predictor(claim)["con"]
    nps = [
        #將claim宣稱的資料從繁體轉為簡體
        do_st_corrections("".join(subtree.leaves()))
        for subtree in tree.subtrees(lambda t: t.label() == "NP")
    ]
    #print(nps)
    return nps

In [15]:
#優化提取函式
# def get_nps_hanlp(
#     predictor: Pipeline,
#     d: Dict[str, Union[int, Claim, Evidence]],
# ) -> Generator[str, None, None]:
#     claim = d["claim"]
#     tree = predictor(claim)["con"]

#     leaves = []

#     for subtree in tree.subtrees(lambda t: t.label() == "NP"):
#         if not leaves:
#             leaves = list(subtree.leaves())
#         else:
#             leaves.extend(subtree.leaves())

#     nps = do_st_corrections("".join(leaves)) if leaves else None

#     yield nps


In [16]:
# 主要要看招回率是否足夠高
def calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")


def calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")

In [17]:
# 預設檢索的文件數量最多為五份。根據您的目標，可以調整這個num_pred_doc數量。
# 以jsonl格式保存數據。
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 5,
) -> None:
    with open(
        f"data/{mode}_doc{num_pred_doc}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

In [18]:
# 主要文件檢索功能

In [19]:
def get_pred_pages(series_data: pd.Series) -> Set[Dict[int, str]]:
    import wikipedia
    import pandas as pd
    from tqdm import tqdm
    import opencc
    import re
    wikipedia.set_lang('zh')
    CONVERTER_T2S = opencc.OpenCC("t2s.json")
    CONVERTER_S2T = opencc.OpenCC("s2t.json")
    def do_st_corrections(text: str) -> str:
        simplified = CONVERTER_T2S.convert(text)
        return CONVERTER_S2T.convert(simplified)
    
    results = []
    tmp_muji = []
    # wiki_page: its index showned in claim
    mapping = {}
    claim = series_data["claim"]
    nps = series_data["hanlp_results"]
    first_wiki_term = []

    for i, np in enumerate(nps):
        # Simplified Traditional Chinese Correction
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        # Remove the wiki page's description in brackets
        wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
        wiki_df = pd.DataFrame({
            "wiki_set": wiki_set,
            "wiki_results": wiki_search_results
        })

        # Elements in wiki_set --> index
        # Extracting only the first element is one way to avoid extracting
        # too many of the similar wiki pages
        grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
        candidates = grouped_df["wiki_results"].tolist()
        # muji refers to wiki_set
        muji = grouped_df.index.tolist()

        for prefix, term in zip(muji, candidates):
            if prefix not in tmp_muji:
                matched = False

                # Take at least one term from the first noun phrase
                if i == 0:
                    first_wiki_term.append(term)

                # Walrus operator :=
                # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
                # Through these filters, we are trying to figure out if the term
                # is within the claim
                if (((new_term := term) in claim) or
                    ((new_term := term.replace("·", "")) in claim) or
                    ((new_term := term.split(" ")[0]) in claim) or
                    ((new_term := term.replace("-", " ")) in claim)):
                    matched = True

                elif "·" in term:
                    splitted = term.split("·")
                    for split in splitted:
                        if (new_term := split) in claim:
                            matched = True
                            break

                if matched:
                    # post-processing
                    term = term.replace(" ", "_")
                    term = term.replace("-", "")
                    results.append(term)
                    mapping[term] = claim.find(new_term)
                    tmp_muji.append(new_term)

    # 5 is a hyperparameter
    if len(results) > 5:
        assert -1 not in mapping.values()
        results = sorted(mapping, key=mapping.get)[:5]
    elif len(results) < 1:
        results = first_wiki_term

    return set(results)

In [20]:
# import wikipedia
# import pandas as pd
# from tqdm import tqdm
# import opencc
# import re
# from typing import List, Dict, Set

# def get_pred_pages(series_data: pd.Series) -> Set[Dict[int, str]]:
#     wikipedia.set_lang('zh')
#     CONVERTER_T2S = opencc.OpenCC("t2s.json")
#     CONVERTER_S2T = opencc.OpenCC("s2t.json")

#     def do_st_corrections(text: str) -> str:
#         simplified = CONVERTER_T2S.convert(text)
#         return CONVERTER_S2T.convert(simplified)

#     def get_filtered_term(term: str, claim: str) -> str:
#         new_term = term
#         if (new_term := term) in claim or \
#                 (new_term := term.replace("·", "")) in claim or \
#                 (new_term := term.split(" ")[0]) in claim or \
#                 (new_term := term.replace("-", " ")) in claim:
#             return new_term
#         elif "·" in term:
#             splitted = term.split("·")
#             for split in splitted:
#                 if (new_term := split) in claim:
#                     return new_term
#         return ""

#     results = []
#     tmp_muji = []
#     mapping = {}
#     claim = series_data["claim"]
#     nps = series_data["hanlp_results"]
#     first_wiki_term = []

#     for i, np in enumerate(nps):
#         wiki_search_results = [do_st_corrections(w) for w in wikipedia.search(np)]
#         wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
#         wiki_df = pd.DataFrame({"wiki_set": wiki_set, "wiki_results": wiki_search_results})
#         grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
#         candidates = grouped_df["wiki_results"].tolist()
#         muji = grouped_df.index.tolist()

#         for prefix, term in zip(muji, candidates):
#             if prefix not in tmp_muji:
#                 matched = False
#                 new_term = get_filtered_term(term, claim)
#                 if new_term:
#                     matched = True
#                 if matched:
#                     term = term.replace(" ", "_").replace("-", "")
#                     results.append(term)
#                     mapping[term] = claim.find(new_term)
#                     tmp_muji.append(new_term)

#     if len(results) > 5:
#         assert -1 not in mapping.values()
#         results = sorted(mapping, key=mapping.get)[:5]
#     elif len(results) < 1:
#         results = first_wiki_term

#     return set(results)

In [21]:
# 第一步：從HanLP的結構解析樹中獲取名詞片語
# 設置HanLP預測器（1分鐘）
# HanLP：面向生產環境的前沿多語種自然語言處理技術
predictor = (hanlp.pipeline().append(
    #原先模型FINE_ELECTRA_SMALL_ZH
    #https://hanlp.hankcs.com/docs/api/hanlp/pretrained/tok.html
    hanlp.load("MSR_TOK_ELECTRA_BASE_CRF"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_ELECTRA_SMALL"),
    output_key="con",
    input_key="tok",
))
print('加載完成')

                                             

加載完成


In [22]:
hanlp_file = f"data/hanlp_con_results.pkl"
# 如果hanlp_file存在的話，將會直接開啟原本存在的檔案
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

In [23]:
doc_path = f"data/train_doc5.jsonl"
# 同上
if Path(doc_path).exists():
    with open(doc_path, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    train_df = pd.DataFrame(TRAIN_DATA)
    train_df.loc[:, "hanlp_results"] = hanlp_results
    predicted_results = train_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TRAIN_DATA, predicted_results, mode="train")

In [24]:
# 計算 precision 和 recall
calculate_precision(TRAIN_DATA, predicted_results)
calculate_recall(TRAIN_DATA, predicted_results)

Precision: 0.246152922818402
Recall: 0.8314037930620573


In [25]:
# 確保兩者長度相同
print(len(TRAIN_DATA))
print(len(predicted_results))

11647
11647


In [26]:
hanlp_test_file = f"data/hanlp_con_test_results.pkl"
# 同上
if Path(hanlp_test_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

In [27]:
test_doc_path = f"data/test_doc5.jsonl"
# 同上
if Path(test_doc_path).exists():
    with open(test_doc_path, "r", encoding="utf8") as f:
        test_results = pd.Series(
            [set(json.loads(line)["predicted_pages"]) for line in f])
else:
    test_df = pd.DataFrame(TEST_DATA)
    test_df.loc[:, "hanlp_results"] = hanlp_results
    test_results = test_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TEST_DATA, test_results, mode="test")

In [28]:
# PART 2. Sentence retrieval

# Step1:從前一步驟找出的相關文章，再進一步抽取出相關句子作為證據句
# Step2:將claim與句子丟入BERT，訓練它做二分類，判斷「證據句/非證據句」
# built-in libs
# import 需要的libs
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset, Dataset

# local libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

In [29]:
# 設置全域變數 Global variable

SEED = 42

TRAIN_DATA = load_json("data/merge_train.jsonl")
TEST_DATA = load_json("data/merge_test_data.jsonl")
DOC_DATA = load_json("data/train_doc5.jsonl")

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
# 0.2
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=_y,
)

In [30]:
# 預先載入wikipedia的資料庫
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=118776), Label(value='0 / 118776')…

Transform to id to evidence_map mapping


In [31]:
# 輔助函式
# 計算句子檢索的精確度 precision

def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py
    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.
    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0
    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()
        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0
        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0
    return 0.0, 0.0

In [32]:
# 輔助函式
# 計算句子檢索的招回率 recall
def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()
        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

In [33]:
# 計算句子檢索的分數
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 5,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]

        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)

    if save_name is not None:
        # write doc7_sent5 file
        with open(f"data/{save_name}", "w", encoding='utf-8') as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim]["predicted_evidence"].tolist()
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

In [34]:
# 推論腳本以獲取候選證據句子的概率
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())

    return np.array(probs)

In [35]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]

        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten

In [36]:
# def pair_with_wiki_sentences(
#     mapping: Dict[str, Dict[int, str]],
#     df: pd.DataFrame,
#     negative_ratio: float,
# ) -> pd.DataFrame:
#     """Only for creating train sentences."""
#     data = []

#     # positive
#     for i in range(len(df)):
#         if df["label"].iloc[i] != "NOT ENOUGH INFO":
#             claim = df["claim"].iloc[i]
#             evidence_sets = df["evidence"].iloc[i]
#             for evidence_set in evidence_sets:
#                 sents = [mapping[evidence[2].replace(" ", "_")][str(evidence[3])] for evidence in evidence_set if evidence[2].replace(" ", "_") != "臺灣海峽危機#第二次臺灣海峽危機（1958）"]
#                 whole_evidence = " ".join(sents)
#                 data.append((claim, whole_evidence, 1))

#     # negative
#     for i in range(len(df)):
#         if df["label"].iloc[i] != "NOT ENOUGH INFO":
#             claim = df["claim"].iloc[i]
#             evidence_set = {(evidence[2], evidence[3]) for evidences in df["evidence"][i] for evidence in evidences}
#             predicted_pages = df["predicted_pages"][i]
#             for page in predicted_pages:
#                 page = page.replace(" ", "_")
#                 try:
#                     page_sent_id_pairs = [(page, sent_idx) for sent_idx in mapping[page].keys()]
#                 except KeyError:
#                     continue

#                 for pair in page_sent_id_pairs:
#                     if pair not in evidence_set:
#                         text = mapping[page][pair[1]]
#                         if text != "" and np.random.rand(1) <= negative_ratio:
#                             data.append((claim, text, 0))

#     return pd.DataFrame(data, columns=["claim", "text", "label"])

# def pair_with_wiki_sentences_eval(
#     mapping: Dict[str, Dict[int, str]],
#     df: pd.DataFrame,
#     is_testset: bool = False,
# ) -> pd.DataFrame:
#     """Only for creating dev and test sentences."""
#     data = []

#     # negative
#     for i in range(len(df)):
#         claim = df["claim"].iloc[i]
#         predicted_pages = df["predicted_pages"][i]
#         for page in predicted_pages:
#             page = page.replace(" ", "_")
#             try:
#                 page_sent_id_pairs = [(page, k) for k in mapping[page]]
#             except KeyError:
#                 continue

#             for page_name, sentence_id in page_sent_id_pairs:
#                 text = mapping[page][sentence_id]
#                 if text != "":
#                     if not is_testset:
#                         evidence = df.loc[i, "evidence"]
#                         data.append((claim, text, evidence))
#                     else:
#                         data.append((claim, text))

#     columns = ["claim", "text"]
#     if not is_testset:
#         columns.append("evidence")

#     return pd.DataFrame(data, columns=columns)

In [37]:
# 原function
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # positive
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]
        for evidence_set in evidence_sets:
            sents = []
            for evidence in evidence_set:
                # evidence[2] is the page title
                page = evidence[2].replace(" ", "_")
                # the only page with weird name
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] is in form of int however, mapping requires str
                sent_idx = str(evidence[3])
                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            claims.append(claim)
            sentences.append(whole_evidence)
            labels.append(1)

    # negative
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue
        claim = df["claim"].iloc[i]

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in df["evidence"][i]
                            for evidence in evidences])
        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for pair in page_sent_id_pairs:
                if pair in evidence_set:
                    continue
                text = mapping[page][pair[1]]
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                if text != "" and np.random.rand(1) <= negative_ratio:
                    claims.append(claim)
                    sentences.append(text)
                    labels.append(0)

    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})

def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue
        claim = df["claim"].iloc[i]

        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for page_name, sentence_id in page_sent_id_pairs:
                text = mapping[page][sentence_id]
                if text != "":
                    claims.append(claim)
                    sentences.append(text)
                    if not is_testset:
                        evidence.append(df.loc[i, "evidence"])
                    #evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

In [38]:
torch.cuda.empty_cache()

In [71]:
#第一個模型 
# Sentence Retrieval

#@title  { display-mode: "form" }

MODEL_NAME = "sijunhe/nezha-base-wwm"  #@param {type:"string"}
NUM_EPOCHS = 1  #@param {type:"integer"}
LR = 6e-5  #@param {type:"number"}
TRAIN_BATCH_SIZE = 32  #@param {type:"integer"}
TEST_BATCH_SIZE = 256  #@param {type:"integer"}
# Negative_ration 正負例子的比例
NEGATIVE_RATIO = 0.031  #@param {type:"number"}
VALIDATION_STEP = 150  #@param {type:"integer"}
# TOP_N 候選證據句子數量，指定每個輸入樣本預測的候選證據句子數量。
# 較大的值表示保留更多的候選證據句子，但可能增加計算成本。
TOP_N = 5  #@param {type:"integer"}

In [72]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_0601_1148{MODEL_NAME}{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

In [73]:
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)
# print(TRAIN_GT)
counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    7217
1    7079
Name: label, dtype: int64


In [74]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, user_faset=True)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(
    val_dataset, 
    batch_size=TEST_BATCH_SIZE,
)

In [75]:
del train_df

In [76]:
# 開啟GPU運算
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

In [78]:
# 訓練跑進度條
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(val_results)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(model, CKPT_DIR, current_steps)

print("Finished training!")

  0%|          | 0/447 [00:00<?, ?it/s]

Start validation


  0%|          | 0/307 [00:00<?, ?it/s]

{'F1 score': 0.3724839309687392, 'Precision': 0.24785042991401243, 'Recall': 0.749250149970006}
Start validation


  0%|          | 0/307 [00:00<?, ?it/s]

{'F1 score': 0.37388905218895857, 'Precision': 0.2485702859428067, 'Recall': 0.7540491901619676}
Finished training!


In [79]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 23172), started 6 days, 14:58:13 ago. (Use '!kill 23172' to kill it.)

In [80]:
# 驗證訓練結果
ckpt_name = "model.300.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
    save_name=f"train_doc5sent{TOP_N}.jsonl",
)
print(f"Training scores => {train_results}")

print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
    save_name=f"dev_doc5sent{TOP_N}.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/1230 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.3888014888167915, 'Precision': 0.26218237294919955, 'Recall': 0.7519507803121248}
Start validation


  0%|          | 0/307 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.37388905218895857, 'Precision': 0.2485702859428067, 'Recall': 0.7540491901619676}


In [81]:
test_data = load_json("data/test_doc5.jsonl")

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer)
test_dataloader = DataLoader(test_set, batch_size=TEST_BATCH_SIZE)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name=f"test_doc5sent{TOP_N}.jsonl",
)

Start predicting the test data


  0%|          | 0/1163 [00:00<?, ?it/s]

In [82]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm

import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

In [83]:
# 全域變數設置
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json("data/train_doc5sent5.jsonl")
DEV_DATA = load_json("data/dev_doc5sent5.jsonl")

TRAIN_PKL_FILE = Path("data/train_doc5sent5.pkl")
DEV_PKL_FILE = Path("data/dev_doc5sent5.pkl")

In [84]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=296938), Label(value='0 / 296938')…

Transform to id to evidence_map mapping


In [93]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]

        # In case there are less than topk evidence sentences
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad
        concat_claim_evidence = " [SEP] ".join([*claim, *evidence])

        concat = self.tokenizer(
            concat_claim_evidence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        label = LABEL2ID[item["label"]] if "label" in item else -1
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}

        if "label" in item:
            concat_ten["labels"] = torch.tensor(label)

        return concat_ten

In [94]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())

    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

In [95]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

In [96]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)

    print(f"Extracting evidence_list for the {mode} mode ...")
    if mode == "eval":
        # extract evidence
        df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
            mapping.get(evi_id, {}).get(str(evi_idx), "")
            for evi_id, evi_idx in x  # for each evidence list
        ][:topk] if isinstance(x, list) else [])
        print(df["evidence_list"][:5])
    else:
        # extract evidence
        df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
            " ".join([  # join evidence
                mapping.get(evi_id, {}).get(str(evi_idx), "")
                for _, _, evi_id, evi_idx in evi_list
            ]) if isinstance(evi_list, list) else ""
            for evi_list in x  # for each evidence list
        ][:1] if isinstance(x, list) else [])

    return df

In [97]:
#第二個模型 
# Claim Verification
#將證據和 claim 丟入 模型判斷正確錯誤或資訊缺乏

#@title  { display-mode: "form" }
# bert-base-chinese 原本模型
MODEL_NAME = "sijunhe/nezha-base-wwm"  #@param {type:"string"}
TRAIN_BATCH_SIZE = 32  #@param {type:"integer"}
TEST_BATCH_SIZE = 32  #@param {type:"integer"}
SEED = 1024  #@param {type:"integer"}
LR = 5e-5  #@param {type:"number"}
NUM_EPOCHS = 30  #@param {type:"integer"}
MAX_SEQ_LEN = 256  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
VALIDATION_STEP = 50  #@param {type:"integer"}

In [98]:
OUTPUT_FILENAME = "submission.jsonl"
# 原本 _bs{TRAIN_BATCH_SIZE}_
EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_test50_1223_new{MODEL_NAME}_{TRAIN_BATCH_SIZE}_{TEST_BATCH_SIZE}" + f"{LR}_top{EVIDENCE_TOPK}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

In [99]:
if not TRAIN_PKL_FILE.exists():
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=4)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=4)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

In [100]:
torch.cuda.empty_cache()
while torch.cuda.memory_allocated() > 5:
      pass
print("GPU memory is fully cleared.")

KeyboardInterrupt: 

In [101]:
from transformers import (
  BertTokenizerFast,
  AutoModel,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)



train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

In [102]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")
model = AutoModelForSequenceClassification.from_pretrained(

    MODEL_NAME,
    num_labels=len(LABEL2ID),
)
model.to(device)
model = nn.DataParallel(model)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

In [103]:
# from torch.cuda.amp import GradScaler, autocast

# scaler = GradScaler()

# progress_bar = tqdm(range(num_training_steps))
# current_steps = 0

# for epoch in range(NUM_EPOCHS):
#     model.train()

#     for batch in train_dataloader:
#         batch = {k: v.to(device) for k, v in batch.items()}

#         with autocast():
#             outputs = model(**batch)
#             loss = outputs.loss

#         scaler.scale(loss).backward()

#         scaler.step(optimizer)
#         scaler.update()
#         optimizer.zero_grad()

#         progress_bar.update(1)
#         writer.add_scalar("training_loss", loss.item(), current_steps)

#         y_pred = torch.argmax(outputs.logits, dim=1).tolist()
#         y_true = batch["labels"].tolist()

#         current_steps += 1

#         if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
#             print("Start validation")
#             val_results = run_evaluation(model, eval_dataloader, device)

#             # log each metric separately to TensorBoard
#             for metric_name, metric_value in val_results.items():
#                 print(f"{metric_name}: {metric_value}")
#                 writer.add_scalar(f"{metric_name}", metric_value, current_steps)

#             save_checkpoint(
#                 model,
#                 CKPT_DIR,
#                 current_steps,
#                 mark=f"val_acc={val_results['val_acc']:.7f}",
#             )

# print("Finished training!")


In [104]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            val_results = run_evaluation(model, eval_dataloader, device)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            save_checkpoint(
                model,
                CKPT_DIR,
                current_steps,
                mark=f"val_acc={val_results['val_acc']:.7f}",
            )

print("Finished training!")

  0%|          | 0/8760 [00:00<?, ?it/s]

Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.0841463902225232
val_acc: 0.4270386266094421
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.1347969574471042
val_acc: 0.4313304721030043
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.2339141956747395
val_acc: 0.43261802575107294
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.42541016454566
val_acc: 0.43047210300429184
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.542861226486833
val_acc: 0.4236051502145923
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.7586792299192247
val_acc: 0.42832618025751074
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.8471560584355706
val_acc: 0.4369098712446352
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.852574698729058
val_acc: 0.47296137339055794
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.036146578723437
val_acc: 0.4493562231759657
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 1.9359408617019653
val_acc: 0.48025751072961376
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.179202272467417
val_acc: 0.4682403433476395
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.0374499020511156
val_acc: 0.48412017167381977
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.3178424051363176
val_acc: 0.463519313304721
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.209486752340238
val_acc: 0.4781115879828326
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.3445157972100663
val_acc: 0.47639484978540775
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.3142970424808866
val_acc: 0.4854077253218884
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.479110934962965
val_acc: 0.47296137339055794
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.464578388488456
val_acc: 0.4918454935622318
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7198348224979556
val_acc: 0.4742489270386266
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.595059786757378
val_acc: 0.4901287553648069
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8140985345187253
val_acc: 0.45107296137339054
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.6404261033828944
val_acc: 0.496137339055794
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.634589730876766
val_acc: 0.4918454935622318
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.6182128403284777
val_acc: 0.492274678111588
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.699944695381269
val_acc: 0.4957081545064378
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8466658918824916
val_acc: 0.48626609442060087
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.883146931047309
val_acc: 0.4648068669527897
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7772392181500996
val_acc: 0.4738197424892704
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.717853221174789
val_acc: 0.49527896995708154
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7134608552880484
val_acc: 0.49141630901287553
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8270367775877863
val_acc: 0.496137339055794
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.9335894960246676
val_acc: 0.4957081545064378
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8655326235784244
val_acc: 0.5051502145922747
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.9609367406531555
val_acc: 0.4927038626609442
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.0767986937744976
val_acc: 0.48068669527896996
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.013879281200775
val_acc: 0.4858369098712446
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.040995330026705
val_acc: 0.4957081545064378
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.0021217963466906
val_acc: 0.4995708154506438
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.891991373610823
val_acc: 0.47339055793991414
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.7113682946113693
val_acc: 0.5034334763948498
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.050610259787677
val_acc: 0.4918454935622318
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.06031704112275
val_acc: 0.49914163090128755
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.8965122928358102
val_acc: 0.4982832618025751
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.888505380447597
val_acc: 0.4939914163090129
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.1024085854830807
val_acc: 0.4918454935622318
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 2.9438475958288532
val_acc: 0.5025751072961373
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.04124827091008
val_acc: 0.5025751072961373
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.148745407796886
val_acc: 0.4939914163090129
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.271685347165147
val_acc: 0.4939914163090129
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.275198653952716
val_acc: 0.49699570815450644
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.129294609370297
val_acc: 0.5008583690987124
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.232289250582865
val_acc: 0.5111587982832618
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.44342029911198
val_acc: 0.49871244635193135
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.5277858041737176
val_acc: 0.4832618025751073
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.497945586295977
val_acc: 0.48669527896995707
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.4374903865056496
val_acc: 0.49742489270386264
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.2584721633832747
val_acc: 0.5025751072961373
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.3095914941944486
val_acc: 0.4957081545064378
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.552575735196675
val_acc: 0.5004291845493563
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.6763370037078857
val_acc: 0.488412017167382
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.8309535082072426
val_acc: 0.47510729613733904
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.5186003805839854
val_acc: 0.49699570815450644
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.696123867818754
val_acc: 0.48454935622317596
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.3922827096834576
val_acc: 0.48454935622317596
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.6285011245779795
val_acc: 0.48884120171673817
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.663491916983095
val_acc: 0.49699570815450644
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.550302420576958
val_acc: 0.4896995708154506
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.6403930334195698
val_acc: 0.49141630901287553
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.702002526962594
val_acc: 0.49141630901287553
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.5918655085237057
val_acc: 0.4995708154506438
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.7811729271117955
val_acc: 0.49527896995708154
Start validation


  0%|          | 0/73 [00:00<?, ?it/s]

val_loss: 3.7480041588822455
val_acc: 0.49699570815450644


RuntimeError: [enforce fail at ..\caffe2\serialize\inline_container.cc:325] . unexpected pos 726114368 vs 726114256

In [None]:
TEST_DATA = load_json("data/test_doc5sent5.jsonl")
TEST_PKL_FILE = Path("data/test_doc5sent5.pkl")

if not TEST_PKL_FILE.exists():
    test_df = join_with_topk_evidence(
        pd.DataFrame(TEST_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    test_df.to_pickle(TEST_PKL_FILE, protocol=4)
else:
    with open(TEST_PKL_FILE, "rb") as f:
        test_df = pickle.load(f)

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

In [None]:
ckpt_name = "val_acc=0.5266094_model.6500.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
predicted_label = run_predict(model, test_dataloader, device)

In [None]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    OUTPUT_FILENAME,
    orient="records",
    lines=True,
    force_ascii=False,
)

In [None]:
torch.cuda.empty_cache()
while torch.cuda.memory_allocated() > 5:
    pass
print("GPU memory is fully cleared.")