In [35]:
import sys
sys.path.append("..")
from utils import GRPODataset, generate_dataset, load_dataset, is_topological_ordering, generate_topological_sort
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from torch.utils.data import Dataset, DataLoader
import os
import ast
import einops
from typing import Callable
import torch
import torch.nn.functional as F
torch.manual_seed(16)

<torch._C.Generator at 0x1553db8a4a70>

In [2]:
def tokenize_collate_fn(tokenizer: AutoTokenizer, batch_texts: list[str])->None:
    enc = tokenizer(
        batch_texts,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=512
    )
    labels = enc["input_ids"].clone() # cloning labels because labels gives the next tokens the model needs to predict. they're automatically right shifted by 1. they must be cloned because if we pass them by reference and then do any in-place modificaitons of them they will also change the toknens

    if tokenizer.pad_token_id is not None:
        labels[labels == tokenizer.pad_token_id] = -100 # -100 is the default ignore index for attention in hugging face. we're saying 'anywhere that is just padding, don't attend to it.'

    enc["labels"] = labels
    return enc
def generate_datasets_main(dataset_params: dict, n: int, k: int)->None:


    for ds in dataset_params.keys():
        print(f"generating dataset {ds}")
        _ = generate_dataset(num_graphs = dataset_params[ds]['size'],
                             writefile = dataset_params[ds]['writefile'],
                             n=n,
                             k=k)
        print(f"wrote dataset to {dataset_params[ds]['writefile']}")



In [3]:
# load / create datasets
n=10
k=10
batch_size=32
model_path = "./models/qwen2.5-0.5b-instruct"  # change if needed
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device is {device}")

dataset_params = {
    'train': {
        "size": 10000,
        "writefile":"./data/train.txt"
    },
    'test': {
        "size": 100,
        "writefile":"./data/test.txt"
    },
    'prompt_examples': {
        "size": 3,
        "writefile":"./data/prompt_samples.txt"
    }
}
def handle_load(dataset_params_file: str, key: str, )->None:
    if not os.path.exists(dataset_params_file[key]['writefile']):
        # generate dataset
        generate_datasets_main(dataset_params = dataset_params, n=n, k=k)
    ds = load_dataset(dataset_params[key]['writefile'])
    return ds
        
# load raw datasets

train_ds_raw = handle_load(dataset_params, key="train")
test_ds_raw = handle_load(dataset_params, key="test")
prompt_examples_raw = handle_load(dataset_params, key="prompt_examples")

# load in model and tokenizer 
tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_fast=True, # use fast tells hugging face to load the rust implemented backend instead of the python one
        trust_remote_code=True
    )
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)
print(f"model is {model}")

# make train / test dataset and dataloader objects
train_dataset = GRPODataset(load_dataset(dataset_params['train']['writefile']))
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    collate_fn=lambda batch: tokenize_collate_fn(tokenizer,batch)
)
test_dataset = GRPODataset(load_dataset(dataset_params['test']['writefile']))
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    drop_last=True,
    pin_memory=True,
    shuffle=False,
    collate_fn=lambda batch: tokenize_collate_fn(tokenizer,batch)
)
print(len(train_dataset))
print(len(test_dataset))

# form prompt_examples into a dict with key=graph, value=valid_topological_ordering
prompt_examples = "Examples:\n"
for i in range(len(prompt_examples_raw)):
    # prepare to extend prompt
    graph_str = prompt_examples_raw[i]
    prompt_examples += (f"Graph: {graph_str}\n")
    # get liter pythonic object list[tuple] from str representing it
    graph = ast.literal_eval(graph_str) # right now this is a str object representing a graph list[tuple]. this inflates it back into a pythonic object
    ordering = generate_topological_sort(graph, n=n)
    
    # now add topological sort to example
    prompt_examples += (f"Topological Sort: {str(ordering)}\n")
print(prompt_examples)


device is cuda


`torch_dtype` is deprecated! Use `dtype` instead!


model is Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_em

In [4]:
print(prompt_examples_raw)
print(len(prompt_examples_raw))

['[(1, 9), (4, 9), (8, 2), (3, 0), (7, 6), (1, 3), (2, 9), (7, 0), (8, 6), (3, 6)]', '[(8, 0), (6, 4), (5, 0), (6, 9), (7, 5), (7, 4), (3, 2), (6, 8), (7, 1), (9, 7)]', '[(7, 0), (5, 4), (7, 1), (9, 4), (4, 8), (4, 7), (6, 8), (3, 0), (6, 0), (4, 1)]']
3


In [20]:
# define GRPO train step funcs
def grpo_train_step(model: AutoModelForCausalLM,
                    batch: torch.Tensor,
                    reward_func: Callable
                   )->tuple[torch.Tensor, float]:
    # for sample in batch
        # form group
        # pass group through model
        # calculate group score using reward 
    pass
class StopOnToken(StoppingCriteria):
    def __init__(self, stop_str: str, tokenizer):
        self.stop_str = stop_str
        self.tokenizer = tokenizer
    
    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool:
        last_token = self.tokenizer.decode(input_ids[0, -1])
        return self.stop_str in last_token

def eval_func(model: AutoModelForCausalLM,
              loader: DataLoader,
              prefix: torch.Tensor,
              suffix: torch.Tensor,
              eval_metric: Callable,
              stopping_criteria: StoppingCriteriaList,
              n: int,
              tokenizer: AutoTokenizer,
              verbose: bool = False) -> tuple[float, float]:

    total = 0
    correct_metric = 0
    correct_format = 0

    for idx, batch in enumerate(loader):
        batch = batch.to(model.device)
        sample = batch['input_ids']
        B = sample.shape[0]
        total += B

        prefix_batch = einops.repeat(prefix, '1 T -> B T', B=B)
        suffix_batch = einops.repeat(suffix, '1 T -> B T', B=B)

        prompt = torch.cat([prefix_batch, sample, suffix_batch], dim=-1)
        with torch.no_grad():
            output = model.generate(
                prompt,
                max_new_tokens=100,
                do_sample=False,
                stopping_criteria=stopping_criteria,
                pad_token_id=tokenizer.eos_token_id
            )

        graph_strs = tokenizer.batch_decode(batch['input_ids'], skip_special_tokens=True)
        print(f"graph strs: {graph_strs}")
        response_strs = tokenizer.batch_decode(output[:, prompt.shape[-1]-1:], skip_special_tokens=True)

        for i in range(B):
            graph = ast.literal_eval(graph_strs[i])
            try:
                top_ord = ast.literal_eval(response_strs[i].strip())
                metric = is_topological_ordering(ordering=top_ord, dag=graph, n=n)
                formatting_reward = True
            except:
                if verbose:
                    print(f"qwen outputted illegal response {response_strs[i]}")
                formatting_reward = False
                metric = False

            if metric:
                correct_metric += 1
            if formatting_reward:
                correct_format += 1

            if verbose:
                print(f"[{idx*B+i}] Graph: {graph_strs[i]}")
                print(f"[{idx*B+i}] Response: {response_strs[i]}")
                print(f"[{idx*B+i}] Metric: {metric}, formatted correctly? {formatting_reward}\n")
                print("=" * 80)

    pct_metric = correct_metric / total if total > 0 else 0.0
    pct_format = correct_format / total if total > 0 else 0.0

    if verbose:
        print(f"\n% correct topological ordering: {pct_metric:.2%}")
        print(f"% correctly formatted: {pct_format:.2%}")

    return pct_metric, pct_format
           
prefix_text = f"""
    You are doing topological sort of a Directed Acyclic Graph (DAG). You are given an edge list,
    a list of tuples where each tuple represents (from_node, to_node). Return a valid topological ordering of the nodes.
    
    {prompt_examples}\n
    Output ONLY a TOPOLOGICAL SORT in the format of those above, with NO CODE OR EXPLANATIONS. JUST the ordering
    Graph:
    """
suffix_text= "\nTopological Sort: ["
prefix = tokenizer(prefix_text, return_tensors="pt", add_special_tokens=False)['input_ids']
suffix = tokenizer(suffix_text, return_tensors="pt", add_special_tokens=False)['input_ids']
print(prefix.shape)
print(suffix.shape)

model = model.to(device)
prefix = prefix.to(device)
suffix = suffix.to(device)

# ok

stopping_criteria = StoppingCriteriaList([StopOnToken(']', tokenizer)])

eval_func(model=model, 
          loader=test_loader, 
          prefix=prefix,
          suffix=suffix,
          tokenizer=tokenizer,
          stopping_criteria = stopping_criteria,
          eval_metric=is_topological_ordering,
          n=n,
          verbose=False)


torch.Size([1, 375])
torch.Size([1, 6])
graph strs: ['[(1, 9), (4, 9), (8, 2), (3, 0), (7, 6), (1, 3), (2, 9), (7, 0), (8, 6), (3, 6)]', '[(8, 0), (6, 4), (5, 0), (6, 9), (7, 5), (7, 4), (3, 2), (6, 8), (7, 1), (9, 7)]', '[(7, 0), (5, 4), (7, 1), (9, 4), (4, 8), (4, 7), (6, 8), (3, 0), (6, 0), (4, 1)]', '[(9, 4), (5, 3), (4, 7), (5, 2), (5, 4), (8, 3), (9, 3), (2, 8), (7, 2), (0, 4)]', '[(8, 0), (7, 9), (9, 5), (0, 6), (8, 6), (2, 1), (6, 4), (6, 3), (0, 3), (0, 7)]', '[(6, 9), (2, 8), (3, 0), (2, 4), (1, 9), (8, 7), (3, 4), (0, 8), (8, 9), (2, 5)]', '[(3, 9), (3, 6), (5, 3), (3, 4), (0, 6), (8, 2), (4, 6), (0, 9), (2, 1), (5, 7)]', '[(7, 6), (3, 9), (8, 4), (8, 9), (4, 6), (4, 9), (8, 0), (3, 2), (9, 5), (0, 9)]', '[(0, 6), (2, 1), (4, 7), (1, 3), (8, 9), (4, 3), (0, 1), (1, 6), (4, 0), (7, 3)]', '[(6, 7), (9, 1), (3, 4), (6, 1), (1, 4), (9, 2), (2, 0), (8, 2), (2, 5), (2, 4)]', '[(7, 1), (3, 9), (2, 1), (5, 1), (8, 0), (8, 3), (7, 3), (5, 4), (2, 0), (7, 9)]', '[(3, 8), (3, 9), (6, 5

(0.052083333333333336, 0.4270833333333333)

In [6]:
# DIAGNOSTIC - run this alone
test_str = "8, 2, 7, 5, 4, 1, 3, 6, 0, 9]"
tokens = tokenizer.encode(test_str, add_special_tokens=False)
print("tokens:", tokens)
print("decoded individually:", [tokenizer.decode([t]) for t in tokens])
print("last token id:", tokens[-1])
print("last token decoded:", tokenizer.decode([tokens[-1]]))

tokens: [23, 11, 220, 17, 11, 220, 22, 11, 220, 20, 11, 220, 19, 11, 220, 16, 11, 220, 18, 11, 220, 21, 11, 220, 15, 11, 220, 24, 60]
decoded individually: ['8', ',', ' ', '2', ',', ' ', '7', ',', ' ', '5', ',', ' ', '4', ',', ' ', '1', ',', ' ', '3', ',', ' ', '6', ',', ' ', '0', ',', ' ', '9', ']']
last token id: 60
last token decoded: ]


In [40]:
# okay, your eval harness wors. now you have time to go quickly eat and ten you can do GRPO

def teacher_force_logprobs(model, traj, use_grad, prompt_length):
    ctx = torch.enable_grad() if use_grad else torch.no_grad()
    with ctx:
        logits = model(traj)['logits'][:, prompt_length-1:-1, :]
        logprobs = F.log_softmax(logits, dim=-1)
        logprobs = logprobs.gather(dim=-1, index=traj[:, prompt_length:].unsqueeze(-1))
    return logprobs.squeeze(-1)


def grpo_train_step(model: AutoModelForCausalLM, 
                    batch: torch.Tensor, 
                    prefix: torch.Tensor,
                    stopping_criteria: StoppingCriteriaList,
                    suffix: torch.Tensor,
                    tokenizer: AutoTokenizer,
                    group_size: int = 32,
                    verbose: bool = False
                    ):
    
    for sample_id in range(batch.shape[0]):

        # form prompt
        sample = einops.rearrange(batch[sample_id, :], 'T -> 1 T')
        prompt = torch.cat([prefix, sample, suffix], dim = -1)

        # assemble group from prompt
        group = einops.repeat(prompt, '1 T -> G T', G=group_size)

        # pass through model
        output = model.generate(
            group,
            max_new_tokens=100,
            do_sample=True,
            stopping_criteria=stopping_criteria,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
        
        # decode the otuputs and determine correctness
        graph_str = tokenizer.decode(batch[sample_id,:], skip_special_tokens=True)
        graph = ast.literal_eval(graph_str.strip())
        response_strs = tokenizer.batch_decode(output[:, prompt.shape[-1]-1:], skip_special_tokens=True)
        correct_sorts = 0
        formatted_correctly = 0
        rewards = torch.zeros(size=(group_size,))

        for i in range(group_size):
            graph = ast.literal_eval(graph_str.strip())
            try:
                top_ord = ast.literal_eval(response_strs[i].strip())
                metric = is_topological_ordering(ordering=top_ord, dag=graph, n=n)
                
                formatted_correctly += 1
                if metric == True:
                    correct_sorts += 1
                    rewards[i] = 1.0
                else: # in this case the sort was wrong but the response was well-formed. assign a small correctness reward
                    rewards[i] = 0.1
                if verbose: 
                    formatting_reward = True
                    
            except:
                if verbose:
                    print(f"qwen outputted illegal response {response_strs[i]}")
                    formatting_reward = False
                continue
                
            if verbose:
                print(f"[group {sample_id}] Graph: {graph_str}")
                print(f"[group {sample_id}] Response: {response_strs[i]}")
                print(f"[group {sample_id}] Metric: {metric}, formatted correctly? {formatting_reward}\n")
                print("=" * 80)
                
        # calculate group score
        print(f"group {sample_id} correct (%): {100*(correct_sorts / group_size)}, formatting valid (%): {100*(formatted_correctly / group_size)}")

        ##################################
        # === Compute advantages ===
        A_i = rewards - rewards.mean()  # (G,)
        A_i = A_i.to(model.device)

        # === Forward pass with gradients to get log-probs ===
        prompt_len = prompt.shape[-1]
        
        # Gradient accumulation to avoid OOM
        mini_batch_size = 1
        total_loss = 0.0
        total_tokens = 0
        
        for start in range(0, group_size, mini_batch_size):
            end = min(start + mini_batch_size, group_size)
            chunk = output[start:end]
            chunk_advantages = A_i[start:end]
            
            log_probs = teacher_force_logprobs(model=model, traj=chunk, use_grad=True, prompt_length=prompt.shape[-1]-1)
            
            # Mask out padding
            response_tokens = chunk[:, prompt_len:]
            pad_mask = (response_tokens != tokenizer.pad_token_id).float()
            
            # Loss for this chunk
            chunk_loss = -(log_probs * chunk_advantages.unsqueeze(-1) * pad_mask).sum()
            chunk_tokens = pad_mask.sum().item()
            total_tokens += chunk_tokens
            
            chunk_loss.backward()
        
        total_loss = total_loss  # for logging only
        
        print(f"group {sample_id} | loss computed | advantages: {A_i.tolist()[:5]}...")


    # calculate the loss (this is going to go inside a trainer class)
        
       
    
batch = next(iter(train_loader))
# batch = batch['input_ids']
batch = batch['input_ids'][0,:].unsqueeze(0)
batch = batch.to(model.device)
grpo_train_step(
    model=model,
    batch=batch,
    prefix=prefix,
    stopping_criteria=stopping_criteria,
    suffix=suffix,
    tokenizer=tokenizer
)
    

    # form prompt 
    
    # project into group 
    # pass group through model (do_sample=True right?) 
    # calculate rewards

OutOfMemoryError: CUDA out of memory. Tried to allocate 26.00 MiB. GPU 0 has a total capacity of 39.49 GiB of which 3.56 MiB is free. Including non-PyTorch memory, this process has 39.48 GiB memory in use. Of the allocated memory 38.86 GiB is allocated by PyTorch, and 130.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)