In [None]:
import os
import torch
import transformers

HF_TOKEN = os.getenv("HF_TOKEN")

# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B"
model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "google/gemma-2-2b"
# model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-9b"
# model_name = "google/gemma-2-9b-it"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    token=HF_TOKEN,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
    token=HF_TOKEN,
)

print(model, model.config)

VBox(children=(Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s],))

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e

In [2]:
from src.util.json_io import *

train_qnas = load_jsonlines(f'data/gsm8k/train.jsonl')
test_qnas = load_jsonlines(f'data/gsm8k/test.jsonl')
len(train_qnas), len(test_qnas)

(7473, 1319)

In [None]:
import random
random.seed(42)

nshot_prompt = f""
for i in random.sample(range(len(train_qnas)), 8):
    nshot_prompt += f"Question: {train_qnas[i]['question']}\nAnswer: {train_qnas[i]['answer']}\n\n"

print(nshot_prompt)

Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberries.
Natalie picked 3

In [4]:
sample_question = 'An airport has only 2 planes that fly multiple times a day. Each day, the first plane goes to Greece for three-quarters of its flights, and the remaining flights are split equally between flights to France and flights to Germany. The other plane flies exclusively to Poland, and its 44 trips only amount to half the number of trips the first plane makes throughout each day. How many flights to France does the first plane take in one day?'
sample_answer = '''The second plane flies half as much as the first, so the first plane makes 44 flights * 2 = <<44*2=88>>88 flights a day.
If 3/4 of the first plane’s flights are to Greece, then flights to France or Germany make up 1 – 3/4 = 1/4 of the total daily flights.
Therefore, 88 daily flights / 4 = <<88/4=22>>22 flights to France or Germany.
Splitting these flights equally means the first plane makes 22 flights / 2 = <<22/2=11>>11 flights to France in one day.
#### 11'''

In [5]:
from src.util.gsm8k_helper import *
extract_num_from_ans(sample_answer)

11

In [6]:
def generate_answer(input_text):

    input_ids = tokenizer.encode(
        input_text, 
        return_tensors='pt', 
    ).to(model.device)

    with torch.no_grad():
        generated_tokens = model.generate(
            input_ids,
            max_length=input_ids.shape[1] + 512,
            # top_k=1, # greedy decoding (deterministic)
            eos_token_id=tokenizer.encode(text='\n\n', add_special_tokens=False)[0],
            pad_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(generated_tokens[0])

    return generated_text.split('Answer: ')[-1].split(tokenizer.special_tokens_map['eos_token'])[0].split('\n\n')[0]

input_text = nshot_prompt + sample_question + " Let's think step by step.\nAnswer: "
print('=== Input Text ===\n', input_text)
print('\n=== Ground Truth ===\n', sample_answer)
print('\n=== Generated ===\n', generate_answer(input_text))

=== Input Text ===
 Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?
Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12

Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?
Answer: Matthew picked 16 + 20 = <<16+20=36>>36 strawberri

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)



=== Generated ===
 44 is half of the number of flights the first plane makes. 44 * 2 = 88
Since the first plane goes to Greece for three-quarters of its flights, and the remaining flights are split equally between flights to France and flights to Germany, 88/4 = 22 flights to Greece
The first plane flies 88 flights in total, so 88 - 22 = 66 flights to France
#### 66


In [7]:
from datetime import datetime

current_time = datetime.now().strftime("%y%m%d_%H%M%S")
errors_log_path = f"log/errors-{model_name.replace('/', '-')}-{current_time}.txt"
with open(errors_log_path, 'w', encoding='utf-8') as log_file:
    log_file.write(f"{model}\n\n")
    log_file.write(f"{model.config}\n\n")
    log_file.write(f"nshot_prompt:\n{nshot_prompt}\n")
    log_file.write("=" * 40 + "\n\n")

correct_count = total_count = 0
for i, qna in enumerate(test_qnas[:10]):
    input_text = nshot_prompt + qna['question'] + " Let's think step by step." +"\nAnswer: "

    generated_answer = generate_answer(input_text)
    generated_answer_int = extract_num_from_ans(generated_answer)
    ground_truth_int = extract_num_from_ans(qna['answer'])

    total_count += 1
    if generated_answer_int == ground_truth_int:
        correct_count += 1
        print('o', end='')
    else:
        print('x', end='')
        with open(errors_log_path, 'a', encoding='utf-8') as log_file:
            log_file.write(f"[Question {i}]\n {qna['question']}\n\n")
            log_file.write(f"[Response {model_name}]\n {generated_answer}\n\n")
            log_file.write(f"[Ground Truth]\n {qna['answer']}\n\n")
            log_file.write(f"[Accuracy]\n {correct_count / total_count * 100:.2f}\n\n\n")
print(f"\nAccuracy: {correct_count / total_count * 100:.2f}%")

print(f"Errors logged to {errors_log_path}")

ooxoxxoxxoooxoxoxoxoxxoooxoxxx
Accuracy: 50.00%
Errors logged to log/errors-meta-llama-Llama-3.2-3B-Instruct-241104_201132.txt
