Combining Chain of Thought (CoT) with Reinforcement Learning (RL) involves enabling an agent to explicitly generate intermediate reasoning steps (CoT) before selecting actions. 

Key Components

Chain of Thought (CoT) Module: Generates intermediate reasoning steps (e.g., natural language or structured thoughts) based on the environment state.

Policy Network: Uses the CoT output to decide actions.

RL Training Loop: Trains the agent using rewards, updating both the policy and CoT module.

Grid World Navigation

Goal: Train an agent to navigate a grid world to reach a goal. The agent generates CoT steps (e.g., "Move right to avoid obstacle") before acting.

Define Environment

This code implements a Reinforcement Learning (RL) agent using a policy gradient method with a "Chain of Thought" (CoT) inspired policy network to navigate a simple grid world environment. 

Define CoT + Policy Network

In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
torch.randn(1, 1, 5*5)

tensor([[[-0.9429, -0.9260,  0.1712,  1.3570,  0.6236,  1.4876, -0.3113,
           0.3550, -0.1785, -0.1906,  1.0326,  0.1951, -2.4507,  2.7198,
          -0.3184, -1.5103,  0.6222,  0.9984, -1.4907,  0.1795, -1.0040,
           0.3620, -1.5469, -0.0327,  1.5555]]])

In [13]:
# 1: The number of batches (or samples). In this case, we're generating a single sample.
# 1: The number of channels (or features). In this case, we're generating a single-channel tensor.
# 5*5: The number of elements in the tensor. Since 5*5 = 25, we're generating a tensor with 25 elements.

In [12]:
class GridWorld:
    def __init__(self, grid_size=(5, 5)):
        self.grid = np.array([
            ['S', '.', '.', '.', '.'],
            ['.', '#', '.', '.', '.'],
            ['.', '.', '#', '.', '.'],
            ['.', '.', '.', '#', 'G']
        ])
        self.grid_size = grid_size
        self.reset()
        self.action_space = [0, 1, 2, 3]  # Up, Down, Left, Right
        self.n_actions = len(self.action_space)

    def reset(self):
        self.agent_pos = self._find_start_pos()
        return self._get_state()

    def _find_start_pos(self):
        for r in range(self.grid_size[0]):
            for c in range(self.grid_size[1]):
                if self.grid[r, c] == 'S':
                    return (r, c)
        return (0, 0)  # Default if 'S' not found

    def _get_state(self):
        # Return a tuple of (grid, agent_position)
        return self.grid.copy(), self.agent_pos

    def step(self, action):
        done = False
        x, y = self.agent_pos
        if action == 0:     # Up
            x = max(0, x-1)
        elif action == 1:  # Down
            x = min(self.grid_size[0] - 1, x+1)
        elif action == 2:  # Left
            y = max(0, y-1)
        elif action == 3:  # Right
            y = min(self.grid_size[1] - 1, y+1)

        if self.grid[x, y] == '#':
            reward = -1    # Hit obstacle
        elif self.grid[x, y] == 'G':
            reward = 10   # Reach goal
            done = True
        else:
            reward = -0.1  # Step penalty
            done = False

        self.agent_pos = (x, y)
        # print("Agent position", self.agent_pos) # Removed frequent printing
        # print("state", self._get_state())
        return self._get_state(), reward, done

input_dim: The size of the input state representation. In this simplified example, it's assumed to be 25 (5x5 grid flattened).

hidden_dim: The number of hidden units in the LSTM and the subsequent linear layers.

action_dim: The number of possible actions (4 in this grid world).

In [13]:
class CoT_Policy(nn.Module):
    def __init__(self, grid_size, hidden_dim, action_dim):
        super().__init__()
        self.grid_size = grid_size
        self.hidden_dim = hidden_dim
        self.action_dim = action_dim

        # Convolutional layer to process the grid
        self.conv = nn.Conv2d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

        # LSTM for temporal processing / "Chain of Thought"
        self.lstm = nn.LSTM(input_size=16 * grid_size[0] * grid_size[1] + 2,  # Conv output + agent pos
                            hidden_size=hidden_dim,
                            batch_first=True)

        # Policy Network
        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def _encode_state(self, grid, agent_pos):
        # One-hot encode the grid
        encoding = np.zeros((4, self.grid_size[0], self.grid_size[1]), dtype=np.float32)
        for r in range(self.grid_size[0]):
            for c in range(self.grid_size[1]):
                if grid[r, c] == 'S':
                    encoding[0, r, c] = 1
                elif grid[r, c] == '.':
                    encoding[1, r, c] = 1
                elif grid[r, c] == '#':
                    encoding[2, r, c] = 1
                elif grid[r, c] == 'G':
                    encoding[3, r, c] = 1
        return encoding

    def forward(self, state_tuple, hidden=None):
        grid, agent_pos = state_tuple
        encoded_grid = self._encode_state(grid, agent_pos)
        grid_tensor = torch.from_numpy(encoded_grid).unsqueeze(0) # Add batch dimension

        conv_out = self.relu(self.conv(grid_tensor))
        flattened_conv = self.flatten(conv_out)

        agent_pos_tensor = torch.tensor(agent_pos, dtype=torch.float32).unsqueeze(0) # Add batch dimension
        lstm_input = torch.cat((flattened_conv, agent_pos_tensor), dim=1).unsqueeze(1) # Add sequence dimension

        lstm_out, hidden = self.lstm(lstm_input, hidden)
        action_logits = self.policy(lstm_out[:, -1, :]) # Take the output of the last time step

        return action_logits, hidden

Training Loop with Policy Gradients

In [None]:
def train(env, model, optimizer, episodes=1000, gamma=0.99):
    model.train()
    for ep in range(episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        hidden_state = None # Initialize LSTM hidden state

        while not done:
            # Get action probabilities from CoT-guided policy
            action_logits, hidden_state = model(state, hidden_state)
            action_probs = F.softmax(action_logits, dim=-1)
            action_dist = torch.distributions.Categorical(action_probs)
            action = action_dist.sample()

            # Execute action
            next_state, reward, done = env.step(action.item())

            # Store log prob and reward
            log_probs.append(action_dist.log_prob(action))
            rewards.append(reward)

            state = next_state

            # Detach hidden state to prevent backpropagation through entire episode
            if hidden_state is not None:
                hidden_state = (hidden_state[0].detach(), hidden_state[1].detach())

        # Calculate discounted returns
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        # Normalize returns
        returns = torch.tensor(returns)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        else:
            returns = torch.zeros_like(returns) # Handle single step episodes

        # Policy gradient loss
        policy_loss = []
        for log_prob, R in zip(log_probs, returns):
            policy_loss.append(-log_prob * R)
        policy_loss = torch.cat(policy_loss).sum()

        # Update model
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        print(f"Episode {ep+1}, Total Reward: {sum(rewards):.1f}")


In [14]:
grid_size = (4, 5)
env = GridWorld(grid_size=grid_size)
hidden_dim = 128
model = CoT_Policy(grid_size, hidden_dim, env.n_actions)
optimizer = optim.Adam(model.parameters(), lr=0.001)
episodes = 1000
train(env, model, optimizer, episodes=episodes)



TypeError: softmax() received an invalid combination of arguments - got (tuple, dim=int), but expected one of:
 * (Tensor input, int dim, torch.dtype dtype = None, *, Tensor out = None)
 * (Tensor input, name dim, *, torch.dtype dtype = None)


In [None]:
# Optional: Evaluation loop
model.eval()
total_reward = 0
num_eval_episodes = 10
for _ in range(num_eval_episodes):
    state = env.reset()
    done = False
    episode_reward = 0
    hidden_state = None
    while not done:
        with torch.no_grad():
            action_logits, hidden_state = model(state, hidden_state)
            action_probs = F.softmax(action_logits, dim=-1)
            action = torch.argmax(action_probs).item()
        state, reward, done = env.step(action)
        episode_reward += reward
    total_reward += episode_reward

print(f"\nAverage reward over {num_eval_episodes} evaluation episodes: {total_reward / num_eval_episodes:.2f}")

Advanced Architecture: CoT + Transformer + PPO
We’ll use a transformer to generate CoT steps and integrate them with the Proximal Policy Optimization (PPO) algorithm.

Key Components
CoT Transformer: Generates reasoning steps (e.g., "Avoid obstacle at (x,y)").

Policy Network: Uses CoT embeddings to select actions.

Value Network: Estimates state value for PPO updates.



: Define the CoT Transformer

In [18]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

class CoTTransformer(nn.Module):
    def __init__(self, state_dim, cot_dim, n_heads=4):
        super().__init__()
        self.encoder = nn.TransformerEncoderLayer(
            d_model=state_dim, nhead=n_heads, dim_feedforward=256
        )
        self.cot_proj = nn.Linear(state_dim, cot_dim)  # CoT embeddings

    def forward(self, state):
        # state: (batch_size, seq_len, state_dim)
        cot_embed = self.encoder(state)
        cot_steps = self.cot_proj(cot_embed)  # (batch_size, seq_len, cot_dim)
        return cot_steps

Define the Actor-Critic Network

In [19]:
class ActorCritic(nn.Module):
    def __init__(self, cot_dim, action_dim):
        super().__init__()
        # Actor (policy)
        self.actor = nn.Sequential(
            nn.Linear(cot_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )
        # Critic (value)
        self.critic = nn.Sequential(
            nn.Linear(cot_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, cot_steps):
        action_logits = self.actor(cot_steps)
        value = self.critic(cot_steps)
        return action_logits, value

PPO Training Loop with CoT

In [20]:
def ppo_train(env, cot_model, actor_critic, optimizer, epochs=10, clip_epsilon=0.2):
    states, actions, rewards, dones = [], [], [], []
    
    # Collect trajectories with CoT
    state = env.reset()
    while not done:
        with torch.no_grad():
            cot_steps = cot_model(state)
            action_logits, value = actor_critic(cot_steps)
            action_probs = torch.softmax(action_logits, dim=-1)
            action_dist = Categorical(action_probs)
            action = action_dist.sample()
        
        next_state, reward, done = env.step(action.item())
        
        # Store data
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        dones.append(done)
        state = next_state
    
    # Compute advantages and returns
    returns = compute_returns(rewards, gamma=0.99)
    advantages = returns - values
    
    # PPO loss
    for _ in range(epochs):
        cot_steps = cot_model(states)
        new_action_logits, new_values = actor_critic(cot_steps)
        new_action_probs = torch.softmax(new_action_logits, dim=-1)
        new_dist = Categorical(new_action_probs)
        
        # Policy loss
        ratio = (new_dist.log_prob(actions) - action_dist.log_prob(actions)).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1-clip_epsilon, 1+clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        # Value loss
        value_loss = nn.MSELoss()(new_values, returns)
        
        # Total loss
        loss = policy_loss + 0.5 * value_loss
        
        # Update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Hierarchical RL with CoT

Decompose tasks into high-level plans (CoT) and low-level actions.

Hierarchical Architecture
Meta-Controller: Generates CoT plans (e.g., "Go to Room A → Pick up Key").

Sub-Controller: Executes low-level actions (e.g., "Move left", "Grab").

In [21]:

class MetaController(nn.Module):
    def __init__(self, state_dim, plan_dim):
        super().__init__()
        self.planner = nn.TransformerEncoder(...)  # Generates plans

class SubController(nn.Module):
    def __init__(self, plan_dim, action_dim):
        super().__init__()
        self.policy = nn.Linear(plan_dim, action_dim)  # Executes actions

# Training loop:
meta_controller = MetaController(...)
sub_controller = SubController(...)

# Meta-controller generates plan
plan = meta_controller(state)

# Sub-controller executes plan
for step in plan:
    action = sub_controller(step)
    env.step(action)


TypeError: MetaController.__init__() missing 1 required positional argument: 'plan_dim'

 CoT for Partial Observability
 
Use CoT to maintain memory in partially observable environments (e.g., robot navigation).

Memory-Augmented CoT

In [None]:
class CoTMemory(nn.Module):
    def __init__(self, input_dim, mem_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, mem_dim)  # Track CoT history

    def forward(self, state, hidden):
        cot_step = generate_cot(state)  # E.g., "Observed door at (x,y)"
        output, hidden = self.lstm(cot_step, hidden)
        return output, hidden