# Fine Tune reward model from scratch

# TODOs:

#TODO: double-check that labels are not somehow misaligned...

#TODO: check if you need to plot 

1. LoRA learns the position of the low rank adaptation matrix that is needed to finetune a model of a much higher rank

#TODO: double check model performance, generate output, maybe adjust training metrics

## 1. Imports, setup, and global variables

In [1]:
import torch
import pandas as pd
import os
import sys
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Add the parent directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), '..')))

from transformers import TrainingArguments, EarlyStoppingCallback
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from collections import Counter

from datasets import Dataset, DatasetDict, load_from_disk

from peft import LoraConfig, get_peft_model, PeftModel

from utils import parse_ratings, tokenize_fn_with_best_window, tokenize_fn_basic_batched, CustomRewardTrainer, find_best_window, convert_label_to_int


# from nltk.tokenize import sent_tokenize

# load the relevant devices available on the server
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("AVAILABLE_DEVICES")

# Enable expandable CUDA segments
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# load cuda
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print("CUDA is available. Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

  from .autonotebook import tqdm as notebook_tqdm


There are 2 GPU(s) available.
CUDA is available. Using GPU: NVIDIA L40S


In [2]:
# load training variables
FEEDBACK_TO_TRAIN_ON = os.getenv("FEEDBACK_TO_TRAIN_ON")
FEEDBACK_TO_REMOVE = os.getenv("FEEDBACK_TO_REMOVE")
MODEL = os.getenv("REWARD_MODEL")
DATASET = os.getenv("REWARD_DATASET")
TOKENIZE_FN = os.getenv("TOKENIZE_FN")
MAX_LENGTH = os.getenv("MAX_LENGTH")
STRIDE = os.getenv("STRIDE")
LORA_CHECKPOINTS_FOLDER = os.getenv("LORA_CHECKPOINTS_FOLDER")
FINAL_LORA_ADAPTERS = os.getenv("FINAL_LORA_ADAPTERS_FOLDER") + f"_{FEEDBACK_TO_TRAIN_ON}_{TOKENIZE_FN}_{DATASET}"
TOKENIZED_DATA_TRAIN = os.getenv("TOKENIZED_DATA") + f"_{FEEDBACK_TO_TRAIN_ON}_{TOKENIZE_FN}_{DATASET}_{MODEL}_train"
TOKENIZED_DATA_EVAL = os.getenv("TOKENIZED_DATA") + f"_{FEEDBACK_TO_TRAIN_ON}_{TOKENIZE_FN}_{DATASET}_{MODEL}_eval"
TOKENIZED_DATA_TEST = os.getenv("TOKENIZED_DATA") + f"_{FEEDBACK_TO_TRAIN_ON}_{TOKENIZE_FN}_{DATASET}_{MODEL}_test"
DATASET_STRUCTURE = os.getenv("DATASET_STRUCTURE")

REWARD_DATA_PATH = os.getenv("REWARD_DATA_PATH")

if DATASET_STRUCTURE == "determined":
    REWARD_MODEL_TRAIN_DATA_HUMAN = REWARD_DATA_PATH + "/train_human_determined.csv"
    REWARD_MODEL_EVAL_DATA_HUMAN = REWARD_DATA_PATH + "/validation_human_determined.csv"
    REWARD_MODEL_TEST_DATA_HUMAN = REWARD_DATA_PATH + "/test_human_determined.csv"

    REWARD_MODEL_TRAIN_DATA_SYNTH = REWARD_DATA_PATH + "/train_synth_determined.csv"
    REWARD_MODEL_EVAL_DATA_SYNTH = REWARD_DATA_PATH + "/validation_synth_determined.csv"
    REWARD_MODEL_TEST_DATA_SYNTH = REWARD_DATA_PATH + "/test_synth_determined.csv"

elif DATASET_STRUCTURE == "random":
    REWARD_MODEL_TRAIN_DATA_HUMAN = REWARD_DATA_PATH + "/train_human_random.csv"
    REWARD_MODEL_EVAL_DATA_HUMAN = REWARD_DATA_PATH + "/validation_human_random.csv"
    REWARD_MODEL_TEST_DATA_HUMAN = REWARD_DATA_PATH + "/test_human_random.csv"

    REWARD_MODEL_TRAIN_DATA_SYNTH = REWARD_DATA_PATH + "/train_synth_random.csv"
    REWARD_MODEL_EVAL_DATA_SYNTH = REWARD_DATA_PATH + "/validation_synth_random.csv"
    REWARD_MODEL_TEST_DATA_SYNTH = REWARD_DATA_PATH + "/test_synth_random.csv"

## 2. Dataset loading and preprocessing

In [3]:
# # load dataframes
# df_1 = pd.read_csv(FILE_1, sep=";")
# df_5 = pd.read_csv(FILE_5, sep=";")
# df_7 = pd.read_csv(FILE_7, sep=";")
# df_9 = pd.read_csv(FILE_9, sep=";")
# df_10_1 = pd.read_csv(FILE_10_1, sep=";")
# df_10_2 = pd.read_csv(FILE_10_2, sep=";")
# df_synth = pd.read_csv(FILE_SYNTH, sep=";")

# df_human = pd.concat([df_1, df_5, df_7, df_9, df_10_1, df_10_2], ignore_index=True)

#### Re-structure df synthetic to fit in training loop

In [4]:
# print("Synthetic feedback shape:", df_synth.shape)


# # Save the current headers since forgot to store headers in csv file
# old_headers = df_synth.columns.tolist()

# # print("Old headers:", old_headers)

# # Step 2: Insert the headers as the first row
# df_synth.loc[-1] = old_headers # Add headers as a new row
# df_synth.index = df_synth.index + 1 # Shift index
# df_synth = df_synth.sort_index() # Sort index to place the new row at the top


# # Step 3: Assign new headers (optional)
# df_synth.columns = ['file', 
#                     'frame_ID', 
#                     'frame_type', 
#                     'frame_text', 
#                     'precondition_id', 
#                     'precondition_text', 
#                     'precondition_position', 
#                     'response_text', 
#                     'prompt_config_examples', 
#                     'prompt_config_chain_of_thought', 
#                     'feedback_extraction', 
#                     'feedback_detection', 
#                     'additional_feedback',
#                     'synthetic_feedback',
#                 ]

# print(df_synth.columns)


# df_synth['prompt_config_examples'] = (df_synth['prompt_config_examples']                                              
#                                                 .astype(str)
#                                                 .str.strip()
#                                                 .str.lower()
#                                                 .map({'true': True, 'false': False})
# )

# df_synth['prompt_config_chain_of_thought'] = (df_synth['prompt_config_chain_of_thought']
#                                                 .astype(str)
#                                                 .str.strip()
#                                                 .str.lower()
#                                                 .map({'true': True, 'false': False})
# )



# print("Synthetic feedback shape:", df_synth.shape)

In [5]:
if DATASET == "human":
    df_train = pd.read_csv(REWARD_MODEL_TRAIN_DATA_HUMAN, sep=";")
    df_eval = pd.read_csv(REWARD_MODEL_EVAL_DATA_HUMAN, sep=";")
    df_test = pd.read_csv(REWARD_MODEL_TEST_DATA_HUMAN, sep=";")
elif DATASET == "synthetic":
    df_train = pd.read_csv(REWARD_MODEL_TRAIN_DATA_SYNTH, sep=";")
    df_eval = pd.read_csv(REWARD_MODEL_EVAL_DATA_SYNTH, sep=";")
    df_test = pd.read_csv(REWARD_MODEL_TEST_DATA_SYNTH, sep=";")
    
    
df_train.shape
df_train.columns

Index(['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id',
       'precondition_text', 'precondition_position', 'response_text',
       'prompt_config_examples', 'prompt_config_chain_of_thought',
       'feedback_extraction', 'feedback_detection', 'additional_feedback'],
      dtype='object')

### 2. a) Parse ratings to numeric values for MSE Loss

In [6]:
df_train[FEEDBACK_TO_TRAIN_ON] = [parse_ratings(feedback) for feedback in df_train[FEEDBACK_TO_TRAIN_ON]]
df_eval[FEEDBACK_TO_TRAIN_ON] = [parse_ratings(feedback) for feedback in df_eval[FEEDBACK_TO_TRAIN_ON]]
df_test[FEEDBACK_TO_TRAIN_ON] = [parse_ratings(feedback) for feedback in df_test[FEEDBACK_TO_TRAIN_ON]]
print("Parsed feedback for extraction:", df_train[FEEDBACK_TO_TRAIN_ON][:5])

Parsed feedback for extraction: 0    2
1    2
2    2
3    2
4    3
Name: feedback_extraction, dtype: object


### 2. b) look at biases in feedback to train on for weights in RL loop --> feedback_detection is very biased through way it was collected, so gets less weight overall...

In [7]:
df_train[FEEDBACK_TO_TRAIN_ON].value_counts()

feedback_extraction
0    354
3    153
1     84
2     83
Name: count, dtype: int64

### 2. c) keep only relevant feedback column

In [8]:
dataset_train = Dataset.from_pandas(df_train)
dataset_eval = Dataset.from_pandas(df_eval)
dataset_test = Dataset.from_pandas(df_test)

print(dataset_train)
print(FEEDBACK_TO_TRAIN_ON) 

datasets = [dataset_train, dataset_eval, dataset_test]

Dataset({
    features: ['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id', 'precondition_text', 'precondition_position', 'response_text', 'prompt_config_examples', 'prompt_config_chain_of_thought', 'feedback_extraction', 'feedback_detection', 'additional_feedback'],
    num_rows: 674
})
feedback_extraction


In [18]:
datasets= [dataset.remove_columns([FEEDBACK_TO_REMOVE]) for dataset in datasets]
datasets = [dataset.rename_column(FEEDBACK_TO_TRAIN_ON, "label") for dataset in datasets]

print(datasets[0]["label"])

['2', '2', '2', '2', '3', '2', '2', '2', '2', '3', '0', '0', '2', '2', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '3', '3', '0', '0', '2', '0', '0', '0', '0', '0', '0', '0', '0', '0', '3', '3', '3', '3', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '2', '3', '3', '3', '3', '3', '3', '3', '3', '3', '3', '0', '3', '3', '3', '3', '3', '3', '2', '3', '3', '3', '3', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '2', '3', '1', '1', '3', '0', '2', '2', '1', '3', '0', '0', '2', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '2', '0', '1', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '2', '3', '1', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '2', '2', '2', '2', '1', '3', '0', '0', '2', '2', '2', '0', '2', '2', '2', '3', '2', '2', '1', '0', '3', '2', '2', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '1', '1', '1',

## 3. Load model with LoRA layer

In [10]:
# Load the model and the tokenizer
model_id = MODEL 
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) # num_labels = 1 since we want to prodict a single scalar (the rating)

# Comment: Automodel for sequence classification with num_labels=1 already has a regression head
print(model)


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ModernBertForSequenceClassification(
  (model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
          (rotary_emb): ModernBertRotaryEmbedding()
          (Wo): Linear(in_features=768, out_features=768, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=768, out_features=2304, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=1152, out_features=768, bias=False)
        )
      

In [11]:
print(tokenizer.model_max_length)
print(model.config.max_position_embeddings)

8192
8192


In [12]:
# Define LoRA config


if MODEL == "answerdotai/ModernBERT-base":

    lora_config = LoraConfig(
    r=8,           # Rank of the LoRA matrices (smaller = less memory)
    lora_alpha=16, # Scaling factor (higher = stronger adaptation)
    target_modules=["Wqkv", "Wo"], # Apply LoRA to attention layers
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS"  # classify each answer 
    )
else:
    lora_config = LoraConfig(
    r=8,           # Rank of the LoRA matrices (smaller = less memory)
    lora_alpha=16, # Scaling factor (higher = stronger adaptation)
    target_modules=["query", "key", "value"], # Apply LoRA to attention layers
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_CLS"  # classify each answer 
    )
    

# Freeze base model
for param in model.base_model.parameters():
    param.requires_grad = False



# Convert the model to a PEFT (LoRA) model
model = get_peft_model(model, lora_config)
# model.gradient_checkpointing_enable()
model.print_trainable_parameters()  # Check trainable params (~0.1% of full model)


trainable params: 1,149,697 || all params: 150,755,330 || trainable%: 0.7626


In [13]:
# Test tokenizer
sample_data = ["What is the capital of France?", "What is the largest capital in the world?"]
tokenizer(sample_data, padding=True, truncation=True, max_length=512)

{'input_ids': [[50281, 1276, 310, 253, 5347, 273, 6181, 32, 50282, 50283, 50283], [50281, 1276, 310, 253, 6253, 5347, 275, 253, 1533, 32, 50282]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

## 5. Encode dataset

In [19]:
print(datasets[0].column_names)
# mao string labels to integers
datasets = [dataset.map(convert_label_to_int) for dataset in datasets]

print(datasets[0]["label"][:5])  # Check labels
print(datasets[0]["response_text"][:5])  # Check labels

['file', 'frame_ID', 'frame_type', 'frame_text', 'precondition_id', 'precondition_text', 'precondition_position', 'response_text', 'prompt_config_examples', 'prompt_config_chain_of_thought', 'label', 'additional_feedback']


Map: 100%|██████████| 674/674 [00:00<00:00, 1818.20 examples/s]
Map: 100%|██████████| 127/127 [00:00<00:00, 7866.10 examples/s]
Map: 100%|██████████| 128/128 [00:00<00:00, 9701.32 examples/s]

[2, 2, 2, 2, 3]
['1. Subfact: Onze Minister\n                2. Positie: Artikel 1, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                3. Subfact: Onze Minister\n                4. Positie: Artikel 8, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                5. Subfact: Onze Minister\n                6. Positie: Artikel 14, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                7. Subfact: Onze Minister\n                8. Positie: Artikel 16, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                9. Subfact: Onze Minister\n                10. Positie: Artikel 17, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                11. Subfact: Onze Minister\n                12. Positie: Artikel 17a, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                13. Subfact: Onze Minister\n                14. Positie: Artikel 26, sectie 1 IN Vreemdelingenwet geldig vanaf 2024\n                15. Subfact: Onze Minister\n                16. Positie: Arti




## Comment

1. Needed for feedback extraction: precondition_text, response_text, label(rating feedback extraction)
2. Needed for feedback detection: precondition_text, precondition_position, response_text, label (rating feedback detection)
3. For the precondition position to be found well, it is a crucial for the model to find the precondition text (at least to a recognizable degree) as well, otherwise the precondition is not found at all...

In [20]:
# Code to test bestw indow function

test_text = """
        Titel: De Weg Door Het Leven

Het leven is een reis vol onverwachte wendingen, een pad dat zich zelden rechtlijnig ontvouwt. Vanaf het moment dat we onze eerste ademhaling nemen, worden we ondergedompeld in een wereld die we nog moeten leren begrijpen. Als kind lijkt alles eenvoudig: lachen, spelen, ontdekken. Maar naarmate we ouder worden, beginnen de lagen van complexiteit zich op te stapelen. We leren dat mensen niet altijd zeggen wat ze bedoelen, dat keuzes consequenties hebben, en dat geluk soms vluchtiger is dan we zouden willen.

In de vroege ochtenden, wanneer de zon net boven de horizon verschijnt en de wereld nog stil is, denken velen na over hun plaats in het grotere geheel. Sommigen vragen zich af of ze de juiste keuzes hebben gemaakt, of ze trouw zijn gebleven aan zichzelf. Anderen proberen simpelweg de dag door te komen, met hoop op iets beters. In die momenten van stilte komt vaak het besef dat, hoewel we allemaal verschillende paden bewandelen, we één waarheid delen: dat het leven, ondanks al onze inspanningen en verlangens, nooit gemakkelijk is. Of, zoals mijn grootmoeder het ooit zei terwijl ze haar handen vouwde na een lange dag werken op het land: “Je moet weten, kind, het leven is nooit gemakkelijk, maar het is wel de moeite waard.”

We worden gevormd door onze ervaringen, door de mensen die we ontmoeten en de obstakels die we overwinnen. Elke fout, elk succes, elke traan en elke glimlach draagt bij aan wie we zijn. En toch, ondanks al die ervaringen, blijven we zoeken. Naar betekenis. Naar verbinding. Naar rust.

Soms lijkt het alsof de wereld te snel draait. Technologie verandert ons leven in een razend tempo, verwachtingen worden hoger, en de druk om te presteren neemt toe. In die chaos vergeten we soms stil te staan. Te ademen. Te voelen. Maar juist in die momenten van rust vinden we vaak de antwoorden die we zo hard nodig hebben.

De liefde, bijvoorbeeld, is een van de krachtigste krachten die ons voortdrijft. Liefde voor een partner, een kind, een vriend, of zelfs voor een passie. Het is die liefde die ons helpt vol te houden wanneer alles tegenzit. Die ons eraan herinnert waarom we begonnen zijn, waarom we blijven proberen.

En dan is er verlies. Een onvermijdelijk onderdeel van het leven. We verliezen mensen, kansen, dromen. Maar in dat verlies schuilt ook groei. We leren loslaten, opnieuw beginnen, sterker worden. Het is pijnlijk, ja, maar ook noodzakelijk.

Wanneer we terugkijken op ons leven, zijn het zelden de materiële zaken die we herinneren. Het zijn de momenten. De gesprekken bij kaarslicht. De wandelingen in de regen. De onverwachte lachbuien. De stilte van een gedeeld verdriet. Die momenten vormen de essentie van ons bestaan.

Dus ja, het leven is vol uitdagingen. Het is rommelig, verwarrend, soms oneerlijk. Maar het is ook prachtig, rijk aan betekenis, en gevuld met kansen om te groeien, te leren en lief te hebben. En misschien is dat wel de grootste les van allemaal: dat we, ondanks alles, blijven kiezen voor hoop. Voor verbinding. Voor het leven zelf.
        """


test_ground_truth = "Het leven is nooit gemakkelijk."

print(find_best_window(test_text, test_ground_truth, device, tokenizer))

# Works as expectd, I am impressed.


        Titel: De Weg Door Het Leven

Het leven is een reis vol onverwachte wendingen, een pad dat zich zelden rechtlijnig ontvouwt. Vanaf het moment dat we onze eerste ademhaling nemen, worden we ondergedompeld in een wereld die we nog moeten leren begrijpen. Als kind lijkt alles eenvoudig: lachen, spelen, ontdekken. Maar naarmate we ouder worden, beginnen de lagen van complexiteit zich op te stapelen. We leren dat mensen niet altijd zeggen wat ze bedoelen, dat keuzes consequenties hebben, en dat geluk soms vluchtiger is dan we zouden willen.

In de vroege ochtenden, wanneer de zon net boven de horizon verschijnt en de wereld nog stil is, denken velen na over hun plaats in het grotere geheel. Sommigen vragen zich af of ze de juiste keuzes hebben gemaakt, of ze trouw zijn gebleven aan zichzelf. Anderen proberen simpelweg de dag door te komen, met hoop op iets beters. In die momenten van stilte komt vaak het besef dat, hoewel we allemaal verschillende paden bewandelen, we één waarheid 

In [None]:
if not os.path.exists(TOKENIZED_DATA_TRAIN):
    if TOKENIZE_FN == "best_window":
        datasets = [dataset.map(tokenize_fn_with_best_window, 
                                fn_kwargs={"feedback_train": FEEDBACK_TO_TRAIN_ON, 
                                            "tokenizer": tokenizer, 
                                            "max_length": int(MAX_LENGTH), 
                                            "stride": int(STRIDE),
                                            "device": device
                                            },
                                batched=False) for dataset in datasets]
    else:
        datasets = [dataset.map(tokenize_fn_basic_batched, 
                                fn_kwargs={"feedback_train": FEEDBACK_TO_TRAIN_ON, 
                                            "tokenizer": tokenizer 
                                            },
                                batched=True) for dataset in datasets]
    
    
    datasets[0].save_to_disk(TOKENIZED_DATA_TRAIN)
    datasets[1].save_to_disk(TOKENIZED_DATA_EVAL)
    datasets[2].save_to_disk(TOKENIZED_DATA_TEST)
else:
    datasets[0] = load_from_disk(TOKENIZED_DATA_TRAIN)
    datasets[1] = load_from_disk(TOKENIZED_DATA_TEST)
    datasets[2] = load_from_disk(TOKENIZED_DATA_EVAL)

Map: 100%|██████████| 674/674 [00:02<00:00, 233.69 examples/s]
Map: 100%|██████████| 127/127 [00:00<00:00, 163.04 examples/s]
Map: 100%|██████████| 128/128 [00:00<00:00, 155.53 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 674/674 [00:00<00:00, 50989.52 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 127/127 [00:00<00:00, 14139.85 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 128/128 [00:00<00:00, 19254.42 examples/s]


In [23]:
print(Counter(datasets[0]['file']))

Counter({'Interpretatie_Vw_over_besluiten_op_aanvragen_voor_een_verblijfsvergunning_regulier_bepaalde_tijd.json': 432, 'rijksbegrotingscyclus.json': 242})


## 6. Train reward model

In [24]:
# Training arguments
training_args = TrainingArguments(
    output_dir=LORA_CHECKPOINTS_FOLDER,
    eval_strategy='steps',
    save_strategy='steps',
    save_steps=10,
    eval_steps=10,
    save_total_limit=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=3e-4,
    num_train_epochs=20,
    logging_steps=10,
    label_names=["labels"],
    # report_to="none",
    logging_dir="./logs",
    fp16=True,  # Use mixed precision training
    metric_for_best_model="eval_loss", # or "eval_loss"
    greater_is_better=False, # False if using loss
    # gradient_accumulation_steps=4, # 
    # weight_decay=0.01
)

# Initialize custom trainer
trainer = CustomRewardTrainer(
    model=model,
    args=training_args,
    train_dataset=datasets[0],
    eval_dataset=datasets[1],
    # compute_metrics=trainer.compute_metrics,  # Use the custom metrics function
    processing_class=tokenizer,
    loss_type="huber",  # "mse" or "huber"
    weight_strategy="linear",  # "linear", "inverse", or None
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=50)] # use early stopping since we are sing high amount of epochs
    # data_collator=RewardDataCollator()
)

print(trainer.args.device)

[2025-06-18 15:19:41,005] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/jacques.furst/miniconda3/envs/RL/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/jacques.furst/miniconda3/envs/RL/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


[2025-06-18 15:19:41,386] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
cuda:0


In [25]:
# if not os.path.exists(FINAL_LORA_ADAPTERS):
# train model
trainer.train()
# # store final model parameters
model.save_pretrained(FINAL_LORA_ADAPTERS)

# #TODO: not storing this properly I suppose, need to change

[34m[1mwandb[0m: Currently logged in as: [33mjacques-furst123[0m ([33mjacques-furst123-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 97, in _worker
    output = module(*input, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/peft/peft_model.py", line 1559, in forward
    return self.base_model(
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 193, in forward
    return self.model.forward(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1166, in forward
    outputs = self.model(
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 881, in forward
    layer_outputs = encoder_layer(
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 548, in forward
    self.compiled_mlp(hidden_states)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
    return fn(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 523, in compiled_mlp
    @torch.compile(dynamic=True)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1209, in forward
    return compiled_fn(full_args)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 315, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100, in g
    return f(*args)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1937, in forward
    fw_outs = call_func_at_runtime_with_args(
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
    return compiled_fn(runtime_args)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
    outs = compiled_fn(args)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 460, in __call__
    return self.current_callable(inputs)
  File "/home/jacques.furst/miniconda3/envs/RL/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2404, in run
    return model(new_inputs)
  File "/tmp/torchinductor_jacques.furst/we/cwe323lhnvrbiewv5s226esrfjevpk74mapz5wkqvyjb6sgwx4cw.py", line 249, in call
    buf11 = empty_strided_cuda((s0, s1, 1152), (1152*s1, 1152, 1), torch.float32)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 576.00 MiB. GPU 0 has a total capacity of 44.40 GiB of which 114.31 MiB is free. Including non-PyTorch memory, this process has 44.28 GiB memory in use. Of the allocated memory 43.59 GiB is allocated by PyTorch, and 19.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


# Reload saved LoRA adapter for inference 

In [None]:
base_model_test = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=1)
new_model = PeftModel.from_pretrained(base_model_test, FINAL_LORA_ADAPTERS)
# new_model = new_model.merge_and_unload()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Initialize trainer with new model
trainer = CustomRewardTrainer(
    model=new_model,
    args=training_args,
    train_dataset=datasets[0],
    eval_dataset=datasets[1],
    # compute_metrics=trainer.compute_metrics,  # Use the custom metrics function
    processing_class=tokenizer,
    loss_type="huber",  # "mse" or "huber"
    weight_strategy="linear",  # "linear", "inverse", or None
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=50)] # use early stopping since we are sing high amount of epochs
    # data_collator=RewardDataCollator()
)

In [None]:
# Evaluate the model on the test set
test_results = trainer.evaluate(eval_dataset=datasets[2])
print("Test Results:", test_results)

Test Results: {'eval_loss': 0.5358006358146667, 'eval_model_preparation_time': 0.008, 'eval_runtime': 0.2558, 'eval_samples_per_second': 547.295, 'eval_steps_per_second': 35.183}


In [None]:
# evaluate model manually on some test cases
model.to(device)
model.eval()

#TODO: change tokenization function here!

with torch.no_grad():
    for i in range(20):
        sample = datasets[2][i]
        inputs = tokenizer(sample['precondition_text'] + " " + sample['response_text'], return_tensors='pt', truncation=True, padding="max_length").to(device)
        outputs = model(**inputs)
        prediction = outputs.logits.item()
        print(f"Sample {i+1}: Predicted Rating: {prediction}, True Rating: {sample['label']}")


Sample 1: Predicted Rating: 2.779296875, True Rating: 1
Sample 2: Predicted Rating: -0.07684326171875, True Rating: 0
Sample 3: Predicted Rating: 2.345703125, True Rating: 3
Sample 4: Predicted Rating: -0.010162353515625, True Rating: 0
Sample 5: Predicted Rating: 2.9140625, True Rating: 3
Sample 6: Predicted Rating: 0.10626220703125, True Rating: 0
Sample 7: Predicted Rating: 2.4453125, True Rating: 1
Sample 8: Predicted Rating: 0.07196044921875, True Rating: 0
Sample 9: Predicted Rating: 0.056854248046875, True Rating: 0
Sample 10: Predicted Rating: -0.040679931640625, True Rating: 0
Sample 11: Predicted Rating: 0.053863525390625, True Rating: 0
Sample 12: Predicted Rating: 0.17333984375, True Rating: 0
Sample 13: Predicted Rating: -0.03515625, True Rating: 0
Sample 14: Predicted Rating: 1.3251953125, True Rating: 1
Sample 15: Predicted Rating: 1.2685546875, True Rating: 0
Sample 16: Predicted Rating: 0.18017578125, True Rating: 0
Sample 17: Predicted Rating: 2.4453125, True Rating: 