# Baseline
python: 3.8.*

use ```Ctrl + ]``` to collapse all section :)

### Download our starter pack (3~5 min)

In [None]:
!gdown 1Xq2Fv6UGA1pc25pF0qwEc_l7Fa5jPP6p

In [None]:
!gdown --folder 1T6jpOtdf_i6XNYA6F_lqU4mRRh1xYPcl
!mv baseline/* ./

In [None]:
!gdown --folder 1hnVYEgN-gYzFCeBZo8cbKjGLBP-YTnTW

In [None]:
%pip install -r requirements.txt

## PART 1. Document retrieval

Prepare the environment and import all library we need

In [1]:
import json
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union
from functools import partial

# 3rd party libs
import hanlp
import opencc
import pandas as pd
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel

# our own libs
from utils import load_json
from hw3_utils import jsonl_dir_to_df

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

In [2]:
from TCSP import read_stopwords_list

stopwords = read_stopwords_list()

Preload the data.

In [3]:
TRAIN_DATA = load_json("data/public_train_0522.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")

Data class for type hinting

In [4]:
@dataclass
class Claim:
    data: str

@dataclass
class AnnotationID:
    id: int

@dataclass
class EvidenceID:
    id: int

@dataclass
class PageTitle:
    title: str

@dataclass
class SentenceID:
    id: int

@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

### Helper function

For the sake of consistency, we convert traditional to simplified Chinese first before converting it back to traditional Chinese.  This is due to some errors occuring when converting traditional to traditional Chinese.

In [5]:
def do_st_corrections(text: str) -> str:
    simplified = CONVERTER_T2S.convert(text)

    return CONVERTER_S2T.convert(simplified)

We use constituency parsing to separate part of speeches or so called constituent to extract noun phrases.  In the later stages, we will use the noun phrases as the query to search for relevant documents.  

In [6]:
def get_nps_hanlp(
    predictor: Pipeline,
    d: Dict[str, Union[int, Claim, Evidence]],
) -> List[str]:
    claim = d["claim"]
    tree = predictor(claim)["con"]
    nps = [
        do_st_corrections("".join(subtree.leaves()))
        for subtree in tree.subtrees(lambda t: t.label() == "NP")
    ]

    return nps

Precision refers to how many related documents are retrieved.  Recall refers to how many relevant documents are retrieved.  

In [5]:
def calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> float:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")
    return precision / count


def calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> float:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")
    return recall / count

In [6]:
def calculate_f1(precision: float, recall: float) -> float:
    return 2*(precision*recall)/(precision+recall)

The default amount of documents retrieved is at most five documents.  This `num_pred_doc` can be adjusted based on your objective.  Save data in jsonl format.

In [42]:
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    suffix: str = "",
    num_pred_doc: int = 5,
    col_name = "predicted_pages"
) -> None:
    with open(
        f"data/{mode}_doc{num_pred_doc}{suffix}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d[col_name] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

In [10]:
import jieba
jieba.set_dictionary('data/jieba_dict/dict.txt.big')
jieba.initialize()

Building prefix dict from d:\VSCodeProject\Gaber_AICUP2023\data\jieba_dict\dict.txt.big ...
Dumping model to file cache C:\Users\GABERI~1\AppData\Local\Temp\jieba.ueab88e364d2cf84fdc7d374c3dcfbc37.cache
Loading model cost 1.072 seconds.
Prefix dict has been built successfully.


In [11]:
def tokenize(text: str, stopwords: list) -> str:
    tokens = jieba.cut(text)

    return " ".join([w for w in tokens if w not in stopwords])

In [12]:
wiki_path = "data/wiki-pages"
min_wiki_length = 10
topk = 50
min_df = 1
max_df = 0.8
use_idf = True
sublinear_tf = True

In [13]:
wiki_cache = "wiki"
target_column = "text"

wiki_cache_path = Path(f"data/{wiki_cache}.pkl")
if wiki_cache_path.exists():
    wiki_pages = pd.read_pickle(wiki_cache_path)
else:
    def text_split(line: str) -> list:
        import re
        line = re.sub(r"[0-9]+\t", "", line)
        lines = line.split("\n")
        lines = list(filter(None, lines))
        return lines
    # You need to download `wiki-pages.zip` from the AICUP website
    wiki_pages = jsonl_dir_to_df(wiki_path)
    # wiki_pages are combined into one dataframe, so we need to reset the index
    wiki_pages = wiki_pages.reset_index(drop=True)

    # tokenize the text and keep the result in a new column `processed_text`
    wiki_pages["lines"] = wiki_pages["lines"].parallel_apply(text_split)
    wiki_pages["processed_text"] = wiki_pages[target_column].parallel_apply(
        partial(tokenize, stopwords=stopwords)
    )
    # save the result to a pickle file
    wiki_pages.to_pickle(wiki_cache_path, protocol=4)

In [14]:
wiki_pages = wiki_pages[
    wiki_pages['processed_text'].str.len() > min_wiki_length
]

### Tfidf Init

In [15]:
corpus = wiki_pages["processed_text"].tolist()

In [16]:
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

In [17]:
import gensim.models
w2vmodel = gensim.models.Word2Vec.load("models/w2v.zh.300/word2vec.model")
w2v = dict(zip(w2vmodel.wv.index_to_key, w2vmodel.wv.vectors))

In [18]:
from collections import defaultdict
class TfidfEmbeddingVectorizer(object):
    def __init__(self, word2vec, size=300):
        self.word2vec = word2vec
        self.word2weight = None
        self.dim = size
    
    def fit(self, X):
        tfidf = TfidfVectorizer(
            min_df=min_df,
            max_df=max_df,
            use_idf=use_idf,
            sublinear_tf=sublinear_tf,
            dtype=np.float64,
            analyzer=lambda x: x
        )
        tfidf.fit(X)
        # if a word was never seen - it must be at least as infrequent
        # as any of the known words - so the default idf is the max of 
        # known idf's
        max_idf = max(tfidf.idf_)
        self.word2weight = defaultdict(
            lambda: max_idf,
            [(w, tfidf.idf_[i]) for w, i in tfidf.vocabulary_.items()])

        return self

    def transform(self, X):
        return np.array([
                np.mean([self.word2vec[w] * self.word2weight[w]
                         for w in words if w in self.word2vec] or
                        [np.zeros(self.dim)], axis=0)
                for words in X
            ])

In [19]:
# vectorizer = TfidfEmbeddingVectorizer(w2v)

In [17]:
vectorizer = TfidfVectorizer(
    min_df=min_df,
    max_df=max_df,
    use_idf=use_idf,
    sublinear_tf=sublinear_tf,
    # dtype=np.float64,
    ngram_range=(1,2),
    # norm=None,
)

In [18]:
# X = vectorizer.fit(corpus).transform(corpus)
X = vectorizer.fit_transform(corpus)

### Sentence BERT

In [25]:
from sentence_transformers import SentenceTransformer, util

sbert_model = SentenceTransformer('uer/sbert-base-chinese-nli', device='cuda')

No sentence-transformers model found with name /home/P78081057/.cache/torch/sentence_transformers/uer_sbert-base-chinese-nli. Creating a new one with MEAN pooling.


In [26]:
pool = sbert_model.start_multi_process_pool()
print(pool)

{'input': <multiprocessing.queues.Queue object at 0x7fdad5f70a30>, 'output': <multiprocessing.queues.Queue object at 0x7fda2210cbb0>, 'processes': [<SpawnProcess name='SpawnProcess-49' pid=1527826 parent=1527156 started daemon>, <SpawnProcess name='SpawnProcess-50' pid=1527842 parent=1527156 started daemon>]}


### Main function for document retrieval

In [21]:
def get_pred_pages(
        series_data: pd.Series, 
        ) -> Set[Dict[int, str]]:
    import wikipedia
    import re
    import opencc
    import pandas as pd

    from TCSP import read_stopwords_list
    stopwords = read_stopwords_list()

    import numpy as np
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity

    wikipedia.set_lang("zh")
    CONVERTER_T2S = opencc.OpenCC("t2s.json")
    CONVERTER_S2T = opencc.OpenCC("s2t.json")
    
    def do_st_corrections(text: str) -> str:
        simplified = CONVERTER_T2S.convert(text)
        return CONVERTER_S2T.convert(simplified)

    results = []
    tmp_muji = []
    # wiki_page: its index showned in claim
    mapping = {}
    claim = series_data["claim"]
    nps = series_data["hanlp_results"]
    first_wiki_term = []
    repeated_mention = []
    quote_search = []

    def clean_claim(claim) -> str:     # Clean claim function because hanlp has error when conducting cons
        def multiple_replacer(*kv):
            replace_dict = dict(kv)
            replace_func = lambda match: replace_dict[match.group(0)]
            pattern = re.compile("|".join([re.escape(k) for k, v in kv]), re.M)
            return lambda string: pattern.sub(replace_func, string) 
        def multiple_replace(string, *kv):
            return multiple_replacer(*replace_dict)(claim)

        replace_dict = (" ", ""), ("牠", "它"), ("（", "("), ("）", ")"), ("，", ","), ("、", ","), ("群", "羣"), ("“", "\""), ("”", "\""), ("「", "“"), ("」", "”")
        claim = multiple_replace(claim, *replace_dict)
        claim = claim.lower()
        return claim

    claim = clean_claim(claim)

    def post_processing(np, page, loc):
        page = do_st_corrections(page)
        page = page.replace(" ", "_")
        page = page.replace("-", "")
        search_pos = claim.find(np)
        if search_pos != -1:
            if page in results:
                repeated_mention.append(page)
                # results.insert(0, results.pop(results.index(page)))     # Fresh page to front if it was mention before
            else:
                results.append(page)
            if loc == 0:
                pass
                # print(f"Add: {page}, at page direct search, np={np}")
            elif loc == 1:
                pass
                # print(f"Add: {page}, at match, new term={np}")
            mapping[page] = search_pos
            tmp_muji.append(np)

    def if_page_exists(page: str) -> bool:
        import requests
        url_base = "https://zh.wikipedia.org/wiki/"
        new_url = [url_base + page, url_base + page.upper()]
        for url in new_url:
            r = requests.head(url)
            if r.status_code == 200:
                return True
            else:
                continue
        return False
    
    def clean_time_format(np: str):
        if (matched := re.search(r"\d+年", np)) != None:
            return True
        if (matched := re.search(r"\d+月\d+日", np)) != None:
            return True
        if (matched := re.search(r"\d+小時", np)) != None:
            return True
        if (matched := re.search(r"\d+天", np)) != None:
            return True
        if (matched := re.search(r"\d+世紀", np)) != None:
            return True
        if (matched := re.search(r"\d+年代", np)) != None:
            return True
        return False
    
    def tokenize(text: str, stopwords: list) -> str:
        import jieba
        """This function performs Chinese word segmentation and removes stopwords.

        Args:
            text (str): claim or wikipedia article
            stopwords (list): common words that contribute little to the meaning of a sentence

        Returns:
            str: word segments separated by space (e.g. "我 喜歡 吃 蘋果")
        """

        tokens = jieba.cut(text)

        return " ".join([w for w in tokens if w not in stopwords])

    for i, np in enumerate(nps):
        # print(f"searching {np}")
        quote_dup = False
        if np in stopwords:         # 如果包含停用詞
            continue
        if clean_time_format(np):   # 如果包含時間
            continue
        
        # Ignore parsing among quotation mark, for example, if《仲夏夜之夢》exists, ignore「仲夏夜」and「夢」
        for search in quote_search:
            if search.find(np) != -1:
                quote_dup = True
        if quote_dup == True:
            continue

        # Delete Bookname Mark, Quote Mark
        np_no_quote = re.sub(r"《|》|〈|〉|【|】|「|」|『|』|（|）", "", np)
        if np != np_no_quote:
            quote_search.append(np_no_quote)
            np = np_no_quote

        # Simplified Traditional Chinese Correction
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        # Directly Search by Redirection
            # Check if a page exists
        if (if_page_exists(np)):
            try:
                page = do_st_corrections(wikipedia.page(title=np).title)
                if page == np:
                    # print(f"Found, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
                else:
                    # print(f"Redirect, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
            except wikipedia.DisambiguationError as diserr:
                page = do_st_corrections(wikipedia.search(np)[0])
                if page == np:
                    # print(f"Disambig, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
            except wikipedia.PageError as pageerr:
                pass

        # Remove the wiki page's description in brackets
        wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
        wiki_df = pd.DataFrame({
            "wiki_set": wiki_set,
            "wiki_results": wiki_search_results
        })

        # Elements in wiki_set --> index
        # Extracting only the first element is one way to avoid extracting
        # too many of the similar wiki pages
        grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
        candidates = grouped_df["wiki_results"].tolist()
        # muji refers to wiki_set
        muji = grouped_df.index.tolist()

        for prefix, term in zip(muji, candidates):
            if prefix not in tmp_muji:  #忽略掉括號，如果括號有重複的話。假設如果有" 1 (數字)", 則"1 (符號)" 會被忽略
                matched = False

                # Take at least one term from the first noun phrase
                if i == 0:
                    first_wiki_term.append(term)

                # try:
                #     print(term)
                #     term_idx = wiki_pages.index[wiki_pages['id'] == do_st_corrections(term.replace(" ", "_").replace("-", ""))].tolist()[0]
                #     processed_tokens = wiki_pages['processed_text'][term_idx]
                #     processed_text_vector = vectorizer.transform([processed_tokens])
                #     sim_score = cosine_similarity(processed_text_vector, claim_vector)[0][0]
                #     if sim_score > 0.25: # 0.25 is hyperparam
                #         score_mapping[term] = sim_score
                #         print(sim_score, term)
                # except IndexError:
                #     pass
                # except wikipedia.DisambiguationError:
                #     pass
                # except wikipedia.PageError:
                #     pass

                # Walrus operator :=
                # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
                # Through these filters, we are trying to figure out if the term
                # is within the claim
                if (((new_term := term) in claim) or
                    ((new_term := term) in claim.replace(" ", "")) or
                    ((new_term := term.replace("·", "")) in claim) or                                   # 過濾人名
                    ((new_term := re.sub(r"\s\(\S+\)", "", term)) in claim) or                          # 過濾空格 / 消歧義
                    ((new_term := term.replace("(", "").replace(")", "").split()[0]) in claim and       # 消歧義與括號內皆有在裡面
                     (new_term := term.replace("(", "").replace(")", "").split()[1]) in claim) or
                    ((new_term := term.replace("-", " ")) in claim) or                                  # 過濾槓號
                    ((new_term := term.lower()) in claim) or                                            # 過濾大小寫
                    ((new_term := term.lower().replace("-", "")) in claim) or                           # 過濾大小寫及槓號
                    ((new_term := re.sub(r"\s\(\S+\)", "", term.lower().replace("-", ""))) in claim)    # 過濾大小寫、槓號及消歧義
                    ):
                    matched = True
                    # print(new_term, term)

                # 人名匹配
                elif "·" in term:
                    splitted = term.split("·")
                    if "·" not in claim:        # 要求claim顯示的不為全名，不然都需要全名
                        for split in splitted:
                            if (new_term := split) in claim:
                                matched = True
                                break

                if matched:
                    post_processing(np=new_term, page=term, loc=1)

    # score_results = sorted(score_mapping, key=score_mapping.get)[:-5]

    # 8 is a hyperparameter
    if len(results) > 8:
        assert -1 not in mapping.values()
        # print("長度大於8", results)

        results = repeated_mention + sorted(mapping, key=mapping.get)[:8]
        results = list(set(results))            # remove duplicates
        # print("排序後", results)
    if len(results) < 1:
        results = first_wiki_term
        # print("第一搜尋結果", results)
    
    print(results)
    return set(results)

In [19]:
def get_pred_pages_search(
        series_data: pd.Series, 
        ):
    import wikipedia
    import re
    import opencc
    import pandas as pd

    import numpy as np

    wikipedia.set_lang("zh")
    CONVERTER_T2S = opencc.OpenCC("t2s.json")
    CONVERTER_S2T = opencc.OpenCC("s2t.json")
    
    def do_st_corrections(text: str) -> str:
        simplified = CONVERTER_T2S.convert(text)
        return CONVERTER_S2T.convert(simplified)
    
    def if_page_exists(page: str) -> bool:
        import requests
        url_base = "https://zh.wikipedia.org/wiki/"
        new_url = [url_base + page, url_base + page.upper()]
        for url in new_url:
            r = requests.head(url)
            if r.status_code == 200:
                return True
            else:
                continue
        return False

    claim = series_data["claim"]
    results = []
    direct_results = []
    nps = series_data["hanlp_results"]
    nps.append(claim)

    def post_processing(page):
        page = do_st_corrections(page)
        page = page.replace(" ", "_")
        page = page.replace("-", "")

    for i, np in enumerate(nps):
        # print(f"searching {np}")

        if (if_page_exists(np)):
            try:
                page = do_st_corrections(wikipedia.page(title=np).title)
                if page == np:
                    # print(f"Found, np={np}, page={page}, claim={claim}")
                    post_processing(page)
                    direct_results.append(page)
                else:
                    # print(f"Redirect, np={np}, page={page}, claim={claim}")
                    post_processing(page)
                    direct_results.append(page)
            except wikipedia.DisambiguationError as diserr:
                for option in diserr.options:
                    option = do_st_corrections(option)
                    if new_option := re.sub(r"\s\(\S+\)", "", option) in claim:
                        # print(f"Disambig, np={np}, page={option}, claim={claim}")
                        post_processing(option)
                        direct_results.append(option)
                    post_processing(option)
                    results.append(option)
                page = do_st_corrections(wikipedia.search(np)[0])
                if page == np:
                    # print(f"Disambig, np={np}, page={page}, claim={claim}")
                    post_processing(page)
                    direct_results.append(page)
            except wikipedia.PageError as pageerr:
                pass

        # Simplified Traditional Chinese Correction
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        for term in wiki_search_results:
            if (((new_term := term) in claim) or
                ((new_term := term) in claim.replace(" ", "")) or
                ((new_term := term.replace("·", "")) in claim) or                                   # 過濾人名
                ((new_term := re.sub(r"\s\(\S+\)", "", term)) in claim) or                          # 過濾空格 / 消歧義
                ((new_term := term.replace("(", "").replace(")", "").split()[0]) in claim and       # 消歧義與括號內皆有在裡面
                    (new_term := term.replace("(", "").replace(")", "").split()[1]) in claim) or
                ((new_term := term.replace("-", " ")) in claim) or                                  # 過濾槓號
                ((new_term := term.lower()) in claim) or                                            # 過濾大小寫
                ((new_term := term.lower().replace("-", "")) in claim) or                           # 過濾大小寫及槓號
                ((new_term := re.sub(r"\s\(\S+\)", "", term.lower().replace("-", ""))) in claim)    # 過濾大小寫、槓號及消歧義
                ):
                post_processing(term)
                direct_results.append(term)
            # if prefix not in tmp_muji:  #忽略掉括號，如果括號有重複的話。假設如果有" 1 (數字)", 則"1 (符號)" 會被忽略
            post_processing(term)
            results.append(term)

    direct_results = list(set(direct_results))
    results = list(set(results))            # remove duplicates
    series_data["predicted_pages"] = results
    series_data["direct_match"] = direct_results

    return series_data

In [40]:
def get_pred_pages_sbert(
    series_data: pd.Series, 
    tokenizing_method: callable,
    # model: SentenceTransformer,
    # wiki_pages: pd.DataFrame,
    topk: int,
    threshold: float,
    i: int,
) -> set:
    # Disable huggingface tokenizor parallelism warning
    import os
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

    import torch.cuda as cuda
    cuda.empty_cache()
    
    # Parameters:
    THRESHOLD_LOWEST = 0.6
    THRESHOLD_SIM_LINE = threshold
    WEIGHT_SIM_ID = 0.05    # The lower it is, the higher sim_id is when it directly matches claim.
    
    def sim_score_eval(sim_line, sim_id):
        if len(claim) > 15:
            if sim_line > THRESHOLD_SIM_LINE:
                res = 2*(1.1*sim_line*1.1*sim_id)/(1.1*sim_line+1.1*sim_id)
            else:
                res = 0
        else:
            res = sim_id
        
        return res
    
    
    def post_processing(page) -> str:
        import opencc
        CONVERTER_T2S = opencc.OpenCC("t2s.json")
        CONVERTER_S2T = opencc.OpenCC("s2t.json")
    
        simplified = CONVERTER_T2S.convert(page)
        page = CONVERTER_S2T.convert(simplified)
        page = page.replace(" ", "_")
        page = page.replace("-", "")
        return page

    claim = series_data["claim"]
    search_list = series_data["predicted_pages"]
    direct_search = series_data["direct_match"]
    results = []
    mapping = {}
    df_res = []

    tokens = tokenizing_method(claim)
    emb_claim_tok = sbert_model.encode(tokens)
    emb_claim = sbert_model.encode(claim)

    search_list = [post_processing(id) for id in search_list]
    '''
    if series_data["label"] != "NOT ENOUGH INFO":
        gt_pages = set([
            evidence[2]
            for evidence_set in series_data["evidence"]
            for evidence in evidence_set
        ])
    else:
        gt_pages = set([])
    '''

    for search_id in search_list:
        # print(search_id)
        search_series = wiki_pages.loc[wiki_pages['id'] == search_id]
        if search_series.empty:
            continue
        try:
            for temp in search_series["lines"]:
                search_lines = temp
        except:
            continue

        if len(search_lines) == 0:
             continue
        search_id_tok = tokenizing_method(search_id)
        emb_id = sbert_model.encode(search_id_tok)
        sim_id = util.pytorch_cos_sim(emb_id, emb_claim).numpy()
        sim_id = sim_id[0][0]
        new_sim_id = 0
        if search_id in direct_search:
            if sim_id > 0:
                new_sim_id = 1-((1-sim_id)*WEIGHT_SIM_ID)
            else:
                sim_id = 0
                new_sim_id = 1-((1-sim_id)*WEIGHT_SIM_ID)
        else:
            new_sim_id = sim_id

        sim_score = 0
        sim_line = 0
        sim_line_b = 0

        embs = sbert_model.encode_multi_process(search_lines, pool=pool)
        for emb in embs:
            sim = util.pytorch_cos_sim(emb, emb_claim).numpy()
            sim = sim[0][0]
            sim_line = max(sim, sim_line)

        search_lines_tok = [tokenizing_method(line) for line in search_lines]
        embs = sbert_model.encode_multi_process(search_lines_tok, pool=pool)
        for emb in embs:
            sim = util.pytorch_cos_sim(emb, emb_claim_tok).numpy()
            sim = sim[0][0]
            sim_line = max(sim, sim_line)

        if sim_line > THRESHOLD_SIM_LINE:
            sim_line = max(sim_line, sim_line_b)
            sim_line_b = sim_line
            sim_score = sim_score_eval(sim_line, new_sim_id)
            sim_score = max(sim_score, sim_line_b)
            # print(sim_score, sim_line, search_id)
            if sim_score > THRESHOLD_LOWEST:
                search_id = post_processing(search_id)
                if search_id in mapping:
                    mapping[search_id] = max(sim_score, mapping[search_id])
                else:
                    mapping[search_id] = sim_score
        data = (claim, search_id, sim_id, new_sim_id, sim_line, sim_score)
        df_res.append(data)

    mapping_sorted = sorted(mapping.items(), key=lambda x:x[1], reverse=True)
    # print(mapping_sorted[:topk])
    DIFF = 0.125
    for k, v in mapping_sorted:
        THRESHOLD_TOP = v
        break
    if len(mapping_sorted) >= topk:
        results = [k for k, v in mapping_sorted if v > THRESHOLD_TOP-DIFF][:topk]
    else:
        results = [k for k, v in mapping_sorted if v > THRESHOLD_LOWEST][:topk]
    if not results:
        results = [k for k, v in mapping_sorted][:topk]
    if not results:
        results = series_data["direct_match"]
    if not results:
        results = series_data["predicted_pages"][:topk]
    # print(results)

    # Analysis on missed pages
    '''
    if series_data["label"] != "NOT ENOUGH INFO":
        for page in gt_pages:
            if page in mapping:
                if page not in results:
                    print(f"Missed: ID={page}, score={mapping[page]}")
                else:
                    continue
            else:
                if page not in search_list:
                    print(f"Missed: ID={page}, not in search_list")
                else:
                    print(f"Missed: ID={page}, score < {THRESHOLD_LOWEST}")
    '''
    df = pd.DataFrame(df_res, columns=['Claim', 'Search_ID', 'Sim_ID', 'Sim_ID_Adjusted', 'Sim_Line', 'Sim_Score'])

    with open(f"data/train_doc5_logging_0522_{i}.jsonl", "a", encoding="utf8") as f:
        f.write(df.to_json(orient='records', lines=True, force_ascii=False))
    

    return set(results)

In [None]:
claim = "天衛三軌道在天王星內部的磁層，以《 仲夏夜之夢 》作者緹坦妮雅命名。"
proof = "1787年由威廉·赫雪爾發現，並以威廉·莎士比亞的《仲夏夜之夢》中的妖精王后緹坦妮雅命名。"
claim_tok = tokenize(claim, stopwords=stopwords)
proof_tok = tokenize(proof, stopwords=stopwords)
print(claim_tok)
print(proof_tok)

emb_claim = sbert_model.encode(claim_tok)
emb_proof = sbert_model.encode(proof_tok)
print(util.pytorch_cos_sim(emb_proof, emb_claim).numpy()[0][0])

emb_claim = sbert_model.encode(claim)
emb_proof = sbert_model.encode(proof)
print(util.pytorch_cos_sim(emb_proof, emb_claim).numpy()[0][0])

In [19]:
import scipy
def get_pred_pages_tfidf(
    series_data: pd.Series, 
    tokenizing_method: callable,
    vectorizer: TfidfVectorizer,
    tf_idf_matrix: scipy.sparse.csr_matrix,
    wiki_pages: pd.DataFrame,
    topk: int,
    threshold: float
) -> set:
    import numpy as np
    from sklearn.metrics.pairwise import cosine_similarity

    claim = series_data["claim"]
    results = []

    tokens = tokenizing_method(claim)
    claim_vector = vectorizer.transform([tokens])
    sim_scores = cosine_similarity(tf_idf_matrix, claim_vector)
    sim_scores = sim_scores[:, 0]  # flatten the array
    sorted_indices = np.argsort(sim_scores)[::-1]
    topk_sorted_indices = sorted_indices[:topk]
    results = wiki_pages.iloc[topk_sorted_indices]["id"]

    # for search_id in search_list:
    #     search_tokens = wiki_pages.loc[wiki_pages['id'] == search_id]
    #     if search_tokens.empty:
    #         continue
    #     search_processed_text = search_tokens["processed_text"]
    #     search_vector = vectorizer.transform(search_processed_text)
    #     sim_scores = cosine_similarity(search_vector, claim_vector)
    #     sim_scores = sim_scores[0][0]
    #     if sim_scores > threshold:
    #         mapping[search_id] = sim_scores
            # print(sim_scores, search_id)

    # print(mapping)
    # results = sorted(mapping, key=mapping.get, reverse=True)[:topk]
    # print(results)
    return set(results)


### Step 1. Get noun phrases from hanlp consituency parsing tree

Setup [HanLP](https://github.com/hankcs/HanLP) predictor (1 min)

In [21]:
predictor = (hanlp.pipeline().append(
    hanlp.load("FINE_ELECTRA_SMALL_ZH"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_ELECTRA_SMALL"),
    output_key="con",
    input_key="tok",
))

                                             

We will skip this process which for creating parsing tree when demo on class

In [22]:
hanlp_file = f"data/hanlp_con_results_0522.pkl"
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

Get pages via wiki online api

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [35]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [16]:
doc_path = f"data/train_doc5.jsonl"
doc_path_aicup = f"data/train_doc5_aicup.jsonl"
doc_path_sbert = f"data/train_doc5_sbert.jsonl"
doc_path_search = f"data/train_doc5_search_0522.jsonl"
doc_path_tfidf = f"data/train_doc5_tfidf_0522.jsonl"

In [35]:
batch = 2000
start = 6000
# if Path(doc_path_search).exists():
#     with open(doc_path_search, "r", encoding="utf8") as f:
#         predicted_results_search = pd.Series([
#             set(json.loads(line)["predicted_pages"])
#             for line in f
#         ], name="search")
# else:
pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=32)
train_df = pd.DataFrame(TRAIN_DATA[start:start+batch])
train_df.loc[:, "hanlp_results"] = hanlp_results[start:start+batch]
train_df_search = train_df.parallel_apply(
    get_pred_pages_search, axis=1)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=53), Label(value='0 / 53'))), HBox…



  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


  lis = 

In [36]:
predicted_results_search_p = train_df_search["predicted_pages"]
predicted_results_search_d = train_df_search["direct_match"]
save_doc(TRAIN_DATA[start:start+batch], predicted_results_search_p, mode="train", suffix=f"_search_{start}p", col_name="predicted_pages")
TRAIN_DATA_SEARCH = load_json(f"data/train_doc5_search_{start}p.jsonl")
save_doc(TRAIN_DATA_SEARCH, predicted_results_search_d, mode="train", suffix=f"_search_{start}d", col_name="direct_match")

In [None]:
if Path(doc_path).exists():
    with open(doc_path_aicup, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    if Path(doc_path_search).exists():
        with open(doc_path_search, "r", encoding="utf8") as f:
            predicted_results_search = pd.Series([
                set(json.loads(line)["predicted_pages"])
                for line in f
            ], name="search")
    else:
        pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
        train_df = pd.DataFrame(TRAIN_DATA)
        train_df.loc[:, "hanlp_results"] = hanlp_results
        # predicted_results = train_df.progress_apply(get_pred_pages, axis=1)
        predicted_results_search = train_df.parallel_apply(
            get_pred_pages_search, axis=1)
        save_doc(TRAIN_DATA, predicted_results_search, mode="train", suffix="_search")

    if Path(doc_path_aicup).exists():
        with open(doc_path_aicup, "r", encoding="utf8") as f:
            predicted_results_aicup = pd.Series([
                set(json.loads(line)["predicted_pages"])
                for line in f
            ], name="aicup")
    else:
        pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
        train_df = pd.DataFrame(TRAIN_DATA)
        train_df.loc[:, "hanlp_results"] = hanlp_results
        # predicted_results = train_df.progress_apply(get_pred_pages, axis=1)
        predicted_results_aicup = train_df.parallel_apply(
            get_pred_pages, axis=1)
        save_doc(TRAIN_DATA, predicted_results_aicup, mode="train", suffix="_aicup")

    # if Path(doc_path_tfidf).exists():
    #     with open(doc_path_tfidf, "r", encoding="utf8") as f:
    #         predicted_results_tfidf = pd.Series([
    #             set(json.loads(line)["predicted_pages"])
    #             for line in f
    #         ], name="tfidf")
    # else:
    #     pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=3)
    #     TRAIN_DATA_SEARCH = load_json(doc_path_search)
    #     train_df_search = pd.DataFrame(TRAIN_DATA_SEARCH)
    #     predicted_results_tfidf = train_df_search.parallel_apply(
    #         partial(
    #             get_pred_pages_tfidf,
    #             tokenizing_method=partial(tokenize, stopwords=stopwords),
    #             vectorizer=vectorizer,
    #             tf_idf_matrix=X,
    #             wiki_pages=wiki_pages,
    #             topk=topk,
    #             threshold=0.65
    #         ), axis=1)
    #     save_doc(TRAIN_DATA, predicted_results_tfidf, mode="train", suffix="_tfidf")

On TF-IDF data:

In [28]:
num_of_samples = 500
start = 7000

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=2)
# train_df = pd.DataFrame(TRAIN_DATA)
# old_precision = 0.017423887587821863
# old_recall = 0.8056206088992974
# old_f1 = 0.03411004626798726
for i in range(14, 16):
    train_df_batch = pd.DataFrame(TRAIN_DATA[start:start+num_of_samples])
    predicted_results_tfidf = train_df_batch.parallel_apply(
        partial(
            get_pred_pages_tfidf,
            tokenizing_method=partial(tokenize, stopwords=stopwords),
            vectorizer=vectorizer,
            tf_idf_matrix=X,
            wiki_pages=wiki_pages,
            topk=50,
            threshold=0.0
        ), axis=1)
    save_doc(TRAIN_DATA[start:start+num_of_samples], predicted_results_tfidf, mode="train", suffix=f"_tfidf_0522_{start}")
    if i <= 13:
        print(f"On TFIDF top 50 Data, batch = {i}:")
        precision = calculate_precision(TRAIN_DATA[start:start+num_of_samples], predicted_results_tfidf)
        # print(f"(Diff: {precision-old_precision})")
        recall = calculate_recall(TRAIN_DATA[start:start+num_of_samples], predicted_results_tfidf)
        # print(f"(Diff: {recall-old_recall})")
        f1 = calculate_f1(precision, recall)
        print(f"F1-Score: {f1}")
    start += num_of_samples

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=250), Label(value='0 / 250'))), HB…

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=89), Label(value='0 / 89'))), HBox…

In [30]:
batch = 500
start = 0
with open(f'data/train_doc5_tfidf_0522_{start}.jsonl') as fp:
    data = fp.read()
for i in range(1, 16):
    with open(f'data/train_doc5_tfidf_0522_{start+batch}.jsonl') as fp:
        data2 = fp.read()
        data += data2
    start += batch

with open (f'data/train_doc5_tfidf_0522.jsonl', 'w') as fp:
    fp.write(data)
with open(f"data/train_doc5_tfidf_0522.jsonl", "r", encoding="utf8") as f:
    predicted_results_tfidf = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="tfidf")
precision = calculate_precision(TRAIN_DATA, predicted_results_tfidf)
# print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA, predicted_results_tfidf)
# print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")

Precision: 0.017680763983628816
Recall: 0.8118300526213217
F1-Score: 0.0346078080427261


In [46]:
def union_result(series_data: pd.Series,) -> set:
    tfidf = series_data["tfidf"]
    search = series_data["search"]
    # print(tfidf, search)
    return set(set(tfidf).union(set(search)))

In [47]:
old_precision = 0.018778644595822937
old_recall = 0.9515607743779643
old_f1 = 0.03683045590844447

In [48]:
with open(doc_path_search, "r", encoding="utf8") as f:
    predicted_results_search = pd.Series([
        set(json.loads(line)["direct_match"])
        for line in f
    ], name="search")
with open(f"data/train_doc5_tfidf_0522.jsonl", "r", encoding="utf8") as f:
    predicted_results_tfidf = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="tfidf")
# total = 0
# for data in predicted_results_search:
#     total += len(data)
# print(total)

results_df = pd.merge(pd.Series([line for line in predicted_results_tfidf], name="tfidf"), 
                      pd.Series([line for line in predicted_results_search], name="search"), right_index=True, left_index=True)
predicted_results = results_df.apply(union_result, axis=1)
save_doc(TRAIN_DATA, predicted_results, mode="train", suffix=f"_tfidf_0522_with_d", col_name="predicted_pages")
total = 0
for data in predicted_results:
    total += len(data)
print(total)
precision = calculate_precision(TRAIN_DATA, predicted_results)
print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA, predicted_results)
print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")
print(f"(Diff: {f1-old_f1})")

433861
Precision: 0.018778644595822937
(Diff: 0.0)
Recall: 0.9515607743779643
(Diff: 0.0)
F1-Score: 0.03683045590844447
(Diff: 0.0)


In [49]:
with open(f"data/train_doc5_tfidf_0522_with_d.jsonl", "r", encoding="utf8") as f:
    predicted_results_tfidf = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="tfidf")
total = 0
for data in predicted_results_tfidf:
    total += len(data)
print(total)
precision = calculate_precision(TRAIN_DATA, predicted_results_tfidf)
print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA, predicted_results_tfidf)
print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")
print(f"(Diff: {f1-old_f1})")

433861
Precision: 0.018778644595822937
(Diff: 0.0)
Recall: 0.9515607743779643
(Diff: 0.0)
F1-Score: 0.03683045590844447
(Diff: 0.0)


On Search Data:

In [24]:
batch = 2000
start = 0
with open(f'data/train_doc5_search_{start}d.jsonl') as fp:
    data = fp.read()
with open(f'data/train_doc5_search_{start+batch}d.jsonl') as fp:
    data2 = fp.read()
    data += data2
with open(f'data/train_doc5_search_{start+batch*2}d.jsonl') as fp:
    data2 = fp.read()
    data += data2
with open(f'data/train_doc5_search_{start+batch*3}d.jsonl') as fp:
    data2 = fp.read()
    data += data2

with open (f'data/train_doc5_search_0522.jsonl', 'w') as fp:
    fp.write(data)

In [18]:
def clean(series_data):
    def post_processing(page) -> str:
        page = page.replace(" ", "_")
        page = page.replace("-", "")
        return page
    
    result = []
    for element in series_data:
        # print(series_data)
        element = post_processing(element)
        if "Template:" in element:
            continue
        result.append(element)

    return set(result)
doc_path_search = f"data/train_doc5_search.jsonl"
doc_path_search_backup = f"data/train_doc5_search_backup.jsonl"
TRAIN_DATA_SEARCH1 = load_json(doc_path_search_backup)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=12)
train_df = pd.DataFrame(TRAIN_DATA)
train_df.loc[:, "hanlp_results"] = hanlp_results
train_df_search1 = pd.DataFrame(TRAIN_DATA_SEARCH1)
predicted_results_search = train_df_search1.loc[:, "predicted_pages"]
direct_match = train_df_search1.loc[:, "direct_match"]
# print(predicted_results_search)
predicted_results_search = predicted_results_search.apply(clean)
direct_match = direct_match.apply(clean)

# predicted_results = train_df.progress_apply(get_pred_pages, axis=1)
save_doc(TRAIN_DATA, predicted_results_search, mode="train", suffix="_search", col_name="predicted_pages")
TRAIN_DATA_SEARCH = load_json(doc_path_search)
# direct_match = train_df.parallel_apply(
#     get_pred_pages_search, axis=1)
save_doc(TRAIN_DATA_SEARCH, direct_match, mode="train", suffix="_search", col_name="direct_match")

FileNotFoundError: [Errno 2] No such file or directory: 'data/train_doc5_search_backup.jsonl'

In [29]:
if Path(doc_path_search).exists():
    with open(doc_path_search, "r", encoding="utf8") as f:
        predicted_results_search = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ], name="search")

In [50]:
TRAIN_DATA_TFIDF = load_json(f"data/train_doc5_tfidf_0522_with_d.jsonl")
with open(f"data/train_doc5_search_0522.jsonl", "r", encoding="utf8") as f:
    direct_match = pd.Series([
        set(json.loads(line)["direct_match"])
        for line in f
    ], name="direct_match")
save_doc(TRAIN_DATA_TFIDF, direct_match, mode="train", suffix=f"_tfidf_0522_with_d", col_name="direct_match")

In [51]:
num_of_samples = 500
TRAIN_DATA_TFIDF = load_json(f"data/train_doc5_tfidf_0522_with_d.jsonl")

for i in range(0, 1):
    start = i*num_of_samples
    train_df_search = pd.DataFrame(TRAIN_DATA_TFIDF[start:start+num_of_samples])
    # pandarallel.initialize(progress_bar=False, verbose=0, nb_workers=10)
    predicted_results_sbert = train_df_search.progress_apply(
        partial(
            get_pred_pages_sbert,
            tokenizing_method=partial(tokenize, stopwords=stopwords),
            # model=sbert_model,
            # wiki_pages=wiki_pages,
            topk=5,
            threshold=0.375,
            i = i,
        ), axis=1)
    save_doc(TRAIN_DATA[start:start+num_of_samples], predicted_results_sbert, mode="train", suffix=f"_sbert_0522_{start}")
    if i <= 13:
        print(f"On Sbert Data, batch = {i}:")
        precision = calculate_precision(TRAIN_DATA[start:start+num_of_samples], predicted_results_sbert)
        # print(f"(Diff: {precision-old_precision})")
        recall = calculate_recall(TRAIN_DATA[start:start+num_of_samples], predicted_results_sbert)
        # print(f"(Diff: {recall-old_recall})")
        f1 = calculate_f1(precision, recall)
        print(f"F1-Score: {f1}")

  0%|          | 0/500 [00:00<?, ?it/s]

In [37]:
print(f"On Sbert Data, batch = {i}:")
precision = calculate_precision(TRAIN_DATA[0:500], predicted_results_sbert)
# print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA[0:500], predicted_results_sbert)
# print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")

On Sbert Data, batch = 0:
Precision: 0.5401899557637262
Recall: 0.7907884465261513
F1-Score: 0.6418976825055823


In [23]:
batch = 500
start = 0
with open(f'data/train_doc5_sbert_0522_{start}.jsonl') as fp:
    data = fp.read()
for i in range(1, 16):
    with open(f'data/train_doc5_sbert_0522_{start+batch}.jsonl') as fp:
        data2 = fp.read()
        data += data2
    start += batch

with open (f'data/train_doc5_sbert_0522.jsonl', 'w') as fp:
    fp.write(data)

In [24]:
with open(f"data/train_doc5_sbert_0522.jsonl", "r", encoding="utf8") as f:
    predicted_results_sbert = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="sbert")

In [27]:
# old_precision = 0.4831244778613204
# old_recall = 0.8928989139515455
# old_f1 = 0.6269970759980318
old_precision = 0.4584699453551913
old_recall = 0.8663153786104606
old_f1 = 0.5996134726235804

In [30]:
print("On SBERT Data:")
precision = calculate_precision(TRAIN_DATA, predicted_results_search)
print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA, predicted_results_search)
print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")
print(f"(Diff: {f1-old_f1})")

On SBERT Data:
Precision: 0.01454832695285644
(Diff: -0.44392161840233485)
Recall: 0.8922643409341909
(Diff: 0.02594896232373034)
F1-Score: 0.028629845656065505
(Diff: -0.5709836269675149)


In [None]:
sbert_model.stop_multi_process_pool(pool=pool)

#### Operation on Log File

In [56]:
def get_pred_pages_log(
    data: pd.DataFrame, 
    topk: int,
    threshold: float,
    progress_bar,
    
):
    # Parameters:
    THRESHOLD_LOWEST = 0.6
    THRESHOLD_MID = 0.7
    THRESHOLD_HIGHEST = 0.885
    THRESHOLD_SIM_LINE = threshold
    WEIGHT_SIM_ID = 0.2    # The lower it is, the higher sim_id is when it directly matches claim.
    
    def sim_score_eval(sim_line, sim_id, claim):
        # res = (weight_id + weight_line)*(s1*s2)/(weight_line*s1+weight_id*s2)
        # if sim_line > 0.5:
        #     res = sim_line + (1-sim_line)*sim_id
        # elif sim_line < 0.5 and sim_id > 0.5:
        #     res = sim_line + (1-sim_line)*sim_id
        # else:
        #     res = 0
        if len(claim) <= 15:
            res = sim_id
        else:
            w_line = 1.1
            w_id = 1.1
            if sim_line > THRESHOLD_SIM_LINE:
                res = 2*(w_line*sim_line*w_id*sim_id)/(w_line*sim_line+w_id*sim_id)
            else:
                res = 0
        
        return res
    
    def post_processing(page) -> str:
        import opencc
        CONVERTER_T2S = opencc.OpenCC("t2s.json")
        CONVERTER_S2T = opencc.OpenCC("s2t.json")
    
        simplified = CONVERTER_T2S.convert(page)
        page = CONVERTER_S2T.convert(simplified)
        page = page.replace(" ", "_")
        page = page.replace("-", "")
        return page

    results = []
    doc_res = []
    mapping = {}
    claim_prev = ""
    claim_count = 0
    claim_comma = 0
    direct_match = []
    predicted_pages = []

    for index, series_data in data.iterrows():
        claim = series_data["Claim"]
        search_id = series_data["Search_ID"]
        sim_id = series_data["Sim_ID"]
        sim_id_new = series_data["Sim_ID_Adjusted"]
        sim_line = series_data["Sim_Line"]

        if index == 0:  
            claim_prev = claim
            claim_comma = claim.count("，")
        elif claim != claim_prev:
            mapping_sorted = sorted(mapping.items(), key=lambda x:x[1], reverse=True)
            DIFF = 0.125
            for k, v in mapping_sorted:
                THRESHOLD_TOP = v
                break
            # print(mapping_sorted[:topk])
            if len(mapping_sorted) >= topk:
                doc_res = [k for k, v in mapping_sorted if v > THRESHOLD_TOP-DIFF][:topk]
            else:
                doc_res= [k for k, v in mapping_sorted if v > THRESHOLD_LOWEST][:topk]
            if not doc_res:
                doc_res = direct_match[:topk]
            if not doc_res:
                doc_res = predicted_pages[:topk]
            
            results.append(doc_res)
            #print(claim_count, mapping)
            doc_res = []
            mapping = {}
            claim_prev = claim
            claim_comma = claim.count("，")
            claim_count = claim_count + 1
            progress_bar.update(1)
            # if claim_count % 100 == 0:
            #     print(f"已處理{claim_count}筆資料")

        if sim_id != sim_id_new:
            direct_match.append(search_id)
            if sim_id > 0:
                # print(f"{search_id}: sim_id={sim_id}")
                sim_id_new = 1-((1-sim_id)*WEIGHT_SIM_ID)
            else:
                sim_id = 0
                sim_id_new = 1-((1-sim_id)*WEIGHT_SIM_ID)
        else:
            sim_id_new = sim_id

        predicted_pages.append(search_id)
        sim_score = sim_score_eval(sim_line=sim_line, sim_id=sim_id_new, claim=claim)
        if sim_score > 0:
            sim_score = max(sim_score, sim_line)
            # print(sim_score, search_id)
            if sim_score > THRESHOLD_LOWEST:
                search_id = post_processing(search_id)
                if search_id in mapping:
                    mapping[search_id] = max(sim_score, mapping[search_id])
                else:
                    mapping[search_id] = sim_score

    mapping_sorted = sorted(mapping.items(), key=lambda x:x[1], reverse=True)
    # print(mapping_sorted[:topk])
    # if len(mapping_sorted) >= topk:
    #     doc_res = [k for k, v in mapping_sorted if v > THRESHOLD_HIGHEST][:topk]
    # else:
    #     doc_res= [k for k, v in mapping_sorted if v > THRESHOLD_LOWEST][:topk]

    # if not doc_res:
    #     doc_res = [k for k, v in mapping_sorted if v > THRESHOLD_MID][:topk]
    if len(mapping_sorted) >= topk:
        doc_res = [k for k, v in mapping_sorted if v > THRESHOLD_TOP-DIFF][:topk]
    else:
        doc_res= [k for k, v in mapping_sorted if v > THRESHOLD_LOWEST][:topk]
    if not doc_res:
        doc_res = direct_match[:topk]
    if not doc_res:
        doc_res = predicted_pages[:topk]

    results.append(doc_res)

    return results

In [59]:
topk = 6
num_of_samples = 500

def merge(series_data: pd.Series) -> set:
    result = []
    for i in range(0, topk):
        if series_data.iloc[i] != None:
            result.append(series_data.iloc[i])
    # print(set(result))
    return set(result)

for i in range(14, 16):
    start = i*num_of_samples
    doc_log = f"data/train_doc5_logging_0522_{i}.jsonl"
    TRAIN_DATA_LOG = load_json(doc_log)
    train_df_log = pd.DataFrame(TRAIN_DATA_LOG)

    progress_bar = tqdm(range(500))
    
    predicted_results_log = get_pred_pages_log(
        data=train_df_log, 
        topk=topk, 
        threshold=0.375, 
        progress_bar=progress_bar
    )
    predicted_results_log_df = pd.DataFrame(predicted_results_log)
    predicted_results_log_df_b = predicted_results_log_df.apply(merge, axis=1)
    save_doc(TRAIN_DATA[start:start+num_of_samples], predicted_results_log_df_b, mode="train", suffix=f"_log_0522_{i}")

    with open(f"data/train_doc5_sbert_0522.jsonl", "r", encoding="utf8") as f:
        predicted_results_original = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ], name="sbert")

    if i < 13:
        print(f"On Original Data, batch = {i}")
        old_precision = calculate_precision(TRAIN_DATA[start:start+num_of_samples], predicted_results_original[start:start+num_of_samples])
        old_recall = calculate_recall(TRAIN_DATA[start:start+num_of_samples], predicted_results_original[start:start+num_of_samples])
        old_f1 = calculate_f1(precision, recall)

        print(f"\nOn Log Data, batch = {i}")
        precision = calculate_precision(TRAIN_DATA[start:start+num_of_samples], predicted_results_log_df_b)
        print(f"(Diff: {precision-old_precision})")
        recall = calculate_recall(TRAIN_DATA[start:start+num_of_samples], predicted_results_log_df_b)
        print(f"(Diff: {recall-old_recall})")
        f1 = calculate_f1(precision, recall)
        print(f"F1-Score: {f1}")
        print(f"(Diff: {f1-old_f1})")
    

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [77]:
batch = 500
start = 0
with open(f'data/train_doc5_log_0522_{start}.jsonl', "r", encoding="utf8") as fp:
    data = fp.read()
    fp.close()
for i in range(1, 16):
    with open(f'data/train_doc5_log_0522_{i}.jsonl', "r", encoding="utf8") as fp:
        data2 = fp.read()
        data += data2
        fp.close()
    start += batch

with open(f'data/train_doc5_log_0522.jsonl', 'w', encoding="utf8") as fp:
    fp.write(data)
    fp.close()

In [80]:
print(f"On Original Data, batch")
with open(f"data/train_doc5_sbert_0522.jsonl", "r", encoding="utf8") as f:
    predicted_results_original = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="sbert")
old_precision = calculate_precision(TRAIN_DATA, predicted_results_original)
old_recall = calculate_recall(TRAIN_DATA, predicted_results_original)
old_f1 = calculate_f1(precision, recall)

with open(f'data/train_doc5_log_0522.jsonl', 'r', encoding="utf8") as fp:
    predicted_results_log_df_b = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in fp
    ], name="sbert")
print(f"\nOn Log Data, batch")
precision = calculate_precision(TRAIN_DATA, predicted_results_log_df_b)
print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA, predicted_results_log_df_b)
print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")
print(f"(Diff: {f1-old_f1})")

On Original Data, batch
Precision: 0.4800688624699537
Recall: 0.8738712401741052

On Log Data, batch
Precision: 0.3968784512440751
(Diff: -0.08319041122587861)
Recall: 0.8965844864548823
(Diff: 0.02271324628077709)
F1-Score: 0.5502052699348324
(Diff: -0.01867088936685879)


In [81]:
start = 0
doc_log_test = f"data/test_doc5_logging_0522_20.jsonl"
TEST_DATA_LOG = load_json(doc_log_test)
test_df_log = pd.DataFrame(TEST_DATA_LOG)

progress_bar = tqdm(range(1000))

test_results_log = get_pred_pages_log(
    data=test_df_log, 
    topk=topk, 
    threshold=0.375, 
    progress_bar=progress_bar
)
test_results_log_df = pd.DataFrame(test_results_log)
test_results_log_df_b = test_results_log_df.apply(merge, axis=1)
save_doc(TEST_DATA, test_results_log_df_b, mode="test", suffix=f"_log_0522")

  0%|          | 0/1000 [00:00<?, ?it/s]

In [36]:
old_precision = 0.4584699453551913
old_recall = 0.8663153786104606
old_f1 = 0.5996134726235804
num_of_samples = 500
i = 12
start = i*num_of_samples
with open(f"data/train_doc5_sbert_0522.jsonl", "r", encoding="utf8") as f:
    predicted_results_original = pd.Series([
        set(json.loads(line)["predicted_pages"])
        for line in f
    ], name="sbert")
precision = calculate_precision(TRAIN_DATA[start:start+num_of_samples], predicted_results_original[start:start+num_of_samples])
# print(f"(Diff: {precision-old_precision})")
recall = calculate_recall(TRAIN_DATA[start:start+num_of_samples], predicted_results_original[start:start+num_of_samples])
# print(f"(Diff: {recall-old_recall})")
f1 = calculate_f1(precision, recall)
print(f"F1-Score: {f1}")
# print(f"(Diff: {f1-old_f1})")

Precision: 0.48262411347517725
Recall: 0.9042553191489362
F1-Score: 0.6293487544677586


Merge Two Pandas Series.

In [None]:
def union_result(series_data: pd.Series,) -> set:
    aicup = series_data["aicup"]
    sbert = series_data["sbert"]
    print(aicup, sbert)
    return set(aicup).union(set(sbert))

In [None]:
results_df = pd.merge(pd.Series([line for line in predicted_results_aicup[:500]], name="aicup"), 
                      pd.Series([line for line in predicted_results_sbert], name="sbert"), right_index=True, left_index=True)
predicted_results = results_df.apply(union_result, axis=1)
save_doc(TRAIN_DATA[:500], predicted_results, mode="train", suffix="_temp")

In [None]:
doc_path_temp = f"data/train_doc5_temp.jsonl"

if Path(doc_path).exists():
    with open(doc_path, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=3)
    TRAIN_DATA_TEMP = load_json(doc_path_temp)
    train_df_temp = pd.DataFrame(TRAIN_DATA_TEMP)
    predicted_results = train_df_temp.parallel_apply(
        partial(
            get_pred_pages_tfidf,
            tokenizing_method=partial(tokenize, stopwords=stopwords),
            vectorizer=vectorizer,
            tf_idf_matrix=X,
            wiki_pages=wiki_pages,
            topk=topk,
            threshold=0,
        ), axis=1)
    save_doc(TRAIN_DATA, predicted_results, mode="train")

In [None]:
print("On Data:")
precision = calculate_precision(TRAIN_DATA[:500], predicted_results)
recall = calculate_recall(TRAIN_DATA[:500], predicted_results)
print(calculate_f1(precision, recall))

### Step 2. Calculate our results

In [None]:
precision = 0.143234
recall = 0.635255
print(calculate_f1(precision, recall))

In [None]:
precision = calculate_precision(TRAIN_DATA, predicted_results)
recall = calculate_recall(TRAIN_DATA, predicted_results)
f1_score = 2*(precision*recall)/(precision+recall)
print(f"F1 Score: {f1_score}")

### Step 3. Repeat the same process on test set
Create parsing tree

In [19]:
hanlp_test_file = f"data/hanlp_con_test_results.pkl"
if Path(hanlp_test_file).exists():
    with open(hanlp_test_file, "rb") as f:
        hanlp_test_results = pickle.load(f)
else:
    hanlp_test_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
    with open(hanlp_test_file, "wb") as f:
        pickle.dump(hanlp_test_results, f)

Get pages via wiki online api

In [20]:
test_doc_path = f"data/test_doc5.jsonl"
test_doc_path_aicup = f"data/test_doc5_aicup.jsonl"
test_doc_path_search = f"data/test_doc5_search.jsonl"
test_doc_path_tfidf = f"data/test_doc5_tfidf.jsonl"
test_doc_path_sbert = f"data/test_doc5_sbert.jsonl"

In [21]:
if Path(test_doc_path_search).exists():
    with open(test_doc_path_search, "r", encoding="utf8") as f:
        test_results_search = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ], name="search")
else:
    pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
    test_df = pd.DataFrame(TEST_DATA)
    test_df.loc[:, "hanlp_results"] = hanlp_test_results
    # predicted_results = test_df.progress_apply(get_pred_pages, axis=1)
    test_results_search = test_df.parallel_apply(
        get_pred_pages_search, axis=1)
    save_doc(TEST_DATA, test_results_search, mode="test", suffix="_search")

In [30]:
TRAIN_DATA_SEARCH = load_json(doc_path_search)
train_df_search = pd.DataFrame(TRAIN_DATA_SEARCH)
res = get_pred_pages_sbert(
    series_data=train_df_search.loc[274],
    tokenizing_method=partial(tokenize, stopwords=stopwords),
    # model=sbert_model,
    # wiki_pages=wiki_pages,
    topk=5,
    threshold=0.375
)

print(list(res))

['資料流程圖']


In [24]:
if Path(test_doc_path_sbert).exists():
    with open(test_doc_path_sbert, "r", encoding="utf8") as f:
        test_results_sbert = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ], name="sbert")
else:
    pandarallel.initialize(progress_bar=False, verbose=0, nb_workers=10)
    TEST_DATA_SEARCH = load_json(test_doc_path_search)
    test_df_search = pd.DataFrame(TEST_DATA_SEARCH)
    test_results_sbert = test_df_search.progress_apply(
        partial(
            get_pred_pages_sbert,
            tokenizing_method=partial(tokenize, stopwords=stopwords),
            # model=sbert_model,
            # wiki_pages=wiki_pages,
            topk=5,
            threshold=0.375
        ), axis=1)
    save_doc(TEST_DATA, test_results_sbert, mode="test", suffix="_sbert")

  0%|          | 0/989 [00:00<?, ?it/s]

In [None]:
results_df = pd.merge(pd.Series([line for line in test_results_aicup], name="aicup"), 
                      pd.Series([line for line in test_results_tfidf], name="tfidf"), right_index=True, left_index=True)
test_results = results_df.apply(union_result, axis=1)
save_doc(TEST_DATA, test_results, mode="test")

notebook2
## PART 2. Sentence retrieval

Import some libs

In [82]:
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset, Dataset

# local libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

Global variable

In [83]:
SEED = 42

TRAIN_DATA = load_json("data/public_train_0522.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
DOC_DATA = load_json("data/train_doc5_sbert_0522.jsonl")

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=_y,
)

Preload wiki database (1 min)

In [84]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=118776), Label(value='0 / 118776')…

Transform to id to evidence_map mapping


### Helper function

Calculate precision for sentence retrieval

In [85]:
def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0

    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0

        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

    return 0.0, 0.0

Calculate recall for sentence retrieval

In [86]:
def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        claim = instance["claim"]

        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

Calculate the scores of sentence retrieval

In [87]:
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 5,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]

        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)

    if save_name is not None:
        # write doc7_sent5 file
        with open(f"data/{save_name}", "w", encoding="utf8") as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim]["predicted_evidence"].tolist()
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

Inference script to get probabilites for the candidate evidence sentences

In [88]:
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())

    return np.array(probs)

AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences

In [89]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]

        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten

### Main function for sentence retrieval

In [90]:
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # positive
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]
        for evidence_set in evidence_sets:
            sents = []
            for evidence in evidence_set:
                # evidence[2] is the page title
                page = evidence[2].replace(" ", "_")
                # the only page with weird name
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] is in form of int however, mapping requires str
                sent_idx = str(evidence[3])
                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            claims.append(claim)
            sentences.append(whole_evidence)
            labels.append(1)

    # negative
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue
        claim = df["claim"].iloc[i]

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in df["evidence"][i]
                            for evidence in evidences])
        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for pair in page_sent_id_pairs:
                if pair in evidence_set:
                    continue
                text = mapping[page][pair[1]]
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                if text != "" and np.random.rand(1) <= negative_ratio:
                    claims.append(claim)
                    sentences.append(text)
                    labels.append(0)

    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})


def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue
        claim = df["claim"].iloc[i]

        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for page_name, sentence_id in page_sent_id_pairs:
                text = mapping[page][sentence_id]
                if text != "":
                    claims.append(claim)
                    sentences.append(text)
                    if not is_testset:
                        evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

### Step 1. Setup training environment

Hyperparams

In [91]:
# MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-bert-wwm" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-bert-wwm-ext" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-macbert-base" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-roberta-wwm-ext" #@param {type:"string"}
MODEL_NAME = "hfl/chinese-lert-base" #@param {type:"string"}

MODEL_SHORT = "lert"
NUM_EPOCHS = 1  #@param {type:"integer"}
LR = 2e-5  #@param {type:"number"}
TRAIN_BATCH_SIZE = 64  #@param {type:"integer"}
TEST_BATCH_SIZE = 256  #@param {type:"integer"}
NEGATIVE_RATIO = 0.075  #@param {type:"number"}
VALIDATION_STEP = 25  #@param {type:"integer"}
TOP_N = 5  #@param {type:"integer"}
#@title  { display-mode: "form" }

Experiment Directory

In [92]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}_{MODEL_SHORT}_new_8"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Combine claims and evidences

In [93]:
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)
counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    8883
1    4289
Name: label, dtype: int64


### Step 3. Start training

Dataloader things

In [94]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

Downloading (…)okenizer_config.json:   0%|          | 0.00/19.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Save your memory.

Trainer

In [95]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")
print(torch.cuda.is_available())
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
if torch.cuda.device_count() > 1:
    # import os
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '5678'
    # torch.distributed.init_process_group(backend="nccl")
    model = nn.DataParallel(model)
    # model = model.cuda()
    # model = nn.parallel.DistributedDataParallel(model)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

True


Downloading pytorch_model.bin:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at hfl/chinese-lert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not init

Please make sure that you are using gpu when training (5 min)

In [96]:
torch.cuda.empty_cache()

In [30]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        torch.cuda.empty_cache()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.sum().backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.sum().item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(current_steps, val_results)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(model, CKPT_DIR, current_steps)

print("Finished training!")

  0%|          | 0/207 [00:00<?, ?it/s]



Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

25 {'F1 score': 0.4005842287692988, 'Precision': 0.27602339181286334, 'Recall': 0.7300194931773879}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

50 {'F1 score': 0.43010652185426856, 'Precision': 0.2935672514619854, 'Recall': 0.804093567251462}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

75 {'F1 score': 0.42843127113835744, 'Precision': 0.2927875243664689, 'Recall': 0.7982456140350878}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

100 {'F1 score': 0.4331048981982797, 'Precision': 0.29571150097465604, 'Recall': 0.8089668615984406}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

125 {'F1 score': 0.4348473811787237, 'Precision': 0.29707602339180994, 'Recall': 0.8109161793372319}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

150 {'F1 score': 0.4325441824067054, 'Precision': 0.29571150097465604, 'Recall': 0.8050682261208577}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

175 {'F1 score': 0.43240727814230223, 'Precision': 0.2953216374268977, 'Recall': 0.8070175438596491}




Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

200 {'F1 score': 0.4340807091725054, 'Precision': 0.2964912280701726, 'Recall': 0.8099415204678363}




Finished training!


In [82]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 1035698), started 2:11:00 ago. (Use '!kill 1035698' to kill it.)

Validation part (15 mins)

In [40]:
torch.cuda.empty_cache()

In [53]:
import json
ckpt_name = "model.200.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
    save_name=f"sent_retrieval/train_doc5sent{TOP_N}_neg{NEGATIVE_RATIO}_{LR}_e{NUM_EPOCHS}_{MODEL_SHORT}_new.jsonl",
)
print(f"Training scores => {train_results}")

print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
    save_name=f"sent_retrieval/dev_doc5sent{TOP_N}_neg{NEGATIVE_RATIO}_{LR}_e{NUM_EPOCHS}_{MODEL_SHORT}_new.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/714 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.4351978361772329, 'Precision': 0.2997929354445913, 'Recall': 0.7936662606577345}
Start validation


  0%|          | 0/181 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.4340807091725054, 'Precision': 0.2964912280701726, 'Recall': 0.8099415204678363}


Load the model we want.

In [97]:
ckpt_name = "model.200.pt"
model = load_model(model, ckpt_name, CKPT_DIR)

RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification:
	Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.self.query.weight", "bert.encoder.layer.0.attention.self.query.bias", "bert.encoder.layer.0.attention.self.key.weight", "bert.encoder.layer.0.attention.self.key.bias", "bert.encoder.layer.0.attention.self.value.weight", "bert.encoder.layer.0.attention.self.value.bias", "bert.encoder.layer.0.attention.output.dense.weight", "bert.encoder.layer.0.attention.output.dense.bias", "bert.encoder.layer.0.attention.output.LayerNorm.weight", "bert.encoder.layer.0.attention.output.LayerNorm.bias", "bert.encoder.layer.0.intermediate.dense.weight", "bert.encoder.layer.0.intermediate.dense.bias", "bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.bias", "bert.encoder.layer.0.output.LayerNorm.weight", "bert.encoder.layer.0.output.LayerNorm.bias", "bert.encoder.layer.1.attention.self.query.weight", "bert.encoder.layer.1.attention.self.query.bias", "bert.encoder.layer.1.attention.self.key.weight", "bert.encoder.layer.1.attention.self.key.bias", "bert.encoder.layer.1.attention.self.value.weight", "bert.encoder.layer.1.attention.self.value.bias", "bert.encoder.layer.1.attention.output.dense.weight", "bert.encoder.layer.1.attention.output.dense.bias", "bert.encoder.layer.1.attention.output.LayerNorm.weight", "bert.encoder.layer.1.attention.output.LayerNorm.bias", "bert.encoder.layer.1.intermediate.dense.weight", "bert.encoder.layer.1.intermediate.dense.bias", "bert.encoder.layer.1.output.dense.weight", "bert.encoder.layer.1.output.dense.bias", "bert.encoder.layer.1.output.LayerNorm.weight", "bert.encoder.layer.1.output.LayerNorm.bias", "bert.encoder.layer.2.attention.self.query.weight", "bert.encoder.layer.2.attention.self.query.bias", "bert.encoder.layer.2.attention.self.key.weight", "bert.encoder.layer.2.attention.self.key.bias", "bert.encoder.layer.2.attention.self.value.weight", "bert.encoder.layer.2.attention.self.value.bias", "bert.encoder.layer.2.attention.output.dense.weight", "bert.encoder.layer.2.attention.output.dense.bias", "bert.encoder.layer.2.attention.output.LayerNorm.weight", "bert.encoder.layer.2.attention.output.LayerNorm.bias", "bert.encoder.layer.2.intermediate.dense.weight", "bert.encoder.layer.2.intermediate.dense.bias", "bert.encoder.layer.2.output.dense.weight", "bert.encoder.layer.2.output.dense.bias", "bert.encoder.layer.2.output.LayerNorm.weight", "bert.encoder.layer.2.output.LayerNorm.bias", "bert.encoder.layer.3.attention.self.query.weight", "bert.encoder.layer.3.attention.self.query.bias", "bert.encoder.layer.3.attention.self.key.weight", "bert.encoder.layer.3.attention.self.key.bias", "bert.encoder.layer.3.attention.self.value.weight", "bert.encoder.layer.3.attention.self.value.bias", "bert.encoder.layer.3.attention.output.dense.weight", "bert.encoder.layer.3.attention.output.dense.bias", "bert.encoder.layer.3.attention.output.LayerNorm.weight", "bert.encoder.layer.3.attention.output.LayerNorm.bias", "bert.encoder.layer.3.intermediate.dense.weight", "bert.encoder.layer.3.intermediate.dense.bias", "bert.encoder.layer.3.output.dense.weight", "bert.encoder.layer.3.output.dense.bias", "bert.encoder.layer.3.output.LayerNorm.weight", "bert.encoder.layer.3.output.LayerNorm.bias", "bert.encoder.layer.4.attention.self.query.weight", "bert.encoder.layer.4.attention.self.query.bias", "bert.encoder.layer.4.attention.self.key.weight", "bert.encoder.layer.4.attention.self.key.bias", "bert.encoder.layer.4.attention.self.value.weight", "bert.encoder.layer.4.attention.self.value.bias", "bert.encoder.layer.4.attention.output.dense.weight", "bert.encoder.layer.4.attention.output.dense.bias", "bert.encoder.layer.4.attention.output.LayerNorm.weight", "bert.encoder.layer.4.attention.output.LayerNorm.bias", "bert.encoder.layer.4.intermediate.dense.weight", "bert.encoder.layer.4.intermediate.dense.bias", "bert.encoder.layer.4.output.dense.weight", "bert.encoder.layer.4.output.dense.bias", "bert.encoder.layer.4.output.LayerNorm.weight", "bert.encoder.layer.4.output.LayerNorm.bias", "bert.encoder.layer.5.attention.self.query.weight", "bert.encoder.layer.5.attention.self.query.bias", "bert.encoder.layer.5.attention.self.key.weight", "bert.encoder.layer.5.attention.self.key.bias", "bert.encoder.layer.5.attention.self.value.weight", "bert.encoder.layer.5.attention.self.value.bias", "bert.encoder.layer.5.attention.output.dense.weight", "bert.encoder.layer.5.attention.output.dense.bias", "bert.encoder.layer.5.attention.output.LayerNorm.weight", "bert.encoder.layer.5.attention.output.LayerNorm.bias", "bert.encoder.layer.5.intermediate.dense.weight", "bert.encoder.layer.5.intermediate.dense.bias", "bert.encoder.layer.5.output.dense.weight", "bert.encoder.layer.5.output.dense.bias", "bert.encoder.layer.5.output.LayerNorm.weight", "bert.encoder.layer.5.output.LayerNorm.bias", "bert.encoder.layer.6.attention.self.query.weight", "bert.encoder.layer.6.attention.self.query.bias", "bert.encoder.layer.6.attention.self.key.weight", "bert.encoder.layer.6.attention.self.key.bias", "bert.encoder.layer.6.attention.self.value.weight", "bert.encoder.layer.6.attention.self.value.bias", "bert.encoder.layer.6.attention.output.dense.weight", "bert.encoder.layer.6.attention.output.dense.bias", "bert.encoder.layer.6.attention.output.LayerNorm.weight", "bert.encoder.layer.6.attention.output.LayerNorm.bias", "bert.encoder.layer.6.intermediate.dense.weight", "bert.encoder.layer.6.intermediate.dense.bias", "bert.encoder.layer.6.output.dense.weight", "bert.encoder.layer.6.output.dense.bias", "bert.encoder.layer.6.output.LayerNorm.weight", "bert.encoder.layer.6.output.LayerNorm.bias", "bert.encoder.layer.7.attention.self.query.weight", "bert.encoder.layer.7.attention.self.query.bias", "bert.encoder.layer.7.attention.self.key.weight", "bert.encoder.layer.7.attention.self.key.bias", "bert.encoder.layer.7.attention.self.value.weight", "bert.encoder.layer.7.attention.self.value.bias", "bert.encoder.layer.7.attention.output.dense.weight", "bert.encoder.layer.7.attention.output.dense.bias", "bert.encoder.layer.7.attention.output.LayerNorm.weight", "bert.encoder.layer.7.attention.output.LayerNorm.bias", "bert.encoder.layer.7.intermediate.dense.weight", "bert.encoder.layer.7.intermediate.dense.bias", "bert.encoder.layer.7.output.dense.weight", "bert.encoder.layer.7.output.dense.bias", "bert.encoder.layer.7.output.LayerNorm.weight", "bert.encoder.layer.7.output.LayerNorm.bias", "bert.encoder.layer.8.attention.self.query.weight", "bert.encoder.layer.8.attention.self.query.bias", "bert.encoder.layer.8.attention.self.key.weight", "bert.encoder.layer.8.attention.self.key.bias", "bert.encoder.layer.8.attention.self.value.weight", "bert.encoder.layer.8.attention.self.value.bias", "bert.encoder.layer.8.attention.output.dense.weight", "bert.encoder.layer.8.attention.output.dense.bias", "bert.encoder.layer.8.attention.output.LayerNorm.weight", "bert.encoder.layer.8.attention.output.LayerNorm.bias", "bert.encoder.layer.8.intermediate.dense.weight", "bert.encoder.layer.8.intermediate.dense.bias", "bert.encoder.layer.8.output.dense.weight", "bert.encoder.layer.8.output.dense.bias", "bert.encoder.layer.8.output.LayerNorm.weight", "bert.encoder.layer.8.output.LayerNorm.bias", "bert.encoder.layer.9.attention.self.query.weight", "bert.encoder.layer.9.attention.self.query.bias", "bert.encoder.layer.9.attention.self.key.weight", "bert.encoder.layer.9.attention.self.key.bias", "bert.encoder.layer.9.attention.self.value.weight", "bert.encoder.layer.9.attention.self.value.bias", "bert.encoder.layer.9.attention.output.dense.weight", "bert.encoder.layer.9.attention.output.dense.bias", "bert.encoder.layer.9.attention.output.LayerNorm.weight", "bert.encoder.layer.9.attention.output.LayerNorm.bias", "bert.encoder.layer.9.intermediate.dense.weight", "bert.encoder.layer.9.intermediate.dense.bias", "bert.encoder.layer.9.output.dense.weight", "bert.encoder.layer.9.output.dense.bias", "bert.encoder.layer.9.output.LayerNorm.weight", "bert.encoder.layer.9.output.LayerNorm.bias", "bert.encoder.layer.10.attention.self.query.weight", "bert.encoder.layer.10.attention.self.query.bias", "bert.encoder.layer.10.attention.self.key.weight", "bert.encoder.layer.10.attention.self.key.bias", "bert.encoder.layer.10.attention.self.value.weight", "bert.encoder.layer.10.attention.self.value.bias", "bert.encoder.layer.10.attention.output.dense.weight", "bert.encoder.layer.10.attention.output.dense.bias", "bert.encoder.layer.10.attention.output.LayerNorm.weight", "bert.encoder.layer.10.attention.output.LayerNorm.bias", "bert.encoder.layer.10.intermediate.dense.weight", "bert.encoder.layer.10.intermediate.dense.bias", "bert.encoder.layer.10.output.dense.weight", "bert.encoder.layer.10.output.dense.bias", "bert.encoder.layer.10.output.LayerNorm.weight", "bert.encoder.layer.10.output.LayerNorm.bias", "bert.encoder.layer.11.attention.self.query.weight", "bert.encoder.layer.11.attention.self.query.bias", "bert.encoder.layer.11.attention.self.key.weight", "bert.encoder.layer.11.attention.self.key.bias", "bert.encoder.layer.11.attention.self.value.weight", "bert.encoder.layer.11.attention.self.value.bias", "bert.encoder.layer.11.attention.output.dense.weight", "bert.encoder.layer.11.attention.output.dense.bias", "bert.encoder.layer.11.attention.output.LayerNorm.weight", "bert.encoder.layer.11.attention.output.LayerNorm.bias", "bert.encoder.layer.11.intermediate.dense.weight", "bert.encoder.layer.11.intermediate.dense.bias", "bert.encoder.layer.11.output.dense.weight", "bert.encoder.layer.11.output.dense.bias", "bert.encoder.layer.11.output.LayerNorm.weight", "bert.encoder.layer.11.output.LayerNorm.bias", "bert.pooler.dense.weight", "bert.pooler.dense.bias", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "module.bert.embeddings.position_ids", "module.bert.embeddings.word_embeddings.weight", "module.bert.embeddings.position_embeddings.weight", "module.bert.embeddings.token_type_embeddings.weight", "module.bert.embeddings.LayerNorm.weight", "module.bert.embeddings.LayerNorm.bias", "module.bert.encoder.layer.0.attention.self.query.weight", "module.bert.encoder.layer.0.attention.self.query.bias", "module.bert.encoder.layer.0.attention.self.key.weight", "module.bert.encoder.layer.0.attention.self.key.bias", "module.bert.encoder.layer.0.attention.self.value.weight", "module.bert.encoder.layer.0.attention.self.value.bias", "module.bert.encoder.layer.0.attention.output.dense.weight", "module.bert.encoder.layer.0.attention.output.dense.bias", "module.bert.encoder.layer.0.attention.output.LayerNorm.weight", "module.bert.encoder.layer.0.attention.output.LayerNorm.bias", "module.bert.encoder.layer.0.intermediate.dense.weight", "module.bert.encoder.layer.0.intermediate.dense.bias", "module.bert.encoder.layer.0.output.dense.weight", "module.bert.encoder.layer.0.output.dense.bias", "module.bert.encoder.layer.0.output.LayerNorm.weight", "module.bert.encoder.layer.0.output.LayerNorm.bias", "module.bert.encoder.layer.1.attention.self.query.weight", "module.bert.encoder.layer.1.attention.self.query.bias", "module.bert.encoder.layer.1.attention.self.key.weight", "module.bert.encoder.layer.1.attention.self.key.bias", "module.bert.encoder.layer.1.attention.self.value.weight", "module.bert.encoder.layer.1.attention.self.value.bias", "module.bert.encoder.layer.1.attention.output.dense.weight", "module.bert.encoder.layer.1.attention.output.dense.bias", "module.bert.encoder.layer.1.attention.output.LayerNorm.weight", "module.bert.encoder.layer.1.attention.output.LayerNorm.bias", "module.bert.encoder.layer.1.intermediate.dense.weight", "module.bert.encoder.layer.1.intermediate.dense.bias", "module.bert.encoder.layer.1.output.dense.weight", "module.bert.encoder.layer.1.output.dense.bias", "module.bert.encoder.layer.1.output.LayerNorm.weight", "module.bert.encoder.layer.1.output.LayerNorm.bias", "module.bert.encoder.layer.2.attention.self.query.weight", "module.bert.encoder.layer.2.attention.self.query.bias", "module.bert.encoder.layer.2.attention.self.key.weight", "module.bert.encoder.layer.2.attention.self.key.bias", "module.bert.encoder.layer.2.attention.self.value.weight", "module.bert.encoder.layer.2.attention.self.value.bias", "module.bert.encoder.layer.2.attention.output.dense.weight", "module.bert.encoder.layer.2.attention.output.dense.bias", "module.bert.encoder.layer.2.attention.output.LayerNorm.weight", "module.bert.encoder.layer.2.attention.output.LayerNorm.bias", "module.bert.encoder.layer.2.intermediate.dense.weight", "module.bert.encoder.layer.2.intermediate.dense.bias", "module.bert.encoder.layer.2.output.dense.weight", "module.bert.encoder.layer.2.output.dense.bias", "module.bert.encoder.layer.2.output.LayerNorm.weight", "module.bert.encoder.layer.2.output.LayerNorm.bias", "module.bert.encoder.layer.3.attention.self.query.weight", "module.bert.encoder.layer.3.attention.self.query.bias", "module.bert.encoder.layer.3.attention.self.key.weight", "module.bert.encoder.layer.3.attention.self.key.bias", "module.bert.encoder.layer.3.attention.self.value.weight", "module.bert.encoder.layer.3.attention.self.value.bias", "module.bert.encoder.layer.3.attention.output.dense.weight", "module.bert.encoder.layer.3.attention.output.dense.bias", "module.bert.encoder.layer.3.attention.output.LayerNorm.weight", "module.bert.encoder.layer.3.attention.output.LayerNorm.bias", "module.bert.encoder.layer.3.intermediate.dense.weight", "module.bert.encoder.layer.3.intermediate.dense.bias", "module.bert.encoder.layer.3.output.dense.weight", "module.bert.encoder.layer.3.output.dense.bias", "module.bert.encoder.layer.3.output.LayerNorm.weight", "module.bert.encoder.layer.3.output.LayerNorm.bias", "module.bert.encoder.layer.4.attention.self.query.weight", "module.bert.encoder.layer.4.attention.self.query.bias", "module.bert.encoder.layer.4.attention.self.key.weight", "module.bert.encoder.layer.4.attention.self.key.bias", "module.bert.encoder.layer.4.attention.self.value.weight", "module.bert.encoder.layer.4.attention.self.value.bias", "module.bert.encoder.layer.4.attention.output.dense.weight", "module.bert.encoder.layer.4.attention.output.dense.bias", "module.bert.encoder.layer.4.attention.output.LayerNorm.weight", "module.bert.encoder.layer.4.attention.output.LayerNorm.bias", "module.bert.encoder.layer.4.intermediate.dense.weight", "module.bert.encoder.layer.4.intermediate.dense.bias", "module.bert.encoder.layer.4.output.dense.weight", "module.bert.encoder.layer.4.output.dense.bias", "module.bert.encoder.layer.4.output.LayerNorm.weight", "module.bert.encoder.layer.4.output.LayerNorm.bias", "module.bert.encoder.layer.5.attention.self.query.weight", "module.bert.encoder.layer.5.attention.self.query.bias", "module.bert.encoder.layer.5.attention.self.key.weight", "module.bert.encoder.layer.5.attention.self.key.bias", "module.bert.encoder.layer.5.attention.self.value.weight", "module.bert.encoder.layer.5.attention.self.value.bias", "module.bert.encoder.layer.5.attention.output.dense.weight", "module.bert.encoder.layer.5.attention.output.dense.bias", "module.bert.encoder.layer.5.attention.output.LayerNorm.weight", "module.bert.encoder.layer.5.attention.output.LayerNorm.bias", "module.bert.encoder.layer.5.intermediate.dense.weight", "module.bert.encoder.layer.5.intermediate.dense.bias", "module.bert.encoder.layer.5.output.dense.weight", "module.bert.encoder.layer.5.output.dense.bias", "module.bert.encoder.layer.5.output.LayerNorm.weight", "module.bert.encoder.layer.5.output.LayerNorm.bias", "module.bert.encoder.layer.6.attention.self.query.weight", "module.bert.encoder.layer.6.attention.self.query.bias", "module.bert.encoder.layer.6.attention.self.key.weight", "module.bert.encoder.layer.6.attention.self.key.bias", "module.bert.encoder.layer.6.attention.self.value.weight", "module.bert.encoder.layer.6.attention.self.value.bias", "module.bert.encoder.layer.6.attention.output.dense.weight", "module.bert.encoder.layer.6.attention.output.dense.bias", "module.bert.encoder.layer.6.attention.output.LayerNorm.weight", "module.bert.encoder.layer.6.attention.output.LayerNorm.bias", "module.bert.encoder.layer.6.intermediate.dense.weight", "module.bert.encoder.layer.6.intermediate.dense.bias", "module.bert.encoder.layer.6.output.dense.weight", "module.bert.encoder.layer.6.output.dense.bias", "module.bert.encoder.layer.6.output.LayerNorm.weight", "module.bert.encoder.layer.6.output.LayerNorm.bias", "module.bert.encoder.layer.7.attention.self.query.weight", "module.bert.encoder.layer.7.attention.self.query.bias", "module.bert.encoder.layer.7.attention.self.key.weight", "module.bert.encoder.layer.7.attention.self.key.bias", "module.bert.encoder.layer.7.attention.self.value.weight", "module.bert.encoder.layer.7.attention.self.value.bias", "module.bert.encoder.layer.7.attention.output.dense.weight", "module.bert.encoder.layer.7.attention.output.dense.bias", "module.bert.encoder.layer.7.attention.output.LayerNorm.weight", "module.bert.encoder.layer.7.attention.output.LayerNorm.bias", "module.bert.encoder.layer.7.intermediate.dense.weight", "module.bert.encoder.layer.7.intermediate.dense.bias", "module.bert.encoder.layer.7.output.dense.weight", "module.bert.encoder.layer.7.output.dense.bias", "module.bert.encoder.layer.7.output.LayerNorm.weight", "module.bert.encoder.layer.7.output.LayerNorm.bias", "module.bert.encoder.layer.8.attention.self.query.weight", "module.bert.encoder.layer.8.attention.self.query.bias", "module.bert.encoder.layer.8.attention.self.key.weight", "module.bert.encoder.layer.8.attention.self.key.bias", "module.bert.encoder.layer.8.attention.self.value.weight", "module.bert.encoder.layer.8.attention.self.value.bias", "module.bert.encoder.layer.8.attention.output.dense.weight", "module.bert.encoder.layer.8.attention.output.dense.bias", "module.bert.encoder.layer.8.attention.output.LayerNorm.weight", "module.bert.encoder.layer.8.attention.output.LayerNorm.bias", "module.bert.encoder.layer.8.intermediate.dense.weight", "module.bert.encoder.layer.8.intermediate.dense.bias", "module.bert.encoder.layer.8.output.dense.weight", "module.bert.encoder.layer.8.output.dense.bias", "module.bert.encoder.layer.8.output.LayerNorm.weight", "module.bert.encoder.layer.8.output.LayerNorm.bias", "module.bert.encoder.layer.9.attention.self.query.weight", "module.bert.encoder.layer.9.attention.self.query.bias", "module.bert.encoder.layer.9.attention.self.key.weight", "module.bert.encoder.layer.9.attention.self.key.bias", "module.bert.encoder.layer.9.attention.self.value.weight", "module.bert.encoder.layer.9.attention.self.value.bias", "module.bert.encoder.layer.9.attention.output.dense.weight", "module.bert.encoder.layer.9.attention.output.dense.bias", "module.bert.encoder.layer.9.attention.output.LayerNorm.weight", "module.bert.encoder.layer.9.attention.output.LayerNorm.bias", "module.bert.encoder.layer.9.intermediate.dense.weight", "module.bert.encoder.layer.9.intermediate.dense.bias", "module.bert.encoder.layer.9.output.dense.weight", "module.bert.encoder.layer.9.output.dense.bias", "module.bert.encoder.layer.9.output.LayerNorm.weight", "module.bert.encoder.layer.9.output.LayerNorm.bias", "module.bert.encoder.layer.10.attention.self.query.weight", "module.bert.encoder.layer.10.attention.self.query.bias", "module.bert.encoder.layer.10.attention.self.key.weight", "module.bert.encoder.layer.10.attention.self.key.bias", "module.bert.encoder.layer.10.attention.self.value.weight", "module.bert.encoder.layer.10.attention.self.value.bias", "module.bert.encoder.layer.10.attention.output.dense.weight", "module.bert.encoder.layer.10.attention.output.dense.bias", "module.bert.encoder.layer.10.attention.output.LayerNorm.weight", "module.bert.encoder.layer.10.attention.output.LayerNorm.bias", "module.bert.encoder.layer.10.intermediate.dense.weight", "module.bert.encoder.layer.10.intermediate.dense.bias", "module.bert.encoder.layer.10.output.dense.weight", "module.bert.encoder.layer.10.output.dense.bias", "module.bert.encoder.layer.10.output.LayerNorm.weight", "module.bert.encoder.layer.10.output.LayerNorm.bias", "module.bert.encoder.layer.11.attention.self.query.weight", "module.bert.encoder.layer.11.attention.self.query.bias", "module.bert.encoder.layer.11.attention.self.key.weight", "module.bert.encoder.layer.11.attention.self.key.bias", "module.bert.encoder.layer.11.attention.self.value.weight", "module.bert.encoder.layer.11.attention.self.value.bias", "module.bert.encoder.layer.11.attention.output.dense.weight", "module.bert.encoder.layer.11.attention.output.dense.bias", "module.bert.encoder.layer.11.attention.output.LayerNorm.weight", "module.bert.encoder.layer.11.attention.output.LayerNorm.bias", "module.bert.encoder.layer.11.intermediate.dense.weight", "module.bert.encoder.layer.11.intermediate.dense.bias", "module.bert.encoder.layer.11.output.dense.weight", "module.bert.encoder.layer.11.output.dense.bias", "module.bert.encoder.layer.11.output.LayerNorm.weight", "module.bert.encoder.layer.11.output.LayerNorm.bias", "module.bert.pooler.dense.weight", "module.bert.pooler.dense.bias", "module.classifier.weight", "module.classifier.bias". 

### Step 4. Check on our test data
(5 min)

In [54]:
test_data = load_json("data/test_doc5.jsonl")

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer)
test_dataloader = DataLoader(test_set, batch_size=TEST_BATCH_SIZE)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name= f"sent_retrieval/test_doc5sent{TOP_N}_neg{NEGATIVE_RATIO}_{LR}_e{NUM_EPOCHS}_{MODEL_SHORT}_new.jsonl",
    # save_name=f"test_doc5sent{TOP_N}.jsonl",
)

Start predicting the test data


  0%|          | 0/112 [00:00<?, ?it/s]

notebook3
## PART 3. Claim verification

import libs

In [1]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm
from functools import partial

import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

Global variables

In [2]:
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json("data/train_doc5sent5.jsonl")
DEV_DATA = load_json("data/dev_doc5sent5.jsonl")

TRAIN_PKL_FILE = Path("data/train_doc5sent5.pkl")
DEV_PKL_FILE = Path("data/dev_doc5sent5.pkl")

Preload wiki database (same as part 2.)

In [3]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=296938), Label(value='0 / 296938')…

Transform to id to evidence_map mapping


### Helper function

AICUP dataset with top-k evidence sentences.

In [4]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]

        # In case there are less than topk evidence sentences
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad
        concat_claim_evidence = " [SEP] ".join([*claim, *evidence])

        concat = self.tokenizer(
            concat_claim_evidence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        label = LABEL2ID[item["label"]] if "label" in item else -1
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}

        if "label" in item:
            concat_ten["labels"] = torch.tensor(label)

        return concat_ten

Evaluation function

In [5]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.sum().item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())

    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

Prediction

In [6]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

### Main function

In [7]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)

    print(f"Extracting evidence_list for the {mode} mode ...")
    # if mode == "eval":
        # extract evidence
    df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
        mapping.get(evi_id, {}).get(str(evi_idx), "")
        for evi_id, evi_idx in x  # for each evidence list
    ][:topk] if isinstance(x, list) else [])
    print(df["evidence_list"][:topk])
    # else:
    #     # extract evidence
    #     # if df["label"] == "NOT ENOUGH INFO":
    #     #     df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
    #     #         mapping.get(evi_id, {}).get(str(evi_idx), "")
    #     #         for evi_id, evi_idx in x  # for each evidence list
    #     #     ][:topk] if isinstance(x, list) else [])
    #     # else:
    #     df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
    #         " ".join([  # join evidence
    #             mapping.get(evi_id, {}).get(str(evi_idx), "")
    #             for _, _, evi_id, evi_idx in evi_list
    #         ]) if isinstance(evi_list, list) else ""
    #         for evi_list in x  # for each evidence list
    #     ][:1] if isinstance(x, list) else [])

    return df

In [8]:
# def join_with_topk_evidence(
#     df: pd.Series,
#     mapping: dict,
#     mode: str = "train",
#     topk: int = 5,
# ) -> pd.Series:
#     # format evidence column to List[List[Tuple[str, str, str, str]]]
#     if "evidence" in df:
#         df["evidence"] = [[df["evidence"]]] if not isinstance(df["evidence"][0], list) else [df["evidence"]] if not isinstance(df["evidence"][0][0], list) else df["evidence"]

#     print(f"Extracting evidence_list for the {mode} mode ...")
#     if mode == "eval":
#         df["evidence_list"] = [
#             mapping.get(evi_id, {}).get(str(evi_idx), "")
#             for evi_id, evi_idx in df["predicted_evidence"]  # for each evidence list
#         ][:1] if isinstance(df["predicted_evidence"], list) else []
#         print(df["evidence_list"][:1])
#     else:
#         if df["label"] == "NOT ENOUGH INFO":
#             df["evidence_list"] = [
#                 mapping.get(evi_id, {}).get(str(evi_idx), "")
#                 for evi_id, evi_idx in df["predicted_evidence"]  # for each evidence list
#             ][:1] if isinstance(df["predicted_evidence"], list) else []
#             print(df["evidence_list"][:1])
#         else:
#             df["evidence_list"] = [
#                 " ".join([  # join evidence
#                     mapping.get(evi_id, {}).get(str(evi_idx), "")
#                     for _, _, evi_id, evi_idx in evi_list
#                 ]) if isinstance(evi_list, list) else ""
#                 for evi_list in df["evidence"]  # for each evidence list
#             ][:1] if isinstance(df["evidence"], list) else []
#     # else:
#     #     # extract evidence
#     #     # if df["label"] == "NOT ENOUGH INFO":
#     #     #     df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
#     #     #         mapping.get(evi_id, {}).get(str(evi_idx), "")
#     #     #         for evi_id, evi_idx in x  # for each evidence list
#     #     #     ][:topk] if isinstance(x, list) else [])
#     #     # else:
#     #     df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
#     #         " ".join([  # join evidence
#     #             mapping.get(evi_id, {}).get(str(evi_idx), "")
#     #             for _, _, evi_id, evi_idx in evi_list
#     #         ]) if isinstance(evi_list, list) else ""
#     #         for evi_list in x  # for each evidence list
#     #     ][:1] if isinstance(x, list) else [])

#     return df

### Step 1. Setup training environment

Hyperparams

In [8]:
#@title  { display-mode: "form" }

# MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
# MODEL_NAME = "ckiplab/bert-base-chinese" #@param {type:"string"}
# MODEL_NAME = "ckiplab/albert-base-chinese" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-bert-wwm" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-bert-wwm-ext" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-macbert-base" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-roberta-wwm-ext" #@param {type:"string"}
MODEL_NAME = "hfl/chinese-lert-base" #@param {type:"string"}
# MODEL_NAME = "hfl/chinese-lert-large" #@param {type:"string"}

MODEL_SHORT = "hfl_lert"
EVAL_VERSION = 2
TRAIN_BATCH_SIZE = 36  #@param {type:"integer"}
TEST_BATCH_SIZE = 32  #@param {type:"integer"}
SEED = 42  #@param {type:"integer"}
LR = 7e-5  #@param {type:"number"}
NUM_EPOCHS = 30  #@param {type:"integer"}
REAL_EPOCHS = 12
MAX_SEQ_LEN = 256  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
VALIDATION_STEP = 40  #@param {type:"integer"}


Experiment Directory

In [77]:
OUTPUT_FILENAME = "submission.jsonl"

EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_top{EVIDENCE_TOPK}_{MODEL_SHORT}_{EVAL_VERSION}_new"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Concat claim and evidences
join topk evidence

In [71]:
if not TRAIN_PKL_FILE.exists():
    # train_df = pd.DataFrame(TRAIN_DATA)
    # train_df = train_df.parallel_apply(partial(
    #     join_with_topk_evidence,
    #     mapping=mapping,
    #     topk=EVIDENCE_TOPK,
    # ), axis=1)
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=4)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    # dev_df = pd.DataFrame(DEV_DATA)
    # dev_df = dev_df.parallel_apply(partial(
    #     join_with_topk_evidence,
    #     mapping=mapping,
    #     mode="eval",
    #     topk=EVIDENCE_TOPK,
    # ), axis=1)
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=4)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

### Step 3. Training

Prevent CUDA out of memory

In [72]:
torch.cuda.empty_cache()

In [73]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
    num_workers=0,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE, num_workers=0,)

In [74]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABEL2ID),
)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
torch.cuda.empty_cache()
model.to(device)
optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at hfl/chinese-lert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not init

Training (30 mins)

In [75]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(REAL_EPOCHS):
    model.train()

    for batch in train_dataloader:
        torch.cuda.empty_cache()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.sum().backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.sum().item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print(f"Start validation: current_steps={current_steps}, epoch={epoch}")
            val_results = run_evaluation(model, eval_dataloader, device)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            val_acc = val_results['val_acc']
            if val_acc > 0.6:
                save_checkpoint(
                    model,
                    CKPT_DIR,
                    current_steps,
                    mark=f"val_acc={val_acc:.4f}",
                )

print("Finished training!")

  0%|          | 0/5130 [00:00<?, ?it/s]



Start validation: current_steps=40, epoch=0


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.132581335802873
val_acc: 0.4446614583333333




Start validation: current_steps=80, epoch=0


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.095502379039923
val_acc: 0.4720052083333333




Start validation: current_steps=120, epoch=0


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.9874166672428448
val_acc: 0.5305989583333334




Start validation: current_steps=160, epoch=0


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.9543982446193695
val_acc: 0.5436197916666666




Start validation: current_steps=200, epoch=1


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.6221595307191212
val_acc: 0.642578125




Start validation: current_steps=240, epoch=1


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.675708954532941
val_acc: 0.6484375




Start validation: current_steps=280, epoch=1


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.6783662773668766
val_acc: 0.654296875




Start validation: current_steps=320, epoch=1


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4261435605585575
val_acc: 0.6861979166666666




Start validation: current_steps=360, epoch=2


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.5751533011595409
val_acc: 0.7024739583333334




Start validation: current_steps=400, epoch=2


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4366440251469612
val_acc: 0.712890625




Start validation: current_steps=440, epoch=2


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.639624825368325
val_acc: 0.673828125




Start validation: current_steps=480, epoch=2


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4106120032568772
val_acc: 0.7018229166666666




Start validation: current_steps=520, epoch=3


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.414666519810756
val_acc: 0.7096354166666666




Start validation: current_steps=560, epoch=3


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.700703961153825
val_acc: 0.7018229166666666




Start validation: current_steps=600, epoch=3


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4411491366724174
val_acc: 0.7356770833333334




Start validation: current_steps=640, epoch=3


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4172103901704152
val_acc: 0.7265625




Start validation: current_steps=680, epoch=3


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.4785202940305073
val_acc: 0.7115885416666666




Start validation: current_steps=720, epoch=4


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.6205676620205243
val_acc: 0.7278645833333334




Start validation: current_steps=760, epoch=4


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.5663796477019787
val_acc: 0.7239583333333334




Start validation: current_steps=800, epoch=4


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.6477253213524818
val_acc: 0.703125




Start validation: current_steps=840, epoch=4


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.5940060416857402
val_acc: 0.720703125




Start validation: current_steps=880, epoch=5


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.016075844566027
val_acc: 0.716796875




Start validation: current_steps=920, epoch=5


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.739190336316824
val_acc: 0.7024739583333334




Start validation: current_steps=960, epoch=5


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.08507943029205
val_acc: 0.7037760416666666




Start validation: current_steps=1000, epoch=5


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 1.7601755112409592
val_acc: 0.7109375




Start validation: current_steps=1040, epoch=6


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.0265119398633638
val_acc: 0.7330729166666666




Start validation: current_steps=1080, epoch=6


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.1912097732226052
val_acc: 0.7115885416666666




Start validation: current_steps=1120, epoch=6


  0%|          | 0/48 [00:00<?, ?it/s]

val_loss: 2.002070697645346
val_acc: 0.71484375




Start validation: current_steps=1160, epoch=6


  0%|          | 0/48 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Step 4. Make your submission

In [78]:
TEST_DATA = load_json("data/test_doc5sent5.jsonl")
TEST_PKL_FILE = Path("data/test_doc5sent5.pkl")

if not TEST_PKL_FILE.exists():
    # test_df = pd.DataFrame(TEST_DATA)
    # test_df = test_df.parallel_apply(partial(
    #     join_with_topk_evidence,
    #     mapping=mapping,
    #     topk=EVIDENCE_TOPK,
    #     mode="eval",
    # ), axis=1)
    test_df = join_with_topk_evidence(
        pd.DataFrame(TEST_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    test_df.to_pickle(TEST_PKL_FILE, protocol=4)
else:
    with open(TEST_PKL_FILE, "rb") as f:
        test_df = pickle.load(f)

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

Extracting evidence_list for the eval mode ...


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=248), Label(value='0 / 248'))), HB…

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Prediction

In [35]:
torch.cuda.empty_cache()

In [83]:
ckpt_name = "val_acc=0.7617_model.1320.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
predicted_label = run_predict(model, test_dataloader, device)

Predicting:   0%|          | 0/31 [00:00<?, ?it/s]

Write files

In [84]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    f"submission/{ckpt_name[:14]}_{MODEL_SHORT}_{LR}_{EVAL_VERSION}_e{NUM_EPOCHS}_new_{OUTPUT_FILENAME}",
    orient="records",
    lines=True,
    force_ascii=False,
)