In [None]:
import sys

!{sys.executable} -m pip install --no-cache-dir torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu124
!{sys.executable} -m pip install --no-cache-dir transformers==4.42.3 tqdm numpy
!{sys.executable} -m pip install --no-cache-dir bitsandbytes==0.43.3 datasets==3.0.1 wandb
!{sys.executable} -m pip install --no-cache-dir openai
!{sys.executable} -m pip install python-dotenv
!{sys.executable} -m pip install accelerate
!{sys.executable} -m pip install requests

In [None]:
!wget https://conceptnet.s3.amazonaws.com/downloads/2019/numberbatch/numberbatch-en-19.08.txt.gz
!gunzip -f numberbatch-en-19.08.txt.gz

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from openai import OpenAI
import os
from dotenv import load_dotenv
import json
import re
import time
import numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    PreTrainedModel,
    PreTrainedTokenizer
)
from tqdm import tqdm
import wandb
import requests

load_dotenv()

# Disable ONLY console log capture
os.environ["WANDB_CONSOLE"] = "off"

# Disable code saving (keeps runs clean)
os.environ["WANDB_DISABLE_CODE"] = "true"

# Disable system metrics (saves overhead)
os.environ["WANDB_DISABLE_SERVICE"] = "true"

# OPTIONAL: Stop wandb from watching model gradients
os.environ["WANDB_WATCH"] = "false"

wandb.init(project="project-adam", name="colors-gpt2-pgsrm")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

bot_id = os.getenv("BOT_ID")
chat_id = os.getenv("CHAT_ID")


In [None]:


class PPOTrainer:

  def __init__(
      self,
      actor_model: PreTrainedModel,
      ref_model: PreTrainedModel,
      tokenizer: PreTrainedTokenizer,
      batch_size: int,
      actor_learning_rate: float = 1e-5,
      critic_learning_rate: float = 1e-5,
      clip_range: float = 0.2,
      value_coef: float = 0.5,
      entropy_coef: float = 0.01,
      kl_coef: float = 0.05,
      target_kl: float = 0.05,
      max_grad_norm: float = 1.0,
      device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):

    self.actor = actor_model.to(device)
    self.ref = ref_model.to(device)
    self.tokenizer = tokenizer
    self.device = device

    # freeze the reference model
    self.ref.eval()
    for param in self.ref.parameters():
      param.requires_grad = False


    # here, we are creating a single perceptron for the critic head
    # this perceptron maps the vector representing the "meaning" of the action (vector of size hidden_state; gpt-medium is 1024), to a single output node
    # this output node is the predicted reward for the action
    hidden_size = self.actor.config.hidden_size
    self.critic = nn.Linear(hidden_size, 1).to(device).float()

    self.clip_range = clip_range
    self.value_coef = value_coef
    self.entropy_coef = entropy_coef
    self.kl_coef = kl_coef
    self.target_kl = target_kl
    self.batch_size = batch_size

    # max grad norm is basically the cumulative max for gradients at each training step
    # this is sort of like clipping, but for the cumulative gradients calculated at each step
    self.max_grad_norm = max_grad_norm

    # we set up 2 different optimizers for the actor and critic, because both have different loss functions, and function separately
    # for the actor, the loss function is negative log-loss
    # for the critic, the loss function is MSE
    self.actor_optimizer = AdamW(self.actor.parameters(), lr = actor_learning_rate)
    self.critic_optimizer = AdamW(self.critic.parameters(), lr = critic_learning_rate)

  # ---- Helper Functions ---- #
  @staticmethod
  # this is the softmax function, converting vector of logits into a vector log probability distribution
  def logprobs_from_logits(logits, labels):
    # logprobs shape: (batch_size, seq_len - 1, vocab_size)
    logprobs = F.log_softmax(logits, dim = -1) # this converts logit vector for all tokens into log prob distribution. dim = - 1 specifies to do this calculation along the "last" axis, which in this case is the vocabulary dimension, i.e. the vector of logits

    # labels.unsqueeze(-1) adds a new dimension to the labels shape (batch_size, seq_len - 1) => (batch_size, seq_len - 1, 1)
    # gather() picks values from the logprobs object, along the -1 dimension (which is the vocab vector), based on the labels object (batch_size, seq_len - 1, 1)
    # basically, the output is of shape (batch_size, seq_len - 1), where we have only the picked token's log prob value, for all tokens in each batch
    return logprobs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)

  @staticmethod
  # this is used to zero out non-response tokens within our loss function
  # at the start of the step function, we define a mask of shape (batch_size, seq_len), with 0's for the prompt and 1's for the response
  # this function zero's out all the prompt tokens in the loss function, so that we only optimize the actor based on it's response
  def masked_mean(values, mask):
    return (values * mask).sum() / mask.sum()



  # ---- Training Function ----
  def step(self, prompts, responses, rewards, average_reward):
    """
    Perform one PPO optimization step.
    Inputs:
      prompts:   [batch_size, prompt_len]
      responses: [batch_size, response_len],
      rewards:   [batch_size, 1]
    """
    self.actor.train()

    if self.batch_size != prompts.size(0):
      print("batch size must match number of prompts")
      return

    # we concatenate the prompt tokens and the response tokens, so now the structure is [batch_size, prompt_len + response_len]
    # this full sequence is what is fed into the transformer again
    input_ids = torch.cat([prompts, responses], dim = 1).to(self.device)

    # here we create a mask, mapping non-padded response tokens to 1, and everything else to 0
    # this is such that when we set up our loss function, we only focus on the response of the actor
    # example: input_ids[0] = [11, 22, 33, 44, 55, 66, 77, 88, 0, 0], where pad_token_id = 0, prompt_len = 5
    pad_id = self.tokenizer.pad_token_id
    prompt_len = prompts.size(1)
    input_mask = (input_ids != pad_id).long() # input_mask[0] = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]. just sets all non pad tokens to 1, and pad tokens to 0
    mask = torch.zeros_like(input_ids) # mask[0] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    mask[:, prompt_len:] = input_mask[:, prompt_len:] # mask[0] = [0, 0, 0, 0, 0, 1, 1, 1, 0, 0]
    # so now, mask only weights the response tokens which are not pad tokens
    # shape is (batch_size, seq_len)


    # ---- Forward Pass ----
    with torch.no_grad(): # torch.no_grad just saves compute time, because we are only doing inference for ref model, so gradient data is unnecessary

      # here, we generate a tensor of size (batch_size, seq_len, vocab_size)
      # ref_logits represents the logits for each token in the sequence, that which predict the next token, at it's position within the sequence
      # using the reference model
      ref_logits = self.ref(input_ids, attention_mask = input_mask).logits # we use input_mask as the attention mask, so that the model only focuses on non-padding tokens

    # here, we generate a tensor of size (batch_size, seq_len, vocab_size)
    # logits represents the logits for each token in the sequence, that which predict the next token, at it's position within the sequence
    # we also want the list of all the hidden states from every layer within the actor model transformer
    # using the trainable actor model
    actor_out = self.actor(input_ids, attention_mask = input_mask, output_hidden_states = True) # we use input_mask as the attention mask, so that the model only focuses on non-padding tokens
    logits = actor_out.logits

    # hidden states is a tensor of size (batch_size, seq_len, hidden_size)
    # it basically represents the final hidden representation ("meaning vector") for each token
    # shape (batch_size, seq_len, hidden_size)
    hidden_states = actor_out.hidden_states[-1]

    # ---- Log probabilities & values ----
    # we calculate the log probability distributions for the actor and reference models

    # logits[:, :-1] is the logits for all tokens in the sequence, for all items in the batch; omit last logit, because last token has nothing to predict
    # input_ids[:, 1:] is the actual tokens within the sequence, for all items in the batch; skip the first one, in order to compare to logits
    # basically, we are comparing the predicted logits of the preceding token, with the actual proceeding token
    logprobs_actor = self.logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])
    logprobs_ref = self.logprobs_from_logits(ref_logits[:, :-1], input_ids[:, 1:])

    # here we generate the critic prediction for the reward of the prompt-response pairs
    # we want the generate the reward prediction based on the hidden state of ONLY the final non-padding response token for each sequence
    rev_mask = torch.flip(mask, dims=[1]) # first we flip our mask left-right for each sequence, so that we can get the index of the final non-padding response token
    last_nonpad_from_end = rev_mask.float().argmax(dim=1) # returns the first occurence of 1 within the reversed mask. i.e., the index of the final non-padding response token
    seq_len = mask.size(1)
    last_nonpad_indices = (seq_len - 1) - last_nonpad_from_end # this is of shape (batch_size), where we have the index of the final non-padding response token for each batch
    batch_indices = torch.arange(hidden_states.size(0), device = hidden_states.device)
    last_hidden = hidden_states[batch_indices, last_nonpad_indices, :] # here we select all batches, only the final non-padding response token for each batch, and the full hidden state for this final token
    values = self.critic(last_hidden.detach().float()).squeeze(-1) # here we generate the reward prediction based on only the final hidden state per sequence

    # keep rewards scalar per sequence
    # rewards is of shape (batch_size, 1), so here we convert it to a 1D vector with length batch_size
    rewards = rewards.view(-1).float().to(self.device)
    values = values.view(-1)

    # ---- Advantage computation ----
    # we compute the z-score advantage values in order to keep stability in param updates
    # without the z-score, if we were to just use the raw advantage values, it might make training unstable due to suddent "huge" advantage values
    advantages = rewards - values.detach()
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # broadcast advantage to match token dimension for policy loss
    # advantages is of shape (batch_size, seq_len - 1)
    advantages = advantages.unsqueeze(1).expand_as(logprobs_actor)

    # ---- PPO policy loss ----
    # here we are directly mirroring the policy loss function for PPO
    # negative log probs, with log prob ratio for stability, and clipping for stability
    # we are also zeroing out the non-response token positions, because we only want to optimize based on the actor response
    ratios = torch.exp(logprobs_actor - logprobs_ref)
    unclipped = ratios * advantages
    clipped = torch.clamp(ratios, 1 - self.clip_range, 1 + self.clip_range) * advantages
    # here we are comparing clipped n to mask n + 1, because the nth logprob refers to the n+1'th token
    policy_loss = -1 * self.masked_mean(torch.min(unclipped, clipped), mask[:, 1:])

    # ---- Value loss ----
    # here we are directly mirroring the value loss function for PPO
    # mean squared error
    value_loss = F.mse_loss(values, rewards)

    # ---- Entropy ----
    # here we are directly mirroring the entropy equation for PPO
    # we don't want to use logprob_actor, because this only represents the logprobs for the CHOSEN token. we want the logprob distributions for all tokens at each step in the sequence
    logprobs_full = F.log_softmax(logits[:, :-1], dim = -1) # shape (batch_size, seq_len - 1, vocab_size)
    probs = logprobs_full.exp() # shape (batch_size, seq_len - 1, vocab_size)
    token_entropy = -1 * (probs * logprobs_full).sum(dim = -1) # shape (batch_size, seq_len - 1)
    entropy = self.masked_mean(token_entropy, mask[:, 1:]) # scalar

    # ---- KL ----
    # here we are directly mirroring the KL equation for PPO
    kl = self.masked_mean(logprobs_actor - logprobs_ref, mask[:, 1:])

    # ---- Actor Optimization ----
    self.actor_optimizer.zero_grad() # resets all stored gradients for the model; each PPO update step should reflect only the current batch's loss
    actor_loss = policy_loss - self.entropy_coef * entropy + self.kl_coef * kl # build the total loss function for actor; minimize policy loss, maximize entropy, minimize kl drift
    actor_loss.backward() # computes the gradients of actor_loss, with respect to all the parameters in the actor model
    torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) # clips the gradients so that their total does not exceed max_grad_norm
    self.actor_optimizer.step() # uses AdamW, a type of gradient descent, to update the parameters of the actor model

    # ---- Critic Optimization ----
    self.critic_optimizer.zero_grad() # resets all stored gradients for the model; each PPO update step should reflect only the current batch's loss
    critic_loss = value_loss * self.value_coef # build the total loss function for critic; minimize value loss
    critic_loss.backward() # computes the gradients of critic_loss, with respect to all the parameters in the critic head
    torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm) # clips the gradients so that their total does not exceed max_grad_norm
    self.critic_optimizer.step() # uses AdamW, a type of gradient descent, to update the parameters of the critic head

    # ---- Adaptive KL ----
    # kl: current average kl divergence between the actor and the reference
    # target_kl: the desired kl level, i.e. how much divergence we are okay with
    # kl_coef: the penalty strength applied during actor loss; how much we want to prevent drift during optimization
    # we want the actor to improve rewards, but not diverge too far from the reference model
    # high kl_coef => stronger penalty, so actor stays close to reference. low kl_coef => weaker penalty, so actor is allowed to explore more
    # however, the ideal kl_coef changes dynamically, because different optimization steps result in different levels of updates
    # if the current kl divergence is too large (1.5x our target), we increase the penalty
    # if the current kl divergence is too small (1/1.5 our target), we decrease the penalty
    with torch.no_grad():
      if kl.item() > 1.5 * self.target_kl:
        self.kl_coef *= 1.5
      elif kl.item() < (self.target_kl / 1.5):
        self.kl_coef /= 1.5

    # ---- Logging ----
    stats = {
        "policy_loss": policy_loss.item(),
        "value_loss": value_loss.item(),
        "entropy": entropy.item(),
        "kl": kl.item(),
        "kl_coef": self.kl_coef,
        "actor_loss": actor_loss.item(),
        "critic_loss": critic_loss.item(),
        "average_reward": average_reward
    }
    wandb.log(stats)

    return stats








In [None]:


def load_numberbatch(path="numberbatch-en-19.08.txt"):
    embeddings = {}
    with open(path, "r", encoding="utf-8") as f:
        header = True
        for line in f:
            if header:
                header = False
                continue

            parts = line.rstrip().split(" ")
            word = parts[0]
            vec = np.array(parts[1:], dtype=np.float32)

            if vec.shape[0] == 300:
                embeddings[word] = vec

    return embeddings


numberbatch = load_numberbatch()
print("loaded:", len(numberbatch))


In [None]:


class PGSRM:

  def __init__(self, parent_model, metric, task, max_retries=3, retry_delay=2):
    self.client = OpenAI(api_key=OPENAI_API_KEY)
    self.parent_model = parent_model
    self.metric = metric  # "cosine" or "euclidean"
    self.task = task
    self.max_retries = max_retries
    self.retry_delay = retry_delay
    self.log_file = "episode_logs.jsonl"

  # ---------------------------------------------------------
  def safe_api_call(self, func, *args, **kwargs):
    for attempt in range(1, self.max_retries + 1):
      try:
        return func(*args, **kwargs)
      except Exception as e:
        print(f"[Warning] API call failed (attempt {attempt}/{self.max_retries}): {e}")
        if attempt < self.max_retries:
          wait_time = self.retry_delay * (2 ** (attempt - 1))
          print(f"Retrying in {wait_time:.1f} seconds...")
          time.sleep(wait_time)
        else:
          print("[Error] Max retries reached. Returning None.")
          return None

  # ---------------------------------------------------------
  def parent_generate(self, input_text):
    system_prompt = f"""
    {self.task}
    You will receive the input data as the value for the INPUT key.
    You must provide your response in a JSON format, with OUTPUT as the only key, and your response as the value.
    Your response must be in this format: {{"OUTPUT": "<your response>"}}.
    """

    input_prompt = f"""
    INPUT: "{input_text}"
    YOUR RESPONSE =>
    """

    response = self.safe_api_call(
      self.client.chat.completions.create,
      model=self.parent_model,
      messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": input_prompt},
      ],
      temperature=0.0,
    )

    if response is None:
      print("[Error] Parent model failed to generate response after all retries.")
      return ""

    raw_text = response.choices[0].message.content.strip()

    try:
      parsed = json.loads(raw_text)
      parent_text = parsed.get("OUTPUT", "").strip()
    except json.JSONDecodeError:
      if "OUTPUT" in raw_text:
        match = re.search(r'"?OUTPUT"?\s*[:=]\s*"([^"]+)"', raw_text)
        parent_text = match.group(1).strip() if match else raw_text
      else:
        parent_text = raw_text

    return parent_text

  # ---------------------------------------------------------
  def embed(self, text, is_euclidean):
    text = text.lower()
    text = re.sub(r"[^a-z\s]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()

    words = text.lower().split()

    vecs = []
    for word in words:
      if word in numberbatch:
        vecs.append(numberbatch[word])

    if not vecs:
      return torch.zeros(300)

    mean_vec = np.mean(vecs, axis = 0)
    vec = torch.tensor(mean_vec, dtype = torch.float32)

    if is_euclidean:
      vec = F.normalize(vec, dim = 0)

    return vec.to("cpu")

  # ---------------------------------------------------------
  def get_reward(self, input_text, child_output, episode, is_test: bool = False):
    parent_output = self.parent_generate(input_text)

    is_euclidean = (self.metric == "euclidean")

    parent_vec = self.embed(parent_output.strip(), is_euclidean)
    child_vec = self.embed(child_output.strip(), is_euclidean)

    if self.metric == "cosine":
      sim = F.cosine_similarity(parent_vec, child_vec, dim=0).item()
      sim = (sim + 1) / 2 # [0, 1]
      power = 4
      reward = sim ** power
      reward = float(torch.clamp(torch.tensor(reward), min=0.0, max=1.0))
    elif self.metric == "euclidean":
      dist = torch.norm(parent_vec - child_vec, p=2).item()
      reward = 1 - dist
    else:
      raise ValueError("metric must be either 'cosine' or 'euclidean'")

    reward = float(torch.clamp(torch.tensor(reward), min=-1.0, max=1.0))

    
    log = (
      f"EPISODE {episode} | "
      f"Input: {input_text} | "
      f"Parent: {parent_output} | "
      f"Child: {child_output} | "
      f"Reward: {reward:.3f}\n"
    )
    if not is_test:
      with open(self.log_file, "a") as f:
        f.write(json.dumps(log) + "\n\n")
    else:
      print(log)

    return reward



task = """
You are an AI assistant tasked with combining two colors and outputting the resulting color name.
For example:
- "red + blue = " → "purple"
- "red + yellow = " → "orange"
- "blue + yellow = " → "green"
- "white + black = " → "gray"
If the colors are similar (e.g. "red + pink"), output the most dominant or blended color (e.g. "light red" or "pink").
Always output a single lowercase color name or a simple descriptive blend like "light blue" or "dark green".
"""
reward_model = PGSRM(
  parent_model="gpt-4o-mini",
  metric="cosine",
  task=task
)

input_text = "red + blue = "
options = ["jungle", "brown", "violet", "[orange"]
for option in options:
  reward_model.get_reward(input_text, option, 0, True)


In [None]:
## Training Loop
import random

base_model = "gpt2"
BATCH_SIZE = 50
EPISODES = 50000
MESSAGE_INTERVAL = 5000

tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

dtype = torch.float32

actor_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=dtype,               # defines precision used for weights
    device_map={"": "cuda:0"},       # loads the model entirely onto GPU 0
    low_cpu_mem_usage=True,          # reduces RAM load during model init
)
actor_model_device = next(actor_model.parameters()).device

ref_base = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=dtype,
    device_map={"": "cuda:0"},
    low_cpu_mem_usage=True,
)

ppo_trainer = PPOTrainer(
    actor_model = actor_model,
    ref_model = ref_base,
    tokenizer = tokenizer,
    device = actor_model_device,
    batch_size = BATCH_SIZE,
    entropy_coef = 0.01,
    clip_range = 0.4,
    kl_coef = 0.00005,
    target_kl = 0.8,
    value_coef = 0.5, # standard value
    max_grad_norm = 1, # standard for keeping stable updates,
    critic_learning_rate = 1e-4 # high learning rate for the critic, such that it picks up on variance in rewards quickly
)

colors = [
    "red", "blue", "yellow", "green", "orange", "purple", "pink",
    "brown", "black", "white"
]

prompt_tensors, response_tensors, rewards = [], [], []

for episode in range(1, EPISODES + 1):

  color1, color2 = random.sample(colors, 2)

  full_input_state = f"{color1} + {color2} = "
  q = tokenizer(full_input_state, return_tensors="pt", padding=True).to(actor_model_device)
  query_tensors = q["input_ids"]

  with torch.no_grad():
    gen = ppo_trainer.actor.generate(
        query_tensors,
        max_new_tokens = 2,
        do_sample = True,
        temperature = 1.0,
        pad_token_id=tokenizer.eos_token_id
    )

  response_ids = gen[:, query_tensors.size(1):]
  gen_txt = tokenizer.batch_decode(response_ids, skip_special_tokens = True)[0]

  # PGSRL reward calculation
  reward = reward_model.get_reward(full_input_state, gen_txt, episode)
  reward_t = torch.tensor([float(reward)], dtype=torch.float, device=actor_model_device)

  prompt_tensors.append(query_tensors)
  response_tensors.append(response_ids)
  rewards.append(reward_t)

  if episode % BATCH_SIZE == 0:

    # Pad variable-length tensors to same length
    prompts_batch = pad_sequence(
        [p.squeeze(0) for p in prompt_tensors], batch_first=True, padding_value=tokenizer.pad_token_id
    )
    responses_batch = pad_sequence(
        [r.squeeze(0) for r in response_tensors], batch_first=True, padding_value=tokenizer.pad_token_id
    )
    rewards_batch = torch.cat(rewards, dim=0)

    average_reward = rewards_batch.mean().item()

    stats = ppo_trainer.step(prompts_batch, responses_batch, rewards_batch, average_reward)

    prompt_tensors, response_tensors, rewards = [], [], []

if episode % MESSAGE_INTERVAL == 0:
    message = f"colors-gpt2-pgsrm: {episode} episodes completed"
    url = f"https://api.telegram.org/bot{bot_id}/sendMessage"
    payload = {
        "chat_id": chat_id,
        "text": message
    }
    requests.post(url, json = payload)


















