In [1]:
# 환경설정

!pip3 install transformers
!pip3 install torch
!pip3 install konlpy

Collecting transformers
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 13.6 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 53.7 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 47.7 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 41.4 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.1.2-py3-none-any.whl (59 kB)
[K     |████████████████████████████████| 59 kB 6.2 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Atte

In [2]:
# 시험데이터셋 수집

!wget https://raw.githubusercontent.com/aifactory-team/hanryubank/main/dataset_test.csv

--2021-11-19 06:03:41--  https://raw.githubusercontent.com/aifactory-team/hanryubank/main/dataset_test.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21935 (21K) [text/plain]
Saving to: ‘dataset_test.csv’


2021-11-19 06:03:41 (106 MB/s) - ‘dataset_test.csv’ saved [21935/21935]



In [3]:
# 모듈 설정

from transformers import ElectraTokenizer, ElectraForQuestionAnswering, pipeline
import pandas as pd

tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-small-v2-distilled-korquad-384")
model = ElectraForQuestionAnswering.from_pretrained("monologg/koelectra-small-v2-distilled-korquad-384")
qa = pipeline("question-answering", tokenizer=tokenizer, model=model)

import konlpy
okt = konlpy.tag.Okt()

Downloading:   0%|          | 0.00/249k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/52.3M [00:00<?, ?B/s]

In [15]:
# 입력된 질문(query)에 해당하는 유사한 질문을 찾고, 그 유사한 질문에 해당하는 지문을 검색하는 함수

def get_context_and_question(query):

    df = pd.read_csv('dataset_test.csv')
    result = df[df['Q1 (질문1)'] == query]

    context = result['Context (지문)'].item()
    question = result['Q1 (질문1)'].item()

    return context, question

In [16]:
# 지문과 질문을 입력하여 정답을 얻는 함수

def get_answer(context, question):
    result = qa({"question": question, "context": context})
    score = result["score"]
    start = result["start"]
    end = result["end"]
    answer = result["answer"]
    return answer, start, end, score

In [17]:
# 해답의 어미를 바꾸는 함수

def modify_answer(before_answer):

    document = before_answer
    
    clean_words = []

    for word in okt.pos(document, stem=True):
        if word[1] not in ['Josa', 'Eomi', 'Punctuation']:
            clean_words.append(word[0])
    
    clean_words.append('이라네.')
    after_answer = ''.join(clean_words)

    return after_answer

In [20]:
if __name__ == "__main__":

    # 질문
    query = '훈민정음을 창제한 이유'

    # 질문에 해당하는 유사 질문 및 지문 획득
    context, question = get_context_and_question(query)

    # 지문과 질문을 입력해서 정답을 얻음
    answer, start, end, score = get_answer(context, question)

    # 출력
    print("query: " + query)
    print("question: " + question)
    print("predict: {0} ({1},{2},{3})".format(answer, start, end, score))
    print("service: " + modify_answer(answer))

query: 훈민정음을 창제한 이유
question: 훈민정음을 창제한 이유
predict: 애민 정신을 (320,326,0.6614944338798523)
service: 애민정신이라네.


In [None]:
def full_test():
    df = pd.read_csv('dataset_test.csv')
    
    for idx, row in df.iterrows():
        it_context = row['Context (지문)']
        it_question = row['Q1 (질문1)']
        it_ans = row['Ans (답변)']

        answer, start, end, score = get_answer(it_context, it_question)

        print("[" + str(idx) + "]")
        print("question: " + it_question)
        print("real: " + it_ans)
        print("predict: {0} ({1},{2},{3})".format(answer, start, end, score))
        print("service: " + modify_answer(answer))