# RFT - GMS8K 
- Candidate: Eric Liu 

# Loading

In [1]:
%load_ext autoreload
%autoreload 2

import os  
import torch 
import numpy as np 

from tqdm import tqdm 
from textwrap import dedent  

import utils 
import prompt 
from utils import GSM8KParser, GMS8KEvaluator
from datasets import load_dataset
from main import GSM8KDataset, Phi3LightningModule 

from sympy.parsing.sympy_parser import parse_expr 
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from torch.utils.data import DataLoader

import pytorch_lightning as pl 
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import RichProgressBar 

import wandb 

os.environ["TOKENIZERS_PARALLELISM"] = "true"
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"

In [2]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.special_tokens_map_extended

{'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 'pad_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True)}

# Dataset Exploration 

## Instpect text 

In [4]:
train_dataset = load_dataset('gsm8k', 'main')['train']
val_dataset = load_dataset('gsm8k', 'main')['test'] 
print(f"Num Training instances: {len(train_dataset)}")
print(f"Num Validation instances: {len(val_dataset)}")
print(type(train_dataset)) 

Num Training instances: 7473
Num Validation instances: 1319
<class 'datasets.arrow_dataset.Dataset'>


In [5]:
train_dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [6]:
for _ in range(5):
    seed = np.random.randint(0, len(train_dataset))
    print("*"*100)
    print(f"Checking instance {seed}:")
    utils.inspect_instance(train_dataset, seed)

****************************************************************************************************
Checking instance 3917:
question
When Jack traveled to Canada, he had to wait 20 hours to get through customs, plus 14 days in coronavirus quarantine. How many hours total did Jack have to wait?
answer
First convert the quarantine wait time to hours by multiplying the number of days by the number of hours in a day: 14 days * 24 hours/day = <<14*24=336>>336 hours
Then add the time Jack spent waiting in customs to the quarantine time: 336 hours + 20 hours = <<336+20=356>>356 hours
#### 356
**************************************************
****************************************************************************************************
Checking instance 3259:
question
Michael has $42. Michael gives away half the money to his brother. His brother then buys 3 dollars worth of candy. If his brother has $35 left, how much money, in dollars, did his brother have at first?
answer
Michael giv

## Extract statistics 

### Calculate Num Tokens

We only look at train set now for certain information that will be used
during inference 

- Maximum length (num_tokens) of question: 239
- Maximum length (num_tokens) of answer: 475 

In [7]:
train_dataset = train_dataset.map(
    lambda x: GSM8KParser.get_question_length(x['question'], tokenizer)
)

train_dataset = train_dataset.map(
    lambda x: GSM8KParser.get_answer_length(x['answer'], tokenizer) 
)
print(f"Maximum answer num_tokens: {max(train_dataset['answer_length'])}")
print(f"Maximum question num_tokens: {max(train_dataset['question_length'])}")

Maximum answer num_tokens: 475
Maximum question num_tokens: 239


### Extract answers

In [8]:
# infer number of hops 
train_dataset = train_dataset.map(
    lambda x: GSM8KParser.get_num_hops(x['answer'])
)

# infer answes using ground truth parser 
train_dataset = train_dataset.map(
    lambda x: GSM8KParser.get_answer_from_gt(x['answer'])
)

In [9]:
# Optinal Cell (Only to verify that parsing from 
# ground truth and parsing from completion would 
# yield the same result 

# infer answers using prediction parser
answer_str_inf = [
    GSM8KParser.get_answer_from_pred(x)['answer_str_digit'] \
    for x in train_dataset['answer']
]
assert answer_str_inf == train_dataset['answer_str_digit']

### Instance Generation 
We selected the longest dataset 

In [None]:
generation_config = {
    "max_new_tokens" : MAX_NEW_TOKNES_SAMPE,
    "temperature": 0.7,
    "num_return_sequences":1,
    "top_p": 0.9,
    "eos_token_id":tokenizer.eos_token_id,  # Specify the EOS token
    "pad_token_id":tokenizer.eos_token_id, 
    "do_sample":True,
    "output_scores":True,
    "return_dict_in_generate":True,
}

In [None]:
instance = sorted(
    train_dataset, 
    key=lambda x: x['answer_length'], 
    reverse=True
)[50]

chat = [
    {
        "role": "system",
        "content": prompt.EvalTemplate.system
    },
    {
        "role": "user",
        "content": prompt.EvalTemplate.user.format(
            question=instance['question'],
            eos_token = tokenizer.eos_token,
        )
    }
]
prompts = tokenizer.apply_chat_template(
    [chat],  
    add_generation_prompt=True,
    tokenize = False,
    return_tensors='pt',
    )

print(len(prompts))
print(prompts[0])

In [None]:
print(instance["answer"], instance["answer_str_digit"])

In [None]:
model.eval()
outs = utils.sample_answers(
    tokenizer,
    model,
    prompts,
    **generation_config,
)

In [None]:
print(outs[0])

In [None]:
preds = [GSM8KParser.get_answer_from_pred(out)["answer_str_digit"] for out in outs]
print(preds)

evaluator = GMS8KEvaluator()
refs =  [instance["answer_str_digit"]]
print(refs)

maj_accs = [
    evaluator.get_maj_at_k(pred, ref) \
    for pred, ref in zip(preds, refs)
]

print(maj_accs)

In [None]:
probs.shape 

In [None]:
print(out[1])

In [None]:
print(
    GSM8KParser.get_answer_from_pred(out[1])["answer_str_digit"]
)

print(instance["answer"])


In [None]:
evaluator = GMS8KEvaluator()

# Base Model Eval 
***

**Before we start, let's get a good hang of the performance of the base model**

In [10]:
valData = GSM8KDataset(val_dataset, tokenizer)
val_dataloader = DataLoader(
    valData, 
    batch_size=4, 
    shuffle=False, 
    num_workers=16,
)
generation_config = {
    "max_new_tokens" : valData.inf_seq_length,
    "temperature": 0.7,
    "num_return_sequences":1,
    "top_p": 0.9,
    "eos_token_id":tokenizer.eos_token_id,  # Specify the EOS token
    "pad_token_id":tokenizer.eos_token_id, 
    "do_sample":True,
    "output_scores":True,
    "return_dict_in_generate":True,
    #"cache_implementation":"static"
}#     val_loader = DataLoader(val_data, batch_size=4, num_workers=4)
print(f"Maximum num_tokens for inference: {valData.inf_seq_length}")

Maximum answer num_tokens: 430
Maximum question num_tokens: 289
Maximum sequence num_tokens: 719
Maximum new tokens in generation: 1024
Setup Completed dataset:
Dataset({
    features: ['question', 'answer', 'answer_str_digit', 'question_length', 'answer_length', 'question_input_ids', 'question_attention_mask', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1319
})
Maximum num_tokens for inference: 1024


In [12]:
module = Phi3LightningModule(
    MODEL_NAME, 
    generation_config=generation_config
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
wandb_logger = WandbLogger(
    project="phi3-gsm8k-training", 
    log_model="all"
)

pbar = RichProgressBar()
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="auto",
    devices=-1,
    logger=wandb_logger,
    #strategy='DDP',
    #callbacks=[RichProgressBar()]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
trainer.test(
    module,  
    dataloaders=val_dataloader, 
)

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. Calling `get_max_cache()` will raise error from v4.48
The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library.

KeyboardInterrupt



# Rejection Sampling 

In [None]:
idx = 0 
instance = sorted_data[idx]

chat = [
    {
        "role": "system",
        "content": prompt.Template.system
    },
    {
        "role": "user",
        "content": prompt.Template.user.format(question=instance['question'])
    }
]


In [None]:
# idxs = [100, 1000, 2000]

# convs = []
# for idx in idxs:
#     instance = train_dataset[idx] 
#     conv =[
#         {
#             "role": "system",
#             "content": prompt.Template.system
#         },
#         {
#             "role": "user",
#             "content": prompt.Template.user.format(question=instance['question'])
#         }
#     ]
#     convs.append(conv)

chats = tokenizer.apply_chat_template(
    [chat],  
    add_generation_prompt=True,
    tokenize = False,
    return_tensors='pt',
    )
print(type(chat), len(chats))
print(chats[0])

In [None]:
#tokenizer.batch_encode_plus(chats, return_tensors='pt', padding='longest')["input_ids"]

In [None]:
from transformers.utils import is_flash_attn_2_available 
is_flash_attn_2_available()

In [None]:
model.eval()
samples = utils.sample_answers(
    tokenizer=tokenizer, 
    model=model, 
    chats = chats,
    max_new_tokens=256, 
    temperature=0.5,
    num_samples=10,
    top_p= 0.85,
)

In [None]:
print(len(samples))

In [None]:
rand_samples_base = ''

for sample in samples:
    rand_samples_base += (sample + '\n') 
    rand_samples_base += ("*"*50 + '\n') 

print(rand_samples_base)
with open("long_hop.txt", 'w') as f:
    f.write(rand_samples_base)
f.close()