# KoSBERT 사용법
1. [현재 리포지토리](https://github.com/KAIST-Lawgical/Sentence-Embedding-Is-All-You-Need)를 `git clone`을 통해 받는다.
2. Lawgical KAIST의 구글 드라이브 내의 `Checkpoint` 폴더를 다운로드하여 `Sentence-Embedding-Is-All-You-Need` 바로 아래에 둔다.
3. 디렉토리 `Sentence-Embedding-Is-All-You-Need`에서 명령어 `docker build -t sbert:1.0.0 .`를 실행하여 도커 이미지를 빌드한다.
4. 명령어 `docker run -itd --ipc host --gpus all --name sbert sbert:1.0.0`를 도커 컨테이너를 실행한다.
5. `Visual Studio Code`의 `remote explorer`등을 통해 도커 컨테이너에 접속한다.
6. 도커 컨테이너 내부의 워크 디렉토리에서 이 파일, 즉 `test.ipynb`를 찾아 아래 코드를 활용하여 준비서면-판결문 문장 간 유사도 측정 작업을 실행한다.

In [None]:
import re
import csv
import pandas
import pickle
import numpy as np
from tqdm import tqdm
from kiwipiepy import Kiwi
from sentence_transformers import SentenceTransformer, util

_whitespace = re.compile(r'\s+')
kiwi = Kiwi()
model_path = './Sentence-Embedding-Is-All-You-Need/Checkpoint/KoSBERT/kosbert-klue-bert-base/'
embedder = SentenceTransformer(model_path)

In [None]:
def collapse_whitespace(text):
    return re.sub(_whitespace, ' ', text)


def calculate_similarity(case_number, document_text, judgement_text, wr, top_k=1):
    document_id = document_text[0]
    documents = kiwi.split_into_sents(document_text[1])
    documents = [row[0] for row in documents]
    
    judgement_id = judgement_text[0]
    judgements = kiwi.split_into_sents(judgement_text[1])
    judgements = [row[0] for row in judgements]
    
    if not len(documents) or not len(judgements):
        return
    
    judgement_embeddings = embedder.encode(judgements, convert_to_tensor=True)
    
    for document in documents:
        document_embedding = embedder.encode(document, convert_to_tensor=True)
        cos_scores = util.pytorch_cos_sim(document_embedding, judgement_embeddings)[0]
        cos_scores = cos_scores.cpu()

        #We use np.argpartition, to only partially sort the top_k results
        top_results = np.argpartition(-cos_scores, range(top_k))[0:top_k]

        max_idx = top_results[0]
        max_similarity = round(float(cos_scores[max_idx]), 4)
        max_sentence = judgements[max_idx].strip()

        #f = open('./Sentence-Embedding-Is-All-You-Need/data.csv', 'a', encoding='utf-8-sig', newline='')
        #wr = csv.writer(f)
        wr.writerow([str(case_number), document_id, judgement_id, collapse_whitespace(document), max_similarity, collapse_whitespace(max_sentence)])
    
        #if max_similarity != 0 and max_similarity != 1 and len(document) > 15 and len(max_sentence) > 15:
        #    wr.writerow([collapse_whitespace(document), max_similarity, collapse_whitespace(max_sentence)])
        

In [None]:
with open('./Sentence-Embedding-Is-All-You-Need/data_output_dec.pickle', 'rb') as datafile:
    data = pickle.load(datafile)

f = open('./Sentence-Embedding-Is-All-You-Need/data.csv', 'w', encoding='utf-8-sig', newline='')
wr = csv.writer(f)
wr.writerow(['case_number', 'brief_edms_id', 'judgement_edms_id', 'brief_sentence', 'similarity', 'similar_judgement_sentence'])

prior_case_number = 0
document_text_list = []
judgement_text = ''

for idx, row in tqdm(data.iterrows(), total=len(list(data.iterrows()))):
    case_number = row[1]
    edms_id = row[2]
    document_type = row[4]
    text = str(row[6])

    try:
        if case_number != prior_case_number:
            if len(document_text_list) and len(judgement_text):
                for document_text in document_text_list:
                    calculate_similarity(case_number, document_text, judgement_text, wr)

            prior_case_number = case_number
            document_text_list = []
            judgement_text = ''

        # 01-판결문, 011-답변서, 010-준비서면, 701-소장
        if document_type == '01':
            judgement_text = (edms_id, text)
        else:
            document_text_list.append((edms_id, text))
    
    except:
        print(document_type, text)
        prior_case_number = case_number
        document_text_list = []
        judgement_text = ''
        pass

In [None]:
with open('./Sentence-Embedding-Is-All-You-Need/data_output_dec.pickle', 'rb') as datafile:
    data = pickle.load(datafile)

print(data)