<a href="https://colab.research.google.com/github/HWAN722/self-improvement/blob/main/Agentic_RAG(Research_R1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Simplified version of Agentic RAG using **Search-R1**

> Policy gradient simulate GRPO


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import Adam

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B")

class SearchEngine:
    def search(self, query):
        # search and return outcome
        pass

def generate_trajectory(model, tokenizer, question, search_engine):
    trajectory = []
    actions = []
    log_probs = []
    state = question
    done = False

    while not done:
        # inference
        inputs = tokenizer(state, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=100, output_scores=True, return_dict_in_generate=True)

        # get output and prob
        step_output = tokenizer.decode(outputs.sequences[0])
        action_prob = torch.softmax(outputs.scores[-1], dim=-1).max().item()  # simplified

        trajectory.append(step_output)

        # check whether search again
        if "<search>" in step_output:
            # log action and prob
            actions.append("search")
            log_probs.append(action_prob)

            # extract query
            query = extract_search_query(step_output)

            # search
            search_results = search_engine.search(query)

            # update
            state = state + step_output + search_results
        else:
            # log action and prob
            actions.append("answer")
            log_probs.append(action_prob)

            # generate answer
            done = True

    return trajectory, actions, torch.tensor(log_probs)

# define RL
def train_rl(model, dataset, search_engine, epochs=3, lr=1e-5):
    optimizer = Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for question, answer in dataset:
            # generate trajectory including search actions
            with torch.no_grad():
                trajectory, actions, log_probs = generate_trajectory(
                    model, tokenizer, question, search_engine
                )

            # cal rewards
            final_answer = trajectory[-1]
            reward = compute_reward(final_answer, answer)

            # cal gradient loss
            policy_loss = -torch.mean(log_probs * reward)

            # update model
            optimizer.zero_grad()
            policy_loss.backward()
            optimizer.step()

    return model

def extract_search_query(text):
    # suppose format: <search>query</search>
    start_tag = "<search>"
    end_tag = "</search>"
    start_idx = text.find(start_tag) + len(start_tag)
    end_idx = text.find(end_tag)

    if start_idx >= len(start_tag) and end_idx > start_idx:
        return text[start_idx:end_idx].strip()
    return ""

def compute_reward(prediction, ground_truth):

    if prediction == ground_truth:
        return 1.0
    return 0.0