## Gemma 2 2B Instruct Performance on GSM8k (Post Fine Tuning)

In [1]:
!pip install -U torch
!pip install -U datasets
!pip install -U sentence_transformers
!pip install -U accelerate
!pip install -U transformers
!pip install -U bitsandbytes
!pip install -U peft
!pip install -U trl

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [13]:
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, PeftModel

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [14]:
compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

In [15]:
output_dir = "/content/drive/MyDrive/lora_gold/checkpoint-750"

In [16]:
from peft import PeftModel

ft_tokenizer = AutoTokenizer.from_pretrained(output_dir)
ft_tokenizer.padding_side = "right"
ft_tokenizer.pad_token = ft_tokenizer.eos_token

In [17]:
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", quantization_config=bnb_config, device_map="auto")
ft_model = PeftModel.from_pretrained(base_model, output_dir)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
gsm8k = load_dataset("gsm8k", "main", cache_dir='/tmp')
gsm8k_train, gsm8k_test = gsm8k['train'], gsm8k['test']

In [22]:


def find_numbers(x: str) -> list[str]:
  numbers = re.compile(
      r'-?[\d,]*\.?\d+',
      re.MULTILINE | re.DOTALL | re.IGNORECASE,
  ).findall(x)
  return numbers

def find_number(x: str, answer_delimiter: str = '**Answer:**') -> str:
  if answer_delimiter in x:
    answer = x.split(answer_delimiter)[-1]
    numbers = find_numbers(answer)
    if numbers:
      return numbers[0]

  numbers = find_numbers(x)
  if numbers:
    return numbers[-1]
  return ''

def find_response_number(x: str) -> str:
  numbers = find_numbers(x)
  if numbers:
    return numbers[0]
  return ''


def maybe_remove_comma(x: str) -> str:
  return x.replace(',', '')

In [23]:
%%time

all_correct = 0
all_responses = {}
short_responses = {}
idx = 0
correct = 0


for task_id, problem in enumerate(gsm8k_test):

  if task_id == 50: break

  prompt = "Question: " + problem['question'] + "\nAnswer:"
  print(f"task_id {task_id}")

  input_ids = ft_tokenizer(prompt, return_tensors='pt').to("cuda")
  outputs = ft_model.generate(input_ids=input_ids.input_ids, max_length=1024, num_return_sequences=1, pad_token_id=ft_tokenizer.eos_token_id)
  response_text = ft_tokenizer.decode(outputs[0], skip_special_tokens=True)

  print(f"Model Response: {response_text}")

  all_responses[task_id] = response_text.strip().split("\n")[-1]
  short_responses[task_id] = maybe_remove_comma(find_response_number(all_responses[task_id]))
  print(f"Short answer: {short_responses[task_id]}")

  try:
    correct += float(maybe_remove_comma(find_number(problem['answer']))) == float(short_responses[task_id])
  except:
    correct += maybe_remove_comma(find_number(problem['answer'])) == maybe_remove_comma(find_response_number(short_responses[task_id]))

  print('-'*40)
  print(f"Ground truth answer {problem['answer']}")
  print(f"Short ground truth answer {find_number(problem['answer'])}")
  print(f"Correct: {correct} out of {idx+1}")
  print("="*40)
  idx += 1

task_id 0
Model Response: Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: Janet eats 3 eggs for breakfast every day, so she eats 3 eggs per day.

She bakes muffins for her friends every day with 4 eggs, so she bakes 4 eggs per day.

The total number of eggs she eats and bakes is 3 + 4 = 7 eggs per day.

She sells the remainder of the eggs, which is 16 - 7 = 9 eggs per day.

She sells these eggs for $2 per egg, so she makes 9 * $2 = $18 per day at the farmers' market.

Therefore, Janet makes $18 every day at the farmers' market.
Short answer: 18
----------------------------------------
Ground truth answer Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18
Short g

In [25]:
print(f"Model Accuracy (Post Fine Tuning) : {correct/(idx) * 100}%")

Model Accuracy (Post Fine Tuning) : 42.0%
