**Part 4**

GRPO Loss

* Terrible Loss
* Gives One Sided Predictions

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from tqdm import tqdm

In [2]:
# Define the neural network
class LogicNet(nn.Module):
    def __init__(self):
        super(LogicNet, self).__init__()
        self.fc1 = nn.Linear(2, 4)  # Input layer -> Hidden Layer
        self.fc2 = nn.Linear(4, 1)  # Hidden Layer -> Output Layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

    def get_action_and_or_log_prob(self, state, action=None):
        """Helper method to get action and its log_prob from logits"""
        logits = self.forward(state)    # Get the logits from a forward pass of the Policy Network
        # For a binary output (0 or 1), Bernoulli distribution is appropriate
        probs = torch.distributions.Bernoulli(logits=logits)

        if action is None:
            sampled_action = probs.sample() # Sample action based on current probabilities (returns 0 or 1)
            log_prob = probs.log_prob(sampled_action)   # Calculate the log of the probability the sampled action is chosen
            return sampled_action, log_prob
        else:
            log_prob = probs.log_prob(action)       # Returns the log of the probability the action is chosen
            return log_prob


In [3]:
# Define the environment
class LogicGateEnv:
    def __init__(self, gate="AND"):
        self.gate = gate
        self.data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
        self.targets = self.get_targets(gate)

    def get_targets(self, gate:str):
        if gate == "AND":
            return torch.tensor([[0], [0], [0], [1]], dtype=torch.float32)
        elif gate == "OR":
            return torch.tensor([[0], [1], [1], [1]], dtype=torch.float32)
        elif gate == "XOR":
            return torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
        elif gate == "XNOR":
            return torch.tensor([[1], [0], [0], [1]], dtype=torch.float32)

    def step(self, input_idx: int, prediction):
        target = self.targets[input_idx]
        # Take the mean squared error
        # print(f"prediction: {prediction} || target: {target}")
        error = (prediction - target).pow(2).mean().item()
        reward = 1.0 - error
        return reward

In [None]:
# Training loop
def train_logic_gate(gate="XOR", epochs=1000, learning_rate=0.0001, batch_size=64, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5):
    print(f"Training {gate} gate with {epochs} epochs, {learning_rate} learning rate, batch size {batch_size}, and KL beta {beta_kl}.")
    # Initialize Agent's Policy, Environment, parameter optimizer, and Total Correct Counter
    env = LogicGateEnv(gate)
    Policy_New = LogicNet()   # STEP 1 || CREATE π_new
    optimizer = optim.Adam(Policy_New.parameters(), lr=learning_rate)
    num_correct = 0.0
    # STEP 2 || FOR I ITERATION STEPS OMITTED
    # STEP 3 || CREATE REFERENCE MODEL OMITTED

    for epoch in range(epochs):     # STEP 4 || FOR M ITERATION STEPS
        rewards_batch = []  # will be a list of floats
        inputs_batch = []   # will be a list of tensors
        targets_batch = []  # will be a list of tensors
        
        # STEP 5 || Sample a batch D_b from D --> OMITTED 
        # STEP 6 || Update the old policy model π_old <- π_new
        Policy_Old = LogicNet()
        Policy_Old.load_state_dict(Policy_New.state_dict())
        Policy_Old.eval()   # Prevent Gradient tracking

        # --- STEP 7 || Collect a Batch of Experiences ---
        # Loop agent prediction, recording important values to lists:
        for i in range(batch_size):
            # Get model inputs and target
            idx = random.randint(0, 3)
            inputs = env.data[idx]
            target = env.targets[idx]

            # Get model prediction and log_prob of that prediction using the old policy
            with torch.no_grad(): # No need to track gradients during data collection
                pred, log_prob  = Policy_Old.get_action_and_or_log_prob(state=inputs)   # returns tensors

            # Calculate reward
            reward = env.step(idx, pred.item())

            # Append to lists
            inputs_batch.append(inputs)
            rewards_batch.append(reward)
            targets_batch.append(target)

        # Convert collected batch lists into PyTorch tensors
        inputs_batch_tensor = torch.stack(inputs_batch)     # Shape: (batch_size, 3)
        targets_old_batch_tensor = torch.stack(targets_batch)   # Shape: (batch_size, 1)
        rewards_batch_tensor = torch.tensor(rewards_batch, dtype=torch.float32)     # Shape: (batch_size,)

        num_correct += (rewards_batch_tensor).sum().item()  ### need to change
        # print(f"Number correct, this iteration: {(rewards_batch_tensor).sum().item()}")

        # STEP 8 || Calculate Discounted Rewards
        # Unsqueeze to ensure rewards_batch_t has the same shape as targets_batch_t for element-wise ops SHAPE:(1, batch_size)
        rewards_batch_t = rewards_batch_tensor.unsqueeze(1)

        # --- STEP 9 || START OF ADVANTAGE CALCULATION ---
        # Calculate the mean of the rewards in the current batch
        mean_reward = rewards_batch_tensor.mean()

        # Calculate the standard deviation of the rewards in the current batch
        # Add a small epsilon (1e-8) to prevent division by zero in case all rewards are identical
        std_reward = rewards_batch_tensor.std() + 1e-8
        # print(f"rewards_batch_t shape: {rewards_batch_t.shape} || mean_reward: {mean_reward}")
        # Calculate the advantage for each time step in the batch using your specified formula
        advantages_of_batch = (rewards_batch_t - mean_reward) / (std_reward)
        # --- END OF ADVANTAGE CALCULATION ---

        # Detach these to prevent gradients from flowing back into old_policy_net
        inputs_batch_tensor = inputs_batch_tensor.detach()
        targets_old_batch_tensor = targets_old_batch_tensor.detach()
        rewards_batch_tensor = rewards_batch_tensor.detach()
        # log_prob_old_batch_tensor = log_prob_old_batch_tensor.detach()
        advantages_of_batch = advantages_of_batch.detach()

        # Get log_probabilities for the collected 'targets' from the OLD policy
        # Detach these to prevent gradients from flowing back into old_net
        with torch.no_grad():
            old_logits = Policy_Old(inputs_batch_tensor)
            # Use the get_action_and_or_log_prob helper
            log_prob_old = Policy_Old.get_action_and_or_log_prob(inputs_batch_tensor, targets_old_batch_tensor).detach()
            # The .detach() is critical here to ensure old_net remains fixed.
            q_dist = torch.distributions.Bernoulli(logits=old_logits.detach())

        # --- STEP 10 || GRPO Optimization ---
        for _ in tqdm(range(k_epochs), desc=f"Epoch {epoch+1}/{epochs} (Inner K-Epochs)", leave=False):
            new_policy_logits = Policy_New(inputs_batch_tensor)
            log_prob_new = Policy_New.get_action_and_or_log_prob(inputs_batch_tensor, targets_old_batch_tensor)


            # --- KL Divergence Calculation ---
            # Create Bernoulli distributions for new and old policies using their logits
            p_dist = torch.distributions.Bernoulli(logits=new_policy_logits)
            # use same q_dist calculated above

            # Calculate KL divergence per sample, then take the mean over the batch
            kl_div_per_sample = torch.distributions.kl.kl_divergence(p_dist, q_dist)
            kl_loss = kl_div_per_sample.mean() # Mean over the batch


            # print(f"log_prob_new: {log_prob_new}")
            # print(f"log_prob_old: {log_prob_old}")

            # Calculate the ratio of each Trajectory in the Group
            # r_t(0) = π_0(a_t|s_t) / π_0_old(a_t|s_t) = exp(log(π_0(a_t|s_t) - log(π_0_old(a_t|s_t)))
            ratio = torch.exp(log_prob_new - log_prob_old)

            # print(f"Ratio: {ratio}")

            surrogate_1 = ratio * advantages_of_batch
            surrogate_2 = torch.clamp(input=ratio, min= 1.0 - epsilon, max= 1.0 + epsilon) * advantages_of_batch

            # Combine clipped loss with KL penalty
            # Remember: we minimize the negative of the main objective, and add the KL term
            # Maximize: min(...) - beta * D_KL(...) => Minimize: -min(...) + beta * D_KL(...)
            policy_objective_term = -torch.min(surrogate_1, surrogate_2).mean()

            # print(f"policy_objective_term: {policy_objective_term}")
            
            loss = policy_objective_term + beta_kl * kl_loss # Add KL term with beta_kl weight

            # In GRPO, the objective function is typically designed to be maximized (e.g., maximizing the expected return). Since PyTorch optimizers are designed for minimization, the common practice is to minimize the negative of the objective function.

            # STEP 11 || Policy Updates
            optimizer.zero_grad()
            loss.backward()
            # --- ADDING GRADIENT CLIPPING HERE TO LIMIT PARAMETER UPDATES---
            torch.nn.utils.clip_grad_norm_(Policy_New.parameters(), max_norm=max_grad_norm)
            # ----------------------------------
            optimizer.step()



        # LOG IF ENOUGH EPOCHS HAVE ELAPSED
        if epoch % 100 == 0:
            avg_reward = rewards_batch_tensor.mean().item()
            print(f"Epoch {epoch}: Loss = {loss.item()}, Avg Reward = {avg_reward:.4f}, Mean Advantage: {advantages_of_batch.mean().item()}")
            # Validation Step
            print("Validating the Model:")
            with torch.no_grad():
                for i in range(4):
                    logits = Policy_New(env.data[i])
                    pred = torch.round(torch.sigmoid(logits)).item()
                    print(f"Input: {env.data[i].tolist()}, Logits: {logits}, Prediction: {pred}, Actual: {env.targets[i].item()}")

    print("Training completed.\n")
    print(f"Number of correct predictions: {num_correct}/{epochs * batch_size}")
    print(f"Accuracy: {num_correct/(epochs * batch_size)}%")

    print("\nTesting Trained Model:")
    for i in range(4):
        logits = Policy_New(env.data[i])
        pred = torch.round(torch.sigmoid(logits)).item()
        print(f"Input: {env.data[i].tolist()}, Prediction: {pred}, Actual: {env.targets[i].item()}")


In [10]:

# Run training
train_logic_gate(gate="XOR", epochs=1000, learning_rate=0.0001, batch_size=64, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5)

Training XOR gate with 1000 epochs, 0.0001 learning rate, batch size 64, and KL beta 0.01.


                                                                     

Epoch 0: Loss = -0.0031273080967366695, Avg Reward = 0.4688, Mean Advantage: 1.4901161193847656e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([0.2180]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([0.4819]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([0.3433]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([0.4804]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 100: Loss = -0.03756774216890335, Avg Reward = 0.4062, Mean Advantage: 0.0
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([3.4598]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([6.5537]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([6.0721]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([9.1660]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 200: Loss = -0.06453421711921692, Avg Reward = 0.4375, Mean Advantage: 0.0
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([9.0213]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([17.2350]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([16.4411]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([24.6548]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 300: Loss = -0.0758737176656723, Avg Reward = 0.3438, Mean Advantage: -8.381903171539307e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([17.0490]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([32.7010]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([31.5960]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([47.2480]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 400: Loss = -0.08336611837148666, Avg Reward = 0.4844, Mean Advantage: 1.5832483768463135e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([27.4625]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([52.2739]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([50.8568]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([75.6682]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 500: Loss = -0.07657364010810852, Avg Reward = 0.5781, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([40.2311]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([75.7176]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([73.9888]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([109.4752]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 600: Loss = -0.08424654603004456, Avg Reward = 0.5781, Mean Advantage: -3.725290298461914e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([55.3809]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([102.9234]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([100.8833]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([148.4258]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 700: Loss = -0.09047267585992813, Avg Reward = 0.4531, Mean Advantage: 3.725290298461914e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([72.9494]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([133.8081]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([131.4569]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([192.3157]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 800: Loss = -0.09459985047578812, Avg Reward = 0.4062, Mean Advantage: 0.0
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([92.9401]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([168.3655]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([165.7035]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([241.1289]), Prediction: 1.0, Actual: 0.0


                                                                       

Epoch 900: Loss = -0.09916718304157257, Avg Reward = 0.5156, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([115.2267]), Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([206.3324]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Logits: tensor([203.3612]), Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Logits: tensor([294.4669]), Prediction: 1.0, Actual: 0.0


                                                                        

Training completed.

Number of correct predictions: 32020.0/64000
Accuracy: 0.5003125%

Testing Trained Model:
Input: [0.0, 0.0], Prediction: 1.0, Actual: 0.0
Input: [0.0, 1.0], Prediction: 1.0, Actual: 1.0
Input: [1.0, 0.0], Prediction: 1.0, Actual: 1.0
Input: [1.0, 1.0], Prediction: 1.0, Actual: 0.0


