## Imports


In [2]:
import torch
from torch import nn

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from reporter import Reporter

## Data

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Dataset (just 4 rows for prototyping)
dataset = load_dataset("amazon_polarity", split="test[:4]")

In [4]:
# Will be used to form the prompt
def get_prompt(text, label):
    sentiment = ["negative", "positive"][label]
    return f"""Below is a movie review in triple backticks: \
        \n```\n{text}\n``` \
        \n\nThe sentiment of the review is {sentiment}
    """

In [None]:
# Get input texts
pos_input_texts = [get_prompt(text, label=1) + tokenizer.eos_token for text in dataset["content"]]
neg_input_texts = [get_prompt(text, label=0) + tokenizer.eos_token for text in dataset["content"]]

# Get the inputs (input_ids and attention_masks) for the positive and negative texts
pos_inputs = tokenizer(pos_input_texts, return_tensors="pt", padding=True)
neg_inputs = tokenizer(neg_input_texts, return_tensors="pt", padding=True)

pos_inputs["input_ids"].shape, neg_inputs["input_ids"].shape

## Combine the reporter with a language model

In [55]:
class MyRewardModel(nn.Module):
    def __init__(self, language_model, reporter, layer=-1):
        super().__init__()
        
        self.language_model = language_model # e.g. GPT-2
        self.reporter = reporter # EigenReporter, loaded using Reporter.load(path)
        self.layer = layer # which layer to extract

    
    def forward(self, pos_inputs, neg_inputs):
        # Get the hidden states
        pos_hidden_states = self.language_model(
            **pos_inputs, output_hidden_states=True,
        ).hidden_states[self.layer]
        neg_hidden_states = self.language_model(
            **neg_inputs, output_hidden_states=True,
        ).hidden_states[self.layer]
        
        # Find the index of the last non-padding token
        pos_last_token_index = torch.sum(pos_inputs["attention_mask"], dim=1) - 1
        neg_last_token_index = torch.sum(neg_inputs["attention_mask"], dim=1) - 1

        # Get the last token's output
        pos_last_tokens = pos_hidden_states[range(len(pos_last_token_index)), pos_last_token_index]
        neg_last_tokens = neg_hidden_states[range(len(neg_last_token_index)), neg_last_token_index]

        # Get the logits for the two classes
        pos_logits = self.reporter(pos_last_tokens)
        neg_logits = self.reporter(neg_last_tokens)

        # Return the difference in logits which will later be
        # passed through a sigmoid function
        return pos_logits - neg_logits


In [56]:
# Config
LAYER = 12
reporter_path = f"elk-reporters/gpt2/reporters/layer_{LAYER}.pt"

# Load the models
language_model = AutoModelForCausalLM.from_pretrained("gpt2")
reporter = Reporter.load(reporter_path)

# Combine the models
my_reward_model = MyRewardModel(language_model, reporter)

In [57]:
# Run the forward pass
with torch.no_grad():
    credences = my_reward_model(pos_inputs, neg_inputs)

outputs = torch.sigmoid(credences)
outputs

tensor([0.5005, 0.5015, 0.4993, 0.5002])

In [58]:
# Compare ground-truth with predictions
dataset["label"], (outputs > 0.5).int().tolist()

([1, 1, 0, 1], [1, 1, 0, 1])