In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import re
import random
from google.colab import drive
drive.mount('/content/drive')
save_path = '/content/drive/MyDrive/qwen_beta=0_001_E_5_epochs=3'
import pandas as pd
from torch.utils.data import DataLoader


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:

#stopwatch placeholder

In [3]:
LR = 3e-6
CHECKPOINT = 'Qwen/Qwen2-0.5B-Instruct'
BATCH_SIZE = 64

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

cuda


In [5]:
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT,
                                             dtype = torch.bfloat16).to(device)
model.train()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2

In [6]:
ref_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT,
                                             dtype = torch.bfloat16).to(device)
ref_model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2

In [7]:
#freeze ref_model
for param in ref_model.parameters():
    param.requires_grad = False

print(any(p.requires_grad for p in ref_model.parameters()))

False


In [8]:
dataset = load_dataset('gsm8k','main')['train']
test_set = load_dataset('gsm8k','main')['test']


In [9]:
def format_prompt(question):
    system_message = 'Think step by step to answer the question given to you. After clearly stating your reasoning, state your final numerical answer after "#### ". For example if the answer is 2 finish your answer with: #### 2. The final answer is always an integer with no spaces'
    messages = [{'role':'system','content':system_message},
                {'role':'user','content':question}]

    formatted_message = tokenizer.apply_chat_template(messages,
                                                     add_generation_prompt=True,
                                                      tokenize = False)
    return formatted_message

In [10]:
def extract_answer(text):
    pattern = r'#### (.*)'
    matches = re.findall(pattern, text, re.DOTALL) #dotall is important
    if matches:
        try:
            return int(matches[0])
        except:
            return None
    else:
        return None

def map_func(example):
    prompt = format_prompt(example['question'])
    solution = example['answer']
    answer = extract_answer(solution)

    return {'prompt':prompt,'solution':solution,'answer':answer}



In [11]:
dataset = dataset.map(map_func)
dataset = dataset.filter(lambda example: example['answer'] is not None)

test_set = test_set.map(map_func)
test_set = test_set.filter(lambda example: example['answer'] is not None)

In [12]:
loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)

In [13]:
len(loader),len(test_loader)

(116, 21)

In [14]:
def generate_answers(prompts,num_return_sequences = 8,temp = 1,use_completion_dict = False):
    '''
    input:
    prompts: List[str]
    output:
    completions: List[str] or Dict[str,List[str]]
    '''


    prompt_inputs = tokenizer(prompts,return_tensors='pt',padding_side='left',padding = True).to(device)
    prompt_length = len(prompt_inputs['input_ids'][0])

    outputs = model.generate(
        **prompt_inputs,
        num_return_sequences = num_return_sequences,
        max_new_tokens = 1024,
        do_sample = True,
        temperature = temp,
        top_p = 0.95,
    )
    completions = tokenizer.batch_decode(outputs[:,prompt_length:],skip_special_tokens=True)
    if use_completion_dict:
        completion_dict = {}
        for prompt_num,i in enumerate(range(0,len(completions),num_return_sequences)):
            completion_dict[prompts[prompt_num]] = completions[i:i+num_return_sequences]
        return completion_dict

    return completions

def extract_reasoning(text):
    splitted = text.split('####')
    return splitted[0]

def calculate_rewards(rollouts,real_answer,R = None):
    rewards = []
    if R == 'penalty':
        rollouts.sort(key=lambda x: len(tokenizer.encode(extract_reasoning(x))),reverse=True)
    elif R == 'reward':
        rollouts.sort(key=lambda x: len(tokenizer.encode(extract_reasoning(x))),reverse=False)

    for i,rollout in enumerate(rollouts):
        try:
            ans = extract_answer(rollout)
        except:
            ans = None


        if ans == None:
            rewards.append(-0.1)

        elif ans == real_answer:
            if R == None:
                rewards.append(1)
            else:
                rewards.append((i+1)/len(rollouts))
        else:
            rewards.append(0)
    return rewards

def calculate_advantages(rewards):
    advantages = []
    mean = sum(rewards)/len(rewards)
    # std = (sum([(reward-mean)**2 for reward in rewards])/len(rewards))**0.5
    # if std == 0:
    #     std = 1
    std = 1
    for reward in rewards:
        advantage = (reward-mean)/std
        advantages.append(advantage)
    return advantages

def is_correct(ans,real_ans):
    if ans == real_ans:
        return True
    else:
        return False

def calculate_accuracy(rollouts,real_answer):
    correct = 0
    for rollout in rollouts:
        if extract_answer(rollout) == real_answer:
            correct += 1
    return correct/len(rollouts)






In [15]:
def get_log_probs(model,prompt,completions,strategy = 'sum'):
    '''
    input:
    prompt: str
    completions: List[str]
    output:
    log_probs: torch.tensor[float]
    '''
    prompt_length = len(tokenizer.encode(prompt))
    prompt_completions = [prompt+completion for completion in completions]

    prompt_inputs = tokenizer(prompt_completions,return_tensors='pt',padding_side='right',padding = True).to(device)
    completion_ids = prompt_inputs['input_ids'][:,prompt_length:]

    logits = model(**prompt_inputs).logits
    completion_logits = logits[:,prompt_length-1:-1,:]
    log_probs = torch.nn.functional.log_softmax(completion_logits,dim=-1)
    log_probs = torch.gather(log_probs,-1,completion_ids.unsqueeze(-1)).squeeze(-1)

    mask = prompt_inputs['attention_mask'][:,prompt_length:]
    log_probs = log_probs*mask

    if strategy == 'sum':
        log_probs = log_probs.sum(dim=-1)

    elif strategy == 'mean':
        summed_log_probs = log_probs.sum(dim=-1)
        num_tokens = mask.sum(dim=-1)
        log_probs = summed_log_probs/num_tokens

    elif strategy == 'list':
        N = log_probs.shape[0]
        list_of_tensors = []
        non_masked_token_number = mask.sum(dim = -1)
        for i in range(N):
            list_of_tensors.append(log_probs[i,:non_masked_token_number[i]])

        log_probs = list_of_tensors


    return log_probs

In [16]:
#deprecated

# def calculate_loss(prompt,completions,real_answer,R = None,beta = 0):
#     rewards = calculate_rewards(completions,real_answer,R = R)
#     advantages = torch.tensor(calculate_advantages(rewards)).to(device)
#     total_objective = 0

#     log_probs_summed = get_log_probs(model,prompt,completions,strategy = 'sum')
#     total_objective += (log_probs_summed*advantages).mean()

#     log_probs_ref_meanned = get_log_probs(ref_model,prompt,completions,strategy = 'mean')
#     log_probs_meanned = get_log_probs(model,prompt,completions,strategy = 'mean')

#     kl_div = (log_probs_meanned-log_probs_ref_meanned).mean()

#     return -(total_objective/1024.0 - beta*kl_div),kl_div.item()




In [17]:
def calculate_loss(prompt,completions,real_answer,R = None,beta = 0):
    rewards = calculate_rewards(completions,real_answer,R = R)
    advantages = torch.tensor(calculate_advantages(rewards)).to(device)
    total_objective = 0

    list_of_log_prob_tensors = get_log_probs(model,prompt,completions,strategy = 'list')
    list_of_log_prob_sums = [t.sum() for t in list_of_log_prob_tensors]
    log_probs_summed = torch.stack(list_of_log_prob_sums).to(device)
    total_objective += (log_probs_summed*advantages).mean()

    ref_list_of_log_prob_tensors = get_log_probs(ref_model,prompt,completions,strategy='list')

    ref_big_tensor = torch.concat(ref_list_of_log_prob_tensors)
    big_tensor = torch.concat(list_of_log_prob_tensors)
    def kl_estimator(log_pi_theta,log_pi_ref):
        r = log_pi_ref - log_pi_theta
        estimator = torch.exp(r) - r - 1
        return estimator
    kl_div = kl_estimator(big_tensor,ref_big_tensor).mean()

    N = len(big_tensor)

    return -(total_objective/N - beta*kl_div),kl_div.item()


In [18]:
import time

time.time()



1763050466.4828677

In [19]:
def pass_at_N(completions,answer):
    for completion in completions:
        if extract_answer(completion) == answer:
            return True
    return False

In [None]:
for experiment in range(3):
    print(f'experiment {experiment}')
    #seed
    torch.manual_seed(42+experiment)

    #models
    model = AutoModelForCausalLM.from_pretrained(CHECKPOINT,
                                                 dtype = torch.bfloat16).to(device)

    model.train()


    loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
    test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)




    optimizer = torch.optim.AdamW(model.parameters(),lr=LR)

    EPOCHS = 3
    epoch_accuracies = []

    for epoch in range(EPOCHS):

        batch_kl_divs = []
        batch_accuracies = []
        batch_average_lengths = []
        batch_average_lengths_with_prompts = []
        batch_time_taken = []



        for batch_num,batch in enumerate(loader):

            start_time = time.time()
            try:
                if batch_num <len(loader)*0.7 or epoch <= 1:
                    R = 'penalty'
                elif batch_num < len(loader)*0.7 + 5:
                    R = 'reward'
                else:
                    R = None

                prompts = batch['prompt']
                answers = batch['answer']

                with torch.no_grad():
                    completion_dict = generate_answers(prompts,use_completion_dict = True)

                optimizer.zero_grad()

                max_lengths = []
                max_lengths_with_promps = []
                accuracies = []
                kl_divs = []

                for prompt,real_answer in zip(prompts,answers):

                    completions = completion_dict[prompt]
                    loss,kl_div = calculate_loss(prompt,completions,real_answer,R = R,beta = 0.001)

                    loss /= len(completions)
                    loss.backward()

                    kl_divs.append(kl_div)

                    lengths = [len(tokenizer.encode(completion)) for completion in completions]
                    lengths_with_prompts = [len(tokenizer.encode(prompt+completion)) for completion in completions]
                    max_lengths.append(max(lengths))
                    max_lengths_with_promps.append(max(lengths_with_prompts))


                    are_correct = [is_correct(extract_answer(completion),real_answer) for completion in completions]
                    accuracies.append(sum(are_correct)/len(are_correct))

                torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
                optimizer.step()

                end_time = time.time()
                batch_time_taken.append(end_time-start_time)
                batch_accuracies.append(sum(accuracies)/len(accuracies))
                batch_average_lengths.append(sum(max_lengths)/len(max_lengths))
                batch_average_lengths_with_prompts.append(sum(max_lengths_with_promps)/len(max_lengths_with_promps))
                batch_kl_divs.append(sum(kl_divs)/len(kl_divs))
                status_for_print = f'''

                Epoch:{epoch}
                Batch Num:{batch_num}
                Accuracy:{sum(accuracies)/len(accuracies)}
                Moving Average Accuracy:{sum(batch_accuracies[-10:])/len(batch_accuracies[-10:])}
                Average Length:{batch_average_lengths[-1]}
                Average Length With Prompt:{batch_average_lengths_with_prompts[-1]}
                Time Taken:{batch_time_taken[-1]}
                Time Taken per average_word: = {batch_time_taken[-1]/batch_average_lengths[-1]}
                Time Taken per average_word_with_prompt: = {batch_time_taken[-1]/batch_average_lengths_with_prompts[-1]}
                Moving Average Length:{sum(batch_average_lengths[-10:])/len(batch_average_lengths[-10:])}
                Average KL Divergence:{sum(kl_divs)/len(kl_divs)}
                Moving Average KL Divergence:{sum(batch_kl_divs[-10:])/len(batch_kl_divs[-10:])}
                '''
                print(status_for_print)





            except Exception as e:
                print(e)
                print(f'Batch Num:{batch_num}')



    print(f'Sample Completions for experiment {experiment} are:')
    for i,completion in enumerate(completions):
        print(f'{i+1}. {completion}')

    import matplotlib.pyplot as plt

    plt.subplots(3,1,figsize=(15,20))
    plt.subplot(3,1,1)
    plt.plot(batch_accuracies)
    plt.title('Accuracy')
    plt.subplot(3,1,2)
    plt.plot(batch_average_lengths)
    plt.title('Average Length')
    plt.subplot(3,1,3)
    plt.plot(batch_kl_divs)
    plt.title('KL Divergence')
    plt.show()

    df_for_save = pd.DataFrame({'batch_accuracies':batch_accuracies,'batch_average_lengths':batch_average_lengths,'batch_kl_divs':batch_kl_divs,'batch_time_taken':batch_time_taken,'batch_average_lengths_with_prompts':batch_average_lengths_with_prompts})
    df_for_save.to_csv(f'{save_path}/experiment_{experiment}.csv')

    #eval
    num_correct_answers = 0
    total_answers = 0
    for batch_num,batch in enumerate(test_loader):

        try:
            prompts = batch['prompt']
            real_answers = batch['answer']
            solutions = batch['solution']

            with torch.no_grad():
                completion_dict = generate_answers(prompts,use_completion_dict=True)

                for prompt,real_answer in zip(prompts,real_answers):
                    completions = completion_dict[prompt]
                    correctness = pass_at_N(completions,real_answer)
                    total_answers += 1
                    if correctness:
                        num_correct_answers += 1
        except Exception as e:
            print(e)
            print(f'Batch Num:{batch_num}')
    print(f'Test Accuracy for experiment {experiment} is {num_correct_answers/total_answers}')




experiment 0


                Epoch:0
                Batch Num:0
                Accuracy:0.04296875
                Moving Average Accuracy:0.04296875
                Average Length:427.234375
                Average Length With Prompt:555.875
                Time Taken:151.08661723136902
                Time Taken per average_word: = 0.35363871933612323
                Time Taken per average_word_with_prompt: = 0.2717996262313812
                Moving Average Length:427.234375
                Average KL Divergence:0.0
                Moving Average KL Divergence:0.0
                


                Epoch:0
                Batch Num:1
                Accuracy:0.060546875
                Moving Average Accuracy:0.0517578125
                Average Length:391.1875
                Average Length With Prompt:515.40625
                Time Taken:147.4160213470459
                Time Taken per average_word: = 0.3768423616476649
                Time Taken per average_word_with_prompt: 

In [None]:
completions

In [None]:
A = torch.tensor([1,2,3],dtype=torch.float32)
B = torch.tensor([4,5,6],dtype=torch.float32)
A.requires_grad = True
B.requires_grad = True
A
l = [A,B]
l_summed = [t.sum() for t in l]
l_summed

In [None]:
_