In [1]:
# !python.exe -m pip install --upgrade pip
# !pip install --upgrade jupyter ipywidgets

In [2]:
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [3]:
# !pip install transformers accelerate bitsandbytes flash-attn

## HuggingFace Setting

For Windows users, type the following command in Command Prompt:

```
setx HF_TOKEN "your_token_here"
```

For macOS users, type the following command in Terminal:

```
export HF_TOKEN="your_token_here"
```

In [4]:
import os
HF_TOKEN = os.getenv("HF_TOKEN")
HF_TOKEN[:3]+'...'

'hf_...'

## Load Model

In [5]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

In [6]:
import transformers
import torch

bnb_config = transformers.BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    token=HF_TOKEN,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, 
    token=HF_TOKEN, 
)

generator = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    pad_token_id=tokenizer.eos_token_id,
    # max_new_tokens=1024,
)

def get_response(chats): 
    gen_text = generator(chats)[0]  # First return sequence
    return gen_text['generated_text'][-1]['content']

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Load Data

In [7]:
from src.util.json_io import *

train_data = load_jsonlines('data/gsm8k/train.jsonl')
test_data = load_jsonlines('data/gsm8k/test.jsonl')

from src.util.gsm8k_helper import *

messages = nshot_chats(nshot_data=train_data, n=8, question=test_data[0]['question'])  # 8-shot prompt
messages

[{'role': 'user',
  'content': 'For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?'},
 {'role': 'assistant',
  'content': 'There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.\nSo, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.\nThere are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.\nSo, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.\nTherefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.\n#### 12'},
 {'role': 'user',
  'content': 'Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they pic

In [8]:
response = get_response(messages)
print(response)

pred_ans = extract_ans_from_response(response)
pred_ans

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Janet's ducks lay 16 eggs per day, so she has 16 eggs initially.

She eats 3 eggs for breakfast, leaving her with 16 - 3 = 13 eggs.

She uses 4 eggs to bake muffins, leaving her with 13 - 4 = 9 eggs.

She sells the remaining 9 eggs at the farmers' market for $2 per egg, making a total of 9 * $2 = $18.

#### 18


'18'

In [9]:
ground_truth = test_data[0]['answer']
print(ground_truth)

true_ans = extract_ans_from_response(ground_truth)
true_ans

Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18


'18'

In [10]:
messages = nshot_chats(nshot_data=train_data, n=8, question=test_data[1]['question'])  # 8-shot prompt
print(messages)

response = get_response(messages)
print(response)

pred_ans = extract_ans_from_response(response)
pred_ans

[{'role': 'user', 'content': 'For every 12 cans you recycle, you receive $0.50, and for every 5 kilograms of newspapers, you receive $1.50. If your family collected 144 cans and 20 kilograms of newspapers, how much money would you receive?'}, {'role': 'assistant', 'content': 'There are 144/12 = <<144/12=12>>12 sets of 12 cans that the family collected.\nSo, the family would receive $0.50 x 12 = $<<0.50*12=6>>6 for the cans.\nThere are 20/5 = <<20/5=4>>4 sets of 5 kilograms of newspapers that the family collected.\nSo, the family would receive $1.50 x 4 = $<<1.50*4=6>>6 for the newspapers.\nTherefore, the family would receive a total of $6 + $6 = $<<6+6=12>>12.\n#### 12'}, {'role': 'user', 'content': 'Betty picked 16 strawberries. Matthew picked 20 more strawberries than Betty and twice as many as Natalie. They used their strawberries to make jam. One jar of jam used 7 strawberries and they sold each jar at $4. How much money were they able to make from the strawberries they picked?'}, 

'The robe takes 2 bolts of blue fiber. For the white fiber it takes half of that which is 2/2 = 1 bolt. In total the robe takes 2 + 1 = 3 bolts of fiber.'

In [11]:
from tqdm import tqdm
import os

if not os.path.exists('log'):
    os.makedirs('log')

log_file_path = 'log/errors.txt'
with open(log_file_path, 'w') as log_file:
    log_file.write('')


total = correct = 0
for qna in tqdm(test_data[:20]):

    messages = nshot_chats(nshot_data=train_data, n=8, question=qna['question'])
    response = get_response(messages)
    
    pred_ans = extract_ans_from_response(response)
    true_ans = extract_ans_from_response(qna['answer'])

    total += 1
    if pred_ans != true_ans:
        with open(log_file_path, 'a', encoding='utf-8') as log_file:
            log_file.write(f"{messages}\n\n")
            log_file.write(f"Response: {response}\n\n")
            log_file.write(f"Ground Truth: {qna['answer']}\n\n")
            log_file.write(f"Current Accuracy: {correct/total:.3f}\n\n")
            log_file.write('\n\n')
    else:
        correct += 1

 40%|████      | 8/20 [00:44<01:18,  6.55s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
100%|██████████| 20/20 [02:26<00:00,  7.33s/it]


In [12]:
print(f"Total Accuracy: {correct/total:.3f}")

Total Accuracy: 0.350
