<a href="https://colab.research.google.com/github/Xuanyiyiren/An-enthusiast-of-mathematics-and-physics/blob/main/LatexOCR_correction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install openai --quite

Collecting openai
  Downloading openai-1.45.0-py3-none-any.whl.metadata (22 kB)
Collecting httpx<1,>=0.23.0 (from openai)
  Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting jiter<1,>=0.4.0 (from openai)
  Downloading jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)
  Downloading httpcore-1.0.5-py3-none-any.whl.metadata (20 kB)
Collecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)
  Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)
Downloading openai-1.45.0-py3-none-any.whl (374 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.1/374.1 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading httpx-0.27.2-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading httpcore-1.0.5-py3-none-any.whl (77 kB)
[2K   [90m━

In [None]:
from openai import OpenAI
from google.colab import userdata
import os
import re

class AI_latexCode_correction:
    def __init__(self, max_tokens=1000):
        self.max_tokens = max_tokens
        self.Sys_prompt = """\
        你将接收到一些经过 LaTeXOCR 得到的 LaTeX 源代码，请你帮忙检查并修正其中的语法错误、OCR产生的错别字等问题。同时请注意调整以下内容
        1. 将所有的行间公式调整为有编号的行间公式，即\begin{equation}...\end{equation}，其中不带有*
        2. 将所有的小标题 section,subsection,subsubsection 调整为有编号的，即去掉原代码中的*和手动编号
        3. 将所有的图片或表格调整为有编号的，即\begin{figure}...\end{figure} ，即去掉原代码中的*
        """
        self.file_path = ""
        self.content = ""
        self.modified_content = ""
        self._dollar_begin = True

    # First Part: get the latex codes
    def get_files_and_set_path(self,directory = "./"):
        """
        This method will scan and extract all the .tex files in the directory and return the list, which is sorted by modify time.
        Meanwhile, it also set the attribution self.file_path for the newest .tex file.
        """
        files = [f for f in os.listdir(directory)\
                if os.path.isfile(os.path.join(directory, f)) and os.path.splitext(f)[1] ==".tex"]

        sorted_files = files.sort(key=lambda x: os.path.getmtime(x))
        self.file_path = directory + files[0]
        print("Get content automatically.")
        return files

    def get_content(self):
        """
        This method will get the content from self.file_path.
        The content will be returned and stored in self.content.
        If self.file_path is empty, it will run self.get_files_and_set_path to set self.file_path.
        """
        if not self.file_path:
            self.get_files_and_set_path()
        with open(self.file_path, 'r') as file:
            content = file.read()
        self.content = content
        return content

    # Second Part: split the codes into chunks according to the max token
    # Here is some difficulty:
    # 1. Split in to different parts and keep the most basic semantics and grammer of latex codes
    # 2. How to calculate the token for a text.
    def _is_start_tag(self, line):
        if line[:len(r"$$")] == r"$$" and self._dollar_begin:
            self._dollar_begin = False
            return (True,"$$")
        if line[:len(r"\[")] == r"\[":
            return (True, "[]")
        if line[:len(r"\begin{")] == r"\begin{":
            if line[len(r"\begin{"):len(r"\begin{document}")] != r"document}":
                return (True, "be")
        return False,None
    def _is_end_tag(self, line):
        if line[:len(r"$$")] == r"$$" and not self._dollar_begin:
            self._dollar_begin = True
            return (True,"$$")
        if line[:len(r"\]")] == r"\]":
            return (True, "[]")
        if line[:len(r"\end{")] == r"\end{":
            if line[len(r"\end{"):len(r"\end{document}")] != r"document}":
                return (True, "be")
        return False,None

    def _pre_split(self):
        """
        This method will split the content into some indivisible parts.
        indivisible means a text lines or something like "begin{...} ... end{...}",
        while begin{document} will be ignored.
        """
        # print(self.content)
        if not self.content:
            self.get_content()
        self._dollar_begin = True
        lines = (self.content).splitlines(keepends=True)
        basic_parts = []
        stack = []
        for index, line in enumerate(lines):
            is_start, start_flag = self._is_start_tag(line)
            is_end, end_flag = self._is_end_tag(line)
            if is_start:
                if not stack:
                    start_pos = index
                stack.append(start_flag)

            elif is_end:
                if stack:
                    if stack[-1] == end_flag:
                        stack.pop()
                        if not stack:
                            end_pos = index + 1
                            basic_parts.append("".join(lines[start_pos:end_pos]))
                    else:
                        raise LatexError("Not Match!")
                else:
                    raise LatexError("Not Match!")
            else:
                if not stack:
                    basic_parts.append(line)
        return basic_parts

    token_counter_map = {
        "ave": "token_counter_average",
    }
    chinese_punctuations = "，。！？；：、“”‘’（）《》【】"

    @staticmethod
    def _count_chars(text):
        chinese_count = 0
        english_count = 0
        chinese_punctuation_count = 0
        english_punctuation_count = 0

        for char in text:
            if char in [' ', '\n']:
                continue
            if '\u4e00' <= char <= '\u9fff':
                chinese_count += 1
            elif 'a' <= char <= 'z' or 'A' <= char <= 'Z':
                english_count += 1
            else:
                if char in AI_latexCode_correction.chinese_punctuations:
                    chinese_punctuation_count += 1
                else:
                    english_punctuation_count += 1

        total_count = len(text)

        return (chinese_count + chinese_punctuation_count,
                english_count + english_punctuation_count)
    @staticmethod
    def token_counter_average(text: str) -> int:
        """
        Through `https://platform.deepseek.com/api-docs/guides/token_usage/`
        """
        chinese_count, english_count = AI_latexCode_correction._count_chars(text)
        return chinese_count * 0.6 + english_count * 0.3

    def split_into_chunks(self, count_mode="ave"):
        token_counter = getattr(self, self.token_counter_map.get(count_mode))
        if not token_counter:
            raise ValueError(f"Invalid count_mode: {count_mode}")

        sentences = self._pre_split()

        # 初始化段落列表和当前段落
        chunks = []
        current_chunk = ""
        current_token_count = 0

        for sentence in sentences:
            # display(sentence)
            new_token_count = token_counter(sentence)

            # 如果当前段落加上新句子超过最大token数，则保存当前段落并开始新段落
            if current_token_count + new_token_count > self.max_tokens:
                if current_chunk:
                    chunks.append(current_chunk)
                current_chunk = sentence
                current_token_count = new_token_count
            else:
                # 否则，将句子添加到当前段落
                current_chunk += sentence
                current_token_count += new_token_count

        # 添加最后一个段落
        if current_chunk:
            chunks.append(current_chunk)

        return chunks


    # Part 3:
    # This part we use LLM to correct the latex codes chunk by chunk.

    def LLMmodel_Deepseek(self, Usr_prompt):
        """
        This method use deepseek-coder model. The system prompt is set in __init__, and the user's prompt is set in argument.
        The system prompt is asking the LLM to check and correct errors in the original latex codes.
        We extract the code part in the answer of the LLM.
        """
        client = OpenAI(api_key=userdata.get('DEEPSEEK_API_KEY'), base_url="https://api.deepseek.com")
        response = client.chat.completions.create(
            model="deepseek-coder",
            messages=[
                {"role": "system", "content": self.Sys_prompt},
                {"role": "user", "content": Usr_prompt},
            ],
            stream=False
        )
        pattern = r"```(.*?)```"
        matches = re.findall(pattern, response.choices[0].message.content, re.DOTALL)
        return matches[0]

    def recombination(self):
        modified_chunks = []
        chunks = self.split_into_chunks()
        for chunk in chunks:
            modified_chunk = self.LLMmodel_Deepseek(chunk)
            modified_chunks.append(modified_chunk)
        self.modified_content = "".join(modified_chunks)
        return self.modified_content



In [None]:
# testing code
test = AI_latexCode_correction(max_tokens= 500)
print("Getting files and contents ...")
test.get_content()
print("Getting contents successfully, the content is: ")
print('-'*30)
print(test.content)
print('-'*30)
print("Splitting into chunks ...")
chunks = test.split_into_chunks()
print("Splitting successfully, the chunks are: ")
print('='*30)
for index, chunk in enumerate(chunks):
    print(f"The #{index + 1} chunks:")
    print('-'*20)
    print(chunk)
    print('-'*20)
print('='*30)
print("Using LLM correcting ...")
test.recombination()
print("Correction successfully, the modified content is: ")
print('-'*30)
print(test.modified_content)
print('-'*30)

Getting files and contents ...
Get content automatically.
Getting contents successfully, the content is: 
------------------------------
% This LaTeX document needs to be compiled with XeLaTeX.
\documentclass[10pt]{article}
\usepackage[utf8]{inputenc}
\usepackage{ucharclasses}
\usepackage{amsmath}
\usepackage{amsfonts}
\usepackage{amssymb}
\usepackage[version=4]{mhchem}
\usepackage{stmaryrd}
\usepackage{graphicx}
\usepackage[export]{adjustbox}
\graphicspath{ {./images/} }
\usepackage{multirow}
\usepackage[fallback]{xeCJK}
\usepackage{polyglossia}
\usepackage{fontspec}
\setCJKmainfont{Noto Serif CJK SC}

\setmainlanguage{english}
\setotherlanguages{hindi}
\newfontfamily\hindifont{Noto Serif Devanagari}
\newfontfamily\lgcfont{CMU Serif}
\setDefaultTransitions{\lgcfont}{}
\setTransitionsFor{Hindi}{\hindifont}{\lgcfont}

\title{使用离子流计算分压力的方法: }

\author{}
\date{}


\begin{document}
\maketitle
实验内容:

\begin{enumerate}
  \item 使用法拉第筒收集缡子流，以分压力为纵坐标测三个谱
  \item 将上达测得的一个谱转化成以离子流为纵坐标, 并用离子流求出残余气