In [1]:
import re
from llama_cpp import Llama

model_Janus = r"E:\cs222\Janus-Pro-7B-LM.Q6_K.gguf"  # local path
model_deepseek_r1 = r"E:\cs222\DeepSeek-R1-q5_k_m.gguf"  # local path
model_deepseek_math = r"E:\cs222\deepseek-math-7b-base-q8_0.gguf"
llm = Llama(
    model_path=model_deepseek_math, n_ctx=4096, n_threads=16, n_gpu_layers=30, verbose=False
)


In [2]:
prompt_test = """### Instruction:
Convert the following math expression into LaTeX format only.
Do not add any explanation or formatting.
### Input:
x squared plus y squared equals pi

### Output:
"""

In [3]:
def build_prompt_latex(user_input: str) -> str:
    return f"""### Instruction:
Convert the following math expression into LaTeX format only.
Do not add any explanation or formatting.
### Input:
{user_input}

### Output:
"""

In [4]:
import re

def clean_latex_output(raw_output: str) -> str:
    cleaned = re.sub(r'^\${1,2}\s*|\s*\${1,2}$', '', raw_output.strip())
    cleaned = re.sub(r'["`]{2,}|\\n+$', '', cleaned).strip()

    return cleaned

In [5]:
def get_latex(user_input: str) -> str:
    prompt = build_prompt_latex(user_input)
    response = llm(
        prompt,
        max_tokens=1200,
        temperature=0,            #for deterministic output for math
        repeat_penalty=1.0,       #natural language  
        top_k=0,                  #no sampling only choosing the top token 
        top_p=1.0,                #include all tokens
        stop=["###", '"""'],
    )
    output = clean_latex_output(response['choices'][0]['text'])
    return output



In [6]:
get_latex("integral of y from 0 to 100")

'\\int_{0}^{100} y dy'

In [7]:
# Test cases for the latex generator
testing_input_complex_matrix = "A 6*6 matrix A with elements a_ij"
testing_input_simple_matrix = "A 2*2 matrix A with elements a_ij"
testing_input_simple_derivative = "the derivative of x squared"
testing_input_simple_integral = "the integral of x squared from 0 to 1"
testing_input_simple_sum = "sum from i equals 1 to n of i squared"
testing_input_simple_fraction = "a over b"
testing_input_simple_subscript = "x sub i"
testing_input_matrix_crazy = "A 6*6 matrix A with elements a_ij, where i and j are integers from 1 to 6, and the elements are defined as follows: a_ij = i^2 + j^2 for all i,j in {1,2,3,4,5,6}"
tests = [
    testing_input_complex_matrix,
    testing_input_simple_matrix,
    testing_input_simple_derivative,
    testing_input_simple_integral,
    testing_input_simple_sum,
    testing_input_simple_fraction,
    testing_input_simple_subscript,
    testing_input_matrix_crazy,
]
for test in tests:
    print(f"Input: {test}")
    print(f"Output: {get_latex(test)}")
    print("-" * 50)

Input: A 6*6 matrix A with elements a_ij
Output: \begin{bmatrix}
a_{11} & a_{12} & a_{13} & a_{14} & a_{15} & a_{16} \\
a_{21} & a_{22} & a_{23} & a_{24} & a_{25} & a_{26} \\
a_{31} & a_{32} & a_{33} & a_{34} & a_{35} & a_{36} \\
a_{41} & a_{42} & a_{43} & a_{44} & a_{45} & a_{46} \\
a_{51} & a_{52} & a_{53} & a_{54} & a_{55} & a_{56} \\
a_{61} & a_{62} & a_{63} & a_{64} & a_{65} & a_{66} \\
\end{bmatrix}
--------------------------------------------------
Input: A 2*2 matrix A with elements a_ij
Output: \begin{bmatrix}
a_{11} & a_{12} \\
a_{21} & a_{22}
\end{bmatrix}
--------------------------------------------------
Input: the derivative of x squared
Output: \frac{d}{dx}x^2
--------------------------------------------------
Input: the integral of x squared from 0 to 1
Output: \int_{0}^{1} x^2 dx
--------------------------------------------------
Input: sum from i equals 1 to n of i squared
Output: \sum_{i=1}^{n}i^2
--------------------------------------------------
Input: a over b
Out