In [8]:
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import torch

In [13]:
base_policy_id = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr"
reward_model_id = "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr"
trained_value_model_id = "Prathyusha101/pythia_1b_100_epochs_0.005"

In [None]:
reward_model_id = AutoModelForSequenceClassification.from_pretrained(reward_model_id)
trained_value_model_id = AutoModelForSequenceClassification.from_pretrained(trained_value_model_id)

In [14]:
tokenizer = AutoTokenizer.from_pretrained(base_policy_id)
tokenizer.context_length = 2048

In [8]:
def first_true_indices(bools: torch.Tensor, dtype=torch.long):
    """
    Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving
    the position of the first True in each "row".

    Returns the length of the rows (bools.size(-1)) if no element is True in a given row.

    Args:
        bools (`torch.Tensor`):
            An N-dimensional boolean tensor.
        dtype (`torch.dtype`, optional):
            The desired data type of the output tensor. Defaults to `torch.long`.

    Returns:
        `torch.Tensor`:
            An (N-1)-dimensional tensor of integers indicating the position of the first True
            in each row. If no True value is found in a row, returns the length of the row.
    """
    row_len = bools.size(-1)
    zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
    return torch.min(zero_or_index, dim=-1).values

In [9]:
def get_reward(
    model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Computes the reward logits and the rewards for a given model and query responses.

    Args:
        model (`torch.nn.Module`):
            The model used to compute the reward logits.
        query_responses (`torch.Tensor`):
            The tensor containing the query responses.
        pad_token_id (`int`):
            The token ID representing the pad token.
        context_length (`int`):
            The length of the context in the query responses.

    Returns:
        tuple:
            - `reward_logits` (`torch.Tensor`):
                The logits for the reward model.
            - `final_rewards` (`torch.Tensor`):
                The final rewards for each query response.
            - `sequence_lengths` (`torch.Tensor`):
                The lengths of the sequences in the query responses.
    """
    attention_mask = query_responses != pad_token_id
    position_ids = attention_mask.cumsum(1) - attention_mask.long()  # exclusive cumsum
    lm_backbone = getattr(model, model.base_model_prefix)
    input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
    output = lm_backbone(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        return_dict=True,
        output_hidden_states=True,
        use_cache=False,  # otherwise mistral-based RM would error out
    )
    reward_logits = model.score(output.hidden_states[-1])
    sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
    # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
    return (
        reward_logits,
        reward_logits[
            torch.arange(reward_logits.size(0), device=reward_logits.device),
            sequence_lengths,
        ].squeeze(-1),
        sequence_lengths,
    )


In [6]:
import pandas as pd
df = pd.read_csv("base_model_with_everything_corrected.csv", usecols=["prompt", "response"])

In [None]:
tokenizer.pad_token 

NameError: name 'tokenizer' is not defined

In [11]:
tokenized = tokenizer(
    text=df["prompt"].tolist(),
    text_pair=df["response"].tolist(),
    padding=True,
    truncation=True,
    max_length=2048,
    return_tensors="pt",
)

NameError: name 'tokenizer' is not defined

In [40]:
import torch

def process_rewards_rowwise(df, model, tokenizer, max_length=2048):
    reward_scores = []
    sequence_lengths = []
    reward_logits_list = []

    model_device = next(model.parameters()).device  # get model device

    for idx, row in df.iterrows():
        prompt = row["prompt"]
        response = row["response"]

        # Tokenize prompt + response pair
        tokenized = tokenizer(
            text=prompt,
            text_pair=response,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )

        input_ids = tokenized["input_ids"].to(model_device)

        # Tokenize prompt alone to get context length
        prompt_tokenized = tokenizer(
            prompt,
            padding=False,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        context_length = prompt_tokenized["input_ids"].size(1)  # prompt token length

        # Call get_reward (input_ids shape: [1, seq_len])
        reward_logits, scores, sequence_length = get_reward(
            model,
            input_ids,
            tokenizer.pad_token_id,
            context_length
        )

        # Store outputs as scalars or lists
        reward_scores.append(scores.item())
        sequence_lengths.append(sequence_length.item())
        reward_logits_list.append(reward_logits.squeeze(0).detach().cpu().tolist())  # full logits per token

    # Add columns to DataFrame
    df["reward_score"] = reward_scores
    df["sequence_length"] = sequence_lengths
    df["reward_logits"] = reward_logits_list

    return df


In [41]:
df = process_rewards_rowwise(df, trained_value_model_id, tokenizer)

In [43]:
df.rename(columns={"reward_score": "RM_reward_score", "reward_logits" : "RM_tokenwise_rewards"}, inplace=True)

In [45]:
df.to_csv("base_model_with_RM_details", index=False)

In [15]:
df

Unnamed: 0,prompt,response
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc..."
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'..."
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b..."
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...
...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve..."
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see..."
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them..."


In [25]:
import pandas as pd
df = pd.read_csv("base_model_with_RM_details.csv")

In [27]:
df

Unnamed: 0,prompt,response,input_ids,RM_reward_score,sequence_length,RM_tokenwise_rewards
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.561722,544,"[[0.16722245514392853], [0.6319139003753662], ..."
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 1585...",0.730876,323,"[[0.16722245514392853], [0.6319139003753662], ..."
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",1.547613,234,"[[0.16722245514392853], [0.6319139003753662], ..."
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 4016...",0.761009,398,"[[0.16722245514392853], [0.6319139003753662], ..."
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",1.390531,245,"[[0.16722245514392853], [0.6319139003753662], ..."
...,...,...,...,...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 4016...",0.759217,472,"[[0.16722245514392853], [0.6319139003753662], ..."
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.903624,417,"[[0.16722245514392853], [0.6319139003753662], ..."
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2727...",0.913609,454,"[[0.16722245514392853], [0.6319139003753662], ..."
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.763344,470,"[[0.16722245514392853], [0.6319139003753662], ..."


In [28]:
df.rename(columns={"RM_tokenwise_rewards" : "RM_tokenwise_values"}, inplace=True)

In [30]:
import ast


def parse_tokenwise_value(row, device="cpu"):
    """
    Parses a single string of RM_tokenwise_values from a DataFrame row.
    Returns a 1D tensor of floats.
    """
    try:
        nested_list = ast.literal_eval(row)  # [[x], [y], ...]
        flat_list = [x[0] for x in nested_list]
        return torch.tensor(flat_list, device=device)
    except Exception as e:
        raise ValueError(f"Error parsing row: {row[:100]}...") from e

def parse_tokenwise_column(df, column="RM_tokenwise_values", device="cpu"):
    """
    Applies the parse_tokenwise_value function to the entire column.
    Returns a new Series of tensors.
    """
    return df[column].apply(lambda row: parse_tokenwise_value(row, device=device))


In [31]:
new_col = parse_tokenwise_column(df)
df["RM_tokenwise_values"] = new_col


In [32]:
def safe_sum_tokenwise_values(value):
    """
    Handles both raw string and already-parsed tensor formats.
    """
    if isinstance(value, str):
        try:
            nested_list = ast.literal_eval(value)  # [[x], [y], ...]
            flat_list = [x[0] for x in nested_list]
            return sum(flat_list)
        except Exception as e:
            raise ValueError(f"Failed to parse string row: {value[:100]}...") from e

    elif isinstance(value, torch.Tensor):
        return value.sum().item()

    elif isinstance(value, list):  # in case it's a list of floats already
        return sum(value)

    else:
        raise TypeError(f"Unexpected type for RM_tokenwise_values: {type(value)}")


def add_tokenwise_sum_column(df, source_col="RM_tokenwise_values", new_col="RM_tokenwise_sum"):
    """
    Adds a new column with the sum of tokenwise values per row.
    """
    df[new_col] = df[source_col].apply(safe_sum_tokenwise_values)
    return df


### verifying RM tokenwise values are different, even thought first few values look the same in all rows

In [None]:
add_tokenwise_sum_column(df)

Unnamed: 0,prompt,response,input_ids,RM_reward_score,sequence_length,RM_tokenwise_values,RM_tokenwise_sum
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.561722,544,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1172.383057
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 1585...",0.730876,323,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1611.715820
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",1.547613,234,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",3149.649658
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 4016...",0.761009,398,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1671.090332
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",1.390531,245,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2826.423340
...,...,...,...,...,...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 4016...",0.759217,472,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1531.397705
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...,"[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.903624,417,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2054.581055
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2727...",0.913609,454,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2052.542725
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them...","[6971, 7941, 1703, 37, 1433, 27, 391, 16, 2284...",0.763344,470,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1662.169678


In [34]:
def compute_gae_and_returns(df, gamma=1.0, lam=0.95):
    """
    Computes GAE, TD errors, and returns from a DataFrame containing
    'RM_tokenwise_values' and 'RM_reward_score'.
    
    Parameters:
    - df: DataFrame with required columns
    - gamma: discount factor
    - lam: GAE lambda parameter

    Returns:
    - advantages: torch.Tensor [batch_size, seq_len]
    - td_errors: torch.Tensor [batch_size, seq_len]
    - returns: torch.Tensor [batch_size, seq_len]
    """
    values = torch.stack(df['RM_tokenwise_values'].tolist())  # [B, T]
    rewards = torch.zeros_like(values)
    for i, score in enumerate(df['RM_reward_score']):
        rewards[i, -1] = score  # reward only at the last token

    batch_size, seq_len = values.shape
    advantages_reversed = []
    td_errors_reversed = []
    lastgaelam = torch.zeros(batch_size, device=values.device)

    for t in reversed(range(seq_len)):
        nextvalues = values[:, t + 1] if t < seq_len - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        td_errors_reversed.append(delta)
        lastgaelam = delta + gamma * lam * lastgaelam
        advantages_reversed.append(lastgaelam)

    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    td_errors = torch.stack(td_errors_reversed[::-1], dim=1)
    returns = advantages + values
    return advantages, td_errors, returns


In [35]:
advantages, td_errors, returns = compute_gae_and_returns(df)
df["RM_advantages"] = advantages.tolist()
df["RM_td_errors"] = td_errors.tolist()
df["RM_returns"] = returns.tolist()

In [36]:
ten_x_value_model_id =  "10x_lr_20k/value_model"
hun_x_value_model_id =  "100x_lr_20k/value_model"


In [17]:
ten_x_value_model = AutoModelForSequenceClassification.from_pretrained(ten_x_value_model_id)

In [18]:
ten_x_value_model.config.pad_token_id = tokenizer.pad_token_id


In [19]:
vf_critic_backbone = getattr(ten_x_value_model, ten_x_value_model.base_model_prefix)

In [38]:
def parse_input_ids_column_to_lists(df, column="input_ids"):
    """
    Convert a string column of input_ids like '[1, 2, 3]' into actual Python lists of ints.
    """
    df[column] = df[column].apply(ast.literal_eval)
    return df

def convert_input_ids_to_tensors(df, column="input_ids"):
    """
    Convert a DataFrame column of input_ids from list of ints to torch.Tensor.
    """
    df[column] = df[column].apply(lambda x: torch.tensor(x, dtype=torch.long))
    return df


parse_input_ids_column_to_lists(df)
convert_input_ids_to_tensors(df)

Unnamed: 0,prompt,response,input_ids,RM_reward_score,sequence_length,RM_tokenwise_values,RM_tokenwise_sum,RM_advantages,RM_td_errors,RM_returns
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.561722,544,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1172.383057,"[0.7460709810256958, 0.29618895053863525, 0.32...","[0.4646914601325989, -0.008473753929138184, 0....","[0.9132934212684631, 0.9281028509140015, 0.944..."
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.730876,323,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1611.715820,"[0.8514430522918701, 0.4071069359779358, 0.437...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0186655521392822, 1.0390207767486572, 1.060..."
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b...","[tensor(6971), tensor(7941), tensor(1703), ten...",1.547613,234,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",3149.649658,"[1.0517261028289795, 0.6179311871528625, 0.659...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2189486026763916, 1.249845027923584, 1.2828..."
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.761009,398,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1671.090332,"[0.8933334946632385, 0.4512021541595459, 0.483...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0605559349060059, 1.083116054534912, 1.1073..."
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...,"[tensor(6971), tensor(7941), tensor(1703), ten...",1.390531,245,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2826.423340,"[0.8885331153869629, 0.44614914059638977, 0.47...","[0.4646914601325989, -0.008473753929138184, 0....","[1.055755615234375, 1.0780630111694336, 1.1019..."
...,...,...,...,...,...,...,...,...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.759217,472,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1531.397705,"[0.9096167087554932, 0.4683423638343811, 0.501...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0768392086029053, 1.1002562046051025, 1.125..."
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.903624,417,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2054.581055,"[0.9705684781074524, 0.5325021147727966, 0.569...","[0.4646914601325989, -0.008473753929138184, 0....","[1.1377909183502197, 1.1644160747528076, 1.192..."
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.913609,454,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2052.542725,"[0.9002393484115601, 0.45847150683403015, 0.49...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0674618482589722, 1.0903854370117188, 1.114..."
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.763344,470,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1662.169678,"[1.0358400344848633, 0.6012091040611267, 0.641...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2030625343322754, 1.2331230640411377, 1.265..."


In [58]:
vf_critic_backbone.config.pad_token_id = tokenizer.pad_token_id

In [1]:
import numpy as np

In [39]:
from tqdm import tqdm

def get_per_token_values_from_prompt_response(
    df,
    tokenizer,
    value_model,
    prompt_col="prompt",
    response_col="response",
    new_col="VF_pred_value",
    device="cuda"
):
    """
    Returns per-token value predictions from value_model.score(hidden_states).
    Stores the list of values (1 per token) in a new DataFrame column.
    """
    value_model.eval()
    value_model.to(device)
    critic_backbone = getattr(value_model, value_model.base_model_prefix)

    all_token_values = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Scoring tokens"):
        text = row[prompt_col] + row[response_col]
        encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

        input_ids = encoded["input_ids"].to(device)
        attention_mask = encoded["attention_mask"].to(device)

        with torch.no_grad():
            output = critic_backbone(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            hidden_states = output.hidden_states[-1]  # [1, seq_len, hidden_dim]
            value_logits = value_model.score(hidden_states).squeeze(0).squeeze(-1)  # [seq_len]

        all_token_values.append(value_logits.detach().cpu().tolist())

    df[new_col] = all_token_values
    return df


In [40]:
df = get_per_token_values_from_prompt_response(df, tokenizer, ten_x_value_model)

Scoring tokens: 100%|██████████| 100/100 [00:08<00:00, 12.47it/s]


In [43]:
len(df["VF_pred_value"][0]), len(df["RM_returns"][0]), len(df["RM_tokenwise_values"][0])

(545, 2048, 2048)

In [41]:
df

Unnamed: 0,prompt,response,input_ids,RM_reward_score,sequence_length,RM_tokenwise_values,RM_tokenwise_sum,RM_advantages,RM_td_errors,RM_returns,VF_pred_value
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.561722,544,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1172.383057,"[0.7460709810256958, 0.29618895053863525, 0.32...","[0.4646914601325989, -0.008473753929138184, 0....","[0.9132934212684631, 0.9281028509140015, 0.944...","[3.5257740020751953, 3.05199933052063, 3.61255..."
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.730876,323,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1611.715820,"[0.8514430522918701, 0.4071069359779358, 0.437...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0186655521392822, 1.0390207767486572, 1.060...","[3.5257749557495117, 3.051999568939209, 3.6125..."
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b...","[tensor(6971), tensor(7941), tensor(1703), ten...",1.547613,234,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",3149.649658,"[1.0517261028289795, 0.6179311871528625, 0.659...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2189486026763916, 1.249845027923584, 1.2828...","[3.5257749557495117, 3.051999807357788, 3.6125..."
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.761009,398,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1671.090332,"[0.8933334946632385, 0.4512021541595459, 0.483...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0605559349060059, 1.083116054534912, 1.1073...","[3.5257749557495117, 3.051999807357788, 3.6125..."
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...,"[tensor(6971), tensor(7941), tensor(1703), ten...",1.390531,245,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2826.423340,"[0.8885331153869629, 0.44614914059638977, 0.47...","[0.4646914601325989, -0.008473753929138184, 0....","[1.055755615234375, 1.0780630111694336, 1.1019...","[3.5257749557495117, 3.051999807357788, 3.6125..."
...,...,...,...,...,...,...,...,...,...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.759217,472,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1531.397705,"[0.9096167087554932, 0.4683423638343811, 0.501...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0768392086029053, 1.1002562046051025, 1.125...","[3.5257749557495117, 3.051999568939209, 3.6125..."
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.903624,417,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2054.581055,"[0.9705684781074524, 0.5325021147727966, 0.569...","[0.4646914601325989, -0.008473753929138184, 0....","[1.1377909183502197, 1.1644160747528076, 1.192...","[3.5257749557495117, 3.051999807357788, 3.6125..."
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.913609,454,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2052.542725,"[0.9002393484115601, 0.45847150683403015, 0.49...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0674618482589722, 1.0903854370117188, 1.114...","[3.5257749557495117, 3.051999568939209, 3.6125..."
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.763344,470,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1662.169678,"[1.0358400344848633, 0.6012091040611267, 0.641...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2030625343322754, 1.2331230640411377, 1.265...","[3.5257749557495117, 3.051999568939209, 3.6125..."


In [44]:
import torch.nn.functional as F

def compute_value_loss_mse(df, vf_col="VF_pred_value", rm_col="RM_returns", new_col="value_loss"):
    losses = []

    for idx, row in df.iterrows():
        vf_pred = torch.tensor(row[vf_col], dtype=torch.float32)
        rm_return = torch.tensor(row[rm_col], dtype=torch.float32)[:len(vf_pred)]

        # Compute MSE loss
        loss = F.mse_loss(vf_pred, rm_return, reduction="mean").item()
        losses.append(loss)

    df[new_col] = losses
    return df


In [45]:
df = compute_value_loss_mse(df)


In [46]:
df

Unnamed: 0,prompt,response,input_ids,RM_reward_score,sequence_length,RM_tokenwise_values,RM_tokenwise_sum,RM_advantages,RM_td_errors,RM_returns,VF_pred_value,value_loss
0,SUBREDDIT: r/relationships\n\nTITLE: Me [19 F]...,"Talked to some guy online, he sent me a Snapc...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.561722,544,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1172.383057,"[0.7460709810256958, 0.29618895053863525, 0.32...","[0.4646914601325989, -0.008473753929138184, 0....","[0.9132934212684631, 0.9281028509140015, 0.944...","[3.5257740020751953, 3.05199933052063, 3.61255...",0.333823
1,SUBREDDIT: r/Parenting\n\nTITLE: My 11 year ol...,"My 11 year old, recently deceased friend won'...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.730876,323,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1611.715820,"[0.8514430522918701, 0.4071069359779358, 0.437...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0186655521392822, 1.0390207767486572, 1.060...","[3.5257749557495117, 3.051999568939209, 3.6125...",0.269458
2,SUBREDDIT: r/relationships\n\nTITLE: The girl ...,"A month long friend and I were texting, and b...","[tensor(6971), tensor(7941), tensor(1703), ten...",1.547613,234,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",3149.649658,"[1.0517261028289795, 0.6179311871528625, 0.659...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2189486026763916, 1.249845027923584, 1.2828...","[3.5257749557495117, 3.051999807357788, 3.6125...",0.419698
3,SUBREDDIT: r/tifu\n\nTITLE: TIFU by accidently...,I tripped while going to my grandadas funeral...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.761009,398,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1671.090332,"[0.8933334946632385, 0.4512021541595459, 0.483...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0605559349060059, 1.083116054534912, 1.1073...","[3.5257749557495117, 3.051999807357788, 3.6125...",0.355657
4,SUBREDDIT: r/relationships\n\nTITLE: I [32 M] ...,found out my wife (now wife) had an affair wi...,"[tensor(6971), tensor(7941), tensor(1703), ten...",1.390531,245,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2826.423340,"[0.8885331153869629, 0.44614914059638977, 0.47...","[0.4646914601325989, -0.008473753929138184, 0....","[1.055755615234375, 1.0780630111694336, 1.1019...","[3.5257749557495117, 3.051999807357788, 3.6125...",0.173714
...,...,...,...,...,...,...,...,...,...,...,...,...
95,SUBREDDIT: r/tifu\n\nTITLE: TIFU by inviting t...,"I couldn't Pee, I was an asshole that day eve...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.759217,472,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1531.397705,"[0.9096167087554932, 0.4683423638343811, 0.501...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0768392086029053, 1.1002562046051025, 1.125...","[3.5257749557495117, 3.051999568939209, 3.6125...",0.409522
96,SUBREDDIT: r/relationships\n\nTITLE: My GF [25...,GF lost her job and is unable to find a new o...,"[tensor(6971), tensor(7941), tensor(1703), ten...",0.903624,417,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2054.581055,"[0.9705684781074524, 0.5325021147727966, 0.569...","[0.4646914601325989, -0.008473753929138184, 0....","[1.1377909183502197, 1.1644160747528076, 1.192...","[3.5257749557495117, 3.051999807357788, 3.6125...",0.142988
97,SUBREDDIT: r/offmychest\n\nTITLE: I am sick an...,"My friend, severely depressed, refuses to see...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.913609,454,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",2052.542725,"[0.9002393484115601, 0.45847150683403015, 0.49...","[0.4646914601325989, -0.008473753929138184, 0....","[1.0674618482589722, 1.0903854370117188, 1.114...","[3.5257749557495117, 3.051999568939209, 3.6125...",0.542371
98,"SUBREDDIT: r/relationships\n\nTITLE: Help, ple...","Wife hates wedding photos, does not like them...","[tensor(6971), tensor(7941), tensor(1703), ten...",0.763344,470,"[tensor(0.1672), tensor(0.6319), tensor(0.6234...",1662.169678,"[1.0358400344848633, 0.6012091040611267, 0.641...","[0.4646914601325989, -0.008473753929138184, 0....","[1.2030625343322754, 1.2331230640411377, 1.265...","[3.5257749557495117, 3.051999568939209, 3.6125...",0.158014


In [47]:
df.to_csv("ten_lr_td_value_eval.csv", index=False)