**Part 4**

GRPO Loss

Improvements:
* Added Log Counter as train function parameter

Attempted the Following:
* Adding Gradient Clipping
* Tried adjusting the reward logic
* Added Entropy penalty to Loss calculation

Look into:
* detatching advantages

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from tqdm import tqdm

In [19]:
# Define the neural network
class LogicNet(nn.Module):
    def __init__(self):
        super(LogicNet, self).__init__()
        self.fc1 = nn.Linear(2, 4)  # Input layer -> Hidden Layer
        self.fc2 = nn.Linear(4, 1)  # Hidden Layer -> Output Layer

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

    def get_action_and_or_log_prob(self, state, action=None):
        """Helper method to get action and its log_prob from logits"""
        logits = self.forward(state)    # Get the logits from a forward pass of the Policy Network
        # For a binary output (0 or 1), Bernoulli distribution is appropriate
        probs = torch.distributions.Bernoulli(logits=logits)

        if action is None:
            sampled_action = probs.sample() # Sample action based on current probabilities (returns 0 or 1)
            log_prob = probs.log_prob(sampled_action)   # Calculate the log of the probability the sampled action is chosen
            return sampled_action, log_prob
        else:
            log_prob = probs.log_prob(action)       # Returns the log of the probability the action is chosen
            return log_prob


In [20]:
# Define the environment
class LogicGateEnv:
    def __init__(self, gate="AND"):
        self.gate = gate
        self.data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
        self.targets = self.get_targets(gate)

    def get_targets(self, gate:str):
        if gate == "AND":
            return torch.tensor([[0], [0], [0], [1]], dtype=torch.float32)
        elif gate == "OR":
            return torch.tensor([[0], [1], [1], [1]], dtype=torch.float32)
        elif gate == "XOR":
            return torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
        elif gate == "XNOR":
            return torch.tensor([[1], [0], [0], [1]], dtype=torch.float32)

    def step(self, input_idx: int, prediction):
        target = self.targets[input_idx].item()
        # Rounds up to 1 if it is >=.5 to get prediction; else 0
        reward = 1.0 if round(prediction.item()) == target else -10.0
        return reward

In [None]:
# Training loop
def train_logic_gate(gate="XOR", epochs=100, learning_rate=0.0001, batch_size=64, k_epochs=64, epsilon=0.2, beta_kl=0.01, max_grad_norm=0.5, entropy_coeff=0.5, log_iterations=10):
    print(f"Training {gate} gate with {epochs} epochs, {learning_rate} learning rate, batch size {batch_size}, and KL beta {beta_kl}.")
    #Initialize Agent's Policy, Environment, parameter optimizer, and Total Correct Counter
    env = LogicGateEnv(gate)
    Policy_New = LogicNet()
    optimizer = optim.Adam(Policy_New.parameters(), lr=learning_rate)
    num_correct = 0.0

    for epoch in range(epochs):
        rewards_batch = []
        inputs_batch = []
        targets_batch = []

        # --- 1. Collect a Batch of Experiences ---
        # Loop agent prediction, recording important values to lists:
        for i in range(batch_size):
            # Get model inputs and target
            idx = random.randint(0, 3)
            inputs = env.data[idx]
            target = env.targets[idx]

            # Get model prediction
            # Get logits from current policy and formulate the model's prediction for reward calculation
            with torch.no_grad(): # No need to track gradients during data collection
                prediction_logits = Policy_New(inputs)
                # print(f"prediction logits: {prediction_logits}")
                pred = torch.round(torch.sigmoid(prediction_logits)).float()

            # Calculate reward
            reward = env.step(idx, pred)

            # Append to lists
            inputs_batch.append(inputs)
            rewards_batch.append(reward)
            targets_batch.append(target)

        # Convert collected batch lists into PyTorch tensors
        inputs_batch_tensor = torch.stack(inputs_batch)
        targets_batch_tensor = torch.stack(targets_batch)
        rewards_batch_tensor = torch.tensor(rewards_batch, dtype=torch.float32)

        num_correct += (rewards_batch_tensor==1.0).sum().item()  ### need to change
        # print(f"Number correct, this iteration: {(rewards_batch_tensor).sum().item()}")

        # Unsqueeze to ensure rewards_batch_t has the same shape as targets_batch_t for element-wise ops SHAPE:(1, batch_size)
        rewards_batch_t = rewards_batch_tensor.unsqueeze(1)

        # --- START OF ADVANTAGE CALCULATION ---
        # Calculate the mean of the rewards in the current batch
        mean_reward = rewards_batch_tensor.mean()

        # Calculate the standard deviation of the rewards in the current batch
        # Add a small epsilon (1e-8) to prevent division by zero in case all rewards are identical
        std_reward = rewards_batch_tensor.std() + 1e-8

        # print(f"rewards_batch_t shape: {rewards_batch_t.shape} || mean_reward: {mean_reward}")
        # Calculate the advantage for each time step in the batch using your specified formula
        advantages_of_batch = (rewards_batch_t - mean_reward) / (std_reward)
        # --- END OF ADVANTAGE CALCULATION ---

        # --- 2. Store "Old Policy" Parameters ---
        # Transfer the weights to the Old Policy Model
        Policy_Old = LogicNet()
        Policy_Old.load_state_dict(Policy_New.state_dict())
        Policy_Old.eval()       # Tells Pytorch not to calculate gradients for this network

        # Get log_probabilities for the collected 'targets' from the OLD policy
        # Detach these to prevent gradients from flowing back into old_net
        with torch.no_grad():
            old_logits = Policy_Old(inputs_batch_tensor)
            # Use the get_action_and_or_log_prob helper
            log_prob_old = Policy_Old.get_action_and_or_log_prob(inputs_batch_tensor, targets_batch_tensor).detach()
            # The .detach() is critical here to ensure old_net remains fixed.

        # --- 3. Inner Loop (K_epochs) --- GRPO Optimization iteration
        for _ in tqdm(range(k_epochs), desc=f"Epoch {epoch+1}/{epochs} (Inner K-Epochs)", leave=False):
            new_policy_logits = Policy_New(inputs_batch_tensor)
            log_prob_new = Policy_New.get_action_and_or_log_prob(inputs_batch_tensor, targets_batch_tensor)


            # --- KL Divergence Calculation ---
            # Create Bernoulli distributions for new and old policies using their logits
            p_dist = torch.distributions.Bernoulli(logits=new_policy_logits)
            q_dist = torch.distributions.Bernoulli(logits=old_logits) # Use the detached old_logits

            # Calculate KL divergence per sample, then take the mean over the batch
            kl_div_per_sample = torch.distributions.kl.kl_divergence(p_dist, q_dist)
            kl_loss = kl_div_per_sample.mean() # Mean over the batch


            # print(f"log_prob_new: {log_prob_new}")
            # print(f"log_prob_old: {log_prob_old}")

            # Calculate the ratio of each Trajectory in the Group
            # r_t(0) = π_0(a_t|s_t) / π_0_old(a_t|s_t) = exp(log(π_0(a_t|s_t) - log(π_0_old(a_t|s_t)))
            ratio = torch.exp(log_prob_new - log_prob_old)

            # print(f"Ratio: {ratio}")

            surrogate_1 = ratio * advantages_of_batch
            surrogate_2 = torch.clamp(input=ratio, min= 1.0 - epsilon, max= 1.0 + epsilon) * advantages_of_batch

            # Combine clipped loss with KL penalty
            # Remember: we minimize the negative of the main objective, and add the KL term
            # Maximize: min(...) - beta * D_KL(...) => Minimize: -min(...) + beta * D_KL(...)
            policy_objective_term = -torch.min(surrogate_1, surrogate_2).mean()

            # print(f"policy_objective_term: {policy_objective_term}")
            
            # Calculate the entropy 
            entropy = p_dist.entropy().mean()
            # print(f"Entropy Monitor: {entropy_coeff * entropy}") # Entropy goes to 0
            loss = policy_objective_term + beta_kl * kl_loss - entropy_coeff * entropy# Add KL term with beta_kl weight

            # In GRPO, the objective function is typically designed to be maximized (e.g., maximizing the expected return). Since PyTorch optimizers are designed for minimization, the common practice is to minimize the negative of the objective function.

            # Update the New Policy Model
            optimizer.zero_grad()
            loss.backward()
            # --- ADDING GRADIENT CLIPPING HERE TO LIMIT PARAMETER UPDATES---
            # torch.nn.utils.clip_grad_norm_(Policy_New.parameters(), max_norm=max_grad_norm)
            # ----------------------------------
            optimizer.step()



        # LOG IF ENOUGH EPOCHS HAVE ELAPSED
        if epoch % log_iterations == 0:
            avg_reward = rewards_batch_tensor.mean().item()
            print(f"Epoch {epoch}: Loss = {loss.item()}, Avg Reward = {avg_reward:.4f}, Mean Advantage: {advantages_of_batch.mean().item()}")
            # Validation Step
            print("Validating the Model:")
            with torch.no_grad():
                for i in range(4):
                    logits = Policy_New(env.data[i])
                    pred = torch.round(torch.sigmoid(logits)).item()
                    print(f"Input: {env.data[i].tolist()}, Logits: {logits}, Prediction: {pred}, Actual: {env.targets[i].item()}")

    print("Training completed.\n")
    print(f"Number of correct predictions: {num_correct}/{epochs * batch_size}")
    print(f"Accuracy: {num_correct/(epochs * batch_size)}%")

    print("\nTesting Trained Model:")
    for i in range(4):
        logits = Policy_New(env.data[i])
        pred = torch.round(torch.sigmoid(logits)).item()
        print(f"Input: {env.data[i].tolist()}, Prediction: {pred}, Actual: {env.targets[i].item()}")


In [24]:

# Run training
train_logic_gate("AND")

Training AND gate with 100 epochs, 0.0001 learning rate, batch size 64, and KL beta 0.01.


                                                                              

Epoch 0: Loss = -0.33785897493362427, Avg Reward = -0.7188, Mean Advantage: 2.2351741790771484e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-0.4696]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-0.4915]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-0.4276]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-0.4438]), Prediction: 0.0, Actual: 1.0


                                                                              

Epoch 10: Loss = -0.33281025290489197, Avg Reward = -1.7500, Mean Advantage: 3.3527612686157227e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-0.6069]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-0.6749]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-0.5331]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-0.6183]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 20: Loss = -0.3191086947917938, Avg Reward = -1.4062, Mean Advantage: 1.4901161193847656e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-0.7694]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-0.9191]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-0.6811]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-0.8804]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 30: Loss = -0.2978871762752533, Avg Reward = -2.0938, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-0.9811]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-1.2464]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-0.9108]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-1.2449]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 40: Loss = -0.2597149610519409, Avg Reward = -2.0938, Mean Advantage: -1.1175870895385742e-08
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-1.2173]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-1.6331]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-1.1935]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-1.6909]), Prediction: 0.0, Actual: 1.0


                                                                     

Epoch 50: Loss = -0.2265864908695221, Avg Reward = -2.6094, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-1.4324]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-2.0088]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-1.5176]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-2.1989]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 60: Loss = -0.18781158328056335, Avg Reward = -2.4375, Mean Advantage: 0.0
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-1.6692]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-2.4700]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-1.8971]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-2.7892]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 70: Loss = -0.14591234922409058, Avg Reward = -1.2344, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-1.9237]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-2.9861]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-2.3308]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-3.4592]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 80: Loss = -0.12646956741809845, Avg Reward = -1.2344, Mean Advantage: -7.450580596923828e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-2.1854]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-3.5261]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-2.7926]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-4.1695]), Prediction: 0.0, Actual: 1.0


                                                                               

Epoch 90: Loss = -0.098661869764328, Avg Reward = -2.2656, Mean Advantage: -2.7939677238464355e-09
Validating the Model:
Input: [0.0, 0.0], Logits: tensor([-2.4917]), Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Logits: tensor([-4.1304]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Logits: tensor([-3.3173]), Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Logits: tensor([-4.9723]), Prediction: 0.0, Actual: 1.0


                                                                                

Training completed.

Number of correct predictions: 4808.0/6400
Accuracy: 0.75125%

Testing Trained Model:
Input: [0.0, 0.0], Prediction: 0.0, Actual: 0.0
Input: [0.0, 1.0], Prediction: 0.0, Actual: 0.0
Input: [1.0, 0.0], Prediction: 0.0, Actual: 0.0
Input: [1.0, 1.0], Prediction: 0.0, Actual: 1.0


