In [1]:
from dotenv import load_dotenv, find_dotenv
from llama_r1_zero.llama import Llama
import os
import json
import random
import re
from rich.console import Console
from rich.table import Table

load_dotenv(find_dotenv())

True

In [2]:
llama = Llama.build(
    ckpt_dir=os.environ.get("MODEL_PATH"),
    max_batch_size=2,
    max_seq_len=1024
)

In [3]:
with open("data/gsm8k_train.jsonl", "r") as f1, open("data/sciq_train.jsonl", "r") as f2:
    math = []
    sci = []
    for line in f1:
        math.append(json.loads(line))
    for line in f2:
        sci.append(json.loads(line))

len(math), len(sci) 

(7470, 12679)

In [4]:
math[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': '72'}

In [11]:
# some parts generated using gpt-4o :)
SYSTEM_PROMPT = """
You are a reasoning model, which given a question thinks before generating the final answer. 
The reasoning to answer the question MUST be within the '<think>' tags and the final, concise answer must be within the '<answer>' tags

Guidelines for Reasoning:
	1.	Break Down the Problem – Analyze the query carefully, identifying key components.
	2.	Logical Deduction – Apply reasoning, knowledge, or calculations to arrive at an informed answer.
	3.	Uncertainty Handling – If applicable, explicitly state assumptions or degrees of confidence.
	4.	Explain Thought Process – Ensure reasoning is transparent and verifiable.

Guidelines for Output:
	1. The thinking part should be within '<think>' and '</think>' tags.
    2. The final, concise answer content should be within '<answer>' and '</answer>' tags.

Following is an example:
Question: 
What is the derivative of  f(x) = x^2 + 3x ?
Response: 
<think>To find the derivative of \( f(x) = x^2 + 3x \), we differentiate each term separately:
- The derivative of \( x^2 \) is \( 2x \).
- The derivative of \( 3x \) is \( 3 \).
Thus, the derivative of \( f(x) \) is \( f'(x) = 2x + 3 \).</think><answer>2x + 3</answer>
"""

In [12]:
def sample_and_print():
    console = Console()

    sample_math = random.sample(math, k=1)[0]
    sample_sci = random.sample(sci, k=1)[0]

    think_pt = r'<think>(.*?)</think>'
    answer_pt = r'<answer>(.*?)</answer>'

    actual_answers = [sample_math['answer'], sample_sci['answer']]
    prompts = [
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Question: {sample_math['question']}"}
        ],
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"Question: {sample_sci['question']}"}
        ]
    ]

    completions: list[str] = llama.text_completion(prompts=prompts, temperature=0.1)

    table = Table(show_header=True, header_style="bold cyan")
    table.add_column("Question", style="dim", width=40)
    table.add_column("Thinking", style="yellow", width=40)
    table.add_column("LLM Answer", style="green", width=30)
    table.add_column("Actual Answer", style="bold magenta", width=20)

    for i, completion in enumerate(completions):
        think_match = re.search(think_pt, completion, re.DOTALL)
        think_text = think_match.group(1).strip() if think_match else "Not Found"

        final_answer_match = re.search(answer_pt, completion)
        final_answer = final_answer_match.group(1).strip() if final_answer_match else completion

       
        table.add_row(
            prompts[i][1]['content'],
            think_text, 
            final_answer,  
            actual_answers[i]  
        )

    console.print(table)

In [13]:
sample_and_print()