# Fine Tune reward model from scratch

# TODOs:

#TODO: double-check that labels are not somehow misaligned...

#TODO: check if you need to plot 

1. LoRA learns the position of the low rank adaptation matrix that is needed to finetune a model of a much higher rank

#TODO: double check model performance, generate output, maybe adjust training metrics

## 1. Imports, setup, and global variables

In [1]:
import torch
import pandas as pd
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from torch.nn import functional as F
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error

from datasets import Dataset, DatasetDict, load_dataset

from peft import LoraConfig, get_peft_model, PeftModel

from utils import parse_ratings

# load the relevant devices available on the server
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("AVAILABLE_DEVICES")

# Enable expandable CUDA segments
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# load cuda
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

  from .autonotebook import tqdm as notebook_tqdm


There are 1 GPU(s) available.
CUDA is available. Using GPU: NVIDIA L40S


In [2]:
# load training variables
FEEDBACK_TO_TRAIN_ON = os.getenv("FEEDBACK_TO_TRAIN_ON")
FEEDBACK_TO_REMOVE = os.getenv("FEEDBACK_TO_REMOVE")
MODEL = os.getenv("REWARD_MODEL")
LORA_CHECKPOINTS_FOLDER = os.getenv("LORA_CHECKPOINTS_FOLDER")
FINAL_LORA_ADAPTERS = os.getenv("FINAL_LORA_ADAPTERS_FOLDER") + "_" + FEEDBACK_TO_TRAIN_ON

# load training data
FILE_1 = os.getenv("FILE_1")
FILE_5 = os.getenv("FILE_5")
FILE_7 = os.getenv("FILE_7")
FILE_9 = os.getenv("FILE_9")
FILE_10_1 = os.getenv("FILE_10_1")
FILE_10_2 = os.getenv("FILE_10_2")
FILE_SYNTH = os.getenv("FILE_SYNTH")

## 2. Dataset loading and preprocessing

In [3]:
# load dataframes
df_1 = pd.read_csv(FILE_1, sep=";")
df_5 = pd.read_csv(FILE_5, sep=";")
df_7 = pd.read_csv(FILE_7, sep=";")
df_9 = pd.read_csv(FILE_9, sep=";")
df_10_1 = pd.read_csv(FILE_10_1, sep=";")
df_10_2 = pd.read_csv(FILE_10_2, sep=";")
df_synth = pd.read_csv(FILE_SYNTH, sep=";")

df_human = pd.concat([df_1, df_5, df_7, df_9, df_10_1, df_10_2], ignore_index=True)

In [4]:
df_train = df_human
df_train.shape

(929, 13)

### 2. a) Parse ratings to numeric values for MSE Loss

In [5]:
df_train[FEEDBACK_TO_TRAIN_ON] = [parse_ratings(feedback) for feedback in df_train[FEEDBACK_TO_TRAIN_ON]]
print("Parsed feedback for extraction:", df_train[FEEDBACK_TO_TRAIN_ON][:5])

Parsed feedback for extraction: 0    4
1    4
2    4
3    4
4    4
Name: feedback_detection, dtype: object


### 2. b) keep only relevant feedback column

In [6]:
dataset = Dataset.from_pandas(df_train)

print(dataset)

Dataset({
    features: ['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id', 'precondition_text', 'precondition_position', 'response_text', 'prompt_config_examples', 'prompt_config_chain_of_thought', 'feedback_extraction', 'feedback_detection', 'additional_feedback'],
    num_rows: 929
})


In [7]:
dataset = dataset.remove_columns([FEEDBACK_TO_REMOVE])
dataset = dataset.rename_column(FEEDBACK_TO_TRAIN_ON, "label")

## 3. Load model with LoRA layer

In [8]:
# Load the model and the tokenizer
model_id = MODEL 
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) # num_labels = 1 since we want to prodict a single scalar (the rating)

# Comment: Automodel for sequence classification with num_labels=1 already has a regression head
print(model)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [9]:
print(tokenizer.model_max_length)
print(model.config.max_position_embeddings)

512
512


In [10]:
# Define LoRA config
lora_config = LoraConfig(
    r=8,           # Rank of the LoRA matrices (smaller = less memory)
    lora_alpha=16, # Scaling factor (higher = stronger adaptation)
    target_modules=["query", "key", "value"], # Apply LoRA to attention layers
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS"  # classify each answer 
)

# Convert the model to a PEFT (LoRA) model
model = get_peft_model(model, lora_config)
# model.gradient_checkpointing_enable()
model.print_trainable_parameters()  # Check trainable params (~0.1% of full model)


trainable params: 443,137 || all params: 109,926,146 || trainable%: 0.4031


In [11]:
# Test tokenizer
sample_data = ["What is the capital of France?", "What is the largest capital in the world?"]
tokenizer(sample_data, padding=True, truncation=True, max_length=512)

{'input_ids': [[101, 1067, 223, 207, 580, 210, 1335, 124, 102, 0, 0, 0], [101, 1067, 223, 207, 5601, 190, 580, 213, 207, 1727, 124, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

## 4. Define Custom Trainer to be used for the task

In [12]:
class CustomRewardTrainer(Trainer):
    def __init__(self, *args, loss_type="mse", weight_strategy=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_type = loss_type  # "mse", "huber", or custom
        self.weight_strategy = weight_strategy  # "linear", "inverse", or None

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        # Extract labels (ratings) and optional sample weights
        labels = inputs.pop("labels").float()  # Shape: (batch_size)
        
        # Optional: Compute sample weights dynamically
        weights = self._get_sample_weights(labels) if self.weight_strategy else None
        
        # Forward pass
        outputs = model(**inputs)
        logits = outputs.logits.squeeze()  # Shape: (batch_size) --> logits are the predicted rewards in this case
        
        # Custom loss calculation
        loss = self._compute_custom_loss(logits, labels, weights)
        
        return (loss, outputs) if return_outputs else loss

    def _compute_custom_loss(self, logits, labels, weights=None):
        if self.loss_type == "mse":
            loss = F.mse_loss(logits, labels, reduction="none") # --> MSE provides precise regression BUT sensitive to outliers
        elif self.loss_type == "huber":
            loss = F.huber_loss(logits, labels, reduction="none", delta=1.0) #--> balances between MSE and MAE for data that has outliers/ noise
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")

        # Apply sample weights if provided
        if weights is not None:
            loss = loss * weights
            loss = loss.mean()  # Normalize by mean if weights are unnormalized
        else:
            loss = loss.mean()
        
        return loss

    def _get_sample_weights(self, labels):
        """
        Generate sample weights based on rating values.
        
        
        """
        if self.weight_strategy == "linear":
            # Linear weighting (e.g., emphasize extremes)
            weights = torch.abs(labels - labels.mean()) + 1.0
        elif self.weight_strategy == "inverse":
            # Inverse frequency weighting (if ratings are skewed)
            unique, counts = torch.unique(labels, return_counts=True)
            freq = counts.float() / len(labels)
            weight_map = 1.0 / (freq + 1e-6)  # Avoid division by zero
            weights = torch.tensor([weight_map[(unique == lbl).nonzero().item()] for lbl in labels])
        else:
            weights = None
        
        return weights.to(labels.device) if weights is not None else None



    def compute_metrics(self, eval_preds):
        predictions, labels = eval_preds
        predictions = predictions.squeeze()
        
        # Regression metrics
        mse = mean_squared_error(labels, predictions)
        pearson = pearsonr(labels, predictions)[0] # Pearson correlation coefficient
        
        # Threshold accuracy --> 
        tolerance_acc = (np.abs(predictions - labels) <= 0.5).mean()
        
        return {"mse": mse, "pearson": pearson, "tolerance_acc": tolerance_acc}
    


    
    #TODO: evaluate whether the plotting should be done or whether it is redundant to add them

    # def evaluation_loop(self, *args, **kwargs):
    #     output = super().evaluation_loop(*args, **kwargs)
    #     predictions = output.predictions.squeeze()
    #     labels = output.label_ids
        
    #     # Generate plots (saved to disk or logged to W&B)
    #     plot_distributions(predictions, labels, self.state.epoch)
    #     plot_calibration(predictions, labels)
        
    #     return output

In [13]:
#TODO: debug if this is truly needed...

# add distributioncallback to trianing to evaluate 

# class DistributionCallback(TrainerCallback):
#     def on_evaluate(self, args, state, control, metrics, **kwargs):
#         # Get predictions and labels from the trainer's eval loop
#         eval_results = trainer.evaluate()
#         predictions = eval_results["eval_predictions"]
#         labels = eval_results["eval_labels"]
        
#         # Log histogram to W&B
#         wandb.log({
#             "reward_histogram": wandb.Histogram(predictions),
#             "true_ratings_histogram": wandb.Histogram(labels),
#         })


# def plot_distributions(predictions, labels, epoch):
#     plt.figure(figsize=(10, 4))
#     plt.subplot(1, 2, 1)
#     plt.hist(predictions, bins=20, alpha=0.7, label="Predicted")
#     plt.title("Predicted Rewards")
    
#     plt.subplot(1, 2, 2)
#     plt.hist(labels, bins=20, alpha=0.7, label="True Ratings", color="orange")
#     plt.title("True Ratings")
    
#     plt.savefig(f"distributions_epoch_{epoch}.png")
#     plt.close()

# class PlotCallback(TrainerCallback):
#     def on_evaluate(self, args, state, control, **kwargs):
#         predictions = trainer.predict(test_dataset).predictions.squeeze()
#         labels = test_dataset["ratings"]
#         plot_distributions(predictions, labels, state.epoch)




# def plot_calibration(predictions, labels):
#     """
#     Function to check if the sdfs

    

#     """
#     bin_means = np.linspace(1, 5, num=5)  # For 1-5 ratings
#     bin_centers = []
#     empirical_means = []
    
#     for i in range(len(bin_means) - 1):
#         mask = (labels >= bin_means[i]) & (labels < bin_means[i+1])
#         if mask.sum() > 0:
#             bin_centers.append((bin_means[i] + bin_means[i+1]) / 2)
#             empirical_means.append(predictions[mask].mean())
    
#     plt.plot(bin_centers, empirical_means, marker="o")
#     plt.plot([1, 5], [1, 5], linestyle="--", color="gray")  # Ideal line
#     plt.xlabel("True Rating")
#     plt.ylabel("Predicted Reward")
#     plt.savefig("calibration_plot.png")



### 4. a) Define a custom data collator

In [14]:
# from transformers import DefaultDataCollator

# #TODO: only do this if the labels fix does not work for some reason

# class RewardDataCollator(DefaultDataCollator):
#     def __call__(self, features):

#         ratings = [f.pop("rating") for f in features]  # Removes rating from features temporarily
#         batch = super().__call__(features)
#         # Explicitly ensure rating is included
#         print(features)
#         # Re-inject ratings into the batch
#         batch["rating"] = torch.tensor(ratings, dtype=torch.float32)
#         return batch

## 5. Encode dataset

In [15]:
# if labels are not integers, convert them to integers
def convert_label_to_int(data):
    data["label"] = int(data["label"])
    return data


print(dataset.column_names)
# mao string labels to integers
dataset = dataset.map(convert_label_to_int)  # Assuming 'text' is the column with the text data

print(dataset["label"][:5])  # Check labels

['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id', 'precondition_text', 'precondition_position', 'response_text', 'prompt_config_examples', 'prompt_config_chain_of_thought', 'label', 'additional_feedback']


Map: 100%|██████████| 929/929 [00:00<00:00, 14576.69 examples/s]

[4, 4, 4, 4, 4]





## Comment

1. Needed for feedback extraction: precondition_text, response_text, label(rating feedback extraction)
2. Needed for feedback detection: precondition_text, precondition_position, response_text, label (rating feedback detection)
3. For the precondition position to be found well, it is a crucial for the model to find the precondition text (at least to a recognizable degree) as well, otherwise the precondition is not found at all...

In [None]:
# tokenize queries and answers together to provide proper context to reward model
def tokenize_fn(examples):
    if FEEDBACK_TO_TRAIN_ON == "feedback_extraction":
        print("Tokenizing for feedback extraction")
        combined_texts = [f"{t} {r}" for t, r in zip(examples["precondition_text"], examples["response_text"])]
    elif FEEDBACK_TO_TRAIN_ON == "feedback_detection":
        print("Tokenizing for feedback detection")
        combined_texts = [f"{t} {p} {r}" for t, p, r in zip(examples["precondition_text"], examples["precondition_position"], examples["response_text"])]
    return tokenizer(combined_texts, truncation=True, padding="max_length")

dataset = dataset.map(tokenize_fn, batched=True)

Map: 100%|██████████| 929/929 [00:00<00:00, 3612.76 examples/s]


In [17]:
print(dataset)

Dataset({
    features: ['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id', 'precondition_text', 'precondition_position', 'response_text', 'prompt_config_examples', 'prompt_config_chain_of_thought', 'label', 'additional_feedback', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 929
})


# Split dataset into train, test, eval

In [18]:
# split into train, test and eval sets
train_test_split = dataset.train_test_split(test_size=0.3, seed=42)
eval_test_split = train_test_split["test"].train_test_split(test_size=0.5, seed=42)


final_splits = DatasetDict({
    'train': train_test_split['train'],
    'validation': eval_test_split['train'],
    'test': eval_test_split['test']
})

## 6. Train reward model

In [19]:
# Training arguments
training_args = TrainingArguments(
    output_dir=LORA_CHECKPOINTS_FOLDER,
    eval_strategy='steps',
    save_strategy='steps',
    save_steps=100,
    eval_steps=100,
    save_total_limit=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=3e-4,
    num_train_epochs=20,
    logging_steps=10,
    label_names=["labels"],
    # report_to="none",
    logging_dir="./logs",
    fp16=True,  # Use mixed precision training
    metric_for_best_model="eval_loss", # or "eval_loss"
    greater_is_better=False, # or False if using loss
    # gradient_accumulation_steps=4 # 
)

# Initialize custom trainer
trainer = CustomRewardTrainer(
    model=model,
    args=training_args,
    train_dataset=final_splits['train'],
    eval_dataset=final_splits['validation'],
    # compute_metrics=trainer.compute_metrics,  # Use the custom metrics function
    processing_class=tokenizer,
    loss_type="huber",  # Try "mse" or "huber"
    weight_strategy="linear",  # Try "linear", "inverse", or None
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] # use early stopping since we are sing high amount of epochs
    # data_collator=RewardDataCollator()
)

# # add distributioncallback to trainer TODO: only integrate if relevant
# trainer.add_callback(DistributionCallback())

print(trainer.args.device)

cuda:0


In [20]:
#train model
trainer.train()

Using EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.


Step,Training Loss,Validation Loss
100,0.1626,0.191629
200,0.1524,0.180362
300,0.1576,0.121676
400,0.1042,0.109979
500,0.1026,0.102007
600,0.0972,0.090389
700,0.0968,0.083614
800,0.0954,0.084119


TrainOutput(global_step=820, training_loss=0.2142195125905479, metrics={'train_runtime': 60.5279, 'train_samples_per_second': 214.777, 'train_steps_per_second': 13.547, 'total_flos': 3438110128128000.0, 'train_loss': 0.2142195125905479, 'epoch': 20.0})

In [21]:
# store final model parameters
model.save_pretrained(FINAL_LORA_ADAPTERS)

# Reload saved LoRA adapter for inference 

In [22]:
base_model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=1)
model = PeftModel.from_pretrained(base_model, FINAL_LORA_ADAPTERS)
model = model.merge_and_unload()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
# Evaluate the model on the test set
test_results = trainer.evaluate(eval_dataset=final_splits['test'])
print("Test Results:", test_results)

Test Results: {'eval_loss': 0.0590321347117424, 'eval_runtime': 0.2548, 'eval_samples_per_second': 549.361, 'eval_steps_per_second': 35.316, 'epoch': 20.0}


In [24]:
# evaluate model manually on some test cases
model.to(device)
model.eval()

with torch.no_grad():
    for i in range(20):
        sample = final_splits['test'][i]
        inputs = tokenizer(sample['precondition_text'] + " " + sample['response_text'], return_tensors='pt', truncation=True, padding="max_length").to(device)
        outputs = model(**inputs)
        prediction = outputs.logits.item()
        print(f"Sample {i+1}: Predicted Rating: {prediction}, True Rating: {sample['label']}")


Sample 1: Predicted Rating: 4.429320335388184, True Rating: 5
Sample 2: Predicted Rating: 5.043562889099121, True Rating: 5
Sample 3: Predicted Rating: 4.477137565612793, True Rating: 4
Sample 4: Predicted Rating: 5.0522565841674805, True Rating: 5
Sample 5: Predicted Rating: 4.228446960449219, True Rating: 4
Sample 6: Predicted Rating: 5.122300148010254, True Rating: 5
Sample 7: Predicted Rating: 4.155384063720703, True Rating: 4
Sample 8: Predicted Rating: 5.1412458419799805, True Rating: 5
Sample 9: Predicted Rating: 4.935603141784668, True Rating: 5
Sample 10: Predicted Rating: 5.197958469390869, True Rating: 5
Sample 11: Predicted Rating: 4.745537757873535, True Rating: 5
Sample 12: Predicted Rating: 4.44428825378418, True Rating: 5
Sample 13: Predicted Rating: 5.061437606811523, True Rating: 5
Sample 14: Predicted Rating: 5.102313041687012, True Rating: 5
Sample 15: Predicted Rating: 3.9752120971679688, True Rating: 5
Sample 16: Predicted Rating: 4.981609344482422, True Rating: 5