In [6]:
import numpy as np
import random
from accelerate import Accelerator
import torch
import importlib
from transformers import AutoTokenizer, AutoModelForCausalLM
from rich import print
import tqdm

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
model = AutoModelForCausalLM.from_pretrained('/Users/peeyushsharma/personal/reasoning_slm/lora_sft', device_map=device)

In [11]:
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-360M-Instruct')

In [12]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import reasoning_gym
from functools import partial


class dataset(Dataset):
    
    def __init__(self, data) -> None:
        super().__init__()
        self.data = data
        
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)


def load_model(model_id, device):
    model = AutoModelForCausalLM.from_pretrained(model_id,
                                                 dtype=torch.bfloat16,
                                                 device_map=device)
    return model

def load_tokenizer(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer
    
    


def collate_fn(batch, tokenizer):
    
    
    full_responses = []
    for sample in batch:
        question = sample['question']
        answer = f"<think>....</think> <answer>{sample['answer']}</answer>."

        prompt = f"""
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <think> </think> and
<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>. User: {question}. Assistant:"""

        full_response = prompt + " " + answer + tokenizer.eos_token
        
        full_responses.append(full_response)
    
    input_tokenized = tokenizer(full_responses, padding=True, add_special_tokens=False, return_tensors='pt')['input_ids']
    
    labels_tokenized = tokenizer([" " + f"<think>....</think> <answer>{sample['answer']}</answer>." + tokenizer.eos_token for sample in batch],
                                 add_special_tokens=False,
                                 return_tensors='pt',
                                 padding='max_length',
                                 max_length=input_tokenized.shape[1]
                                 )['input_ids']
    
    labels_tokenized = torch.where(labels_tokenized != tokenizer.pad_token_id, labels_tokenized, -100)
    labels_tokenized[:, -1] = tokenizer.pad_token_id
    
    input_ids_tokenized_left_shifted = input_tokenized[:, :-1]
    labels_tokenized_right_shifted = labels_tokenized[:, 1:]
    
    attention_mask = input_ids_tokenized_left_shifted != tokenizer.pad_token_id
    
    return {
        "input" : input_ids_tokenized_left_shifted,
        "target" : labels_tokenized_right_shifted,
        'attention_mask': attention_mask
    }
    
def get_dataloader(data_set, tokenizer):
    gym_data = reasoning_gym.create_dataset(data_set, size=50, seed=42)
    data = dataset(gym_data)
    collate = partial(collate_fn, tokenizer=tokenizer)
    return DataLoader(
        data,
        batch_size=8,
        collate_fn=collate,
        shuffle=True
    )

In [14]:
prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves
it. The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <think>...</think>and <answer>...</answer> tags, respectively, i.e., <think> reasoning process here </think>
<answer> answer here </answer>. User: How are you model?. Assistan"""

tokens = tokenizer(prompt, add_special_tokens=False, return_tensors='pt')['input_ids']
out = model(tokens, max_new_tokens=20)

In [17]:
selected_logits = out['logits'].argmax(dim=-1)