# Reward Model Training

In [None]:
from transformers import AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
%pip install datasets==3.5.0

In [None]:
from datasets import load_dataset
dataset_name = 'sst2'
dataset = load_dataset(dataset_name)
dataset

In [4]:
ds_train, ds_val = dataset['train'], dataset['validation']

In [5]:
ds_train[4]

{'idx': 4,
 'sentence': 'on the worst revenge-of-the-nerds clichés the filmmakers could dredge up ',
 'label': 0}

## Tokenize the dataset

In [6]:
REWARD_TOKEN_ID = tokenizer.eos_token_id

In [7]:
REWARD_TOKEN_ID

50256

In [None]:
def tokenize(batch):
    outputs = tokenizer(batch['sentence'])
    outputs['score'] = [0] * len(outputs['input_ids'])
    outputs['score_index'] = [0] * len(outputs['input_ids'])
    for i in range(len(outputs['input_ids'])):
        outputs['input_ids'][i].append(REWARD_TOKEN_ID)
        outputs['attention_mask'][i].append(1)
        outputs['score'][i] = float(batch['label'][i])
        outputs['score_index'][i] = len(outputs['input_ids'][i]) - 1
    return outputs

map_kwargs = {
    "batched": True,
    "batch_size": 512,
    "remove_columns": ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

In [9]:
tokenized_dataset_train[4]

{'input_ids': [261,
  262,
  5290,
  15827,
  12,
  1659,
  12,
  1169,
  12,
  1008,
  9310,
  35478,
  20954,
  262,
  28303,
  714,
  47478,
  469,
  510,
  220,
  50256],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'score': 0.0,
 'score_index': 20}

In [10]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [11]:
tokenized_dataset_train[4]

{'input_ids': tensor([  261,   262,  5290, 15827,    12,  1659,    12,  1169,    12,  1008,
          9310, 35478, 20954,   262, 28303,   714, 47478,   469,   510,   220,
         50256]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 'score': tensor(0.),
 'score_index': tensor(20)}

### Filter out shorter tweets

In [None]:
tokenized_dataset_train = tokenized_dataset_train.filter(lambda x: len(x['input_ids']) > 6)
tokenized_dataset_val = tokenized_dataset_val.filter(lambda x: len(x['input_ids']) > 6)

In [13]:
len(tokenized_dataset_train)

49401

## LLM with Reward Head

In [14]:
import torch
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class RewardHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.reward = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.reward.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.reward.bias)

    def forward(self, hidden_states):
        return self.reward(hidden_states)

class GPT2RewardHead(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        self.reward_head = RewardHead(self.llm.config)

    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.llm.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = transformer_outputs.hidden_states[-1]
        reward = self.reward_head(last_hidden_state).squeeze(-1)
        return torch.sigmoid(reward)


In [None]:
model = GPT2RewardHead(model_name)

In [16]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorWithPadding(tokenizer)
dataloader_params = {
    'batch_size': 64,
    'shuffle': True,
    'collate_fn': data_collator
}
train_dataloader = DataLoader(tokenized_dataset_train, **dataloader_params)
val_dataloader = DataLoader(tokenized_dataset_val, **dataloader_params)

In [17]:
batch = next(iter(train_dataloader))
print(batch.keys())

dict_keys(['input_ids', 'attention_mask', 'score', 'score_index'])


In [18]:
print(batch['input_ids'][1])
print(batch['attention_mask'][1])
print(batch['score'][1])
print(batch['score_index'][1])

tensor([ 5661,  4260, 10590, 32251,   287,  2989,   286,  4007,   393,   772,
          257,  7110,   220, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor(0.)
tensor(13)


In [19]:
print(tokenizer.decode(batch['input_ids'][1]))

this alleged psychological thriller in search of purpose or even a plot <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


In [20]:
batch['attention_mask'][1].nonzero()[-1]

tensor([13])

In [21]:
outputs = model(batch['input_ids'], batch['attention_mask'])

In [22]:
print(outputs.shape)

torch.Size([64, 46])


### Training Config

In [23]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCELoss()
num_epochs = 1 # N+ Implementation Detail paper


In [24]:
def validate():
    model.eval()
    total_loss = 0
    for i, batch in enumerate(val_dataloader):
        inputs = batch.to(device)
        model_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        with torch.no_grad():
            scores = model(**model_inputs)
            batch_indices = torch.arange(scores.shape[0])
            score = scores[batch_indices, inputs['score_index']]
            target = inputs['score']
            loss = criterion(score, target)
        total_loss += loss.item()
    print('validation loss:', total_loss / len(val_dataloader))

### Training Loop

In [25]:
model.to(device)

validate()
for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(train_dataloader):
        inputs = batch.to(device)
        model_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        scores = model(**model_inputs)
        batch_indices = torch.arange(scores.shape[0])
        score = scores[batch_indices, inputs['score_index']]
        target = inputs['score']
        loss = criterion(score, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss.item())
    validate()


validation loss: 4.236975857189724
3.081965923309326
2.295074939727783
0.9389018416404724
1.3304227590560913
2.5907254219055176
2.9665000438690186
2.1143782138824463
1.1857104301452637
0.9820933938026428
1.11643385887146
0.9809982776641846
0.9127055406570435
0.9850586652755737
0.8745125532150269
0.879654049873352
0.9371069669723511
0.8351837396621704
0.7151826620101929
0.7962563037872314
0.8667033910751343
0.8059372305870056
0.8398273587226868
0.8546754121780396
0.8896375894546509
0.8449140191078186
0.720230758190155
0.858842134475708
1.077275276184082
0.7604016661643982
0.7401319742202759
0.7428716421127319
0.8113798499107361
0.7954106330871582
0.7568372488021851
0.7153240442276001
0.8176007270812988
0.8225393295288086
0.8309254050254822
0.7394857406616211
0.823320746421814
0.7697409987449646
0.7335368990898132
0.7643752694129944
0.7798486948013306
0.7756420373916626
0.6968182325363159
0.6697213649749756
0.7042258977890015
0.7872732877731323
0.7080168128013611
0.6795705556869507
0.737

In [26]:
torch.save(model.state_dict(), 'reward_model.pt')

In [27]:
validate()

validation loss: 0.2508827745914459


### Confusion Matrix

In [28]:
from sklearn.metrics import confusion_matrix
model.eval()

all_predictions = []
all_labels = []

for i, batch in enumerate(val_dataloader):
    inputs = batch.to(device)
    model_inputs = {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask']
    }
    with torch.no_grad():
        scores = model(**model_inputs)
        batch_indices = torch.arange(scores.shape[0])
        score = scores[batch_indices, inputs['score_index']]
        target = inputs['score']
    predictions = (score > 0.5).int()

    all_predictions.extend(predictions.cpu().numpy())
    all_labels.extend(target.cpu().numpy())

confusion_matrix(all_labels, all_predictions)

array([[391,  33],
       [ 49, 394]])