In [1]:
import logging
import random
import json
import os

from prompts import description_sys_prompt, nl_sys_prompt
from datasets import load_dataset
from dotenv import load_dotenv
from typing import TextIO
from openai import OpenAI
from time import sleep


In [2]:
logger = logging.getLogger("datagen")
logger.setLevel(logging.INFO)
logging.basicConfig(filename="datagen.log", encoding="utf-8", level=logging.INFO)

In [3]:
load_dotenv()
client = OpenAI(
    api_key=os.getenv("TOGETHER_API_KEY"),
    base_url="https://api.together.xyz/v1"
)

In [4]:
# data = load_dataset("OleehyO/latex-formulas", data_files="cleaned_formulas/train-00000-of-00006.parquet") 

In [5]:
def generate_instruction(latex_eqn: str, sys_prompt: str):
    """Generate description given a LaTeX equation and a system prompt describing the style of 'description' required."""
    logger.info(" [LLM] Sending request to LLM")
    res = client.chat.completions.create(
        model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": latex_eqn},
        ]
    )
    logger.info(" [LLM] Received response from LLM")
    return res.choices[0].message.content

In [6]:
def save_data(qn: str, ans: str, ans_type: str, data_file: TextIO):
    """Save the dataset in the instruction answer pair format."""
    logger.info(f" [DATA_PERSIST] Persisting data of type {ans_type}")
    dict_obj = {"qn": qn, "ans": ans, "type": ans_type}
    json.dump(dict_obj, data_file, ensure_ascii=False)
    data_file.write('\n')
    logger.info(f" [DATA_PERSIST] Successfully persisted data of type {ans_type}")

In [7]:
def create_instruction_data(eqn_file: str, data_file: str):
    """Create instruction dataset given a LaTeX equations filepath and a output path."""
    eqn_file = open(eqn_file, 'r')
    data_file = open(data_file, 'a')
    
    for i, line in enumerate(eqn_file):
        logger.info(f" [DATA_GEN] Starting generation for line {i+1} type 'desc'")
        desc = generate_instruction(line, description_sys_prompt)
        save_data(desc, line, "desc", data_file)
        logger.info(f" [DATA_GEN] Finished generation for line {i+1} type 'desc'")

        logger.info(f" [DATA_GEN] Starting generation for line {i+1} type 'nl'")
        nl = generate_instruction(line, nl_sys_prompt)
        save_data(nl, line, "nl", data_file)
        logger.info(f" [DATA_GEN] Finished generation for line {i+1} type 'nl'")

        sleep(1)
        
        if i+1 == 20:
            break
    eqn_file.close()
    data_file.close()

In [8]:
create_instruction_data("sample.txt", "data.jsonl")