In [1]:
import os
import torch
import transformers

HF_TOKEN = os.getenv("HF_TOKEN")

model_name = "mistralai/Mistral-7B-Instruct-v0.1"
# model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B"
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
# model_name = "google/gemma-2-2b"
# model_name = "google/gemma-2-2b-it"
# model_name = "google/gemma-2-9b"
# model_name = "google/gemma-2-9b-it"

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    quantization_config=transformers.BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
    token=HF_TOKEN,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name,
    token=HF_TOKEN,
)

print(model, model.config)

VBox(children=(Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s],))

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNo

In [2]:
def generate_text(input_text):
    input_ids = tokenizer.encode(
        input_text, 
        return_tensors='pt', 
    ).to(model.device)

    with torch.no_grad():
        generated_tokens = model.generate(
            input_ids,
            max_length=input_ids.shape[1] + 512,
            do_sample=True, top_k=1, # greedy decoding (deterministic)
            # eos_token_id=tokenizer.encode(text='\n\n', add_special_tokens=False)[0],
            pad_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(generated_tokens[0])

    return generated_text

In [3]:
# Ref: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1

question_start = '[INST] Question: '  # "[INST] "
question_end = ' [\INST] '  # "[\INST] \n"

answer_start = 'Answer: ' # '' # "Answer: "
answer_end = '\n' # ''

oneshot_start = '' # ''  # tokenizer.bos_token
oneshot_end = '' # ''  # tokenizer.eos_token # tokenizer.eos_token + "\n\n"

In [4]:
def generate_oneshot_prompt(question, answer):
    return f'{oneshot_start}{question_start}{question}{question_end}{answer_start}{answer}{answer_end}{oneshot_end}'

# Simple multi-turn conversation
q1 = 'Write an email in the style of a pirate.'
a1 = 'To your frend Andrew?'
q2 = 'Yes.'

oneshot_prompt = generate_oneshot_prompt(q1, a1)
prompt = oneshot_prompt + question_start + q2 + question_end + answer_start
prompt

'[INST] Question: Write an email in the style of a pirate. [\\INST] Answer: To your frend Andrew?\n[INST] Question: Yes. [\\INST] Answer: '

In [5]:
generate_text(prompt)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


"<s> [INST] Question: Write an email in the style of a pirate. [\\INST] Answer: To your frend Andrew?\n[INST] Question: Yes. [\\INST] Answer: \n\nArr, me hearty friend Andrew!\n\nHow be ye farin'? I hope this missive finds ye in good health and high spirits. I've been havin' a grand ol' time on the high seas, huntin' for treasure and makin' new friends along the way.\n\nBut enough about me, tell me of your own adventures! Have ye been keepin' busy? I'd love to hear of any new challenges or victories ye've had recently. And if ye've got any tales of daring deeds or swashbucklin' escapades, well, I'm all ears!\n\nIn the meantime, I've got me eye on a new prize - a treasure map that's said to lead to a hidden treasure beyond imagination. I'm thinkin' of puttin' together a crew and settin' sail for the X that marks the spot. Would ye be interested in joinin' me on this grand adventure?\n\nYours truly,\nCaptain Jack</s>"

In [6]:
from src.util.json_io import *

train_qnas = load_jsonlines(f'data/gsm8k/train.jsonl')
# train_qnas = load_jsonlines(f'data/gsm8k/train_socratic.jsonl')
test_qnas = load_jsonlines(f'data/gsm8k/test.jsonl')
# test_qnas = load_jsonlines(f'data/gsm8k/test_socratic.jsonl')
len(train_qnas), len(test_qnas)

(7473, 1319)

In [7]:
import random; rseed = 42; random.seed(rseed)

nshot_prompt = f""
for i in random.sample(range(len(train_qnas)), 8):
    oneshot_prompt = generate_oneshot_prompt(train_qnas[i]['question'], train_qnas[i]['answer'])
    nshot_prompt += oneshot_prompt

print(nshot_prompt)

[INST] Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive? [\INST] Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12
[INST] Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked? [\INST] Answer: Matthew picked 16 + 20 = <<16+20=36>>36 s

In [8]:
def get_prompt(question, nshot_prompt=nshot_prompt):
    cot_prompt = " Let's think step by step."
    return f"{nshot_prompt}{oneshot_start}{question_start}{question}{cot_prompt}{question_end}{answer_start}"

sample_i = 4
print(get_prompt(test_qnas[sample_i]['question']))

[INST] Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive? [\INST] Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12
[INST] Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked? [\INST] Answer: Matthew picked 16 + 20 = <<16+20=36>>36 s

In [9]:
print('=== Input Text ===\n', get_prompt(test_qnas[sample_i]['question']))

=== Input Text ===
 [INST] Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive? [\INST] Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12
[INST] Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked? [\INST] Answer: Matthew picked 16 + 2

In [10]:
def get_answer(generated_text):
    return generated_text.split(question_end+answer_start)[-1].split(tokenizer.eos_token)[0].split('\n\n')[0].split('[')[0]

generated_text = generate_text(get_prompt(test_qnas[sample_i]['question']))
print('\n=== Generated Text===\n', generated_text)
print('\n=== Generated Answer===\n', get_answer(generated_text))


=== Generated Text===
 <s> [INST] Question: For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive? [\INST] Answer: There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.
So, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.
There are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.
So, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.
Therefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.
#### 12
[INST] Question: Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked? [\INST] Answer: Matthew picke

In [11]:
from src.util.gsm8k_helper import *
print('=== Ground Truth ===')
print(test_qnas[sample_i]['answer'])
print('\n=== Ground Truth (Integer) ===')
print(extract_num_from_ans(test_qnas[sample_i]['answer']))
print('\n=== Generated (Integer) ===')
print(extract_num_from_ans(get_answer(generated_text)))

=== Ground Truth ===
If each chicken eats 3 cups of feed per day, then for 20 chickens they would need 3*20=<<3*20=60>>60 cups of feed per day.
If she feeds the flock 15 cups of feed in the morning, and 25 cups in the afternoon, then the final meal would require 60-15-25=<<60-15-25=20>>20 cups of chicken feed.
#### 20

=== Ground Truth (Integer) ===
20

=== Generated (Integer) ===
25


In [None]:
from datetime import datetime

current_time = datetime.now().strftime("%y%m%d_%H%M%S")
errors_log_path = f"log/errors-{model_name.replace('/', '-')}-{current_time}.txt"
with open(errors_log_path, 'w', encoding='utf-8') as log_file:
    log_file.write(f"{model}\n\n")
    log_file.write(f"{model.config}\n\n")
    log_file.write(f"[random seed]: {rseed}\n\n")
    log_file.write(f"[Nshot prompt]:\n{nshot_prompt}\n\n")
    log_file.write("=" * 40 + "\n\n")

correct_count = total_count = 0
for i, qna in enumerate(test_qnas[:100]):
    input_text = get_prompt(qna['question'])

    generated_answer = get_answer(generate_text(input_text))
    generated_answer_int = extract_num_from_ans(generated_answer)
    ground_truth_int = extract_num_from_ans(qna['answer'])

    total_count += 1
    if generated_answer_int == ground_truth_int:
        correct_count += 1
        print('o', end='')
    else:
        print('x', end='')
        with open(errors_log_path, 'a', encoding='utf-8') as log_file:
            log_file.write(f"[Question (i:{i})]\n{qna['question']}\n\n")
            log_file.write(f"[Response {model_name}]\n{generated_answer}\n\n")
            log_file.write(f"[Ground Truth]\n{qna['answer']}\n\n")
            log_file.write(f"[Accuracy]\n{correct_count / total_count * 100:.2f}\n\n\n")
print(f"\nAccuracy: {correct_count / total_count * 100:.2f}%")

print(f"Errors logged to {errors_log_path}")