In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
import pandas as pd
import glob
import json
import logging
import os
import torch
import plotly.express as px
import matplotlib.pyplot as plt
import re
import joblib

from langdetect import DetectorFactory, detect_langs
from datetime import datetime

from tqdm.auto import tqdm
from compute_ppl import compute_ppl
from transformers import AutoModelForCausalLM, BarthezTokenizer

DetectorFactory.seed = 0

pd.options.display.max_colwidth = 200

In [None]:
def open_md(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        md_content = f.read()
    return md_content


def open_json(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        json_meta = json.load(f)
    return json_meta


def get_current_datetime():
    """Get current date and time"""
    now = datetime.now()

    # Format the date-time string
    formatted_datetime = now.strftime("%Y-%m-%d_%H_%M_%S")

    return formatted_datetime


def setup_logger(
    prefix: str,
    console_level: str = "DEBUG",
    file_level: str = "WARNING",
):
    """
    Sets up a logger with a console and file handler.
    """
    logger = logging.getLogger(__name__)
    logger.setLevel(
        logging.DEBUG
    )  # Set to the highest level; handlers will filter appropriately

    # Create console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(console_level)

    log_file = f"../log/{prefix}_{get_current_datetime()}.log"

    file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
    file_handler.setLevel(file_level)

    # Define log format
    formatter = logging.Formatter(
        "{asctime},{levelname},{message}",
        style="{",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # Assign formatter to handlers
    console_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)

    # Add handlers to logger
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)

    return logger

In [None]:
def filter_by_picture(json_files, logger):
    """Filter PDFs where the number of 'Picture' blocks per page is less than 1."""
    remain = []

    for file_path in tqdm(json_files):
        try:
            data = open_json(file_path)
            page_stats = data.get("page_stats", [])

            if not page_stats:
                logger.warning(f"No page stats found in {pdf_name}, skipping.")
                continue

            num_all_pages = len(page_stats)

            # Collect all block counts across pages
            block_counts = [
                block for page in page_stats for block in page.get("block_counts", [])
            ]

            # Count the total number of 'Picture' blocks
            num_all_pic = sum(
                count for block, count in block_counts if block == "Picture"
            )

            # Compute the ratio and filter PDFs
            if num_all_pages > 0 and (num_all_pic / num_all_pages) < 1:
                remain.append(file_path)

        except Exception as e:
            logger.error(f"Error processing {pdf_name}: {e}")

    logger.info(f"Num remain: {len(remain)}")
    logger.info(f"Num filter: {len(json_files) - len(remain)}")

    return remain


def filter_by_lang(md_files, logger):
    """Filter PDFs that are detected to be in Thai with a probability > 0.5."""
    remain = []

    for md_path in tqdm(md_files):
        try:
            md_content = open_md(md_path)

            # Detect language probabilities
            lang_probs = [
                lang.prob for lang in detect_langs(md_content) if lang.lang == "th"
            ]

            # If Thai language probability is greater than 0.5, keep the document
            if lang_probs and lang_probs[0] > 0.5:
                remain.append(md_path)

        except Exception as e:
            logger.error(f"Cannot read OCR result for {md_path}: {e}")
            continue

    logger.info(f"Num remain: {len(remain)}")
    logger.info(f"Num filter: {len(md_files) - len(remain)}")

    return remain

# Filter by image and language

In [None]:
md_files = glob.glob("../ocr_result/*/*.md")
json_files = glob.glob("../ocr_result/*/*.json")
print(f"Num pdf: {len(md_files)}")
assert len(md_files) == len(json_files)

logger = setup_logger(prefix="ocr_filter")

In [None]:
remain = filter_by_picture(json_files=json_files, logger=logger)

In [None]:
remain = [p.replace("_meta.json", ".md") for p in remain]

In [None]:
remain = filter_by_lang(md_files=remain, logger=logger)

In [None]:
pd.Series(remain).to_csv("./filtered_md.csv", index=False)

# Filter by perplexity

In [None]:
remain = pd.read_csv("./filtered_md.csv").iloc[:, 0].to_list()
logger = setup_logger(prefix="ocr_filter_ppl")

model_name = "airesearch/wangchanbart-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
assert device == "cuda"

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = BarthezTokenizer.from_pretrained(model_name)

In [None]:
def compute_perplexity(
    max_length, stride, list_of_path, model, tokenizer, logger, device
):
    """
    Computes the perplexity for a list of text files and returns a list of tuples (file_path, perplexity).
    """
    ppl_list = []

    print(f"Number of GPUs used: {torch.cuda.device_count()}")

    for path in tqdm(list_of_path):
        try:
            text = open_md(path)
        except Exception as e:
            logger.error(f"Cannot open {path}: {e}")
            continue

        try:
            encodings = tokenizer(text, return_tensors="pt")
            seq_len = encodings.input_ids.size(1)

            _, ppl = compute_ppl(
                model,
                tokenizer,
                encodings,
                max_length=max_length,
                seq_len=seq_len,
                stride=stride,
                device=device,
            )

            ppl_list.append((path, ppl))

        except Exception as e:
            logger.error(f"Error computing perplexity for {path}: {e}")

    return ppl_list

In [None]:
max_length = model.config.max_position_embeddings
stride = 512

ppl_list = compute_perplexity(
    max_length=max_length,
    stride=stride,
    list_of_path=remain,
    model=model,
    tokenizer=tokenizer,
    device=device,
    logger=logger,
)

In [None]:
ppl_df = pd.DataFrame(ppl_list, columns=["file_path", "ppl_score"])
ppl_df.to_csv("ppl_df.csv", index=False)

In [None]:
ppl_df = pd.read_csv("./ppl_df.csv")

In [None]:
(
    ppl_df["ppl_score"]
    .describe(percentiles=[0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.999])
    .to_frame()
)

In [None]:
# Create a figure with two subplots side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# First histogram
axes[0].hist(ppl_df["ppl_score"], bins="auto")
axes[0].set_title("PPL score")

# Second histogram with x-axis limit
axes[1].hist(ppl_df["ppl_score"], bins="auto")
axes[1].set_xlim(0, 3200)
axes[1].set_title("PPL score (Limited x-axis)")

# Display the plots
plt.tight_layout()
plt.show()

In [None]:
filtered = ppl_df.loc[ppl_df["ppl_score"].lt(1500)].copy()
filtered["file_path"] = filtered["file_path"].str.replace(r"../", "", regex=False)
filtered.to_csv("../ppl_filtered_md.csv", index=False)

# LLM

In [None]:
from openai import OpenAI

md = pd.read_csv("../ppl_filtered_md.csv")

openai_api_key = "EMPTY"
llama_api = "http://10.204.100.76:11000/v1"
qwen_api = "http://10.204.100.79:11700/v1"

llama_client = OpenAI(
    api_key=openai_api_key,
    base_url=llama_api,
)

qwen_client = OpenAI(
    api_key=openai_api_key,
    base_url=qwen_api,
)

llama_model = "hugging-quants/Meta-Llama-3.1-405B-Instruct-AWQ-INT4"
qwen_model = "Qwen/Qwen2.5-72B-Instruct"


def get_output(client, model: str, messages: list[dict[str, str]]):
    chat_response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0,
    )
    return chat_response.choices[0].message.content.strip()

In [None]:
system = {
    "role": "system",
    "content": """
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
You specialize in proofreading and refining OCR-extracted Thai text from PDF files while preserving the original context, language, and layout.

Your responsibilities:
1. **Analyze** the OCR-extracted text.
2. **Correct** OCR errors and refine the text for enhanced correctness and readability.
3. **Preserve** the original language (Thai, English, or mixed). Output Chinese only if the input is entirely in Chinese.
4. **Maintain** the original context, semantic meaning, and overall layout.
5. **Preserve** all Markdown formatting including headings, paragraphs, lists, tables, equations, and special symbols.
6. **For tables:** Retain the table structure and refine individual cell content as needed.
7. **Remove** any redundant, erroneous text, and html tags that does not contribute to the overall meaning or layout.
8. **If uncertain** about a correction, leave the original segment unchanged.

Return only the final refined text in **Markdown** format without any additional commentary or explanation.
""",
}


user_content = """
Your task is to refine and correct the following Thai text extracted via OCR. Ensure you:

1. Preserve the original context, language, and layout (including tables and special symbols).
2. Maintain all Markdown formatting in your output.
3. Remove redundant or erroneous text that does not contribute to the overall meaning.
4. Retain any segment as-is if you are uncertain about the correction.
5. **Do not** add new headings or sections that do not exist in the input.
6. Correct OCR errors strictly without altering the semantic meaning.

**Example**

INPUT:
# ทที่ 3
# ่วงโซ่อุปทานและระบบโลจิสติกส์อุตสาหกรรมน้ำมันปาล์ม
## .1 ภาพรวมห่วงโซ่อุปทานอุตสาหกรรมน้ำมันปาล์ม
าพรวมห่วงโซ่อุปทาน (Supply Chain) อุตสาหกรรมน้ำมันปาล์มพิจารณาตามผลผลิต (Output) พบว่าห่วงโซ่อุปทานอุตสาหกรรมน้ำมันปาล์มของอินโดนีเซีย มาเลเซีย และไทยมีลักษณะ ามลำดับขั้นเป็นไปในลักษณะเดียวกัน กล่าวอีกนัยหนึ่งไม่มีความแตกต่างกัน ห่วงโซ่อูปทาน ียงลำดับเริ่มต้นจากการปลูก (Planting) การสกัด (Milling) การกลั่น (Refinery) และอุตสาหกรรม ลายน้ำ (Downstream) รายละเอียดดังนี้

EXPECTED OUTPUT:
# บทที่ 3
# ห่วงโซ่อุปทานและระบบโลจิสติกส์อุตสาหกรรมน้ำมันปาล์ม
## 3.1 ภาพรวมห่วงโซ่อุปทานอุตสาหกรรมน้ำมันปาล์ม
ภาพรวมห่วงโซ่อุปทาน (Supply Chain) อุตสาหกรรมน้ำมันปาล์มพิจารณาตามผลผลิต (Output) พบว่าห่วงโซ่อุปทานอุตสาหกรรมน้ำมันปาล์มของอินโดนีเซีย มาเลเซีย และไทยมีลักษณะตามลำดับขั้นเป็นไปในลักษณะเดียวกัน กล่าวอีกนัยหนึ่งไม่มีความแตกต่างกัน ห่วงโซ่อุปทานเรียงลำดับเริ่มต้นจากการปลูก (Planting) การสกัด (Milling) การกลั่น (Refinery) และอุตสาหกรรมปลายน้ำ (Downstream) รายละเอียดดังนี้

Below is your input:
INPUT:
"""

In [None]:
# md

In [None]:
# path = os.path.join("..", "ocr_result/pdf_doc_00351/pdf_doc_00351.md")
# test = open_md(path)
# print(test)

# split_pat = re.compile(r"\{\d+\}\-{48}")
# test_split = split_pat.split(test)

In [None]:
# outout = []
# for i, page in enumerate(tqdm(test_split)):
#     if page:
#         if i > 15:
#             break
#         user_content_w_input = user_content + page.strip()
#         user = {"role": "user", "content": user_content_w_input}
#         messages = [system, user]
#         llm_response = get_output(
#             client=qwen_client,
#             model=qwen_model,
#             messages=messages,
#         )
#         outout.append(llm_response)

In [None]:
# outout_tmp = [
#     f"{i}------------------------------------------------\n" + txt
#     for i, txt in enumerate(outout)
# ]
# print("\n\n".join(outout_tmp))

In [None]:
# # Open a file in write mode (using 'utf-8' encoding to support special characters)
# with open("../output.md", "w", encoding="utf-8") as file:
#     file.write("\n\n".join(outout_tmp))

# print("Markdown file saved as output.md")

In [None]:
path = "../pdf_data_2025-02-26_02_25_48_has_sep.jsonl"
jsonl = pd.read_json(path, lines=True, orient="records")
logger = setup_logger("pdf_llm", console_level="DEBUG", file_level="DEBUG")

In [None]:
# split_pat = re.compile(r"\{\d+\}\-{48}\n\n")
# result = []
# for i, text in enumerate(tqdm(jsonl["text"])):
#     logger.info(f"Processing pdf id {i}")
#     text_split = split_pat.split(text)
#     output = []
#     for page_no, page in enumerate(text_split):
#         if page:
#             user_content_w_input = user_content + page.strip()
#             user = {"role": "user", "content": user_content_w_input}
#             messages = [system, user]
#             try:
#                 llm_response = get_output(
#                     client=qwen_client,
#                     model=qwen_model,
#                     messages=messages,
#                 )
#                 output.append(llm_response)
#             except Exception as e:
#                 logger.error(f"There is error at pdf id {i}, page no: {page_no}")
#                 logger.error(f"The error is {e}")
#                 logger.error("Will append original page")
#                 output.append(page.strip())


#     text_llm = "\n\n".join(output)
#     result.append(text_llm)

In [None]:
# print(text)

In [None]:
# print(text_llm)

In [None]:
# def process_page(pdf_id: int, page_no: int, page_text: str) -> str:
#     """
#     Processes a single page by combining the prompt with the page text,
#     sending it to the LLM, and handling any errors.
#     """
#     if not page_text:
#         return ""
#     prompt = user_content + page_text.strip()
#     messages = [system, {"role": "user", "content": prompt}]
#     try:
#         result = get_output(client=qwen_client, model=qwen_model, messages=messages)
#     except Exception as e:
#         logger.error(f"Error processing pdf id {pdf_id}, page {page_no}: {e}")
#         result = page_text.strip()  # Fallback: return the original text
#     return result


# def process_pdf(pdf_id: int, pdf_text: str) -> str:
#     """
#     Splits the PDF text into pages and processes each page concurrently.
#     Returns the processed text for the PDF.
#     """
#     logger.info(f"Processing pdf id {pdf_id}")
#     pages = split_pattern.split(pdf_text)
#     processed_pages = [None] * len(pages)

#     # Process pages concurrently using multithreading
#     with ThreadPoolExecutor(max_workers=8) as page_executor:
#         future_to_index = {
#             page_executor.submit(process_page, pdf_id, idx, page): idx
#             for idx, page in enumerate(pages)
#         }
#         for future in as_completed(future_to_index):
#             idx = future_to_index[future]
#             try:
#                 processed_pages[idx] = future.result()
#             except Exception as e:
#                 logger.error(f"Error in pdf id {pdf_id}, page {idx}: {e}")
#                 processed_pages[idx] = pages[
#                     idx
#                 ].strip()  # Fallback: original page text
#     return "\n\n".join(processed_pages)

In [None]:
# num_pdfs = len(jsonl_df)
# results = [None] * num_pdfs

# # Process PDFs concurrently using multithreading
# with ThreadPoolExecutor(max_workers=32) as pdf_executor:
#     future_to_pdf = {
#         pdf_executor.submit(process_pdf, idx, text): idx
#         for idx, text in enumerate(jsonl_df["text"])
#     }
#     for future in tqdm(
#         as_completed(future_to_pdf), total=len(future_to_pdf), desc="Processing PDFs"
#     ):
#         idx = future_to_pdf[future]
#         try:
#             results[idx] = future.result()
#         except Exception as e:
#             logger.error(f"Error processing pdf id {idx}: {e}")
#             results[idx] = ""

In [None]:
path = "../pdf_data_2025-02-26_02_25_48_has_sep.jsonl"
jsonl_df = pd.read_json(path, lines=True, orient="records")
logger = setup_logger("pdf_llm", console_level="DEBUG", file_level="DEBUG")
split_pat = re.compile(r"\{\d+\}\-{48}\n\n")

In [None]:
def process_pdf(pdf_id, pdf_text):
    """
    Process a single PDF:
      1. Split its text into pages.
      2. For each page, send it to the API to improve its quality.
      3. Return the combined, processed text.
    """
    logger.info(f"Processing PDF {pdf_id}")
    pages = split_pat.split(pdf_text)
    processed_pages = []
    
    for page_no, page in enumerate(pages):
        if page.strip():
            prompt = user_content + page.strip()
            messages = [system, {"role": "user", "content": prompt}]
            try:
                processed_text = get_output(client=qwen_client, model=qwen_model, messages=messages)
            except Exception as e:
                logger.error(f"Error in PDF {pdf_id} page {page_no}: {e}")
                processed_text = page.strip()  # Fallback: use original text
            processed_pages.append(processed_text)
    return "\n\n".join(processed_pages)

In [None]:
checkpoint_file = "checkpoint_results.json"

# Load checkpoint if it exists
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, "r") as f:
        checkpoint_data = json.load(f)
else:
    checkpoint_data = {}

total_pdfs = len(jsonl_df)
results = {int(k): v for k, v in checkpoint_data.items()}  # keys as int

# Build a list of pdf indices that are not yet processed
pending_ids = [i for i in range(total_pdfs) if i not in results]

In [None]:
# Process pending PDFs concurrently
from concurrent.futures import ThreadPoolExecutor, as_completed

with ThreadPoolExecutor(max_workers=128) as executor:  # Adjust as needed
    future_to_pdf = {
        executor.submit(process_pdf, idx, jsonl_df.iloc[idx]["text"]): idx
        for idx in pending_ids
    }
    
    for future in tqdm(as_completed(future_to_pdf), total=len(future_to_pdf), desc="Processing PDFs"):
        pdf_id = future_to_pdf[future]
        try:
            processed_text = future.result()
            results[pdf_id] = processed_text
        except Exception as e:
            logger.error(f"Error processing PDF {pdf_id}: {e}")
            results[pdf_id] = ""
        # Write out checkpoint after each PDF completes
        with open(checkpoint_file, "w") as f:
            json.dump({str(k): v for k, v in results.items()}, f, ensure_ascii=False)

# Convert to jsonl

In [None]:
path = "../ppl_filtered_md.csv"
selected = pd.read_csv(path)
selected_pdf = [os.path.join("..", p) for p in selected["file_path"].to_list()]
print(f"Num remaining pdf: {len(selected_pdf)}")

meta = glob.glob("../pdf_meta/*.csv")
df_meta = pd.concat([pd.read_csv(p) for p in meta], axis=0)

col = [
    "title",
    "pdf_link",
    "source",
    "license",
    "filename",
]
path = "../pdf_meta_openbase/meta_openbase_2025-02-07_09_52_59.csv"
df_meta_openbase = pd.read_csv(path)
df_meta_openbase = df_meta_openbase.loc[:, col].rename(columns={"pdf_link": "url"})

df_meta = pd.concat([df_meta, df_meta_openbase], ignore_index=True)
# df_meta.info()

In [None]:
remove_pat_1 = re.compile(r"\{\d+\}\-{48}\n\n")
# remove_pat_2 = re.compile('<.*?>')
json_list = []
for i, path in enumerate(tqdm(selected_pdf)):
    data = {}
    pdf_file_name = path.split("/")[-2] + ".pdf"

    text = open_md(path)
    # text = remove_pat_1.sub("", text)
    # text = remove_pat_2.sub("", text)

    meta_data = df_meta.loc[df_meta["filename"].eq(pdf_file_name)]

    data["id"] = str(i)
    data["text"] = text
    data["source"] = meta_data["source"].values[0]

    data["metadata"] = {}
    data["metadata"]["title"] = meta_data["title"].values[0]
    data["metadata"]["url"] = meta_data["url"].values[0]
    data["metadata"]["license"] = meta_data["license"].values[0]
    json_list.append(data)

In [None]:
pd.DataFrame(json_list).to_json(
    f"../pdf_data_{get_current_datetime()}_has_sep.jsonl",
    orient="records",
    lines=True,
    force_ascii=False,
)

# tmp