In [1]:
import os
cache_dir = '/scratch3/workspace/wenlongzhao_umass_edu-reason/dev_kedar/transformers_cache'
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME']=cache_dir
os.environ['HF_HUB_CACHE']=cache_dir+'/hub'
hf_token=os.getenv('hf_token')

In [2]:
import numpy as np
import random
from datasets import load_from_disk, Dataset, load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data=load_from_disk('../datasets/gsm8k/feedback/')
data=data[0:71]
data=Dataset.from_dict(data)
data

Dataset({
    features: ['question', 'answer'],
    num_rows: 71
})

In [4]:
import re

In [5]:
def formatting_prompts_func(examples):
    answer = format_answer(examples['answer'])
    # text = f'<|start_header_id|>user<|end_header_id|>\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: {examples['question']}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{answer}'
    text = f'<|start_header_id|>user<|end_header_id|>\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: {examples['question']}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{answer}<|eot_id|>'
    
    return text

def format_answer(answer):
        answer = re.sub(r'<<.*?>>', '', answer)
        answer = answer.replace('####', 'The final answer is')
        return answer

In [6]:
model_name='meta-llama/Llama-3.2-3B-Instruct'
output_dir= 'sft'
add_special_tokens= True
epochs= 5
lr=1e-5 
lr_scheduler_type= 'cosine'
warmup= 0.1 
weight_decay= 0.01
per_device_train_batch_size= 4
gradient_accumulation_steps= 4
max_seq_length= 500 
torch_dtype='bfloat16'

In [7]:
from trl import  SFTConfig, SFTTrainer
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id

response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"


collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# Set up the trainer
training_args = SFTConfig(
    model_init_kwargs={
        "torch_dtype": "bfloat16",
        "cache_dir":cache_dir
    },
    output_dir=output_dir,
    num_train_epochs=epochs,
    learning_rate=lr,
    lr_scheduler_type=lr_scheduler_type,
    weight_decay=weight_decay,
    warmup_ratio=warmup,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_strategy="epoch",
    logging_steps=100,
    # Using this 3072(prompt) + 512(output). The 3072(prompt) is taken from LLaMA : https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals?row=0
    max_seq_length  = max_seq_length
)

training_args.add_special_tokens = add_special_tokens

In [9]:
trainer = SFTTrainer(
        model=model_name,
        args=training_args,
        train_dataset=data,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        tokenizer=tokenizer
    )

  trainer = SFTTrainer(
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 58.81it/s]
Applying formatting function to train dataset: 100%|██████████| 71/71 [00:00<00:00, 11356.27 examples/s]
Converting train dataset to ChatML: 100%|██████████| 71/71 [00:00<00:00, 17647.15 examples/s]
Applying chat template to train dataset: 100%|██████████| 71/71 [00:00<00:00, 20416.54 examples/s]
Tokenizing train dataset: 100%|██████████| 71/71 [00:00<00:00, 2154.63 examples/s]
Tokenizing train dataset: 100%|██████████| 71/71 [00:00<00:00, 4551.29 examples/s]


In [10]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=20, training_loss=0.5624967098236084, metrics={'train_runtime': 185.8885, 'train_samples_per_second': 1.91, 'train_steps_per_second': 0.108, 'total_flos': 1192871346315264.0, 'train_loss': 0.5624967098236084})

In [4]:
teacher_data_path = '../outputs/exp-2.0.3/eval_1/generated_outputs.json'
teacher_data = eacher_data=load_dataset('json',data_files=teacher_data_path)['train']
teacher_data

Dataset({
    features: ['input', 'output', 'token_ids', 'log_probs', 'all_returned_log_probs', 'model_answer', 'GT_Answer', 'score'],
    num_rows: 1000
})

In [5]:
student_data_path='../outputs/exp-2.1.1/eval_1/logprobs.json'
student_data=load_dataset('json', data_files=student_data_path)['train']
student_data

Dataset({
    features: ['prompt', 'gt_reasoning', 'gt_answer', 'student_token_ids', 'student_reasoning', 'student_answer', 'student_correctness', 'student_log_probs'],
    num_rows: 1000
})

In [6]:
# data_path='../outputs/exp-2.1.1/eval_1/logprobs1.json'
# logprobs=load_dataset('json', data_files=data_path)['train']
# logprobs

In [9]:
def get_log_prob_ratio(teacher_log_prob, student_log_prob):
    tr_stu_logprob=[]
    student_logprob=[]
    teacher_logprob=[]
    for i in range(len(teacher_log_prob)):
        student_log_probs=np.array(student_log_prob[i])
        teacher_log_probs=np.array(teacher_log_prob[i])
        student_logprob.append(np.mean(student_log_probs))
        teacher_logprob.append(np.mean(teacher_log_probs))
        tr_stu_logprob.append(
            np.subtract(
                np.mean(teacher_log_probs),
                np.mean(student_log_probs)
            )
        )
                
    return tr_stu_logprob
def merge_and_sample_data(data,teacher_data,student_data, remove_incorrects, sampling_ratio, seed=42 ):
    # print(teacher_data['log_probs'][1])
    # print(len(teacher_data['log_probs']))
    tr_stu_logprob_ratio=get_log_prob_ratio(teacher_data['log_probs'],student_data['student_log_probs'])
    threshold=np.median(tr_stu_logprob_ratio)
    teacher_answers=[]
    teacher_scores=[]
    print(f'Median teacher-student-logprob-ratio: {threshold}')
    for i in range(teacher_data.num_rows):
        teacher_answers.append(teacher_data['output'][i][0])
        teacher_scores.append(teacher_data['score'][i])
    questions=data['question']
    new_data = {
        'question': questions,
        'answer': teacher_answers,
        'score': teacher_scores,
        'logprob_ratio':tr_stu_logprob_ratio
    }
    
    data= Dataset.from_dict(new_data)

    if remove_incorrects:
        data= data.filter(lambda x: x['score']==1)
    print(f'After removing incorrects from teacher:{data.num_rows}')
    
    
    rng = random.Random(seed)

    total_size = len(data)
    size_below = int(total_size * sampling_ratio)
    size_above = total_size - size_below

    below_thresh = data.filter(lambda example: example['logprob_ratio'] < threshold)
    above_thresh = data.filter(lambda example: example['logprob_ratio'] >= threshold)
    print(f'below threshold:{below_thresh.num_rows}')
    print(f'above threshold:{above_thresh.num_rows}')

    def upsample(ds, target_size):
        if len(ds) == 0:
            return ds  # Avoid divide-by-zero
        indices = [rng.randint(0, len(ds) - 1) for _ in range(target_size)]
        return ds.select(indices)

    sampled_below = upsample(below_thresh, size_below)
    sampled_above = upsample(above_thresh, size_above)

    data = concatenate_datasets([sampled_below, sampled_above])
    return data.shuffle(seed=seed)
    
    
    
    

In [10]:
merge_and_sample_data(data, teacher_data, student_data, True, 0.9)

Median teacher-student-logprob-ratio: 0.06102693236978156


Filter: 100%|██████████| 1000/1000 [00:00<00:00, 235993.02 examples/s]


After removing incorrects from teacher:951


Filter: 100%|██████████| 951/951 [00:00<00:00, 115990.09 examples/s]
Filter: 100%|██████████| 951/951 [00:00<00:00, 114864.46 examples/s]

below threshold:475
above threshold:476





Dataset({
    features: ['question', 'answer', 'score', 'logprob_ratio'],
    num_rows: 951
})

In [12]:
import re
def formatting_prompts_func(examples):
    answer = format_answer(examples['answer'])
    text = f'<|start_header_id|>user<|end_header_id|>\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: {examples['question']}\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{answer}'
    return text

def format_answer(answer):
        answer = re.sub(r'<<.*?>>', '', answer)
        answer = answer.replace('####', 'The final answer is')
        return answer

In [14]:
# Loading model
hf_token = os.getenv("hf_token")

# model_name= "meta-llama/Llama-3.2-1B-Instruct"
model_name= "meta-llama/Llama-3.2-3B-Instruct"
# model_name= "meta-llama/Llama-3.1-8B-Instruct"
# model_name= "meta-llama/Llama-3.2-3B"
# model_name= "meta-llama/Llama-3.3-70B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id



In [18]:
examples={'question':data['question'][0], 'answer':data['answer'][0]}
formatted = formatting_prompts_func(data[0])
print(formatted)


<|start_header_id|>user<|end_header_id|>

Given the following problem, reason and give a final answer to the problem.
Problem: Nicole collected 400 Pokemon cards. Cindy collected twice as many, and Rex collected half of Nicole and Cindy's combined total. If Rex divided his card equally among himself and his three younger siblings, how many cards does Rex have left?
Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Cindy has 400 x 2 = 800 cards.
Nicole and Cindy have 400 + 800 = 1200 cards.
Rex has 1200/2 = 600 cards.
Rex is left with 600/(3+1=4) = 150 cards
The final answer is 150


In [21]:
input_ids = tokenizer(formatted)
input_ids

{'input_ids': [128000, 128006, 882, 128007, 271, 22818, 279, 2768, 3575, 11, 2944, 323, 3041, 264, 1620, 4320, 311, 279, 3575, 627, 32298, 25, 45130, 14890, 220, 3443, 28831, 7563, 13, 70431, 14890, 11157, 439, 1690, 11, 323, 42907, 14890, 4376, 315, 45130, 323, 70431, 596, 11093, 2860, 13, 1442, 42907, 18255, 813, 3786, 18813, 4315, 5678, 323, 813, 2380, 14992, 37783, 11, 1268, 1690, 7563, 1587, 42907, 617, 2163, 5380, 7927, 2077, 1288, 842, 449, 330, 791, 1620, 4320, 374, 510, 9399, 19727, 1405, 510, 9399, 60, 374, 279, 2077, 311, 279, 3575, 627, 128009, 128006, 78191, 128007, 271, 34, 50090, 706, 220, 3443, 865, 220, 17, 284, 220, 4728, 7563, 627, 58916, 1286, 323, 70431, 617, 220, 3443, 489, 220, 4728, 284, 220, 4364, 15, 7563, 627, 49, 327, 706, 220, 4364, 15, 14, 17, 284, 220, 5067, 7563, 627, 49, 327, 374, 2163, 449, 220, 5067, 12148, 18, 10, 16, 28, 19, 8, 284, 220, 3965, 7563, 198, 791, 1620, 4320, 374, 220, 3965], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [17]:
print(tokenizer.convert_ids_to_tokens(input_ids["input_ids"]))

['<|begin_of_text|>', '<|start_header_id|>', 'user', '<|end_header_id|>', 'ĊĊ', 'Given', 'Ġthe', 'Ġfollowing', 'Ġproblem', ',', 'Ġreason', 'Ġand', 'Ġgive', 'Ġa', 'Ġfinal', 'Ġanswer', 'Ġto', 'Ġthe', 'Ġproblem', '.Ċ', 'Problem', ':', 'ĠNicole', 'Ġcollected', 'Ġ', '400', 'ĠPokemon', 'Ġcards', '.', 'ĠCindy', 'Ġcollected', 'Ġtwice', 'Ġas', 'Ġmany', ',', 'Ġand', 'ĠRex', 'Ġcollected', 'Ġhalf', 'Ġof', 'ĠNicole', 'Ġand', 'ĠCindy', "'s", 'Ġcombined', 'Ġtotal', '.', 'ĠIf', 'ĠRex', 'Ġdivided', 'Ġhis', 'Ġcard', 'Ġequally', 'Ġamong', 'Ġhimself', 'Ġand', 'Ġhis', 'Ġthree', 'Ġyounger', 'Ġsiblings', ',', 'Ġhow', 'Ġmany', 'Ġcards', 'Ġdoes', 'ĠRex', 'Ġhave', 'Ġleft', '?Ċ', 'Your', 'Ġresponse', 'Ġshould', 'Ġend', 'Ġwith', 'Ġ"', 'The', 'Ġfinal', 'Ġanswer', 'Ġis', 'Ġ[', 'answer', ']"', 'Ġwhere', 'Ġ[', 'answer', ']', 'Ġis', 'Ġthe', 'Ġresponse', 'Ġto', 'Ġthe', 'Ġproblem', '.Ċ', '<|eot_id|>', '<|start_header_id|>', 'assistant', '<|end_header_id|>', 'ĊĊ', 'C', 'indy', 'Ġhas', 'Ġ', '400', 'Ġx', 'Ġ', '2', 'Ġ=', 'Ġ