## PART 1. Document retrieval

In [6]:
# original import
import json
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union
# 3rd party libs
import hanlp
import opencc
import pandas as pd
import wikipedia
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel
# our own libs
from utils import load_json, generate_evidence_to_wiki_pages_mapping

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
wikipedia.set_lang("zh")

In [7]:
# HW3 import
from pathlib import Path
from functools import partial
import re
import numpy as np
import pandas as pd
import jieba
import scipy

jieba.set_dictionary("dict.txt.big")
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from pandarallel import pandarallel
# Adjust the number of workers if you want
pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

from tqdm import tqdm
tqdm.pandas()

from hw3_utils import (jsonl_dir_to_df, calculate_precision, calculate_recall)
from TCSP import read_stopwords_list
stopwords = read_stopwords_list()

In [3]:
# Preload the data.

# TRAIN_DATA = load_json("data/public_train_0522.jsonl")
# TEST_DATA = load_json("data/private_test_data.jsonl")
# CONVERTER_T2S = opencc.OpenCC("t2s.json")
# CONVERTER_S2T = opencc.OpenCC("s2t.json")

In [4]:
# Data class for type hinting

@dataclass
class Claim:
    data: str
@dataclass
class AnnotationID:
    id: int
@dataclass
class EvidenceID:
    id: int
@dataclass
class PageTitle:
    title: str
@dataclass
class SentenceID:
    id: int
@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

In [5]:
def tokenize(text: str, stopwords: list) -> str:
    # 用jieba把整段句子拆成token list
    tokens = list(jieba.cut(text))
    result = " ".join([word for word in tokens if word not in stopwords])
    return result

The default amount of documents retrieved is at most five documents.  This `num_pred_doc` can be adjusted based on your objective.  Save data in jsonl format.

In [6]:
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 5,
) -> None:
    with open(
        f"data/{mode}_doc{num_pred_doc}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

In [7]:
def get_pred_docs_sklearn_sentence(
    claim: str,
    tokenizing_method: callable,
    vectorizer: TfidfVectorizer,
    topk: int,
) -> set:
    global wiki_sentences
    tokens = tokenizing_method(claim)
    claim_vector = vectorizer.transform([tokens])
    # TODO: Write your code here
    similarity_scores = cosine_similarity(claim_vector , X)
    # flatten the array
    similarity_scores = similarity_scores[0, :]

    # Sort the similarity scores in descending order
    # TODO: Write your code here
    sorted_indices = similarity_scores.argsort()[::-1]
    topk_sorted_indices = sorted_indices[:1000]
    results = []
    for idx in topk_sorted_indices:
        real_id = wiki_sentences.iloc[idx]['id']
        if(real_id not in results):
            results.append(real_id)
            if(len(results) == topk):
                break
            
    exact_matchs = []
    Count = 0
    for i,result in enumerate(results):
        if (
            (result in claim)
            or (result in claim.replace(" ", ""))
            or (result.replace("·", "") in claim)
            or (result.replace("-", "") in claim)
        ):
            exact_matchs.append(result)
        elif "·" in result:
            splitted = result.split("·")
            for split in splitted:
                if split in claim:
                    exact_matchs.append(result)
                    break
        elif "_(" in result:
            splitted = result.split("_(")
            splitted[1] = splitted[1][:-1]
            
            for split in splitted:
                if(")" in split):
                    split = split[:-1]
                    
                if split in claim:
                    exact_matchs.append(result)
                    break
            
    return set(exact_matchs)

In [8]:
# First time running this cell will 34 minutes using Google Colab.
wiki_path = "data/wiki-pages"
wiki_cache = "wiki"
target_column = "text"

wiki_cache_path = Path(f"data/{wiki_cache}.pkl")
if wiki_cache_path.exists():
    wiki_pages = pd.read_pickle(wiki_cache_path)
else:
    wiki_pages = jsonl_dir_to_df(wiki_path)
    wiki_pages = wiki_pages.reset_index(drop=True)
    wiki_pages["processed_text"] = wiki_pages[target_column].progress_apply(
        partial(tokenize, stopwords=stopwords)
    )
    # save the result to a pickle file
    wiki_pages.to_pickle(wiki_cache_path, protocol=4)

In [9]:
mapping_path = Path(f"data/mapping_document_retrieval.json")
if mapping_path.exists():
    mapping = json.load( open( "data/mapping_document_retrieval.json" ) )
else:
    mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
    json.dump( mapping, open( "data/mapping_document_retrieval.json", 'w' ) )

## 只取大於min_wiki_length的wiki_pages

In [10]:
# 參數
min_wiki_length = 100
min_sentence_length = 15
num_of_samples = 500
topk = 10
use_idf = True
sublinear_tf = True

wiki_pages = wiki_pages[wiki_pages['processed_text'].str.len() > min_wiki_length]
len(wiki_pages)

781253

## 產生wiki_sentences.pkl

In [11]:
path = Path("data/wiki_sentence.pkl")
if path.exists():
    wiki_sentences = pd.read_pickle("data/wiki_sentence.pkl")  
else:
    data = {'id': [], 'idx': [], 'text': []}
    wiki_sentences = []
    for i in tqdm(range(len(wiki_pages))):
        id = wiki_pages.iloc[i]['id']
        for sentence in mapping[id].values():
            if(sentence != ''):
                dic = {'id':id, 'idx':int(i), 'text':sentence}
                wiki_sentences.append(dic)
    wiki_sentences = pd.DataFrame(wiki_sentences)
    del wiki_pages
    wiki_sentences["processed_text"] = wiki_sentences['text'].progress_apply(
        partial(tokenize, stopwords=stopwords)
    )
    wiki_sentences.to_pickle('data/wiki_sentence.pkl')

In [12]:
len(wiki_sentences)

3732467

## TF-IDF(HW3)

從wiki_sentences中取長度大於min_sentence_length的

In [13]:
wiki_sentences = wiki_sentences[ wiki_sentences['processed_text'].str.len() >= min_sentence_length]

In [14]:
TRAIN_DATA = load_json("data/public_train.jsonl")
doc_path = f"data/train_doc5.jsonl"
TRAIN_GT = pd.DataFrame(TRAIN_DATA)

CPU times: total: 125 ms
Wall time: 124 ms


### 將public_train.json的資料放入wiki_sentences

In [15]:
train_list = []
for i,row in TRAIN_GT.iterrows():
    if(row['label'] == 'NOT ENOUGH INFO'):
        continue
    wiki_names = []
    evidence_sets = row['evidence']
    for sets in evidence_sets:
        for one_set in sets:
            if(one_set[2] not in wiki_names):
                wiki_names.append(one_set[2])
    
    claim = tokenize(row['claim']  , stopwords)
    if len(wiki_names) > 0:
        for name in wiki_names:
            dic = {'id':name, 'idx':int(i), 'text':row['claim'], 'processed_text':claim}
            train_list.append(dic)
        
    # 整理成dictionary並新增到wiki_sentences中

Building prefix dict from D:\FDA\AICUP2023-baseline\dict.txt.big ...
Loading model from cache C:\Users\User\AppData\Local\Temp\jieba.u9fc27dfe134194f973eec0c33d27bf6b.cache
Loading model cost 2.222 seconds.
Prefix dict has been built successfully.


In [16]:
wiki_sentences = pd.concat([wiki_sentences , pd.DataFrame(train_list)])
wiki_sentences = wiki_sentences.drop(['idx'], axis=1)
wiki_sentences = wiki_sentences.reset_index(drop=True)
wiki_sentences

Unnamed: 0,id,text,processed_text
0,數學,數學 ， 是研究數量 、 結構以及空間等概念及其變化的一門學科 ， 屬於形式科學的一種 。,數學 研究 數量 結構 空間 概念 變化 一門 學科 屬於 形式 科學
1,數學,基礎數學的知識與運用是生活中不可或缺的一環 。,基礎 數學 知識 運用 生活 不可或缺 一環
2,數學,對數學基本概念的完善 ， 早在古埃及 、 美索不達米亞及古印度歷史上的古代數學文本便可觀見 ...,數學 基本概念 完善 早 古埃及 美索不達米亞 古印度 歷史 古代 數學 文...
3,數學,從那時開始 ， 數學的發展便持續不斷地小幅進展 ， 至16世紀的文藝復興時期 ， 因爲新的科...,數學 發展 便 持續 不斷 小幅 進展 16 世紀 文藝復興 時期 ...
4,數學,數學併成爲許多國家及地區的教育中的一部分 。,數學 併 成爲 國家 地區 教育 一部分
...,...,...,...
3398032,京畿道,亞洲有超過一半的人口在京畿道。,亞洲 超過 一半 人口 京畿道
3398033,亞洲,亞洲有超過一半的人口在京畿道。,亞洲 超過 一半 人口 京畿道
3398034,軟件測試,一種用來抵銷鑑定軟體的過程稱作軟件測試，測試的定義之一是爲了評估而質疑產品的過程。,抵銷 鑑定 軟體 過程 稱作 軟件測試 測試 定義 爲了 評估 質疑 產品 過程
3398035,福山城_(備後國),備後國的福山城現今位於日本廣島縣，曾在戰爭中遭到攻擊。,備後國 福 山城 現今 位於 日本 廣島 縣 戰爭 遭到 攻擊


In [17]:
corpus = wiki_sentences['processed_text'].tolist()

In [18]:
# Build the TfidfVectorizer
# TODO: Write your code here
vectorizer = TfidfVectorizer(use_idf=use_idf, sublinear_tf=sublinear_tf, stop_words = stopwords)

In [19]:
X = vectorizer.fit_transform(corpus)

### training

In [20]:
train = load_json("data/public_train.jsonl")

# encode the corpus with TF-IDF
train_df = pd.DataFrame(train)

# prediction
train_df["predicted_pages"] = train_df["claim"].progress_apply(
    partial(
        get_pred_docs_sklearn_sentence,
        tokenizing_method=partial(tokenize, stopwords=stopwords),
        vectorizer=vectorizer,
        topk=topk,
    )
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3969/3969 [4:58:21<00:00,  4.51s/it]

CPU times: total: 4h 57min 57s
Wall time: 4h 58min 21s





In [21]:
precision = calculate_precision(train, train_df["predicted_pages"])
recall = calculate_recall(train, train_df["predicted_pages"])
dictionary = {'topk':topk,
              'num_of_samples':num_of_samples,
              'min_sentence_length':min_sentence_length,
              'precision':precision,
              'recall':recall}
print(dictionary)

Precision: 0.7159799107142859
Recall: 0.8962968750000002
{'topk': 10, 'num_of_samples': 500, 'min_sentence_length': 15, 'precision': 0.7159799107142859, 'recall': 0.8962968750000002}


In [22]:
save_doc(train, train_df["predicted_pages"], mode="train")

### prediction

In [23]:
test = load_json("data/all_test_data.jsonl")

test_df = pd.DataFrame(test)
# prediction
test_df["predicted_pages"] = test_df["claim"].progress_apply(
    partial(
        get_pred_docs_sklearn_sentence,
        tokenizing_method=partial(tokenize, stopwords=stopwords),
        vectorizer=vectorizer,
        topk=topk,
    )
)
save_doc(test, test_df["predicted_pages"], mode="test")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9038/9038 [11:17:17<00:00,  4.50s/it]


## 下面是修改TOP N版(舊版)的document retrieval

### Helper function

For the sake of consistency, we convert traditional to simplified Chinese first before converting it back to traditional Chinese.  This is due to some errors occuring when converting traditional to traditional Chinese.

In [24]:
# def do_st_corrections(text: str) -> str:
#     simplified = CONVERTER_T2S.convert(text)

#     return CONVERTER_S2T.convert(simplified)

We use constituency parsing to separate part of speeches or so called constituent to extract noun phrases.  In the later stages, we will use the noun phrases as the query to search for relevant documents.  

In [25]:
# def get_nps_hanlp(
#     predictor: Pipeline,
#     d: Dict[str, Union[int, Claim, Evidence]],
# ) -> List[str]:
#     claim = d["claim"]
#     tree = predictor(claim)["con"]
#     nps = [
#         do_st_corrections("".join(subtree.leaves()))
#         for subtree in tree.subtrees(lambda t: t.label() == "NP")
#     ]

#     return nps

Precision refers to how many related documents are retrieved.  Recall refers to how many relevant documents are retrieved.  

In [26]:
# def calculate_precision(
#     data: List[Dict[str, Union[int, Claim, Evidence]]],
#     predictions: pd.Series,
# ) -> None:
#     precision = 0
#     count = 0

#     for i, d in enumerate(data):
#         if d["label"] == "NOT ENOUGH INFO":
#             continue

#         # Extract all ground truth of titles of the wikipedia pages
#         # evidence[2] refers to the title of the wikipedia page
#         gt_pages = set([
#             evidence[2]
#             for evidence_set in d["evidence"]
#             for evidence in evidence_set
#         ])

#         predicted_pages = predictions.iloc[i]
#         hits = predicted_pages.intersection(gt_pages)
#         if len(predicted_pages) != 0:
#             precision += len(hits) / len(predicted_pages)

#         count += 1

#     # Macro precision
#     print(f"Precision: {precision / count}")


# def calculate_recall(
#     data: List[Dict[str, Union[int, Claim, Evidence]]],
#     predictions: pd.Series,
# ) -> None:
#     recall = 0
#     count = 0

#     for i, d in enumerate(data):
#         if d["label"] == "NOT ENOUGH INFO":
#             continue

#         gt_pages = set([
#             evidence[2]
#             for evidence_set in d["evidence"]
#             for evidence in evidence_set
#         ])
#         predicted_pages = predictions.iloc[i]
#         hits = predicted_pages.intersection(gt_pages)
#         recall += len(hits) / len(gt_pages)
#         count += 1

#     print(f"Recall: {recall / count}")

### Main function for document retrieval

In [27]:
# def get_pred_pages(series_data: pd.Series) -> Set[Dict[int, str]]:
#     results = []
#     tmp_muji = []
#     # wiki_page: its index showned in claim
#     mapping = {}
#     claim = series_data["claim"]
#     nps = series_data["hanlp_results"]
#     first_wiki_term = []

#     for i, np in enumerate(nps):
#         # Simplified Traditional Chinese Correction
#         wiki_search_results = [
#             do_st_corrections(w) for w in wikipedia.search(np)
#         ]

#         # Remove the wiki page's description in brackets
#         wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
#         wiki_df = pd.DataFrame({
#             "wiki_set": wiki_set,
#             "wiki_results": wiki_search_results
#         })

#         # Elements in wiki_set --> index
#         # Extracting only the first element is one way to avoid extracting
#         # too many of the similar wiki pages
#         grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
#         candidates = grouped_df["wiki_results"].tolist()
#         # muji refers to wiki_set
#         muji = grouped_df.index.tolist()

#         for prefix, term in zip(muji, candidates):
#             if prefix not in tmp_muji:
#                 matched = False

#                 # Take at least one term from the first noun phrase
#                 if i == 0:
#                     first_wiki_term.append(term)

#                 # Walrus operator :=
#                 # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
#                 # Through these filters, we are trying to figure out if the term
#                 # is within the claim
#                 if (((new_term := term) in claim) or
#                     ((new_term := term.replace("·", "")) in claim) or
#                     ((new_term := term.split(" ")[0]) in claim) or
#                     ((new_term := term.replace("-", " ")) in claim)):
#                     matched = True

#                 elif "·" in term:
#                     splitted = term.split("·")
#                     for split in splitted:
#                         if (new_term := split) in claim:
#                             matched = True
#                             break

#                 if matched:
#                     # post-processing
#                     term = term.replace(" ", "_")
#                     term = term.replace("-", "")
#                     results.append(term)
#                     mapping[term] = claim.find(new_term)
#                     tmp_muji.append(new_term)

#     # 5 is a hyperparameter
#     if len(results) > 5:
#         assert -1 not in mapping.values()
#         results = sorted(mapping, key=mapping.get)[:5]
#         # results = sorted(mapping, key=mapping.get)[:5]
#     elif len(results) >=3 and len(results) <= 5:
#         assert -1 not in mapping.values()
#         results = sorted(mapping, key=mapping.get)[:3]
#     elif len(results) < 1:
#         results = first_wiki_term

#     return set(results)

### Step 1. Get noun phrases from hanlp consituency parsing tree

Setup [HanLP](https://github.com/hankcs/HanLP) predictor (1 min)

In [28]:
# predictor = (hanlp.pipeline().append(
#     hanlp.load("FINE_ELECTRA_SMALL_ZH"),
#     output_key="tok",
# ).append(
#     hanlp.load("CTB9_CON_ELECTRA_SMALL"),
#     output_key="con",
#     input_key="tok",
# ))

We will skip this process which for creating parsing tree when demo on class

In [29]:
# hanlp_file = f"data/hanlp_con_results.pkl"
# if Path(hanlp_file).exists():
#     with open(hanlp_file, "rb") as f:
#         hanlp_results = pickle.load(f)
# else:
#     hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
#     with open(hanlp_file, "wb") as f:
#         pickle.dump(hanlp_results, f)

Get pages via wiki online api

In [31]:
# doc_path = f"data/train_doc5.jsonl"
# if Path(doc_path).exists():
#     with open(doc_path, "r", encoding="utf8") as f:
#         predicted_results = pd.Series([
#             set(json.loads(line)["predicted_pages"])
#             for line in f
#         ])
# else:
#     train_df = pd.DataFrame(TRAIN_DATA)
#     train_df.loc[:, "hanlp_results"] = hanlp_results
#     predicted_results = train_df.apply(get_pred_pages, axis=1)
#     save_doc(TRAIN_DATA, predicted_results, mode="train")

### Step 2. Calculate our results

In [32]:
# calculate_precision(TRAIN_DATA, predicted_results)
# calculate_recall(TRAIN_DATA, predicted_results)

### Step 3. Repeat the same process on test set
Create parsing tree

In [33]:
# hanlp_test_file = f"data/hanlp_con_test_results.pkl"
# if Path(hanlp_test_file).exists():
#     with open(hanlp_file, "rb") as f:
#         hanlp_results = pickle.load(f)
# else:
#     hanlp_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
#     with open(hanlp_file, "wb") as f:
#         pickle.dump(hanlp_results, f)

In [34]:
# test_doc_path = f"data/test_doc5.jsonl"
# if Path(test_doc_path).exists():
#     with open(test_doc_path, "r", encoding="utf8") as f:
#         test_results = pd.Series(
#             [set(json.loads(line)["predicted_pages"]) for line in f])
# else:
#     test_df = pd.DataFrame(TEST_DATA)
#     test_df.loc[:, "hanlp_results"] = hanlp_results
#     test_results = test_df.parallel_apply(get_pred_pages, axis=1)
#     save_doc(TEST_DATA, test_results, mode="test")

## PART 2. Sentence retrieval

Import some libs

In [21]:
# built-in libs
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
    AutoConfig
)

from dataset import BERTDataset, Dataset

# local libs
from utils import (generate_evidence_to_wiki_pages_mapping, jsonl_dir_to_df, load_json, load_model, save_checkpoint, set_lr_scheduler)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

Global variable

In [36]:
SEED = 42

TRAIN_DATA = load_json("data/public_train.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
DOC_DATA = load_json("data/train_doc5.jsonl")

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=_y,
)

Preload wiki database (1 min)

In [37]:
mapping_path = Path("data/mapping_sentence_retrieval.json")
if mapping_path.exists():
    mapping = json.load( open( "data/mapping_sentence_retrieval.json" ) )
else:
    wiki_pages = jsonl_dir_to_df("data/wiki-pages")
    mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
    json.dump( mapping, open( "data/mapping_sentence_retrieval.json", 'w' ) )
    del wiki_pages

# OLD VERSION
# wiki_pages = jsonl_dir_to_df("data/wiki-pages")
# mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
# wiki_pages
# del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping
Transform to id to evidence_map mapping


### Helper function

Calculate precision for sentence retrieval

In [38]:
def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0

    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0

        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

    return 0.0, 0.0

Calculate recall for sentence retrieval

In [39]:
def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        claim = instance["claim"]

        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

Calculate the scores of sentence retrieval

In [40]:
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 5,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]

        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)

    if save_name is not None:
        # write doc7_sent5 file
        with open(f"data/{save_name}", "w", encoding="utf-8") as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim]["predicted_evidence"].tolist()
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

Inference script to get probabilites for the candidate evidence sentences

In [41]:
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())

    return np.array(probs)

AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences

In [42]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]

        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten

### Main function for sentence retrieval

In [43]:
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # positive
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]
        for evidence_set in evidence_sets:
            sents = []
            for evidence in evidence_set:
                # evidence[2] is the page title
                page = evidence[2].replace(" ", "_")
                # the only page with weird name
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] is in form of int however, mapping requires str
                sent_idx = str(evidence[3])
                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            claims.append(claim)
            sentences.append(whole_evidence)
            labels.append(1)

    # negative
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue
        claim = df["claim"].iloc[i]

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in df["evidence"][i]
                            for evidence in evidences])
        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for pair in page_sent_id_pairs:
                if pair in evidence_set:
                    continue
                text = mapping[page][pair[1]]
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                if text != "" and np.random.rand(1) <= negative_ratio:
                    claims.append(claim)
                    sentences.append(text)
                    labels.append(0)

    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})


def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue
        claim = df["claim"].iloc[i]

        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for page_name, sentence_id in page_sent_id_pairs:
                text = mapping[page][sentence_id]
                if text != "":
                    claims.append(claim)
                    sentences.append(text)
                    if not is_testset:
                        evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

### Step 1. Setup training environment

Hyperparams

In [69]:
MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
NUM_EPOCHS = 10  #@param {type:"integer"}
LR = 2e-5  #@param {type:"number"}
TRAIN_BATCH_SIZE = 8  #@param {type:"integer"}
TEST_BATCH_SIZE = 8  #@param {type:"integer"}
MAX_SEQ_LEN =  256  #@param {type:"integer"}
NEGATIVE_RATIO = 0.1  #@param {type:"number"}
VALIDATION_STEP = 100  #@param {type:"integer"}
TOP_N = 5  #@param {type:"integer"}

Experiment Directory

In [70]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Combine claims and evidences

In [71]:
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)
counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    4397
1    2774
Name: label, dtype: int64


### Step 3. Start training

Dataloader things

In [72]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer, max_length=MAX_SEQ_LEN)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer, max_length=MAX_SEQ_LEN)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

In [73]:
del train_df

Trainer

In [74]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Please make sure that you are using gpu when training (5 min)

In [75]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(val_results)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(model, CKPT_DIR, current_steps)

print("Finished training!")

  0%|          | 0/8970 [00:00<?, ?it/s]

Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.41048323729701325, 'Precision': 0.28958333333333247, 'Recall': 0.7046875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45497366193208405, 'Precision': 0.319895833333332, 'Recall': 0.7875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4634112906933956, 'Precision': 0.32489583333333183, 'Recall': 0.8078125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4665389607432355, 'Precision': 0.32645833333333174, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.465261124281935, 'Precision': 0.3252083333333317, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4641794120340673, 'Precision': 0.3248958333333317, 'Recall': 0.8125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46590039187095444, 'Precision': 0.32583333333333175, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45965037871789977, 'Precision': 0.3214583333333318, 'Recall': 0.80625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46347540045766433, 'Precision': 0.32395833333333174, 'Recall': 0.8140625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46672695575140666, 'Precision': 0.3261458333333317, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4646879852041266, 'Precision': 0.3248958333333317, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4476575150014273, 'Precision': 0.31395833333333195, 'Recall': 0.7796875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4656461033035209, 'Precision': 0.3258333333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4646211568287549, 'Precision': 0.3245833333333317, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4620746735331969, 'Precision': 0.3233333333333318, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4646211568287549, 'Precision': 0.3245833333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46915228650262186, 'Precision': 0.32802083333333176, 'Recall': 0.8234375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4632865548724521, 'Precision': 0.32427083333333173, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46576623169955284, 'Precision': 0.3252083333333318, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4649412281101772, 'Precision': 0.3248958333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46672695575140666, 'Precision': 0.3261458333333317, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46672695575140666, 'Precision': 0.3261458333333317, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45543544457977936, 'Precision': 0.31958333333333194, 'Recall': 0.7921875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4667932508641061, 'Precision': 0.3264583333333318, 'Recall': 0.81875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45965037871789993, 'Precision': 0.32145833333333196, 'Recall': 0.80625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.459715126225262, 'Precision': 0.3217708333333319, 'Recall': 0.8046875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4609894582486591, 'Precision': 0.32302083333333187, 'Recall': 0.8046875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.464621156828755, 'Precision': 0.32458333333333184, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46583409194355785, 'Precision': 0.3255208333333318, 'Recall': 0.81875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46551402294663835, 'Precision': 0.32520833333333177, 'Recall': 0.81875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4646879852041266, 'Precision': 0.3248958333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46150084054900364, 'Precision': 0.32302083333333187, 'Recall': 0.8078125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46755197481655775, 'Precision': 0.3264583333333317, 'Recall': 0.8234375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4646211568287549, 'Precision': 0.3245833333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46583409194355774, 'Precision': 0.32552083333333176, 'Recall': 0.81875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46608664772727126, 'Precision': 0.3255208333333318, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4666591032855311, 'Precision': 0.32583333333333175, 'Recall': 0.821875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4669106317411386, 'Precision': 0.32583333333333175, 'Recall': 0.8234375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46608664772727126, 'Precision': 0.3255208333333318, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46640688903026284, 'Precision': 0.32583333333333175, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643009102938477, 'Precision': 0.3242708333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46658969854940907, 'Precision': 0.3255208333333318, 'Recall': 0.8234375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.464621156828755, 'Precision': 0.32458333333333184, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46544564080407336, 'Precision': 0.32489583333333183, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46608664772727126, 'Precision': 0.3255208333333318, 'Recall': 0.8203125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4665896985494091, 'Precision': 0.32552083333333176, 'Recall': 0.8234375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46430091029384774, 'Precision': 0.32427083333333184, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46430091029384774, 'Precision': 0.32427083333333184, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46430091029384774, 'Precision': 0.32427083333333184, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46111654394845697, 'Precision': 0.32239583333333194, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46226263645536964, 'Precision': 0.32302083333333187, 'Recall': 0.8125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46200940428072607, 'Precision': 0.32302083333333187, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4616896765597707, 'Precision': 0.32270833333333193, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46200940428072607, 'Precision': 0.32302083333333187, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46290226460071354, 'Precision': 0.3236458333333318, 'Recall': 0.8125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46150084054900364, 'Precision': 0.3230208333333318, 'Recall': 0.8078125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4606079833871698, 'Precision': 0.32239583333333194, 'Recall': 0.80625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.460288958641062, 'Precision': 0.32208333333333194, 'Recall': 0.80625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46111654394845686, 'Precision': 0.3223958333333319, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4616896765597707, 'Precision': 0.32270833333333193, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46372829067641536, 'Precision': 0.32395833333333185, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46111654394845697, 'Precision': 0.32239583333333194, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45971512622526217, 'Precision': 0.321770833333332, 'Recall': 0.8046875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.45996975713362126, 'Precision': 0.32177083333333195, 'Recall': 0.80625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46175547327752603, 'Precision': 0.323020833333332, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46054323776159173, 'Precision': 0.322083333333332, 'Recall': 0.8078125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46290226460071365, 'Precision': 0.32364583333333186, 'Recall': 0.8125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4649412281101772, 'Precision': 0.32489583333333183, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46436826237894974, 'Precision': 0.32458333333333184, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46340804151046755, 'Precision': 0.32364583333333186, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46200940428072607, 'Precision': 0.32302083333333187, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46258253851797365, 'Precision': 0.3233333333333319, 'Recall': 0.8125}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46436826237894974, 'Precision': 0.32458333333333184, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4623936977107643, 'Precision': 0.32364583333333186, 'Recall': 0.809375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4641146738633229, 'Precision': 0.3245833333333318, 'Recall': 0.8140625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4649412281101772, 'Precision': 0.3248958333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4623289558269799, 'Precision': 0.32333333333333186, 'Recall': 0.8109375}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4637951249084904, 'Precision': 0.3242708333333318, 'Recall': 0.8140625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46347540045766433, 'Precision': 0.32395833333333185, 'Recall': 0.8140625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643009102938477, 'Precision': 0.3242708333333318, 'Recall': 0.8171875}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46404836425111784, 'Precision': 0.3242708333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46404836425111784, 'Precision': 0.3242708333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.46404836425111784, 'Precision': 0.3242708333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

{'F1 score': 0.4643682623789496, 'Precision': 0.3245833333333318, 'Recall': 0.815625}
Finished training!


In [79]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Validation part (15 mins)

In [80]:
ckpt_name = "model.1700.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer, max_length=MAX_SEQ_LEN)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
#     save_name="train_doc5sent.jsonl",
#     save_name="train_doc5sent"+str(TOP_N)+".jsonl",
    save_name=f"train_doc5sent{TOP_N}.jsonl",
)
print(f"Training scores => {train_results}")

print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
#     save_name="dev_doc5sent.jsonl",
#     save_name="dev_doc5sent"+str(TOP_N)+".jsonl",
    save_name=f"dev_doc5sent{TOP_N}.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/6515 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.48269295870007345, 'Precision': 0.34375000000000583, 'Recall': 0.81015625}
Start validation


  0%|          | 0/1734 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.46915228650262186, 'Precision': 0.32802083333333176, 'Recall': 0.8234375}


### Step 4. Check on our test data
(5 min)

In [81]:
test_data = load_json("data/test_doc5.jsonl")

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer, max_length=MAX_SEQ_LEN)
test_dataloader = DataLoader(test_set, batch_size=TEST_BATCH_SIZE)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name=f"test_doc5sent{TOP_N}.jsonl",
)

Start predicting the test data


  0%|          | 0/16608 [00:00<?, ?it/s]

notebook3
## PART 3. Claim verification

In [54]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm

import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

Global variables

In [55]:
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json("data/train_doc5sent5.jsonl")
DEV_DATA = load_json("data/dev_doc5sent5.jsonl")

TRAIN_PKL_FILE = Path("data/train_doc5sent5.pkl")
DEV_PKL_FILE = Path("data/dev_doc5sent5.pkl")

In [56]:
mapping_path = Path(f"data/mapping_claim_verification.json")
if mapping_path.exists():
    mapping = json.load( open( "data/mapping_claim_verification.json" ) )
else:
    wiki_pages = jsonl_dir_to_df("data/wiki-pages")
    mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
    json.dump( mapping, open( "data/mapping_claim_verification.json", 'w' ) )
    del wiki_pages

# OLD VERSION
# wiki_pages = jsonl_dir_to_df("data/wiki-pages")
# mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
# del wiki_pages

### Helper function

AICUP dataset with top-k evidence sentences.

In [57]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]

        # In case there are less than topk evidence sentences
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad
        concat_claim_evidence = " [SEP] ".join([*claim, *evidence])

        concat = self.tokenizer(
            concat_claim_evidence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        label = LABEL2ID[item["label"]] if "label" in item else -1
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}

        if "label" in item:
            concat_ten["labels"] = torch.tensor(label)

        return concat_ten

Evaluation function

In [58]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())

    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

Prediction

In [59]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

### Main function

In [60]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)

    print(f"Extracting evidence_list for the {mode} mode ...")
    if mode == "eval":
        # extract evidence
        df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
            mapping.get(evi_id, {}).get(str(evi_idx), "")
            for evi_id, evi_idx in x  # for each evidence list
        ][:topk] if isinstance(x, list) else [])
        print(df["evidence_list"][:5])
    else:
        # extract evidence
        df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
            " ".join([  # join evidence
                mapping.get(evi_id, {}).get(str(evi_idx), "")
                for _, _, evi_id, evi_idx in evi_list
            ]) if isinstance(evi_list, list) else ""
            for evi_list in x  # for each evidence list
        ][:1] if isinstance(x, list) else [])

    return df

### Step 1. Setup training environment

Hyperparams

In [61]:
#@title  { display-mode: "form" }

MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
TRAIN_BATCH_SIZE = 8  #@param {type:"integer"}
TEST_BATCH_SIZE = 8  #@param {type:"integer"}
SEED = 42  #@param {type:"integer"}
LR = 1e-6  #@param {type:"number"}
NUM_EPOCHS = 10  #@param {type:"integer"}
MAX_SEQ_LEN = 256  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
VALIDATION_STEP = 100  #@param {type:"integer"}


Experiment Directory

In [62]:
OUTPUT_FILENAME = "submission.jsonl"

EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_top{EVIDENCE_TOPK}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Concat claim and evidences
join topk 
evidence

In [63]:
if not TRAIN_PKL_FILE.exists():
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=4)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=4)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

### Step 3. Training

Prevent CUDA out of memory

In [64]:
torch.cuda.empty_cache()

In [65]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

In [66]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config = AutoConfig.from_pretrained(MODEL_NAME, num_labels=len(LABEL2ID))
config.hidden_dropout_prob = 0.2
config.attention_probs_dropout_prob = 0.2
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(LABEL2ID))
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Training (30 mins)

In [67]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            val_results = run_evaluation(model, eval_dataloader, device)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            save_checkpoint(
                model,
                CKPT_DIR,
                current_steps,
                mark=f"val_acc={val_results['val_acc']:.4f}",
            )

print("Finished training!")

  0%|          | 0/3970 [00:00<?, ?it/s]

Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.1724776935577392
val_acc: 0.20025188916876574
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.064797071814537
val_acc: 0.40428211586901763
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.046289986371994
val_acc: 0.40554156171284633
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.0956793999671937
val_acc: 0.40554156171284633
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.1900182938575745
val_acc: 0.40428211586901763
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.2913585525751115
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.3349124109745025
val_acc: 0.4105793450881612
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.3362459594011307
val_acc: 0.41183879093198994
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.3941460812091828
val_acc: 0.41183879093198994
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.4036365890502929
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.410123153924942
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.4645028388500214
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.4789185154438018
val_acc: 0.41435768261964734
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.4765325695276261
val_acc: 0.42065491183879095
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.493186427950859
val_acc: 0.4168765743073048
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.4931908667087554
val_acc: 0.4105793450881612
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.546320687532425
val_acc: 0.4105793450881612
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5661044442653655
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.570606762766838
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5527435302734376
val_acc: 0.41435768261964734
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5805345606803893
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5866148585081101
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5903266471624375
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5775149899721146
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.618773390650749
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.5940221613645553
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6452614277601243
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.632832238972187
val_acc: 0.4105793450881612
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6393212151527405
val_acc: 0.4105793450881612
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6515728795528413
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.649721218943596
val_acc: 0.4093198992443325
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6572446432709693
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6624558275938035
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6715858888626098
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6756228053569793
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6788299867510796
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6795055854320526
val_acc: 0.4068010075566751
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.6832346433401109
val_acc: 0.4080604534005038
Start validation


  0%|          | 0/100 [00:00<?, ?it/s]

val_loss: 1.684688812494278
val_acc: 0.4068010075566751
Finished training!


### Step 4. Make your submission

In [68]:
TEST_DATA = load_json("data/test_doc5sent5.jsonl")
TEST_PKL_FILE = Path("data/test_doc5sent5.pkl")

if not TEST_PKL_FILE.exists():
    test_df = join_with_topk_evidence(
        pd.DataFrame(TEST_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    test_df.to_pickle(TEST_PKL_FILE, protocol=4)
else:
    with open(TEST_PKL_FILE, "rb") as f:
        test_df = pickle.load(f)

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

Extracting evidence_list for the eval mode ...


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=2260), Label(value='0 / 2260'))), …

0    [顯微鏡泛指將微小不可見或難見物品之影像放大 ， 而能被肉眼或其他成像儀器觀察之工具 。, ...
1    [許多昆蟲被認爲是對生態有益的捕食者 ， 少數昆蟲提供直接的經濟利益 。, 蠶產絲 ， 蜜蜂...
2    [綠山城縣  ， 是波蘭的縣份 ， 位於該國西部 ， 由盧布斯卡省負責管轄 ， 首府設於綠山...
3    [《 魂斷藍橋 》 （ Waterloo Bridge ） 是美國黑白電影 ， 由米高梅電影...
4    [2015年以 《 刺客聶隱娘 》 獲得第68屆坎城影展最佳導演獎及第52屆金馬獎最佳導演獎...
Name: evidence_list, dtype: object


Prediction

In [70]:
ckpt_name = "val_acc=0.4207_model.1400.pt"  #@param {type:"string"}

model = load_model(model, ckpt_name, CKPT_DIR)
predicted_label = run_predict(model, test_dataloader, device)

Predicting:   0%|          | 0/1130 [00:00<?, ?it/s]

Write files

In [71]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    OUTPUT_FILENAME,
    orient="records",
    lines=True,
    force_ascii=False,
)