In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-1.5B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")



In [2]:
# give a model
# give a reward function
# step function
# optimise function

class GRPO_agent():

    def __init__(self, model, tokenizer, chat_template: str, amount_of_answers: int = 5):
        self.model = model
        self.reference_model = None #model
        self.chat_template = chat_template
        self.tokenizer = tokenizer
        self.amount = amount_of_answers

    def get_action(self, prompt):
        # Do I only sample from the value model or also from reference model?
        # Maybe make it optionaL?
        # More effictient way of getting value + better naming?

        messages = [
            {"role": "system", "content": self.chat_template},
            {"role": "user", "content": prompt}
        ]

        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        model_inputs = tokenizer([text] * self.amount, return_tensors="pt", padding=True).to(device)

        generated_ids = model.generate(
            model_inputs.input_ids,
            max_new_tokens=512,
            num_return_sequences=self.amount
        )

        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        answers = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        return answers

    def optimise(self):
        pass

    def step(self):
        pass

In [3]:
import datasets

dataset = datasets.load_dataset("TIGER-Lab/AceCode-87K", split='train')

In [None]:
print(dataset[0])
print(type(dataset))
# https://huggingface.co/datasets/TIGER-Lab/AceCode-87K
# example dataset https://www.oxen.ai/ox/Rust/file/main/results/GRPO_82_2025-03-02_22-49-17_Qwen2.5-Coder-1.5B-Instruct/results_code_and_tests.parquet



In [4]:
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset,
                         batch_size = 1,
                         shuffle = True
                        )



In [None]:
SYSTEM_PROMPT = """You are a pragmatic Rust programmer who enjoys test driven development. Given the following question, write a Rust function to complete the task. Make the code simple and easy to understand. The code should pass `cargo build` and `cargo clippy`. Try to limit library usage to the standard library std. Be careful with your types, and try to limit yourself to the basic built in types and standard library functions. When writing the function you can think through how to solve the problem and perform reasoning in the comments above the function.

    Then write unit tests for the function you defined. Write multiple unit tests for the function. The tests should be a simple line delimited list of assert! or assert_eq! statements. When writing the unit tests you can have comments specifying what you are testing in plain english. The tests should use super::*.


    An example output should look like the following:

    ```rust
    /// Reasoning goes here
    /// and can be multi-line
    fn add_nums(x: i32, y: i32) -> i32 {
      x + y
    }

    #[cfg(test)]
    mod tests {
        use super::*;

        #[test]
        fn test_add_nums() {
            // Test adding positive numbers
            assert_eq!(add_nums(4, 2), 6);
            // Test adding a positive and negative number
            assert_eq!(add_nums(4, -2), 2);
            // Test adding two negative numbers
            assert_eq!(add_nums(-12, -1), -13);
        }
    }
    ```

    Make sure to only respond with a single  ```rust``` block. The unit tests must be defined inside the mod tests {} module. Make sure to import any standard library modules that you need. Do not add a main function.
    """

x = GRPO_agent(model, tokenizer, SYSTEM_PROMPT, 2)

for k, prompt_batch in enumerate(data_loader):
    if k == 4:
        break
    q = prompt_batch["question"][0]
    q.replace("Python", "Rust")


    action = x.get_action(q)

    # In common rl settings we send action to environment and use the output as input for our next run
    # In this case we run output through the model and get the next action
    # We use that to calculate the reward and update the model

    # The "environment" tries to run the code and tests and gives a reward based on the output


    reward, compiler_output = env(action)

    print("---------------------------------------------------------------")
    print(q)
    print("---------------------------------------------------------------")
    for i in action:
        print(i)





In [None]:
import re
from typing import Optional

rustcode = '''```rust
fn sort_list(mut list: Vec<i32>) -> Vec<i32> {
    list.sort();
    list
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_sort_list() {
        let unsorted = vec![5, 3, 8, 1, 2];
        let sorted = sort_list(unsorted.clone());
        assert_eq!(sorted, vec![1, 2, 3, 5, 8]);
        assert_eq!(sorted, vec![1, 2, 3, 5, 8]);
        assert_eq!(sorted, vec![1, 2, a3, 5, 8]);
        assert_eq!(sorted, vec![1, 2, a3, 5, 8]);
        assert_eq!(sorted, vec![1, 2123, a3, 5, 8]);
    }
}
```'''

def extract_rust_code(text: str) -> Optional[str]:
    pattern = r'```rust\n(.*?)\n```'
    match = re.search(pattern, text, re.DOTALL)
    print(match)
    if match:
        return match.group(1)
    return None

# How useful is this? To combat reward hacking?
# Maybe just check if non empty
def check_code_not_empty(code: str) -> bool:
    if len(code) > 10:
        return True
    return False

def check_code_block(code: str) -> bool:
    if extract_rust_code(code):
        return True
    return False

def check_test_block(code: str) -> bool:
    pattern = r'(#\[cfg\(test\)\]\s*mod\s+tests\s*\{.*?\})'
    match = re.search(pattern, code, re.DOTALL)
    if match:
        return True
    return False

def response_contains_asserts(code: str) -> float:
    pattern = r'#\[cfg\(test\)\]\s*mod\s+tests\s*\{([^}]*)\}'
    match = re.search(pattern, code, re.DOTALL)

    if not match:
        return 0.0
    
    test_block = match.group(0)

    # Find all assert statements
    assert_pattern = r'assert(?:_eq)?\!(.*?);'
    all_asserts = re.findall(assert_pattern, test_block)
    total_asserts = len(all_asserts)
    
    if total_asserts == 0:
        return 0.0
        
    # Store unique assert statements
    unique_asserts = set(assert_stmt.strip() for assert_stmt in all_asserts)
    
    return len(unique_asserts) / total_asserts

# code running rewards and output rewards
# This should all be in the environment
def get_rewards(code: str):
    total_reward = {"not empty": 0, "code block": 0, "test block": 0, "asserts": 0}
    if check_code_not_empty(code):
        total_reward["not empty"] = 1
    if check_code_block(code):
        total_reward["code block"] = 1
    if check_test_block(code):
        total_reward["test block"] = 1
    total_reward["asserts"] = response_contains_asserts(code)
    return total_reward
    



<re.Match object; span=(0, 522), match='```rust\nfn sort_list(mut list: Vec<i32>) -> Vec<>
True True True 0.6
