In [None]:
# built-in libs
import json
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

from helper_function import *
# 3rd party libs
import hanlp
import opencc
import pandas as pd
import wikipedia
from hanlp.components.pipeline import Pipeline
from pandarallel import pandarallel
# our own libs
from utils import load_json
#from helper_function import *
from utils import (
    generate_evidence_to_wiki_pages_mapping,
    jsonl_dir_to_df,
    load_json,
)


pandarallel.initialize(progress_bar=True, verbose=0, nb_workers=10)
wikipedia.set_lang("zh")

In [None]:
wiki_pages = jsonl_dir_to_df("./data/wiki-pages")
mapping = generate_evidence_to_wiki_pages_mapping(wiki_pages)

In [None]:
Private_Data = load_json("./data/raw/Private/private_test_data.jsonl")

# Hanlp Predictor

In [None]:
predictor = (hanlp.pipeline().append(
    hanlp.load("FINE_ELECTRA_SMALL_ZH"),
    output_key="tok",
).append(
    hanlp.load("CTB9_CON_ELECTRA_SMALL"),
    output_key="con",
    input_key="tok",
))

In [None]:
hanlp_file = f"./data/raw/Private/hanlp_con_PrivateTest_results.pkl"
if Path(hanlp_file).exists():
    with open(hanlp_file, "rb") as f:
        hanlp_results = pickle.load(f)
else:
    hanlp_results = [get_nps_hanlp(predictor, d) for d in Private_Data]
    with open(hanlp_file, "wb") as f:
        pickle.dump(hanlp_results, f)

In [None]:
import torch
CONVERTER_T2S = opencc.OpenCC("t2s.json")
CONVERTER_S2T = opencc.OpenCC("s2t.json")
def compare_strings(a, b):
    if a is None or b is None:
        print("Number of Same Characters: 0")
        return
    
    MaxSize = max(len(a), len(b)) # Finding the max length
    MinSize = min(len(a), len(b))
    SizeMinus=MaxSize-MinSize
    count = 0 # A counter to keep track of same characters

    for i in range(MaxSize):
        if(i==MinSize):
            break
        if a[i] != b[i]:
            count += 1 # Updating the counter when characters are same at an index
    count+=SizeMinus
    return count
    
def do_st_corrections(text: str) -> str:
    simplified = CONVERTER_T2S.convert(text)

    return CONVERTER_S2T.convert(simplified)
def get_nps_hanlp(
    predictor: Pipeline,
    d
) -> List[str]:
    claim = d["claim"]
    
    tree = predictor(claim)["con"][0]
    
    nps = [
        do_st_corrections("".join(subtree.leaves()))
        for subtree in tree.subtrees(lambda t: t.label() == "NP")
    ]
    
    return nps

def TextTok(Text):
    Hanlp_Text={}
    for title,text in Text.items():
        nps_tok={}
        if text=="":
        
            nps_tok["tok"]=[]
            
            Hanlp_Text[title]=nps_tok
        else:
          try:
            tok= predictor(text)["tok"]
            nps_tok["tok"]=tok
            Hanlp_Text[title]=nps_tok
          except RuntimeError as e:
              if "out of memory" in str(e):
                torch.cuda.empty_cache()
              nps_tok["tok"]=[]
              
          Hanlp_Text[title]=nps_tok
    Hanlp_TextTok_list=[]
    
    for k ,v in Hanlp_Text.items():
        
        Hanlp_TextTok_list.append(v["tok"])
    return Hanlp_TextTok_list


In [None]:
from rank_bm25 import BM25Okapi
import time
import numpy
all_candidate_list=[]
def get_pred_pages_BM25_ver1(series_data: pd.Series,mapping2) -> Set[Dict[int, str]]: #加上ver3 的 找'_' 的標題
    print(series_data.name)
    AllCandidateList=[]
    candiate_in_claim=[]
    nps_in_mapping=[]
    nps_in_candiate2=[]
    CompareNpsCandidate=[]
    
    
    results = []
    tmp_muji = []
    # wiki_page: its index showned in claim
    mapping = {}
    claim = series_data["claim"]
    nps = series_data["hanlp_results"]
    
    first_wiki_term = [] 
    for np in nps:
      
      if np in mapping2:
        nps_in_mapping.append(np)
    
    if len(nps)==0:
      return set([]) 
    for i, np in enumerate(nps):
        # Simplified Traditional Chinese Correction
        wiki_search_results = [
            do_st_corrections(w) for w in wikipedia.search(np)
        ]

        # Remove the wiki page's description in brackets
        wiki_set = [re.sub(r"\s\(\S+\)", "", w) for w in wiki_search_results]

        wiki_set = [w for w in wiki_search_results]#璿鈞改的(ver 1)

        wiki_df = pd.DataFrame({
            "wiki_set": wiki_set,
            "wiki_results": wiki_search_results
        })

        # Elements in wiki_set --> index
        # Extracting only the first element is one way to avoid extracting
        # too many of the similar wiki pages
        grouped_df = wiki_df.groupby("wiki_set", sort=False).first()
        candidates = grouped_df["wiki_results"].tolist()
        all_candidate_list.append(candidates)

        AllCandidateList.extend(candidates)#連上透過np search到的candidates

        # muji refers to wiki_set
        muji = grouped_df.index.tolist()

        for candidate in candidates:
          if candidate in claim:
              
              candiate_in_claim.append(candidate)

        for prefix, term in zip(muji, candidates):
            if prefix not in tmp_muji:
                matched = False

                # Take at least one term from the first noun phrase
                if i == 0:
                    first_wiki_term.append(term)

                # Walrus operator :=
                # https://docs.python.org/3/whatsnew/3.8.html#assignment-expressions
                # Through these filters, we are trying to figure out if the term
                # is within the claim
                if (((new_term := term) in claim) or
                    ((new_term := term.replace("·", "")) in claim) or
                    ((new_term := term.split(" ")[0]) in claim) or
                    ((new_term := term.replace("-", " ")) in claim)):
                    matched = True

                elif "·" in term:
                    splitted = term.split("·")
                    for split in splitted:
                        if (new_term := split) in claim:
                            matched = True
                            break

                if matched:
                    # post-processing
                    term = term.replace(" ", "_")
                    term = term.replace("-", "")
                    results.append(term)
                    mapping[term] = claim.find(new_term)
                    tmp_muji.append(new_term)
        #time.sleep(0.5)
    ######### BM 25 ###################################
    if len(AllCandidateList) == 0:
      return set([])
    Test_Doc_Full=list(set(AllCandidateList))
    Text={}
    for Doc in Test_Doc_Full:
        TempNull=wiki_pages[wiki_pages["id"]==Doc].empty
        if not TempNull:
            Text[Doc]=wiki_pages[wiki_pages["id"]==Doc].values[0][1]
        else:
            Text[Doc]=""
    
    Hanlp_TextTok_list=TextTok(Text)
    Text_bm25 = BM25Okapi(Hanlp_TextTok_list)

    Claim_tok = predictor(claim)["tok"]
    doc_scores = Text_bm25.get_scores(Claim_tok)
    
    BM25_Answer_Index=numpy.argsort(doc_scores)[-1]
    
    BM25_Answer=[Test_Doc_Full[BM25_Answer_Index]]
    ######### BM 25 Finish###################################    

    # 5 is a hyperparameter
    if len(results) > 10:
        assert -1 not in mapping.values()
        results = sorted(mapping, key=mapping.get)[:10]
    elif len(results) < 1:
        results = first_wiki_term
    results.extend(candiate_in_claim)
    results.extend(nps_in_mapping)
    results.extend(BM25_Answer)
    
    return set(results)

In [None]:
doc_path = f"./data/raw/Private/Private_doc10_BM25_final.jsonl"
if Path(doc_path).exists():
    with open(doc_path, "r", encoding="utf8") as f:
        predicted_results = pd.Series([
            set(json.loads(line)["predicted_pages"])
            for line in f
        ])
else:
    private_df = pd.DataFrame(Private_Data)
    private_df.loc[:, "hanlp_results"] = hanlp_results
    predicted_results = private_df.apply(get_pred_pages_BM25_ver1, mapping2 = mapping, axis=1)
    

In [None]:
def save_private_doc(
    data: List[Dict[str, Union[int, Claim, Evidence]]],
    predictions: pd.Series,
    mode: str = "train",
    num_pred_doc: int = 5,
    
) -> None:
    with open(
        f"./data/Stage1/{mode}_doc{num_pred_doc}_BM25_Ver1_final.jsonl",
        "w",
        encoding="utf8",
    ) as f:
        for i, d in enumerate(data):
            print(i)
            d["predicted_pages"] = list(predictions.iloc[i])
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

In [None]:
save_private_doc(Private_Data, predicted_results, mode="Private",num_pred_doc=10)