In [18]:
import json
from datasets import load_from_disk
from torch.utils.data import DataLoader, TensorDataset

In [19]:
train_file = load_from_disk("/opt/ml/input/data/data/train_dataset")["train"]
validation_file = load_from_disk("/opt/ml/input/data/data/train_dataset")["validation"]

with open("/opt/ml/lastcode/dataset/preprocess_wiki.json", "r") as f:
    wiki = json.load(f)
wiki_contexts = list(dict.fromkeys([v['text'] for v in wiki.values()]))
#wiki_articles = [{"document_text" : wiki_contexts[i]} for i in range(len(wiki_contexts))]

In [20]:
print(wiki[str(6)])
print(len(wiki_contexts))

{'text': "텍스트 파일을 아오조라 문고에 수록할 때, 텍스트 파일이 갖추어야 할 서식을 '아오조라 문고' 형식이라 부른다. 아오조라 문고 형식은 텍스트 파일로서 많은 환경에서 읽을 수 있도록 규격화되어있다. 때문에 가능한 한 원본의 충실한 재현을 목표로 삼고 있지만, 줄 바꿈이나 삽화 등의 정보는 원칙적으로 포함되지 않는다. 아오조라 문고 형식에 대응하는 텍스트 뷰어와 텍스트 편집기도 존재하며, 올림문자와 방점 등도 재현할 수 있다. 또 이러한 텍스트 뷰어에서는 본래 아오조라 문고 형식에 포함되지 않았던 삽화 정보를 삽입하거나 세로쓰기로 표시할 수 있으며, 텍스트를 읽기 쉽도록 만드는 다양한 기능이 포함되어 있다. 이러한 소프트웨어는 유료와 무료를 불문하고 종류가 다양하다.", 'corpus_source': '위키피디아', 'url': 'TODO', 'domain': None, 'title': '아오조라 문고', 'author': None, 'html': None, 'document_id': 6}
55963


In [21]:
qa_records = [{"example_id" : train_file[i]["id"], "document_title" : train_file[i]["title"], "question_text" : train_file[i]["question"], "answer" : train_file[i]["answers"]} for i in range(len(train_file))]
wiki_articles = [{"document_text" : wiki_contexts[i]} for i in range(len(wiki_contexts))]
#wiki_articles = [{"document_text" : wiki_contexts[i]["text"], "document_title" : wiki_contexts[i]["title"], "id" : wiki_contexts[i]["document_id"]} for i in range(len(wiki_contexts))]


In [22]:
print(qa_records[0])

{'example_id': 'mrc-1-000067', 'document_title': '미국 상원', 'question_text': '대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?', 'answer': {'answer_start': [235], 'text': ['하원']}}


In [6]:
print(type(wiki_articles[0]["id"]))

KeyError: 'id'

In [None]:
# # download elasticsearch
# ! wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.2-linux-x86_64.tar.gz -q
# ! tar -xzf elasticsearch-7.6.2-linux-x86_64.tar.gz
# ! chown -R daemon:daemon elasticsearch-7.6.2

In [7]:
import os
from subprocess import Popen, PIPE, STDOUT
es_server = Popen(['elasticsearch-7.6.2/bin/elasticsearch'],
                   stdout=PIPE, stderr=STDOUT,
                   preexec_fn=lambda: os.setuid(1)  # as daemon
                  )
# wait until ES has started
! sleep 30

In [9]:
# collapse-hide
# !pip install elasticsearch
# !pip install tqdm

In [8]:
from elasticsearch import Elasticsearch

config = {'host':'localhost', 'port':9200}
es = Elasticsearch([config])

# test connection
es.ping()

True

In [9]:
# es.indices.delete(index='nori-index', ignore=[400, 404])
# es.indices.delete(index='ngram-bm25-index', ignore=[400, 404])
# es.indices.delete(index='nori-dfr-index ', ignore=[400, 404])
# es.indices.delete(index='squad-standard-index', ignore=[400, 404])

In [10]:
es.indices.delete(index='nori-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer"}
                }
            }
        }

index_name = 'nori-index'
es.indices.create(index=index_name, body=index_config, ignore=400)

{'acknowledged': True, 'shards_acknowledged': True, 'index': 'nori-index'}

In [11]:
# collapse-hide
from tqdm.notebook import tqdm

def populate_index(es_obj, index_name, evidence_corpus):
    '''
    Loads records into an existing Elasticsearch index

    Args:
        es_obj (elasticsearch.client.Elasticsearch) - Elasticsearch client object
        index_name (str) - Name of index
        evidence_corpus (list) - List of dicts containing data records

    '''

    for i, rec in enumerate(tqdm(evidence_corpus)):
    
        try:
            index_status = es_obj.index(index=index_name, id=i, body=rec)
        except:
            print(f'Unable to load document {i}.')
            
    n_records = es_obj.count(index=index_name)['count']
    print(f'Succesfully loaded {n_records} into {index_name}')


    return

In [12]:
all_wiki_articles = wiki_articles

populate_index(es_obj=es, index_name='nori-index', evidence_corpus=all_wiki_articles)

HBox(children=(FloatProgress(value=0.0, max=55963.0), HTML(value='')))

Unable to load document 21954.

Succesfully loaded 55963 into nori-index


In [13]:
# collapse-hide
def search_es(es_obj, index_name, question_text, n_results):
    '''
    Execute an Elasticsearch query on a specified index
    
    Args:
        es_obj (elasticsearch.client.Elasticsearch) - Elasticsearch client object
        index_name (str) - Name of index to query
        query (dict) - Query DSL
        n_results (int) - Number of results to return
        
    Returns
        res - Elasticsearch response object
    
    '''
    
    # construct query
    query = {
            'query': {
                'match': {
                    'document_text': question_text
                    }
                }
            }
    
    res = es_obj.search(index=index_name, body=query, size=n_results)
    
    return res

In [14]:
question_text = "대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?"

# execute query
res = search_es(es_obj=es, index_name='nori-index', question_text=question_text, n_results=20)

In [15]:
print(f'Question: {question_text}')
print(f'Query Duration: {res["took"]} milliseconds')
print('Title, Relevance Score:')
context_list = [(hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']]

Question: 대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?
Query Duration: 14 milliseconds
Title, Relevance Score:


In [16]:
# collapse-hide
import numpy as np
import pandas as pd

def average_precision(binary_results):
    
    ''' Calculates the average precision for a list of binary indicators '''
    
    m = 0
    precs = []

    for i, val in enumerate(binary_results):
        if val == 1:
            m += 1
            precs.append(sum(binary_results[:i+1])/(i+1))
            
    ap = (1/m)*np.sum(precs) if m else 0
            
    return ap


def evaluate_retriever(es_obj, index_name, qa_records, n_results):
    '''
    This function loops through a set of question/answer examples from SQuAD2.0 and 
    evaluates Elasticsearch as a information retrieval tool in terms of recall, mAP, and query duration.
    
    Args:
        es_obj (elasticsearch.client.Elasticsearch) - Elasticsearch client object
        index_name (str) - name of index to query
        qa_records (list) - list of qa_records from preprocessing steps
        n_results (int) - the number of results ElasticSearch should return for a given query
        
    Returns:
        test_results_df (pd.DataFrame) - a dataframe recording search results info for every example in qa_records
    
    '''
    
    results = []
    
    for i, qa in enumerate(tqdm(qa_records)):
        
        ex_id = qa['example_id']
        question = 
        answer = qa['answer']
        
        # execute query
        res = search_es(es_obj=es_obj, index_name=index_name, question_text=question, n_results=n_results)
        
        # calculate performance metrics from query response info
        duration = res['took']
        binary_results = [int(answer["text"][0].lower() in doc['_source']['document_text'].lower()) for doc in res['hits']['hits']]
        ans_in_res = int(any(binary_results))
        ap = average_precision(binary_results)

        rec = (ex_id, question, answer, duration, ans_in_res, ap)
        results.append(rec)
    
    # format results dataframe
    cols = ['example_id', 'question', 'answer', 'query_duration', 'answer_present', 'average_precision']
    results_df = pd.DataFrame(results, columns=cols)
    
    # format results dict
    metrics = {'Recall': results_df.answer_present.value_counts(normalize=True)[1],
               'Mean Average Precision': results_df.average_precision.mean(),
               'Average Query Duration':results_df.query_duration.mean()
              }
               
    
    return results_df, metrics

In [29]:
all_qa_records = qa_records
qa_records_answerable = [record for record in all_qa_records if record['answer'] != '']

# run evaluation
results_df, metrics = evaluate_retriever(es_obj=es, index_name='nori-index', qa_records=qa_records_answerable, n_results=20)

HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))




In [30]:
metrics

{'Recall': 0.9461032388663968,
 'Mean Average Precision': 0.7185290783416389,
 'Average Query Duration': 9.699645748987853}

In [21]:
es.indices.delete(index='nori-dfr-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            },
            "index": {
                "similarity":{
                    "my_similarity": {
                        "type": "DFR",
                        "basic_model": "g",
                        "after_effect": "l",
                        "normalization": "h2",
                        "normalization.h2.c": "3.0"
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "document_title": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "id": {"type" : "integer"},
                }
            }
        }

es.indices.create(index='nori-dfr-index', body=index_config, ignore=400)
populate_index(es_obj=es, index_name='nori-dfr-index', evidence_corpus=all_wiki_articles)
dfr_results_df, dfr_stem_metrics = evaluate_retriever(es_obj=es, index_name='nori-dfr-index', qa_records=qa_records_answerable, n_results=1)
print(dfr_stem_metrics)

HBox(children=(FloatProgress(value=0.0, max=60613.0), HTML(value='')))


Succesfully loaded 60613 into nori-dfr-index


HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))


{'Recall': 0.715334008097166, 'Mean Average Precision': 0.715334008097166, 'Average Query Duration': 11.179402834008098}


In [22]:
es.indices.delete(index='nori-dfi-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            },
            "index": {
                "similarity":{
                    "my_similarity": {
                        "type": "DFI",
                        "independence_measure": "saturated"
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "document_title": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "id": {"type" : "integer"},
                }
            }
        }

es.indices.create(index='nori-dfi-index', body=index_config, ignore=400)
populate_index(es_obj=es, index_name='nori-dfi-index', evidence_corpus=all_wiki_articles)
dfr_results_df, dfr_stem_metrics = evaluate_retriever(es_obj=es, index_name='nori-dfi-index', qa_records=qa_records_answerable, n_results=1)
print(dfr_stem_metrics)

HBox(children=(FloatProgress(value=0.0, max=60613.0), HTML(value='')))


Succesfully loaded 60613 into nori-dfi-index


HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))


{'Recall': 0.694585020242915, 'Mean Average Precision': 0.694585020242915, 'Average Query Duration': 26.84741902834008}


In [23]:
es.indices.delete(index='nori-ib-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            },
            "index": {
                "similarity":{
                    "my_similarity": {
                        "type": "IB",
                        "distribution" : "ll",
                        "lambda" : "df",
                        "normalization": "h2",
                        "normalization.h2.c": "3.0"
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "document_title": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "id": {"type" : "integer"},
                }
            }
        }

es.indices.create(index='nori-ib-index', body=index_config, ignore=400)
populate_index(es_obj=es, index_name='nori-ib-index', evidence_corpus=all_wiki_articles)
dfr_results_df, dfr_stem_metrics = evaluate_retriever(es_obj=es, index_name='nori-ib-index', qa_records=qa_records_answerable, n_results=1)
print(dfr_stem_metrics)

HBox(children=(FloatProgress(value=0.0, max=60613.0), HTML(value='')))

Unable to load document 24711.
Unable to load document 58517.

Succesfully loaded 60613 into nori-ib-index


HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))


{'Recall': 0.6447368421052632, 'Mean Average Precision': 0.6447368421052632, 'Average Query Duration': 38.584514170040485}


In [None]:
 0.6955971659919028

In [24]:
es.indices.delete(index='nori-lmd-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            },
            "index": {
                "similarity":{
                    "my_similarity": {
                        "type": "LMDirichlet",
                        "mu": 1000
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "document_title": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "id": {"type" : "integer"},
                }
            }
        }

es.indices.create(index='nori-lmd-index', body=index_config, ignore=400)
populate_index(es_obj=es, index_name='nori-lmd-index', evidence_corpus=all_wiki_articles)
dfr_results_df, dfr_stem_metrics = evaluate_retriever(es_obj=es, index_name='nori-lmd-index', qa_records=qa_records_answerable, n_results=1)
print(dfr_stem_metrics)

HBox(children=(FloatProgress(value=0.0, max=60613.0), HTML(value='')))


Succesfully loaded 60613 into nori-lmd-index


HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))


{'Recall': 0.7249493927125507, 'Mean Average Precision': 0.7249493927125507, 'Average Query Duration': 32.183704453441294}


In [25]:
es.indices.delete(index='nori-lmj-index', ignore=[400, 404])
index_config = {
        "settings": {
            "analysis": {
                "analyzer": {
                    "nori_analyzer": {
                        "type": "custom",
                        "tokenizer": "nori_tokenizer",
                        "decompound_mode": "mixed",
                        "stopwords": "_korean_",
                    }
                }
            },
            "index": {
                "similarity":{
                    "my_similarity": {
                        "type": "LMJelinekMercer",
                        "lambda": 0.7
                    }
                }
            }
        },
        "mappings": {
            "dynamic": "strict", 
            "properties": {
                "document_text": {"type": "text", "analyzer": "nori_analyzer", "similarity" : "my_similarity"},
                "document_title": {"type": "text"},
                "id": {"type" : "integer"},
                }
            }
        }

es.indices.create(index='nori-lmj-index', body=index_config, ignore=400)
populate_index(es_obj=es, index_name='nori-lmj-index', evidence_corpus=all_wiki_articles)
dfr_results_df, dfr_stem_metrics = evaluate_retriever(es_obj=es, index_name='nori-lmj-index', qa_records=qa_records_answerable, n_results=1)
print(dfr_stem_metrics)

HBox(children=(FloatProgress(value=0.0, max=60613.0), HTML(value='')))

Unable to load document 32746.

Succesfully loaded 60613 into nori-lmj-index


HBox(children=(FloatProgress(value=0.0, max=3952.0), HTML(value='')))


{'Recall': 0.7110323886639676, 'Mean Average Precision': 0.7110323886639676, 'Average Query Duration': 26.332995951417004}
