# REINFORCE training with explainable rewards

This notebook implements REINFORCE with baseline RLHF method and its training process with explainable reward model. For REINFORCE w/ baseline training `PyTorch` library used. Explainable reward model is implemented in `src/reward.py` like in the paper [Explainable Rewards in RLHF Using LLM-as-a-Judge](https://openreview.net/forum?id=FaOeBrlPst).

### Setup

In [1]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, GenerationConfig, get_linear_schedule_with_warmup, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model
import logging
import time
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
import wandb
from huggingface_hub import login
import random
import numpy as np
import os
import sys
import gc

from src import config, utils, reward
from data import preprocess_helpsteer


for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(config.SEED)
random.seed(config.SEED)
np.random.seed(config.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = utils.get_device()

login(token=config.HF_TOKEN)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


### Data loading and preparation

In [2]:
logger.info("Loading and preparing RL dataset")
try:
    rl_dataset = preprocess_helpsteer.load_and_prepare_rl_dataset()
    train_ds = rl_dataset["train"].shuffle(seed=config.SEED).select(range(1000))
    val_ds = rl_dataset["test"].shuffle(seed=config.SEED).select(range(100))
    # train_ds = rl_dataset["train"]   # TO_DO: CHANGE TO BIGGER PARTITION FOR FULL TRAINING
    logger.info(f"Loaded {len(train_ds)} training prompts.")

    def collate_fn(batch):
        return {'query': [item['query'] for item in batch]}

    dataloader_train = DataLoader(train_ds, batch_size=config.RL_BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    dataloader_val = DataLoader(val_ds, batch_size=config.RL_BATCH_SIZE, collate_fn=collate_fn)
    logger.info(f"Created DataLoaders with batch size {config.RL_BATCH_SIZE}.")
except Exception as e:
    logger.error(f"Failed to load dataset: {e}")
    raise

INFO - Loading and preparing RL dataset
INFO - Loading dataset: nvidia/HelpSteer2
INFO - Dataset loaded: DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 20324
    })
    validation: Dataset({
        features: ['prompt', 'response', 'helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'],
        num_rows: 1038
    })
})
INFO - Renamed prompt column to 'query'.
INFO - RL Dataset prepared with 20324 training prompts and 1038 test prompts.
INFO - Loaded 1000 training prompts.
INFO - Created DataLoaders with batch size 1.


In [3]:
train_ds

Dataset({
    features: ['query'],
    num_rows: 1000
})

In [4]:
val_ds

Dataset({
    features: ['query'],
    num_rows: 100
})

In [5]:
val_ds[0]

{'query': 'What does OOPS as a development methodology stand for?\n<extra_id_1>Assistant\nOOPS is an acronym for Object-Oriented Programming with Nested Structures. OOPS is a software development methodology that emphasizes the use of objects, classes, and inheritance in software design. OOPS also supports the use of modular programming techniques and the divide-and-conquer algorithm for problem solving. OOPS has gained popularity as a software development methodology because it supports the creation of modular, reusable code that is well-suited for large software projects.\n<extra_id_1>User\nInvestigate any two the divide-and-conquer algorithms. Select one real life problem in which they are applicable. Implement the solution to the problem. Provide both theoretical and experimental assessment of the performance the algorithm.\n<extra_id_1>Assistant\nOne of the most common divide-and-conquer algorithms is the merge sort algorithm. The merge sort algorithm takes an array of items and r

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)

def tokenize_fn(examples):
    tokens = tokenizer(
        examples["query"],
        padding="max_length",
        truncation=True,
        max_length=config.MAX_SEQ_LENGTH,
    )
    return tokens

In [7]:
train_ds = train_ds.map(tokenize_fn, batched=True, remove_columns=["query"])
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_ds = val_ds.map(tokenize_fn, batched=True, remove_columns=["query"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

train_ds

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1000
})

In [8]:
batch = next(iter(torch.utils.data.DataLoader(train_ds, batch_size=2)))
print(type(batch["input_ids"]))

<class 'torch.Tensor'>


In [9]:
# check that dataloader_train works
batch = next(iter(dataloader_train))
print(batch["query"][0])

riddled


### Load models

#### Policy model

In [10]:
logger.info(f"Loading policy model for REINFORCE: {config.BASE_MODEL_NAME}")

policy_model, tokenizer = utils.load_model_and_tokenizer(
    config.BASE_MODEL_NAME,
    load_4bit=True,
    add_lora=True
)

logger.info("Policy model and tokenizer loaded.")
policy_model.print_trainable_parameters()

INFO - Loading policy model for REINFORCE: Qwen/Qwen2.5-0.5B-Instruct
INFO - Loading model: Qwen/Qwen2.5-0.5B-Instruct for mode: causal
INFO - 4-bit quantization enabled.
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO - Prepared model for 4-bit training.
INFO - Initializing new LoRA layers for training.
INFO - Model and tokenizer loading complete.
INFO - Policy model and tokenizer loaded.


trainable params: 8,798,208 || all params: 502,830,976 || trainable%: 1.7497
trainable params: 8,798,208 || all params: 502,830,976 || trainable%: 1.7497


#### Reference model

In [11]:
logger.info(f"Loading reference model: {config.BASE_MODEL_NAME}")
ref_model, _ = utils.load_model_and_tokenizer(
    config.BASE_MODEL_NAME,
    load_4bit=True,
    add_lora=False
)

ref_model.eval()
logger.info("Reference model loaded.")

INFO - Loading reference model: Qwen/Qwen2.5-0.5B-Instruct
INFO - Loading model: Qwen/Qwen2.5-0.5B-Instruct for mode: causal
INFO - 4-bit quantization enabled.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO - Prepared model for 4-bit training.
INFO - Model and tokenizer loading complete.
INFO - Reference model loaded.


#### Reward model

In [12]:
logger.info("Initializing Explainable Reward Model...")
explainable_reward = reward.ExplainableRewardModel(model_name=config.JUDGE_MODEL_NAME, device=config.DEVICE)
logger.info("Reward model initialized.")

INFO - Initializing Explainable Reward Model...
INFO - Initializing explainable RM using judge: Qwen/Qwen2.5-0.5B-Instruct on device cuda
INFO - Loading model: Qwen/Qwen2.5-0.5B-Instruct for mode: causal
INFO - 4-bit quantization enabled.
INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO - Prepared model for 4-bit training.
INFO - Model and tokenizer loading complete.
INFO - Judge model loaded and set to evaluation mode.
INFO - Explainable RM initialized.
INFO - Reward model initialized.


### Training (REINFORCE w/ baseline implementation)

In [13]:
trainable_params = list(filter(lambda p: p.requires_grad, policy_model.parameters()))
optimizer = AdamW(trainable_params, lr=config.REINFORCE_LEARNING_RATE)

num_update_steps_per_epoch = len(dataloader_train) // config.REINFORCE_GRAD_ACCUMULATION
total_training_steps = num_update_steps_per_epoch * config.REINFORCE_NUM_EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.03 * total_training_steps) if total_training_steps > 0 else 0, 
    num_training_steps=total_training_steps if total_training_steps > 0 else 1
)

In [14]:
class MovingAverageBaseline:
    """Simple exponential moving average baseline."""
    def __init__(self, initial_value=0.0, alpha=config.REINFORCE_BASELINE_ALPHA):
        self.value = initial_value
        self.alpha = alpha
        self.initialized = False
        self.steps = 0

    def update(self, rewards: torch.Tensor):
        batch_mean = rewards.mean().item()
        if not self.initialized:
            self.value = batch_mean
            self.initialized = True
        else:
            # Exponential moving average
            self.value = self.alpha * self.value + (1 - self.alpha) * batch_mean
        self.steps += rewards.numel()

    def get(self) -> float:
        return self.value

In [15]:
baseline = MovingAverageBaseline(alpha=config.REINFORCE_BASELINE_ALPHA)

In [16]:
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

logger.info(f"Optimizer, Scheduler, Baseline, GradScaler initialized.")
logger.info(f"Training for {config.REINFORCE_NUM_EPOCHS} epochs.")
logger.info(f"Total expected optimizer steps: {total_training_steps}")

INFO - Optimizer, Scheduler, Baseline, GradScaler initialized.
INFO - Training for 1 epochs.
INFO - Total expected optimizer steps: 500


In [17]:
logger.info("Starting REINFORCE training loop...")

import wandb
wandb.login(key=config.WANDB_API)

if config.LOG_WITH == "wandb":
    try:
        run = wandb.init(
            project="xai-reinforce-explainable",
            name=f"reinforce-{config.BASE_MODEL_NAME.split('/')[-1]}"
        )
        logger.info("WandB initialized.")
    except Exception as e:
        logger.error(f"Failed to initialize WandB: {e}")

INFO - Starting REINFORCE training loop...
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\milya\_netrc
[34m[1mwandb[0m: Currently logged in as: [33mmiliusha2801[0m ([33mmiliusha2801-innopolis-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO - WandB initialized.


In [18]:
generation_config = GenerationConfig(
    max_new_tokens=config.RL_MAX_NEW_TOKENS,
    min_length=-1,
    top_k=config.RL_TOP_K,
    top_p=config.RL_TOP_P,
    temperature=config.RL_TEMPERATURE,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
)

In [19]:
def generate_responses(prompts, policy_model, tokenizer, generation_config, device):
    """Generates responses for a batch of prompts."""
    prompt_tokens = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=config.MAX_SEQ_LENGTH - config.RL_MAX_NEW_TOKENS
    ).to(device)

    response_ids_list = []
    with torch.no_grad():
        policy_model.eval()
        try:
            outputs = policy_model.generate(
                input_ids=prompt_tokens.input_ids,
                attention_mask=prompt_tokens.attention_mask,
                generation_config=generation_config
            )

            if tokenizer.padding_side == "left":
                response_ids_list = [ out[len(in_id):] for out, in_id in zip(outputs, prompt_tokens.input_ids) ]
            else:
                response_ids = outputs[:, prompt_tokens.input_ids.shape[1]:]
                response_ids_list = [r for r in response_ids]

        except Exception as e:
            logger.error(f"Error during generation: {e}", exc_info=True)
            response_ids_list = [torch.tensor([], dtype=torch.long, device=device) for _ in prompts]
        finally:
            policy_model.train()

    responses = tokenizer.batch_decode(response_ids_list, skip_special_tokens=True)
    return prompt_tokens, response_ids_list, responses

In [20]:
def calculate_log_probs(prompt_tokens, response_ids, policy_model, tokenizer, device):
    """Calculates log probabilities of generated sequences."""
    log_probs_list = []
    batch_size = prompt_tokens.input_ids.shape[0]

    for i in range(batch_size):
        # Combine prompt and response
        single_prompt_ids = prompt_tokens.input_ids[i : i + 1]
        single_response_ids = response_ids[i].unsqueeze(0)

        if single_response_ids.numel() == 0:
            log_probs_list.append(torch.tensor(0.0).to(device))
            continue
        full_ids = torch.cat([single_prompt_ids, single_response_ids], dim=1)

        # Create attention mask for the full sequence
        single_prompt_mask = prompt_tokens.attention_mask[i : i + 1]
        response_mask = torch.ones_like(single_response_ids)
        full_attention_mask = torch.cat([single_prompt_mask, response_mask], dim=1)

        # Truncate
        if full_ids.shape[1] > config.MAX_SEQ_LENGTH:
            full_ids = full_ids[:, :config.MAX_SEQ_LENGTH]
            full_attention_mask = full_attention_mask[:, :config.MAX_SEQ_LENGTH]

        try:
            with torch.autocast(device_type=config.DEVICE):
                model_outputs = policy_model(input_ids=full_ids, attention_mask=full_attention_mask)
                logits = model_outputs.logits

            # Shift logits/labels for calculating log prob of response
            start_index = single_prompt_ids.shape[1] - 1
            end_index = full_ids.shape[1] - 1

            if end_index <= start_index:
                 log_probs_list.append(torch.tensor(0.0).to(device))
                 continue

            shift_logits = logits[:, start_index:end_index, :].contiguous()
            shift_labels = full_ids[:, start_index+1:end_index+1].contiguous()

            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            log_probs_per_token = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).long())

            # Mask padding tokens within the response part
            label_mask = (shift_labels != tokenizer.pad_token_id).float()
            log_probs_per_token = log_probs_per_token.view(shift_labels.shape[0], -1) * label_mask
            sequence_log_probs = log_probs_per_token.sum(dim=1)
            log_probs_list.append(sequence_log_probs.squeeze())

        except Exception as e:
            logger.error(f"Error calculating log probs for sample {i}: {e}", exc_info=True)
            log_probs_list.append(torch.tensor(0.0).to(device))

    if not log_probs_list:
        return torch.tensor([], device=device)
    return torch.stack(log_probs_list)

In [21]:
def calculate_kl_penalty(prompt_tokens, response_ids, policy_model, ref_model, tokenizer, device):
    """Calculates KL penalty between policy and reference model."""
    kl_penalties = []
    batch_size = prompt_tokens.input_ids.shape[0]

    for i in range(batch_size):
        # Combine prompt and response
        single_prompt_ids = prompt_tokens.input_ids[i : i + 1]
        single_response_ids = response_ids[i].unsqueeze(0)

        if single_response_ids.numel() == 0:
            kl_penalties.append(torch.tensor(0.0).to(device))
            continue

        full_ids = torch.cat([single_prompt_ids, single_response_ids], dim=1)
        single_prompt_mask = prompt_tokens.attention_mask[i : i + 1]
        response_mask = torch.ones_like(single_response_ids)
        full_attention_mask = torch.cat([single_prompt_mask, response_mask], dim=1)

        if full_ids.shape[1] > config.MAX_SEQ_LENGTH:
            full_ids = full_ids[:, :config.MAX_SEQ_LENGTH]
            full_attention_mask = full_attention_mask[:, :config.MAX_SEQ_LENGTH]

        start_index = single_prompt_ids.shape[1] - 1
        end_index = full_ids.shape[1] - 1

        if end_index <= start_index:
            kl_penalties.append(torch.tensor(0.0).to(device))
            continue

        try:
            with torch.no_grad(), torch.autocast(device_type=config.DEVICE):
                # Get policy logits
                policy_outputs = policy_model(input_ids=full_ids, attention_mask=full_attention_mask)
                policy_logits_shifted = policy_outputs.logits[:, start_index:end_index, :].contiguous()

                # Get reference logits
                ref_outputs = ref_model(input_ids=full_ids, attention_mask=full_attention_mask)
                ref_logits_shifted = ref_outputs.logits[:, start_index:end_index, :].contiguous()

                # Calculate KL
                prob_policy = F.softmax(policy_logits_shifted, dim=-1)
                log_prob_policy = F.log_softmax(policy_logits_shifted, dim=-1)
                prob_ref = F.softmax(ref_logits_shifted, dim=-1)
                log_prob_ref = F.log_softmax(ref_logits_shifted, dim=-1)

                # Mask needs to match shifted logits shape
                shift_labels = full_ids[:, start_index+1:end_index+1].contiguous()
                label_mask = (shift_labels != tokenizer.pad_token_id).float()

                if label_mask.shape != kl_div_per_token.shape:
                     label_mask = label_mask.view_as(kl_div_per_token)

                # KL divergence: sum_vocab(P_policy * (log P_policy - log P_ref))
                kl_div_per_token = torch.sum(prob_policy * (log_prob_policy - log_prob_ref), dim=-1)
                kl_div_per_token = kl_div_per_token * label_mask
                kl_penalty = kl_div_per_token.sum(dim=1)
                kl_penalties.append(kl_penalty.squeeze())

        except Exception as e:
            logger.error(f"Error calculating KL for sample {i}: {e}", exc_info=True)
            kl_penalties.append(torch.tensor(0.0).to(device))

    if not kl_penalties:
        return torch.tensor([], device=device)
    return torch.stack(kl_penalties)

In [22]:
def compute_advantages(rewards, kl_penalty, baseline_value):
    """Applies KL penalty and calculates advantages."""
    final_rewards = rewards - config.KL_PENALTY_BETA * kl_penalty
    advantages = final_rewards - baseline_value
    return final_rewards, advantages

In [23]:
def reinforce_loss_step(advantages, sequence_log_probs):
    """Calculates the REINFORCE loss for a batch."""
    loss = - (advantages.detach() * sequence_log_probs).mean()
    return loss

In [24]:
logger.info("Starting REINFORCE training loop...")

global_step = 0
policy_model.train()

for epoch in range(config.REINFORCE_NUM_EPOCHS):
    logger.info(f"Epoch {epoch+1}/{config.REINFORCE_NUM_EPOCHS}")
    policy_model.train()

    progress_bar = tqdm(dataloader_train, desc=f"Epoch {epoch+1} Batches")
    step_accumulation_counter = 0
    epoch_metrics = {
        "loss": [], "reward_raw_mean": [], "kl_penalty_mean": [],
        "final_reward_mean": [], "advantage_mean": [], "log_prob_mean": []
    }

    for i, batch in enumerate(progress_bar):
        start_time_batch = time.time()
        prompts = batch["query"]

        # Generate responses
        prompt_tokens, response_ids_list, responses = generate_responses(
            prompts, policy_model, tokenizer, generation_config, device
        )

        # Calculate logprobs
        sequence_log_probs = calculate_log_probs(
            prompt_tokens, response_ids_list, policy_model, tokenizer, device
        )

        # Handle cases where log_prob calculation failed (if returned empty tensor)
        if sequence_log_probs.numel() == 0:
            logger.warning(f"Skipping batch {i} due to log_prob calculation errors.")
            continue

        # Calculate KL penalty
        kl_penalty = calculate_kl_penalty(
            prompt_tokens, response_ids_list, policy_model, ref_model, tokenizer, device
        )
        if kl_penalty.numel() == 0:
             kl_penalty = torch.zeros_like(sequence_log_probs)

        # Compute rewards
        try:
            rewards = explainable_reward(prompts, responses)
            rewards = rewards.to(device).squeeze()
            if rewards.shape != sequence_log_probs.shape:
                 logger.warning(f"Reward shape mismatch ({rewards.shape}) vs log_probs ({sequence_log_probs.shape}). Adjusting.")
                 aligned_rewards = torch.zeros_like(sequence_log_probs)
                 min_len = min(rewards.numel(), sequence_log_probs.numel())
                 aligned_rewards[:min_len] = rewards[:min_len]
                 rewards = aligned_rewards

        except Exception as e:
            logger.error(f"Error getting reward for batch {i}: {e}", exc_info=True)
            rewards = torch.zeros_like(sequence_log_probs)

        # Compute advantages
        current_baseline_value = baseline.get()
        final_rewards, advantages = compute_advantages(rewards, kl_penalty, current_baseline_value)

        # Calculate loss and backpropagate
        loss = reinforce_loss_step(advantages, sequence_log_probs)
        loss_scaled = loss / config.REINFORCE_GRAD_ACCUMULATION
        scaler.scale(loss_scaled).backward()

        step_accumulation_counter += 1

        # Logging
        epoch_metrics["loss"].append(loss.item())
        epoch_metrics["reward_raw_mean"].append(rewards.mean().item())
        epoch_metrics["kl_penalty_mean"].append(kl_penalty.mean().item())
        epoch_metrics["final_reward_mean"].append(final_rewards.mean().item())
        epoch_metrics["advantage_mean"].append(advantages.mean().item())
        epoch_metrics["log_prob_mean"].append(sequence_log_probs.mean().item())

        # Optimizer Step
        if step_accumulation_counter % config.REINFORCE_GRAD_ACCUMULATION == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

            step_time = time.time() - start_time_batch
            avg_metrics = {k: np.mean(v) for k, v in epoch_metrics.items()}
            baseline.update(torch.tensor(epoch_metrics["final_reward_mean"], device=device))

            logs = {
                f"train/{k}": v for k, v in avg_metrics.items()
            }
            logs.update({
                "train/baseline": baseline.get(),
                "train/learning_rate": scheduler.get_last_lr()[0],
                "train/step_time_s": step_time,
                "global_step": global_step + 1,
                "epoch": epoch + (i+1)/len(dataloader_train)
            })

            progress_bar.set_postfix({
                "Reward": f"{avg_metrics['final_reward_mean']:.3f}",
                "Loss": f"{avg_metrics['loss']:.3f}",
                "KL": f"{avg_metrics['kl_penalty_mean']:.3f}",
                "Adv": f"{avg_metrics['advantage_mean']:.3f}"
            })

            if run:
                try:
                    wandb.log(logs)
                except Exception as e:
                    logger.error(f"Failed to log metrics to WandB at step {global_step}: {e}")

            epoch_metrics = {k: [] for k in epoch_metrics}

            global_step += 1

            # Save checkpoint
            if global_step % config.SAVE_FREQ == 0:
                logger.info(f"Saving checkpoint at step {global_step}...")
                local_save_path = f"{config.OUTPUT_DIR}/reinforce/checkpoint-{global_step}"
                try:
                    os.makedirs(os.path.dirname(local_save_path), exist_ok=True)
                    policy_model.save_pretrained(local_save_path)
                    tokenizer.save_pretrained(local_save_path)
                    logger.info(f"REINFORCE checkpoint saved locally to {local_save_path}")
                except Exception as e:
                    logger.error(f"Failed to save checkpoint at step {global_step}: {e}")

            step_accumulation_counter = 0

    # Validation
    logger.info(f"Validation for epoch {epoch+1}")
    policy_model.eval()
    total_val_reward = 0.0
    total_val_samples = 0
    with torch.no_grad():
        for val_batch in tqdm(dataloader_val, desc=f"Validation epoch {epoch+1}", leave=False):
            val_prompts = val_batch["query"]
            # Generate responses for validation
            _, _, val_responses = generate_responses(
                val_prompts, policy_model, tokenizer, generation_config, device
            )

            try:
                # Score responses
                val_rewards = explainable_reward(val_prompts, val_responses)
                total_val_reward += val_rewards.sum().item()
                total_val_samples += val_rewards.numel()
            except Exception as e:
                logger.error(f"Error getting validation reward: {e}")

    avg_val_reward = total_val_reward / total_val_samples if total_val_samples > 0 else 0
    logger.info(f"Epoch {epoch+1} Validation average reward: {avg_val_reward:.4f}")

    if run:
        try:
            wandb.log({"validation/avg_reward": avg_val_reward, "epoch": epoch + 1, "global_step": global_step})
        except Exception as e:
             logger.error(f"Failed to log validation metrics to WandB: {e}")

logger.info("REINFORCE Training loop finished.")

INFO - Starting REINFORCE training loop...
INFO - Epoch 1/1
Epoch 1 Batches:   0%|          | 0/1000 [00:00<?, ?it/s]ERROR - Error during generation: index -1 is out of bounds for dimension 0 with size 0
Traceback (most recent call last):
  File "C:\Users\milya\AppData\Local\Temp\ipykernel_7916\1816318076.py", line 15, in generate_responses
    outputs = policy_model.generate(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "m:\python_projects\interpretable_rewards_in_RLHF\venv\Lib\site-packages\peft\peft_model.py", line 1875, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "m:\python_projects\interpretable_rewards_in_RLHF\venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "m:\python_projects\interpretable_rewards_in_RLHF\venv\Lib\site-packages\transformers\generation\utils.py", line 2465, in generate
    

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn