# Baseline
python: 3.8.*

use ```Ctrl + ]``` to collapse all section :)

Download our starter pack (3~5 min)

In [21]:
# !gdown --folder 1T6jpOtdf_i6XNYA6F_lqU4mRRh1xYPcl
# !mv baseline/* ./

In [22]:
# %pip install -r requirements.txt

notebook1
## PART 1. Document retrieval

Prepare the environment and import all library we need

In [23]:
import json
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# 3rd party libs
import hanlp
import opencc
import pandas as pd
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel

# our own libs
from utils import load_json

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

Preload the data.

In [24]:
TRAIN_DATA = load_json("data/public_train.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")

Data class for type hinting

In [25]:
@dataclass
class Claim:
    data: str

@dataclass
class AnnotationID:
    id: int

@dataclass
class EvidenceID:
    id: int

@dataclass
class PageTitle:
    title: str

@dataclass
class SentenceID:
    id: int

@dataclass
class Evidence:
    data: List[List[Tuple[AnnotationID, EvidenceID, PageTitle, SentenceID]]]

### Helper function

For the sake of consistency, we convert traditional to simplified Chinese first before converting it back to traditional Chinese.  This is due to some errors occuring when converting traditional to traditional Chinese.

In [26]:
def do_st_corrections(text: str) -> str:
    simplified = CONVERTER_T2S.convert(text)

    return CONVERTER_S2T.convert(simplified)

We use constituency parsing to separate part of speeches or so called constituent to extract noun phrases.  In the later stages, we will use the noun phrases as the query to search for relevant documents.  

In [27]:
def get_nps_hanlp(
    predictor: Pipeline,
    d: Dict[str, Union[int, Claim, Evidence]],
) -> List[str]:
    claim = d["claim"]
    tree = predictor(claim)["con"]
    nps = [
        do_st_corrections("".join(subtree.leaves()))
        for subtree in tree.subtrees(lambda t: t.label() == "NP")
    ]

    return nps

Precision refers to how many related documents are retrieved.  Recall refers to how many relevant documents are retrieved.  

In [28]:
def calculate_precision(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")


def calculate_recall(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")

The default amount of documents retrieved is at most five documents.  This `num_pred_doc` can be adjusted based on your objective.  Save data in jsonl format.

In [29]:
def save_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 5,
) -> None:
    with open(
        f"data/{mode}_doc{num_pred_doc}.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

### Main function for document retrieval

In [None]:
wiki_path = "data/wiki-pages"
min_wiki_length = 10
num_of_samples = 500
topk = 15
min_df = 2
max_df = 0.8
use_idf = True
sublinear_tf = True

!pip install TCSP
from TCSP import read_stopwords_list
stopwords = read_stopwords_list()

# Set up the experiment name for logging
exp_name = (
    f"len{min_wiki_length}_top{topk}_min_df={min_df}_"
    + f"max_df={max_df}_{num_of_samples}s"
)
if sublinear_tf:
    exp_name = "sublinearTF_" + exp_name
if not use_idf:
    exp_name = "no_idf_" + exp_name

In [None]:
def tokenize(text: str, stopwords: list) -> str:
    import jieba
    """This function performs Chinese word segmentation and removes stopwords.

    Args:
        text (str): claim or wikipedia article
        stopwords (list): common words that contribute little to the meaning of a sentence

    Returns:
        str: word segments separated by space (e.g. "我 喜歡 吃 蘋果")
    """

    tokens = jieba.cut(text)

    return " ".join([w for w in tokens if w not in stopwords])

In [None]:
from functools import partial
from utils import jsonl_dir_to_df

wiki_cache = "wiki"
target_column = "text"

wiki_cache_path = Path(f"data/{wiki_cache}.pkl")
if wiki_cache_path.exists():
    wiki_pages = pd.read_pickle(wiki_cache_path)
else:
    # You need to download `wiki-pages.zip` from the AICUP website
    wiki_pages = jsonl_dir_to_df(wiki_path)
    # wiki_pages are combined into one dataframe, so we need to reset the index
    wiki_pages = wiki_pages.reset_index(drop=True)

    # tokenize the text and keep the result in a new column `processed_text`
    wiki_pages["processed_text"] = wiki_pages[target_column].parallel_apply(
        partial(tokenize, stopwords=stopwords)
    )
    # save the result to a pickle file
    wiki_pages.to_pickle(wiki_cache_path, protocol=4)

In [31]:
def get_pred_pages(series_data: pd.Series) -> Set[Dict[int, str]]:
    import wikipedia
    import re
    import opencc
    import pandas as pd

    from TCSP import read_stopwords_list
    stopwords = read_stopwords_list()

    wikipedia.set_lang("zh")
    CONVERTER_T2S = opencc.OpenCC("t2s.json")
    CONVERTER_S2T = opencc.OpenCC("s2t.json")
    
    def do_st_corrections(text: str) -> str:
        simplified = CONVERTER_T2S.convert(text)
        return CONVERTER_S2T.convert(simplified)

    results = []
    tmp_muji = []
    # wiki_page: its index showned in claim
    mapping = {}
    claim = series_data["claim"]
    nps = series_data["hanlp_results"]
    first_wiki_term = []
    repeated_mention = []
    quote_search = []

    def clean_claim(claim) -> str:     # Clean claim function because hanlp has error when conducting cons
        def multiple_replacer(*kv):
            replace_dict = dict(kv)
            replace_func = lambda match: replace_dict[match.group(0)]
            pattern = re.compile("|".join([re.escape(k) for k, v in kv]), re.M)
            return lambda string: pattern.sub(replace_func, string) 
        def multiple_replace(string, *kv):
            return multiple_replacer(*replace_dict)(claim)

        replace_dict = (" ", ""), ("牠", "它"), ("（", "("), ("）", ")"), ("，", ","), ("、", ","), ("群", "羣"), ("“", "\""), ("”", "\""), ("「", "“"), ("」", "”")
        claim = multiple_replace(claim, *replace_dict)
        claim = claim.lower()
        return claim

    claim = clean_claim(claim)

    def post_processing(np, page, loc):
        page = do_st_corrections(page)
        page = page.replace(" ", "_")
        page = page.replace("-", "")
        search_pos = claim.find(np)
        if search_pos != -1:
            if page in results:
                repeated_mention.append(page)
                # results.insert(0, results.pop(results.index(page)))     # Fresh page to front if it was mention before
            else:
                results.append(page)
            if loc == 0:
                pass
                # print(f"Add: {page}, at page direct search, np={np}")
            elif loc == 1:
                pass
                # print(f"Add: {page}, at match, new term={np}")
            mapping[page] = search_pos
            tmp_muji.append(np)

    def if_page_exists(page: str) -> bool:
        import requests
        url_base = "https://zh.wikipedia.org/wiki/"
        new_url = [url_base + page, url_base + page.upper()]
        for url in new_url:
            r = requests.head(url)
            if r.status_code == 200:
                return True
            else:
                continue
        return False
    
    def clean_time_format(np: str):
        if (matched := re.search(r"\d+年", np)) != None:
            return True
        if (matched := re.search(r"\d+月\d+日", np)) != None:
            return True
        if (matched := re.search(r"\d+小時", np)) != None:
            return True
        if (matched := re.search(r"\d+天", np)) != None:
            return True
        if (matched := re.search(r"\d+世紀", np)) != None:
            return True
        if (matched := re.search(r"\d+年代", np)) != None:
            return True
        return False

    for i, np in enumerate(nps):
        # print(f"searching {np}")
        quote_dup = False
        if np in stopwords:         # 如果包含停用詞
            continue
        if clean_time_format(np):   # 如果包含時間
            continue
        
        # Ignore parsing among quotation mark, for example, if《仲夏夜之夢》exists, ignore「仲夏夜」and「夢」
        for search in quote_search:
            if search.find(np) != -1:
                quote_dup = True
        if quote_dup == True:
            continue

        # Delete Bookname Mark, Quote Mark
        np_no_quote = re.sub(r"《|》|〈|〉|【|】|「|」|『|』|（|）", "", np)
        if np != np_no_quote:
            quote_search.append(np_no_quote)
            np = np_no_quote

        # Simplified Traditional Chinese Correction
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        # Directly Search by Redirection
            # Check if a page exists
        if (if_page_exists(np)):
            try:
                page = do_st_corrections(wikipedia.page(title=np).title)
                if page == np:
                    # print(f"Found, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
                else:
                    # print(f"Redirect, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
            except wikipedia.DisambiguationError as diserr:
                page = do_st_corrections(wikipedia.search(np)[0])
                if page == np:
                    # print(f"Disambig, np={np}, page={page}")
                    post_processing(np=np, page=page, loc=0)
            except wikipedia.PageError as pageerr:
                pass

        # Remove the wiki page's description in brackets
        wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]
        wiki_df = pd.DataFrame({
            "wiki_set": wiki_set,
            "wiki_results": wiki_search_results
        })

        # Elements in wiki_set --> index
        # Extracting only the first element is one way to avoid extracting
        # too many of the similar wiki pages
        grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
        candidates = grouped_df["wiki_results"].tolist()
        # muji refers to wiki_set
        muji = grouped_df.index.tolist()

        for prefix, term in zip(muji, candidates):
            if prefix not in tmp_muji:  #忽略掉括號，如果括號有重複的話。假設如果有" 1 (數字)", 則"1 (符號)" 會被忽略
                matched = False

                # Take at least one term from the first noun phrase
                if i == 0:
                    first_wiki_term.append(term)

                # Walrus operator :=
                # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
                # Through these filters, we are trying to figure out if the term
                # is within the claim
                if (((new_term := term) in claim) or
                    ((new_term := term) in claim.replace(" ", "")) or
                    ((new_term := term.replace("·", "")) in claim) or                                   # 過濾人名
                    ((new_term := re.sub(r"\s\(\S+\)", "", term)) in claim) or                          # 過濾空格 / 消歧義
                    ((new_term := term.replace("(", "").replace(")", "").split()[0]) in claim and       # 消歧義與括號內皆有在裡面
                     (new_term := term.replace("(", "").replace(")", "").split()[1]) in claim) or
                    ((new_term := term.replace("-", " ")) in claim) or                                  # 過濾槓號
                    ((new_term := term.lower()) in claim) or                                            # 過濾大小寫
                    ((new_term := term.lower().replace("-", "")) in claim) or                           # 過濾大小寫及槓號
                    ((new_term := re.sub(r"\s\(\S+\)", "", term.lower().replace("-", ""))) in claim)    # 過濾大小寫、槓號及消歧義
                    ):
                    matched = True
                    # print(new_term, term)

                # 人名匹配
                elif "·" in term:
                    splitted = term.split("·")
                    if "·" not in claim:        # 要求claim顯示的不為全名，不然都需要全名
                        for split in splitted:
                            if (new_term := split) in claim:
                                matched = True
                                break

                if matched:
                    post_processing(np=new_term, page=term, loc=1)

    # 8 is a hyperparameter
    if len(results) > 8:
        assert -1 not in mapping.values()
        # print("長度大於8", results)

        results = repeated_mention + sorted(mapping, key=mapping.get)[:8]
        results = list(set(results))            # remove duplicates
        # print("排序後", results)
    if len(results) < 1:
        results = first_wiki_term
        # print("第一搜尋結果", results)
    
    print(results)
    return set(results)

### Step 1. Get noun phrases from hanlp consituency parsing tree

Setup [HanLP](https://github.com/hankcs/HanLP) predictor (1 min)

In [34]:
predictor = (hanlp.pipeline().append(
    hanlp.load("FINE_ELECTRA_SMALL_ZH"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_ELECTRA_SMALL"),
    output_key="con",
    input_key="tok",
))

                                             

We will skip this process which for creating parsing tree when demo on class

In [35]:
hanlp_file = f"data/hanlp_con_results.pkl"
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in TRAIN_DATA]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

Get pages via wiki online api

In [36]:
doc_path = f"data/train_doc5.jsonl"
if Path(doc_path).exists():
    with open(doc_path, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    train_df = pd.DataFrame(TRAIN_DATA)
    train_df.loc[:, "hanlp_results"] = hanlp_results
    # predicted_results = train_df.progress_apply(get_pred_pages, axis=1)
    predicted_results = train_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TRAIN_DATA, predicted_results, mode="train")

### Step 2. Calculate our results

In [37]:
calculate_precision(TRAIN_DATA, predicted_results)
calculate_recall(TRAIN_DATA, predicted_results)

Precision: 0.24198516414141477
Recall: 0.8701666666666668


### Step 3. Repeat the same process on test set
Create parsing tree

In [38]:
hanlp_test_file = f"data/hanlp_con_test_results.pkl"
if Path(hanlp_test_file).exists():
    with open(hanlp_test_file, "rb") as f:
        hanlp_test_results = pickle.load(f)
else:
    hanlp_test_results = [get_nps_hanlp(predictor, d) for d in TEST_DATA]
    with open(hanlp_test_file, "wb") as f:
        pickle.dump(hanlp_test_results, f)

Get pages via wiki online api

In [39]:
test_doc_path = f"data/test_doc5.jsonl"
if Path(test_doc_path).exists():
    with open(test_doc_path, "r", encoding="utf8") as f:
        test_results = pd.Series(
            [set(json.loads(line)["predicted_pages"]) for line in f])
else:
    test_df = pd.DataFrame(TEST_DATA)
    test_df.loc[:, "hanlp_results"] = hanlp_test_results
    test_results = test_df.parallel_apply(get_pred_pages, axis=1)
    save_doc(TEST_DATA, test_results, mode="test")

notebook2
## PART 2. Sentence retrieval

Import some libs

In [40]:
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

# third-party libs
import numpy as np
import pandas as pd
from pandarallel import pandarallel
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset, Dataset

# local libs
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)

Global variable

In [41]:
SEED = 42

TRAIN_DATA = load_json("data/public_train.jsonl")
TEST_DATA = load_json("data/public_test.jsonl")
DOC_DATA = load_json("data/train_doc5.jsonl")

LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

_y = [LABEL2ID[data["label"]] for data in TRAIN_DATA]
# GT means Ground Truth
TRAIN_GT, DEV_GT = train_test_split(
    DOC_DATA,
    test_size=0.2,
    random_state=SEED,
    shuffle=True,
    stratify=_y,
)

Preload wiki database (1 min)

In [42]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=118776), Label(value='0 / 118776')…

### Helper function

Calculate precision for sentence retrieval

In [None]:
def evidence_macro_precision(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate precision for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of precision)
        [2]: retrieved (denominator of precision)
    """
    this_precision = 0.0
    this_precision_hits = 0.0

    # Return 0, 0 if label is not enough info since not enough info does not
    # contain any evidence.
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # e[2] is the page title, e[3] is the sentence index
        all_evi = [[e[2], e[3]]
                   for eg in instance["evidence"]
                   for e in eg
                   if e[3] is not None]
        claim = instance["claim"]
        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for prediction in predicted_evidence:
            if prediction in all_evi:
                this_precision += 1.0
            this_precision_hits += 1.0

        return (this_precision /
                this_precision_hits) if this_precision_hits > 0 else 1.0, 1.0

    return 0.0, 0.0

Calculate recall for sentence retrieval

In [None]:
def evidence_macro_recall(
    instance: Dict,
    top_rows: pd.DataFrame,
) -> Tuple[float, float]:
    """Calculate recall for sentence retrieval
    This function is modified from fever-scorer.
    https://github.com/sheffieldnlp/fever-scorer/blob/master/src/fever/scorer.py

    Args:
        instance (dict): a row of the dev set (dev.jsonl) of test set (test.jsonl)
        top_rows (pd.DataFrame): our predictions with the top probabilities

        IMPORTANT!!!
        instance (dict) should have the key of `evidence`.
        top_rows (pd.DataFrame) should have a column `predicted_evidence`.

    Returns:
        Tuple[float, float]:
        [1]: relevant and retrieved (numerator of recall)
        [2]: relevant (denominator of recall)
    """
    # We only want to score F1/Precision/Recall of recalled evidence for NEI claims
    if instance["label"].upper() != "NOT ENOUGH INFO":
        # If there's no evidence to predict, return 1
        if len(instance["evidence"]) == 0 or all(
            [len(eg) == 0 for eg in instance]):
            return 1.0, 1.0

        claim = instance["claim"]

        predicted_evidence = top_rows[top_rows["claim"] ==
                                      claim]["predicted_evidence"].tolist()

        for evidence_group in instance["evidence"]:
            evidence = [[e[2], e[3]] for e in evidence_group]
            if all([item in predicted_evidence for item in evidence]):
                # We only want to score complete groups of evidence. Incomplete
                # groups are worthless.
                return 1.0, 1.0
        return 0.0, 1.0
    return 0.0, 0.0

Calculate the scores of sentence retrieval

In [None]:
def evaluate_retrieval(
    probs: np.ndarray,
    df_evidences: pd.DataFrame,
    ground_truths: pd.DataFrame,
    top_n: int = 5,
    cal_scores: bool = True,
    save_name: str = None,
) -> Dict[str, float]:
    """Calculate the scores of sentence retrieval

    Args:
        probs (np.ndarray): probabilities of the candidate retrieved sentences
        df_evidences (pd.DataFrame): the candiate evidence sentences paired with claims
        ground_truths (pd.DataFrame): the loaded data of dev.jsonl or test.jsonl
        top_n (int, optional): the number of the retrieved sentences. Defaults to 2.

    Returns:
        Dict[str, float]: F1 score, precision, and recall
    """
    df_evidences["prob"] = probs
    top_rows = (
        df_evidences.groupby("claim").apply(
        lambda x: x.nlargest(top_n, "prob"))
        .reset_index(drop=True)
    )

    if cal_scores:
        macro_precision = 0
        macro_precision_hits = 0
        macro_recall = 0
        macro_recall_hits = 0

        for i, instance in enumerate(ground_truths):
            macro_prec = evidence_macro_precision(instance, top_rows)
            macro_precision += macro_prec[0]
            macro_precision_hits += macro_prec[1]

            macro_rec = evidence_macro_recall(instance, top_rows)
            macro_recall += macro_rec[0]
            macro_recall_hits += macro_rec[1]

        pr = (macro_precision /
              macro_precision_hits) if macro_precision_hits > 0 else 1.0
        rec = (macro_recall /
               macro_recall_hits) if macro_recall_hits > 0 else 0.0
        f1 = 2.0 * pr * rec / (pr + rec)

    if save_name is not None:
        # write doc7_sent5 file
        with open(f"data/{save_name}", "w", encoding="utf8") as f:
            for instance in ground_truths:
                claim = instance["claim"]
                predicted_evidence = top_rows[
                    top_rows["claim"] == claim]["predicted_evidence"].tolist()
                instance["predicted_evidence"] = predicted_evidence
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")

    if cal_scores:
        return {"F1 score": f1, "Precision": pr, "Recall": rec}

Inference script to get probabilites for the candidate evidence sentences

In [None]:
def get_predicted_probs(
    model: nn.Module,
    dataloader: Dataset,
    device: torch.device,
) -> np.ndarray:
    """Inference script to get probabilites for the candidate evidence sentences

    Args:
        model: the one from HuggingFace Transformers
        dataloader: devset or testset in torch dataloader

    Returns:
        np.ndarray: probabilites of the candidate evidence sentences
    """
    model.eval()
    probs = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            probs.extend(torch.softmax(logits, dim=1)[:, 1].tolist())

    return np.array(probs)

AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences

In [None]:
class SentRetrievalBERTDataset(BERTDataset):
    """AicupTopkEvidenceBERTDataset class for AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        sentA = item["claim"]
        sentB = item["text"]

        # claim [SEP] text
        concat = self.tokenizer(
            sentA,
            sentB,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}
        if "label" in item:
            concat_ten["labels"] = torch.tensor(item["label"])

        return concat_ten

### Main function for sentence retrieval

In [None]:
def pair_with_wiki_sentences(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    negative_ratio: float,
) -> pd.DataFrame:
    """Only for creating train sentences."""
    claims = []
    sentences = []
    labels = []

    # positive
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue

        claim = df["claim"].iloc[i]
        evidence_sets = df["evidence"].iloc[i]
        for evidence_set in evidence_sets:
            sents = []
            for evidence in evidence_set:
                # evidence[2] is the page title
                page = evidence[2].replace(" ", "_")
                # the only page with weird name
                if page == "臺灣海峽危機#第二次臺灣海峽危機（1958）":
                    continue
                # evidence[3] is in form of int however, mapping requires str
                sent_idx = str(evidence[3])
                sents.append(mapping[page][sent_idx])

            whole_evidence = " ".join(sents)

            claims.append(claim)
            sentences.append(whole_evidence)
            labels.append(1)

    # negative
    for i in range(len(df)):
        if df["label"].iloc[i] == "NOT ENOUGH INFO":
            continue
        claim = df["claim"].iloc[i]

        evidence_set = set([(evidence[2], evidence[3])
                            for evidences in df["evidence"][i]
                            for evidence in evidences])
        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [
                    (page, sent_idx) for sent_idx in mapping[page].keys()
                ]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for pair in page_sent_id_pairs:
                if pair in evidence_set:
                    continue
                text = mapping[page][pair[1]]
                # `np.random.rand(1) <= 0.05`: Control not to add too many negative samples
                if text != "" and np.random.rand(1) <= negative_ratio:
                    claims.append(claim)
                    sentences.append(text)
                    labels.append(0)

    return pd.DataFrame({"claim": claims, "text": sentences, "label": labels})


def pair_with_wiki_sentences_eval(
    mapping: Dict[str, Dict[int, str]],
    df: pd.DataFrame,
    is_testset: bool = False,
) -> pd.DataFrame:
    """Only for creating dev and test sentences."""
    claims = []
    sentences = []
    evidence = []
    predicted_evidence = []

    # negative
    for i in range(len(df)):
        # if df["label"].iloc[i] == "NOT ENOUGH INFO":
        #     continue
        claim = df["claim"].iloc[i]

        predicted_pages = df["predicted_pages"][i]
        for page in predicted_pages:
            page = page.replace(" ", "_")
            try:
                page_sent_id_pairs = [(page, k) for k in mapping[page]]
            except KeyError:
                # print(f"{page} is not in our Wiki db.")
                continue

            for page_name, sentence_id in page_sent_id_pairs:
                text = mapping[page][sentence_id]
                if text != "":
                    claims.append(claim)
                    sentences.append(text)
                    if not is_testset:
                        evidence.append(df["evidence"].iloc[i])
                    predicted_evidence.append([page_name, int(sentence_id)])

    return pd.DataFrame({
        "claim": claims,
        "text": sentences,
        "evidence": evidence if not is_testset else None,
        "predicted_evidence": predicted_evidence,
    })

### Step 1. Setup training environment

Hyperparams

In [None]:
MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
NUM_EPOCHS = 1  #@param {type:"integer"}
LR = 2e-5  #@param {type:"number"}
TRAIN_BATCH_SIZE = 64  #@param {type:"integer"}
TEST_BATCH_SIZE = 256  #@param {type:"integer"}
NEGATIVE_RATIO = 0.03  #@param {type:"number"}
VALIDATION_STEP = 50  #@param {type:"integer"}
TOP_N = 5  #@param {type:"integer"}
#@title  { display-mode: "form" }

Experiment Directory

In [None]:
EXP_DIR = f"sent_retrieval/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_neg{NEGATIVE_RATIO}_top{TOP_N}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Combine claims and evidences

In [17]:
train_df = pair_with_wiki_sentences(
    mapping,
    pd.DataFrame(TRAIN_GT),
    NEGATIVE_RATIO,
)
counts = train_df["label"].value_counts()
print("Now using the following train data with 0 (Negative) and 1 (Positive)")
print(counts)

dev_evidences = pair_with_wiki_sentences_eval(mapping, pd.DataFrame(DEV_GT))

Now using the following train data with 0 (Negative) and 1 (Positive)
0    3347
1    2774
Name: label, dtype: int64


### Step 3. Start training

Dataloader things

In [18]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = SentRetrievalBERTDataset(train_df, tokenizer=tokenizer)
val_dataset = SentRetrievalBERTDataset(dev_evidences, tokenizer=tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE)

Save your memory.

In [19]:
del train_df

Trainer

In [20]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")
print(torch.cuda.is_available())
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

True


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Please make sure that you are using gpu when training (5 min)

In [35]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print("Start validation")
            probs = get_predicted_probs(model, eval_dataloader, device)

            val_results = evaluate_retrieval(
                probs=probs,
                df_evidences=dev_evidences,
                ground_truths=DEV_GT,
                top_n=TOP_N,
            )
            print(val_results)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                writer.add_scalar(
                    f"dev_{metric_name}",
                    metric_value,
                    current_steps,
                )

            save_checkpoint(model, CKPT_DIR, current_steps)

print("Finished training!")

  0%|          | 0/185 [00:00<?, ?it/s]

Start validation


  0%|          | 0/128 [00:00<?, ?it/s]

In [34]:
%load_ext tensorboard
%tensorboard --logdir logs

Reusing TensorBoard on port 6006 (pid 17672), started 0:15:52 ago. (Use '!kill 17672' to kill it.)

Validation part (15 mins)

In [37]:
ckpt_name = "model.100.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
print("Start final evaluations and write prediction files.")

train_evidences = pair_with_wiki_sentences_eval(
    mapping=mapping,
    df=pd.DataFrame(TRAIN_GT),
)
train_set = SentRetrievalBERTDataset(train_evidences, tokenizer)
train_dataloader = DataLoader(train_set, batch_size=TEST_BATCH_SIZE)

print("Start calculating training scores")
probs = get_predicted_probs(model, train_dataloader, device)
train_results = evaluate_retrieval(
    probs=probs,
    df_evidences=train_evidences,
    ground_truths=TRAIN_GT,
    top_n=TOP_N,
    save_name=f"train_doc5sent{TOP_N}.jsonl",
)
print(f"Training scores => {train_results}")

print("Start validation")
probs = get_predicted_probs(model, eval_dataloader, device)
val_results = evaluate_retrieval(
    probs=probs,
    df_evidences=dev_evidences,
    ground_truths=DEV_GT,
    top_n=TOP_N,
    save_name=f"dev_doc5sent{TOP_N}.jsonl",
)

print(f"Validation scores => {val_results}")

Start final evaluations and write prediction files.
Start calculating training scores


  0%|          | 0/505 [00:00<?, ?it/s]

Training scores => {'F1 score': 0.4025395110732468, 'Precision': 0.26937664041994885, 'Recall': 0.7960629921259843}
Start validation


  0%|          | 0/128 [00:00<?, ?it/s]

Validation scores => {'F1 score': 0.3646408438320653, 'Precision': 0.23765723270440292, 'Recall': 0.7830188679245284}


Load the model we want.

In [39]:
ckpt_name = "model.100.pt"
model = load_model(model, ckpt_name, CKPT_DIR)

### Step 4. Check on our test data
(5 min)

In [40]:
test_data = load_json("data/test_doc5.jsonl")

test_evidences = pair_with_wiki_sentences_eval(
    mapping,
    pd.DataFrame(test_data),
    is_testset=True,
)
test_set = SentRetrievalBERTDataset(test_evidences, tokenizer)
test_dataloader = DataLoader(test_set, batch_size=TEST_BATCH_SIZE)

print("Start predicting the test data")
probs = get_predicted_probs(model, test_dataloader, device)
evaluate_retrieval(
    probs=probs,
    df_evidences=test_evidences,
    ground_truths=test_data,
    top_n=TOP_N,
    cal_scores=False,
    save_name=f"test_doc5sent{TOP_N}.jsonl",
)

Start predicting the test data


  0%|          | 0/162 [00:00<?, ?it/s]

In [None]:
del wiki_pages

notebook3
## PART 3. Claim verification

import libs

In [3]:
import pickle
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from tqdm.auto import tqdm

import torch
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
)

from dataset import BERTDataset
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
    load_model,
    save_checkpoint,
    set_lr_scheduler,
)

pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=4)

Global variables

In [4]:
LABEL2ID: Dict[str, int] = {
    "supports": 0,
    "refutes": 1,
    "NOT ENOUGH INFO": 2,
}
ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}

TRAIN_DATA = load_json("data/train_doc5sent5.jsonl")
DEV_DATA = load_json("data/dev_doc5sent5.jsonl")

TRAIN_PKL_FILE = Path("data/train_doc5sent5.pkl")
DEV_PKL_FILE = Path("data/dev_doc5sent5.pkl")

Preload wiki database (same as part 2.)

In [36]:
wiki_pages = jsonl_dir_to_df("data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages,)
del wiki_pages

Reading and concatenating jsonl files in data/wiki-pages
Generate parse mapping


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=296938), Label(value='0 / 296938')…

Transform to id to evidence_map mapping


### Helper function

AICUP dataset with top-k evidence sentences.

In [5]:
class AicupTopkEvidenceBERTDataset(BERTDataset):
    """AICUP dataset with top-k evidence sentences."""

    def __getitem__(
        self,
        idx: int,
        **kwargs,
    ) -> Tuple[Dict[str, torch.Tensor], int]:
        item = self.data.iloc[idx]
        claim = item["claim"]
        evidence = item["evidence_list"]

        # In case there are less than topk evidence sentences
        pad = ["[PAD]"] * (self.topk - len(evidence))
        evidence += pad
        concat_claim_evidence = " [SEP] ".join([*claim, *evidence])

        concat = self.tokenizer(
            concat_claim_evidence,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
        )
        label = LABEL2ID[item["label"]] if "label" in item else -1
        concat_ten = {k: torch.tensor(v) for k, v in concat.items()}

        if "label" in item:
            concat_ten["labels"] = torch.tensor(label)

        return concat_ten

Evaluation function

In [6]:
def run_evaluation(model: torch.nn.Module, dataloader: DataLoader, device):
    model.eval()

    loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            y_true.extend(batch["labels"].tolist())

            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss += outputs.loss.item()
            logits = outputs.logits
            y_pred.extend(torch.argmax(logits, dim=1).tolist())

    acc = accuracy_score(y_true, y_pred)

    return {"val_loss": loss / len(dataloader), "val_acc": acc}

Prediction

In [7]:
def run_predict(model: torch.nn.Module, test_dl: DataLoader, device) -> list:
    model.eval()

    preds = []
    for batch in tqdm(test_dl,
                      total=len(test_dl),
                      leave=False,
                      desc="Predicting"):
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(**batch).logits
        pred = torch.argmax(pred, dim=1)
        preds.extend(pred.tolist())
    return preds

### Main function

In [8]:
def join_with_topk_evidence(
    df: pd.DataFrame,
    mapping: dict,
    mode: str = "train",
    topk: int = 5,
) -> pd.DataFrame:
    """join_with_topk_evidence join the dataset with topk evidence.

    Note:
        After extraction, the dataset will be like this:
               id     label         claim                           evidence            evidence_list
        0    4604  supports       高行健...     [[[3393, 3552, 高行健, 0], [...  [高行健 （ ）江西赣州出...
        ..    ...       ...            ...                                ...                     ...
        945  2095  supports       美國總...  [[[1879, 2032, 吉米·卡特, 16], [...  [卸任后 ， 卡特積極參與...
        停各种战争及人質危機的斡旋工作 ， 反对美国小布什政府攻打伊拉克...

        [946 rows x 5 columns]

    Args:
        df (pd.DataFrame): The dataset with evidence.
        wiki_pages (pd.DataFrame): The wiki pages dataframe
        topk (int, optional): The topk evidence. Defaults to 5.
        cache(Union[Path, str], optional): The cache file path. Defaults to None.
            If cache is None, return the result directly.

    Returns:
        pd.DataFrame: The dataset with topk evidence_list.
            The `evidence_list` column will be: List[str]
    """

    # format evidence column to List[List[Tuple[str, str, str, str]]]
    if "evidence" in df.columns:
        df["evidence"] = df["evidence"].parallel_map(
            lambda x: [[x]] if not isinstance(x[0], list) else [x]
            if not isinstance(x[0][0], list) else x)

    print(f"Extracting evidence_list for the {mode} mode ...")
    if mode == "eval":
        # extract evidence
        df["evidence_list"] = df["predicted_evidence"].parallel_map(lambda x: [
            mapping.get(evi_id, {}).get(str(evi_idx), "")
            for evi_id, evi_idx in x  # for each evidence list
        ][:topk] if isinstance(x, list) else [])
        print(df["evidence_list"][:5])
    else:
        # extract evidence
        df["evidence_list"] = df["evidence"].parallel_map(lambda x: [
            " ".join([  # join evidence
                mapping.get(evi_id, {}).get(str(evi_idx), "")
                for _, _, evi_id, evi_idx in evi_list
            ]) if isinstance(evi_list, list) else ""
            for evi_list in x  # for each evidence list
        ][:1] if isinstance(x, list) else [])

    return df

### Step 1. Setup training environment

Hyperparams

In [9]:
#@title  { display-mode: "form" }

MODEL_NAME = "bert-base-chinese"  #@param {type:"string"}
TRAIN_BATCH_SIZE = 32  #@param {type:"integer"}
TEST_BATCH_SIZE = 32  #@param {type:"integer"}
SEED = 42  #@param {type:"integer"}
LR = 7e-5  #@param {type:"number"}
NUM_EPOCHS = 20  #@param {type:"integer"}
MAX_SEQ_LEN = 256  #@param {type:"integer"}
EVIDENCE_TOPK = 5  #@param {type:"integer"}
VALIDATION_STEP = 25  #@param {type:"integer"}


Experiment Directory

In [10]:
OUTPUT_FILENAME = "submission.jsonl"

EXP_DIR = f"claim_verification/e{NUM_EPOCHS}_bs{TRAIN_BATCH_SIZE}_" + f"{LR}_top{EVIDENCE_TOPK}"
LOG_DIR = "logs/" + EXP_DIR
CKPT_DIR = "checkpoints/" + EXP_DIR

if not Path(LOG_DIR).exists():
    Path(LOG_DIR).mkdir(parents=True)

if not Path(CKPT_DIR).exists():
    Path(CKPT_DIR).mkdir(parents=True)

### Step 2. Concat claim and evidences
join topk evidence

In [13]:
if not TRAIN_PKL_FILE.exists():
    train_df = join_with_topk_evidence(
        pd.DataFrame(TRAIN_DATA),
        mapping,
        topk=EVIDENCE_TOPK,
    )
    train_df.to_pickle(TRAIN_PKL_FILE, protocol=4)
else:
    with open(TRAIN_PKL_FILE, "rb") as f:
        train_df = pickle.load(f)

if not DEV_PKL_FILE.exists():
    dev_df = join_with_topk_evidence(
        pd.DataFrame(DEV_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    dev_df.to_pickle(DEV_PKL_FILE, protocol=4)
else:
    with open(DEV_PKL_FILE, "rb") as f:
        dev_df = pickle.load(f)

### Step 3. Training

Prevent CUDA out of memory

In [14]:
torch.cuda.empty_cache()

In [15]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = AicupTopkEvidenceBERTDataset(
    train_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
val_dataset = AicupTopkEvidenceBERTDataset(
    dev_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=TRAIN_BATCH_SIZE,
    num_workers=0,
)
eval_dataloader = DataLoader(val_dataset, batch_size=TEST_BATCH_SIZE, num_workers=0,)

In [18]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABEL2ID),
)
torch.cuda.empty_cache()
model.to(device)
optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = set_lr_scheduler(optimizer, num_training_steps)

writer = SummaryWriter(LOG_DIR)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Training (30 mins)

In [48]:
progress_bar = tqdm(range(num_training_steps))
current_steps = 0

for epoch in range(NUM_EPOCHS):
    model.train()

    for batch in train_dataloader:
        torch.cuda.empty_cache()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        writer.add_scalar("training_loss", loss.item(), current_steps)

        y_pred = torch.argmax(outputs.logits, dim=1).tolist()
        y_true = batch["labels"].tolist()

        current_steps += 1

        if current_steps % VALIDATION_STEP == 0 and current_steps > 0:
            print(f"Start validation: current_steps={current_steps}, epoch={epoch}")
            val_results = run_evaluation(model, eval_dataloader, device)

            # log each metric separately to TensorBoard
            for metric_name, metric_value in val_results.items():
                print(f"{metric_name}: {metric_value}")
                writer.add_scalar(f"{metric_name}", metric_value, current_steps)

            val_acc = val_results['val_acc']
            if val_acc > 0.425:
                save_checkpoint(
                    model,
                    CKPT_DIR,
                    current_steps,
                    mark=f"val_acc={val_acc:.4f}",
                )

print("Finished training!")

  0%|          | 0/7900 [00:00<?, ?it/s]

Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.114389246762401
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.1316226064556776
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.19079269905283
val_acc: 0.40304182509505704
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.1761916279792786
val_acc: 0.39036755386565275
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.279089006510648
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3377732556275648
val_acc: 0.41191381495564006
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3664308292697174
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.442280523102693
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.4725218969162064
val_acc: 0.40430925221799746
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.456760449240906
val_acc: 0.40304182509505704
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3379164584959395
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5474961592693521
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6274360466485072
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3992377730331036
val_acc: 0.3979721166032953
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5774370040556398
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6046950558219293
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7114758515598798
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6305967120811193
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6619010350920937
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5201638870769076
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5772019179180414
val_acc: 0.40430925221799746
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6082600812719325
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.4425658607723737
val_acc: 0.40304182509505704
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5432460416447034
val_acc: 0.39543726235741444
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6776298838432389
val_acc: 0.40430925221799746
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6333725133327524
val_acc: 0.3979721166032953
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6597484482659235
val_acc: 0.3979721166032953
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.565665921779594
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6728206439451738
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7221430214968594
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.788847894379587
val_acc: 0.39923954372623577
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6063556057034116
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6055542092130641
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6360067761305608
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.688050759561134
val_acc: 0.39670468948035487
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7771392651278564
val_acc: 0.40304182509505704
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 2.0613622436619767
val_acc: 0.39670468948035487
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3851500550905864
val_acc: 0.41698352344740175
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.759592714333775
val_acc: 0.41064638783269963
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 2.0486410573275404
val_acc: 0.4017743979721166
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 2.036674429069866
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.836132632361518
val_acc: 0.394169835234474
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8691654301653005
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9113480069420554
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5370751360450128
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.3365787207478224
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.624794015378663
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5801026429792848
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.2080159837549382
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.2133192444088483
val_acc: 0.40304182509505704
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5159313865382262
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.2682004414423547
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.4858153974167023
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5088602045569757
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6082136354061087
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6596766914984193
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5790591468714705
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.5893238849110074
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6433561396117162
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.6717540356847975
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.754118314897171
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7130930893348926
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7489637750567812
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7651879058943853
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7776750524838765
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.7989339798387856
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8128677769140764
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.82206956726132
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8278968779727667
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8315473042353234
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8458542444489219
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8523936259626137
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8644671404000484
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.87747550251508
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9081054293748103
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8806627212148723
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8887571788797475
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8973556383691652
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.904657446374797
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.910041412921867
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9168186627253136
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9234089797193354
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9300729343385408
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.950245231691033
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.942913289021964
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9610961609416537
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9480240465414644
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.956796215640174
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9471057039318662
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8823885863477534
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8816798485890784
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8693001186004792
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8905226919386122
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8981373984404284
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9113300165744742
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.910336140430335
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9045647581418355
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.873437609937456
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8850812315940857
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.893773782734919
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8821155651651247
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8670349343858583
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8809945468950753
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8902107850469725
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.898110523368373
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9073972954894558
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9160297055437108
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9253904903777923
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9358821445041232
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9544267257054646
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9328162387163952
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9347268108165625
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9444632385716294
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9486715588906798
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9552927625299705
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9621414313412675
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9700979632560653
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9692478047476873
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9747218971300606
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9919854258046006
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9857500905942436
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 2.0050617450415484
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.996324851055338
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9597034815585974
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9558044304751387
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9089498495814776
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.922931264747273
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.926931715974904
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9450206708426427
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9614314780090794
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.958143209568178
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9757732381724349
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9168919818569916
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.911529075015675
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9238110527847752
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9386637403507425
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.945758372846276
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9408821668287721
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8883594900670677
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8965851122682744
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9198751130489389
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.92873265827545
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9449204495458892
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.951263210388145
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.959362561654563
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9652047030853503
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9647740884260698
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9605914823936694
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.918159554101
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.916269190985747
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9286567893895237
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9376779469576748
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9577501626929852
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.958535464725109
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9721567642809166
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9643251004845206
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9158637391196356
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9252703653441534
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9407567165114663
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.953162361275066
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.953540907363699
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8991600630259273
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.899287380955436
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8972414804227424
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8575215285474604
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8780827925662802
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.895766971087215
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.907292375660906
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9197067516018647
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9307089146941598
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.947939730653859
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9515001713627516
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9022691207702713
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.902116273388718
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9129085480564771
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9261670371498725
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9370805807787963
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9501849551393529
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9592483507262335
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9667067810742542
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9738587472173903
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9795717037085332
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9918602307637532
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 2.002118235886699
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.98721462789208
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9314115149806244
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8669846045850502
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8830660280555185
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.890835479654447
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8436683019002278
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8611581542275168
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8836158098596516
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9121750051325017
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9160132751320347
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.870487169183866
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8718265927199162
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8869341768399635
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8382199394582497
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8509901184024233
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.861169129309028
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8112889215199635
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.828856271926803
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.847919427987301
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8641256118061567
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8794542701557428
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.892731011515916
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9055095519682375
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9144494551600832
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.926146787826461
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9202162418702635
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.927043288645118
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.935493908747278
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8955080750012638
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.893350773387485
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8773889475398593
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.853958345422841
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.815235632838625
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8168292075696617
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.833876417140768
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8498192604142005
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8659137198419282
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8836074242688188
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9070822373785155
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8676744440589288
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8641666902436151
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8749460246827867
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.890156505685864
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8973789750927625
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8825749357541401
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8331665631496545
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8233193827397896
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8369527796302179
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8508106060702392
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.863029956817627
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8734649245185082
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8851159082518683
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.886540128727152
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8924893875314732
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8983475445496916
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8997215140949597
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8886984177310058
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8721110977307716
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8802075078993132
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8923964903812216
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9031694134076436
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9064556890063815
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9069673641763552
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8386737806628448
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8441867840410484
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8554518150560784
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8667356871595286
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8793858783413666
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8842536861246282
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8924444721202658
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8998611804210779
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9070864410111399
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9136741570752076
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9239052275214532
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9312280007083007
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.931041778940143
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9348831303191907
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9327395082724215
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9122241789644414
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9068616863453027
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9126368761062622
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9170794035449172
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.894884503248966
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.862934900654687
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8683555914898111
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8432207926355226
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.824076002294367
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.829669855459772
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8386082035122495
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8506405979695946
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8583870162867537
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.864117818649369
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8699574344085925
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8777647114763356
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8830173996963886
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8876455986138545
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8876917458543874
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.87339478189295
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8741495236001833
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8785654432845837
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8842043720110497
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8872819771670333
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.891517733684694
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8954879623470884
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8988508491805105
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9033363015964777
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9075293661368014
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.9101697862750353
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.910006749509561
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8994662418509975
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.898671188739815
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8901634180184566
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8879606350503786
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.881472399138441
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8795611129866705
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8821152514881558
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8844396544225288
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8871957322563788
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.879960054701025
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8813557751250989
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8750986499015732
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8733888184181366
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8743026015734432
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8750392915022494
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8759119263803117
val_acc: 0.4005069708491762
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8768047136489792
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8773339715870945
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8784166761118957
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.878962141696853
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8775354780332008
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8774738805462616
val_acc: 0.4055766793409379
Start validation


  0%|          | 0/99 [00:00<?, ?it/s]

val_loss: 1.8775473104582892
val_acc: 0.4055766793409379
Finished training!


### Step 4. Make your submission

In [16]:
TEST_DATA = load_json("data/test_doc5sent5.jsonl")
TEST_PKL_FILE = Path("data/test_doc5sent5.pkl")

if not TEST_PKL_FILE.exists():
    test_df = join_with_topk_evidence(
        pd.DataFrame(TEST_DATA),
        mapping,
        mode="eval",
        topk=EVIDENCE_TOPK,
    )
    test_df.to_pickle(TEST_PKL_FILE, protocol=4)
else:
    with open(TEST_PKL_FILE, "rb") as f:
        test_df = pickle.load(f)

test_dataset = AicupTopkEvidenceBERTDataset(
    test_df,
    tokenizer=tokenizer,
    max_length=MAX_SEQ_LEN,
)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

Prediction

In [19]:
ckpt_name = "val_acc=0.4559_model.250.pt"  #@param {type:"string"}
model = load_model(model, ckpt_name, CKPT_DIR)
predicted_label = run_predict(model, test_dataloader, device)

Predicting:   0%|          | 0/31 [00:00<?, ?it/s]

Write files

In [20]:
predict_dataset = test_df.copy()
predict_dataset["predicted_label"] = list(map(ID2LABEL.get, predicted_label))
predict_dataset[["id", "predicted_label", "predicted_evidence"]].to_json(
    f"submission/{ckpt_name[:14]}_{OUTPUT_FILENAME}",
    orient="records",
    lines=True,
    force_ascii=False,
)