In [21]:
import os
import sys
import time
import re
import torch
from IPython.display import display
from dotenv import load_dotenv
from tqdm import tqdm
import json
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import pandas as pd
from bert_score import score

from langchain import PromptTemplate
# from langchain.chains import LLMChain
from langchain.docstore.document import Document
# from langchain.document_loaders import DirectoryLoader, TextLoader

load_dotenv()
root_dir = os.path.dirname(os.getcwd())
sys.path.append(root_dir) # if import module in this project error
if os.name != "nt": os.environ["TOKENIZERS_PARALLELISM"]="false"

from main.llm.openai import load_OpenAI
from main.utils.data import loadDocuments, char_data_splitter
from main.utils.database import embed_documents, embed_database
from main.template.prompt import UniversalPromptTemplate
from main.utils.chain import Chain
from main.utils.database import embed_documents

### **setup var**

In [17]:
chunk_size = 2000
chunk_overlap = 200
embedding_algorithm = "faiss"
source_directory = f"{root_dir}/docs"
persist_directory = f"{root_dir}/tmp/embeddings/{embedding_algorithm}"
mapper = {
    "law_doc-84-89.txt": "761/2566",
    "law_doc-44-46.txt": "1301/2566",
    "law_doc-54-57.txt": "1225/2566",
    "law_doc-12-13.txt": "2525/2566",
    "law_doc-40-43.txt": "1305/2566",
    "law_doc-14-15.txt": "2085/2566",
    "law_doc-64-69.txt": "1090/2566",
    "law_doc-1-5.txt": "2610/2566",
    "law_doc-78-81.txt": "882/2566",
    "law_doc-82-83.txt": "835/2566",
    "law_doc-35-39.txt": "1306/2566",
    "law_doc-16-20.txt": "1574/2566",
    "law_doc-32-34.txt": "1373/2566",
    "law_doc-74-77.txt": "934/2566",
    "law_doc-6-11.txt": "2609/2566",
    "law_doc-90-92.txt": "756/2566",
    "law_doc-47-53.txt": "1300/2566",
    "law_doc-58-63.txt": "1101/2566",
    "law_doc-70-73.txt": "1003/2566",
    "law_doc-21-31.txt": "1542/2566"
}

### **clean up doc & export**

In [37]:
documents = loadDocuments(source_dir=source_directory,chunk_size=10e12,chunk_overlap=0)
# print(documents)
for i in range(len(documents)):
    if documents[i].metadata['source'].split('/')[-1] in mapper:
        documents[i].page_content = re.sub(r'\s{2,}', ' ', documents[i].page_content)
        documents[i].page_content = re.sub(r'_{2,}', '\n', documents[i].page_content)
        documents[i].page_content = f"คดี {mapper[documents[i].metadata['source'].split('/')[-1]]}\n" + documents[i].page_content
    else:
        pass
    documents[i].metadata['source'] = documents[i].metadata['source'].replace('\\','/')

### **init LLMs**

In [5]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
import logging
from langchain.llms import HuggingFacePipeline
from dotenv import load_dotenv

load_dotenv()
exclude_pattern = re.compile(r'[^ก-๙]+') #|[^0-9a-zA-Z]+
def is_exclude(text):
    return bool(exclude_pattern.search(text))

def load_seaLLMs():
    model_id = "SeaLLMs/SeaLLM-7B-Chat"  # @param ["pythainlp/wangchanglm-7.5B-sft-enth-sharded", "TinyPixel/Llama-2-7B-bf16-sharded", "SeaLLMs/SeaLLM-7B-Chat"]
    bnb_config = BitsAndBytesConfig(
        # load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float32,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        # quantization_config=bnb_config,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float32,
        device_map={"": "cpu"},
        token=os.getenv('SEALLM_ACCESS_TOKEN')
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv('SEALLM_ACCESS_TOKEN'), device_map={"": "cpu"})
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_special_tokens = False
    tokenizer.padding_side = "right"
    
    df = pd.DataFrame(tokenizer.vocab.items(), columns=['text', 'idx'])
    df['is_exclude'] = df.text.map(is_exclude)
    exclude_ids = df[df.is_exclude==True].idx.tolist()
    logging.info(f"Forced LLM model to only response in Thai: Yes")
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=512,
        begin_suppress_tokens=exclude_ids,
        no_repeat_ngram_size=2,
        token=os.getenv('HUGGING_FACE')
    )
    llm = HuggingFacePipeline(pipeline=pipe)
    return llm

# llm = load_seaLLMs()
llm = load_OpenAI(model="gpt-3.5-turbo-16k-0613")

In [18]:
from pythainlp import word_tokenize, pos_tag
from pythainlp.corpus.common import thai_stopwords
import nltk
from nltk.corpus import stopwords

key_tags = ["NCMN", "NCNM", "NPRP", "NONM", "NLBL", "NTTL"]

thaistopwords = list(thai_stopwords())
nltk.download('stopwords')
def remove_stopwords(text):
    res = [word.lower() for word in text if (word not in thaistopwords and word not in stopwords.words())]
    return res

def keyword_search(question):
    tokens = word_tokenize(question, engine="newmm", keep_whitespace=False)
    pos_tags = pos_tag(tokens)
    noun_pos_tags = []
    for e in pos_tags:
        if e[1] in key_tags:
            noun_pos_tags.append(e[0])
    noun_pos_tags = remove_stopwords(noun_pos_tags)
    noun_pos_tags = list(set(noun_pos_tags))
    return noun_pos_tags

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\akira\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping corpora\stopwords.zip.


In [19]:
def find_case_number(text):
    pattern = re.compile(r"(?<!\d)(\d{1,5}/\d{4})(?!\d)")
    match = re.findall(pattern, text)
    if pattern.search(text) and all(e in mapper.values() for e in match ):
        return True, match
    else:
        return False, ""

def keyword_matcher(doc, keywords):
    matched_keywords = []
    for keyword in keywords:
        pattern = re.compile(re.escape(keyword))
        if pattern.search(doc.page_content):
            matched_keywords.append(keyword)
    return matched_keywords

def filter_docs_by_keywords(docs, keywords, question):
    filtered_docs = []
    matches = []
    for doc in docs:
        matched_keywords = []
        if find_case_number(question)[0]:
            case_num = find_case_number(question)[1]
            for num in case_num:
                pattern = re.compile(re.escape(num))
                if pattern.search(doc.page_content):
                    matched_keywords = keyword_matcher(doc, keywords)
                    if len(matched_keywords) >= min(3, len(keywords)):
                        matches.append(matched_keywords)
                        filtered_docs.append(doc)
            continue
        matched_keywords = keyword_matcher(doc, keywords)
        if len(matched_keywords) >= min(2, len(keywords)):
            matches.append(matched_keywords)
            filtered_docs.append(doc)
    return filtered_docs, matches

In [42]:
def parse_source_docs(source_docs):
    if source_docs is not None:
        results = []
        for res in source_docs:
            if res.metadata["source"].split("/")[-1] in mapper:
                context = f"""คดีหมายเลข {mapper[res.metadata["source"].split("/")[-1]]}\n{res.page_content}"""
                results.append(context)
            else:
                results.append(res.page_content)
        # srcs = [f"""<<<{res.metadata["source"].split("/")[-1]}>>>\n<<<case #{mapper[res.metadata["source"].split("/")[-1]]}>>>\n{res.page_content}""" for res in source_docs]
        result = "\n\n".join(results)
        return result
    else:
        return []

def parse_matched_keywords(matched_keywords):
    if matched_keywords is not None:
        result = "\n".join(str(keyword) + ',' for keyword in matched_keywords)
    else:
        result = []
    return result

In [9]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

def load_embedding_model(embedding_model_name="intfloat/multilingual-e5-small"):

    if torch.cuda.is_available():
        device_type = "cuda"
    elif torch.backends.mps.is_available():
        device_type = "mps"
    else:
        device_type = "cpu"

    embeddings = HuggingFaceEmbeddings(
        model_name=embedding_model_name,
        model_kwargs={"device": device_type},
    )
    return embeddings

# def embed_database(documents):
#     embeddings = load_embedding_model()
#     vectordb = FAISS.from_documents(
#                 documents=documents,
#                 embedding=embeddings,
#             )
#     return vectordb

In [39]:
vector_database = embed_database(documents=documents,persist_directory=persist_directory)

In [45]:
llm = load_OpenAI(model="gpt-4-0125-preview")

  warn_deprecated(


In [46]:
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import cohere
co = cohere.Client(os.getenv('COHERE'))

srcs = pd.DataFrame()
raw_src = pd.DataFrame()

def qa(question):
    try:
        if question in ['','-',None]: raise Exception("No question")
        ti = time.time()
        
        # keywords search
        keywords = keyword_search(question)
        keywords_filtered_docs, matched_keywords = filter_docs_by_keywords(documents, keywords, question)
        parse_match_keywords = parse_matched_keywords(matched_keywords)
        parse_keywords_filtered_docs = parse_source_docs(keywords_filtered_docs)
        if len(keywords_filtered_docs) == 0: return {"time": 0, "question": question, "answer": "", "keywords": keywords, "matched_keywords":"", "keywords_filtered_docs": "No Relevant source doc by keywords search", "num keywords_filtered_docs": 0, "retrieved_docs": "", "num retrieved_docs": 0, "reranked_docs": "", "num parse_reranked_docs": 0}

        # context search
        retrieved_docs = []
        parse_retrieved_docs = []
        if not find_case_number(question)[0]:
            retriever = vector_database.as_retriever(search_type="similarity")
            retrieved_docs = retriever.get_relevant_documents(question)
            parse_retrieved_docs = parse_source_docs(retrieved_docs)
            if len(retrieved_docs) == 0: return {"time": 0, "question": question, "answer": "", "keywords": "", "matched_keywords":"", "keywords_filtered_docs": "", "num keywords_filtered_docs": 0, "retrieved_docs": "No Relevant source doc by vector stores retriever", "num retrieved_docs": 0, "reranked_docs": "", "num parse_reranked_docs": 0} 
        
        # rerank
        relevant_src_docs = keywords_filtered_docs + retrieved_docs
        relevant_docs = [doc.page_content for doc in relevant_src_docs]
        rerank_hits = co.rerank(query=question, documents=relevant_docs, model='rerank-multilingual-v2.0', top_n=5)
        results = [relevant_src_docs[hit.index] for hit in rerank_hits]
        parse_reranked_docs = parse_source_docs(results)
        
        prompt = UniversalPromptTemplate.QAZeroShotThaiTemplate()
        chain = LLMChain(llm=llm, prompt=prompt)
        result = chain.run(context=parse_reranked_docs, question=question)
        tf = time.time()
        # return f"""> Time: {tf-ti}\n\n> Question: {question}\n\n> Answer: {result}\n\n> Source docs:\n{relevant_source_docs}"""
        return {"time": tf-ti, "question": question, "answer": result, "keywords": keywords, "matched_keywords":parse_match_keywords, "keywords_filtered_docs": parse_keywords_filtered_docs, "num keywords_filtered_docs": len(keywords_filtered_docs), "retrieved_docs": parse_retrieved_docs, "num retrieved_docs": len(retrieved_docs), "reranked_docs": parse_reranked_docs, "num parse_reranked_docs": len(results)}
    except Exception as e:
        print(f"{question} @{e}")
        return {"error":str(e),"source_doc":[],"response":"","time":"","source":""}

question = "ขอดูอย่างคดีที่มีการพิพากษาของศาลฎีกาต่างจากศาลอุทธรณ์หน่อยได้ไหมครับ"
qa_res = qa(question)
qa_res

  warn_deprecated(


{'time': 11.222966194152832,
 'question': 'ขอดูอย่างคดีที่มีการพิพากษาของศาลฎีกาต่างจากศาลอุทธรณ์หน่อยได้ไหมครับ',
 'answer': 'ไม่ทราบ',
 'keywords': ['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],
 'matched_keywords': "['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ไหม', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุทธรณ์', 'ศาลฎีกา'],\n['ศาลอุ

In [50]:
df = pd.read_csv("demo.csv",encoding="utf-8",header=0)
df.head(2)

Unnamed: 0,question
0,ถ้าเราเผลอไปยิงคนตาย จะผิดกฎหมายข้อไหน
1,ขายยาเสพติดจะผิดกฎหมายไหม


In [51]:
res = []
for index, row in tqdm(df.iterrows()):
    res.append(qa(row['question']))


10it [02:11, 13.10s/it]


In [52]:
answers = []
docs = []
for r in res:
    answers.append(r['answer'])
    docs.append(r['reranked_docs'])

df['answer'] = answers
df['sources'] = docs
    

In [53]:
df.to_csv("demo.csv",index=False,encoding="utf-8")

### **dataset**

In [11]:
dataset = "human"
dataset_path = f"{root_dir}/asset/dataset/law/{dataset}.csv"

In [13]:
df = pd.read_csv(dataset_path)
df.dropna(inplace=True)
df = df.iloc[60:]
display(df.head(3))
# display(df.tail(3))

Unnamed: 0,question,answer
60,ขอดูอย่างคดีที่มีการพิพากษาของศาลฎีกาต่างจากศา...,"คดีหมายเลข 2610/2566, 2609/2666, 1306/2566, 13..."
61,ขอดูคดีที่เกี่ยวข้องกับยาเสพติดหน่อยครับ,"คดีหมายเลข 2610/2566, 2609/2666, 1225/2566, 10..."
62,มีคดีใดบ้าง ที่มีหน่วยงานของรัฐเป็นจำเลยของคดี,คดีหมายเลข 1542/2566 ได้แก่กรมสรรพากร


In [18]:
def qa_with_delay(question):
    delay_seconds = 60
    time.sleep(delay_seconds)
    return qa(question)

In [19]:
def generateQAresponse(df, max_workers=3):
    questions = df['question'].tolist()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(qa_with_delay, questions), total=len(questions)))


    times = [result['time'] for result in results]
    answer = [result['answer'] for result in results]
    keywords = [result['keywords'] for result in results]
    matched_keywords = [result['matched_keywords'] for result in results]
    keywords_filtered_docs = [result['keywords_filtered_docs'] for result in results]
    num_keywords_filtered_docs = [result['num keywords_filtered_docs'] for result in results]
    retrieved_docs = [result['retrieved_docs'] for result in results]
    num_retrieved_docs = [result['num retrieved_docs'] for result in results]
    reranked_docs = [result['reranked_docs'] for result in results]
    num_parse_reranked_docs = [result['num parse_reranked_docs'] for result in results]

    return results, times, answer, keywords, matched_keywords, keywords_filtered_docs, num_keywords_filtered_docs, retrieved_docs, num_retrieved_docs, reranked_docs, num_parse_reranked_docs

In [20]:
results, time, answer, keywords, match, keywords_docs, num_keywords_filtered_docs, retrieved, num_retrieved_docs, reranked_docs, num_parse_reranked_docs = generateQAresponse(df)

100%|██████████| 40/40 [14:41<00:00, 22.04s/it]


In [21]:
df["response_gpt3.5_0613_16k"] = answer
df["keywords"] = keywords
df["matched_keywords"] = match
df["keywords_search_docs"] = keywords_docs
df["num keywords_search_docs"] = num_keywords_filtered_docs
df["context_search_docs"] = retrieved
df["num context_search_docs"] = num_retrieved_docs
df["reranked_docs"] = reranked_docs
df["num_reranked_docs"] = num_parse_reranked_docs
df["time"] = time

In [22]:
df.to_csv(f"{root_dir}/asset/dataset/law/{dataset}_ir_response_gen_edge.csv", index=False)

### **Eval w/ L2**

In [117]:
# def l2_score(sources, answer):
#     r = {"l2": "", "l2_raw":[]}
#     try:
#         if sources in ['','-',None] or answer in ['', '-', None]: return {"l2": "", "l2_raw":[]}
#         vectordb = embed_documents(sources)
#         print("test")
#         doc_and_score = vectordb.similarity_search_with_score(answer)
#         l2 = [e[-1] for e in doc_and_score]
#         print(l2)
#         l2_results = [str(np.mean(l2)), str(min(l2)), str(max(l2))]
#         l2_response = '\n'.join([str(np.mean(l2)), str(min(l2)), str(max(l2))])
#         r["l2"] = l2_response
#         r["l2_raw"] = l2_results
#     except Exception as e:
#         print(e)
#         # raise Exception(e)
#     return r

# def generateL2(a,sd,max_workers=10):

#     with ThreadPoolExecutor(max_workers=max_workers) as executor:
#         results = list(tqdm(executor.map(l2_score, sd, a), total=len(a)))

#     l2 = [result['l2'] for result in results]
#     l2_raw = [result['l2_raw'] for result in results]
#     return l2_raw, l2

# l2,l2_raw = generateL2(a,sd)
# # ll2 = [f"mean {l[0]}\nmin {l[1]}\nmax {l[2]}" for l in l2]
# print(l2[:3])
# df["L2_gpt3.5_0613_16k"] = l2

In [118]:
# l2,l2_raw = generateL2(r,sd)
# ll2 = [f"mean {l[0]}\nmin {l[1]}\nmax {l[2]}" for l in l2]
# print(l2[:3])
# df["L2_gpt3.5_0613_16k"] = l2

In [119]:
# ll2 = [f"mean {l[0]}\nmin {l[1]}\nmax {l[2]}" for l in l2]

In [120]:
# df["L2_gpt3.5_0613_16k"] = ll2

### **Eval w/ BertScore**

In [23]:
df["answer"] = df["answer"].apply(lambda x: "-" if (not isinstance(x,str)) and np.isnan(x) else x)

In [24]:
model_type = 'microsoft/deberta-xlarge-mnli'  
labels = df['answer'].tolist()
preds = df['response_gpt3.5_0613_16k'].tolist()
print(labels)
P, R, F1 = score(preds, labels, lang='en', model_type=model_type, verbose=False)

['คดีหมายเลข 2610/2566, 2609/2666, 1306/2566, 1305/2566, 1301/2566', 'คดีหมายเลข 2610/2566, 2609/2666, 1225/2566, 1090/2566', 'คดีหมายเลข 1542/2566 ได้แก่กรมสรรพากร', 'จำคุกตั้งแต่ห้าปีถึงยี่สิบปี หรือจำคุกตลอดชีวิต และปรับตั้งแต่สองพันบาทถึงสี่หมื่นบาท', 'คดีหมายเลข 2085/2566, 1301/2566, 1300/2566, 1003/2566', 'คดีหมายเลข 2085/2566, 1306/2566, 1300/2566, 1101/2566, 1003/2566, 882/2566, 835/2566', 'ในการริบทรัพย์สิน นอกจากศาลจะมีอำนาจริบตามกฎหมายที่บัญญัติไว้โดยเฉพาะแล้ว ให้ศาลมีอำนาจสั่งให้ริบทรัพย์สินดังต่อไปนี้อีกด้วย คือ (1) ทรัพย์สินซึ่งบุคคลได้ใช้ หรือมีไว้เพื่อใช้ในการกระทำความผิด หรือ (2) ทรัพย์สินซึ่งบุคคลได้มาโดยได้กระทำความผิด', 'ในกรณีความผิดใดเกิดขึ้นโดยการกระทำของบุคคลตั้งแต่สองคนขึ้นไป ผู้ที่ได้ร่วมกระทำความผิดด้วยกันนั้นเป็นตัวการ ต้องระวางโทษตาม ที่กฎหมายกำหนดไว้สำหรับความผิดนั้น', 'หุ้นทุก ๆ หุ้นจำต้องให้ใช้เป็นเงินจนเต็มค่า เว้นแต่หุ้นซึ่งออกตาม บทบัญญัติมาตรา 1108 อนุมาตรา (5) หรือมาตรา 1221 ในการใช้เงินเป็นค่าหุ้นนั้น ผู้ถือหุ้นจะหักหนี้กับบริษัทหาได้ไม่', 'ผู้ใดต้



In [25]:
print(F1.mean().item())

0.5861912965774536


In [26]:
df["BERTScore_F1_gpt3.5_0613_16k"] = F1.tolist()

In [27]:
df.to_csv(f"{root_dir}/asset/dataset/law/{dataset}_ir_response_gen_edge.csv",index=False)