In [4]:
import torch
import torch.nn as nn

In [8]:
class SarsaQNet(nn.Module):
    def __init__(self, s_dim, a_dim, hidden, act):
        super().__init__()

        act_fn = nn.ReLU()
        self.layers = nn.ModuleList()
        prev = s_dim

        for h in hidden:
            self.layers.append(nn.Linear(prev, h))
            prev = h

        self.out = nn.Linear(prev, a_dim)
        self.act_fn = act_fn

    def forward(self, x):
        for l in self.layers:
            x = self.act_fn(l(x))

        return self.out(x)

In [11]:
# Generating Q values for multiple states
s_dim = 8
states = torch.randn(4, s_dim)  # Batch of 4 states
print(states.shape)
print(states)

q_net = SarsaQNet(s_dim=s_dim, a_dim=4, hidden=[64, 64], act='relu')
q_values = q_net(states)
print(q_values.shape)  # Should print torch.Size([4, 4])
print(q_values)  # Should print a tensor of shape (4, 4)


torch.Size([4, 8])
tensor([[-0.0341, -1.7673,  1.6315, -1.2016,  1.2342, -0.5946,  0.9670, -0.5552],
        [ 0.5648, -0.1728,  0.5526,  0.6931,  0.6063, -0.2072,  0.0939, -0.6567],
        [ 1.4145, -0.8998,  0.3079,  0.6522, -0.9508, -0.0288, -0.0244, -1.3830],
        [-0.5737, -1.5074,  1.3948, -0.2741, -0.2908,  0.4586, -0.9738,  0.5688]])
torch.Size([4, 4])
tensor([[-0.1469, -0.0085, -0.0791,  0.0235],
        [ 0.0102,  0.0943,  0.0214, -0.0465],
        [ 0.1680,  0.1256, -0.0166,  0.0761],
        [-0.0286, -0.0851, -0.0384,  0.0174]], grad_fn=<AddmmBackward0>)


In [None]:
# Example: How gather() works for extracting Q-values
print("="*60)
print("EXAMPLE: Extracting Q(s,a) for specific actions")
print("="*60)

# Simulate 3 experiences with 4 possible actions each
batch_size = 3
action_dim = 4

# Q-values for all actions (output from Q-network)
q_values = torch.tensor([
    [0.5, 1.2, -0.3, 0.8],  # Experience 0: Q-values for actions [0,1,2,3]
    [0.1, 0.7,  1.5, 0.2],  # Experience 1: Q-values for actions [0,1,2,3]
    [0.9, 0.3,  0.6, 1.1]   # Experience 2: Q-values for actions [0,1,2,3]
])

# Actions actually taken in each experience
actions = torch.tensor([1, 2, 3])  # Took action 1, 2, and 3 respectively

print(f"\nq_values shape: {q_values.shape}")
print(f"q_values:\n{q_values}\n")

print(f"actions shape: {actions.shape}")
print(f"actions: {actions}\n")

# Step 1: unsqueeze to add dimension
print("-" * 60)
print("STEP 1: actions.unsqueeze(1)")
print("-" * 60)
actions_unsqueezed = actions.unsqueeze(1)
print(f"Shape: {actions_unsqueezed.shape}")
print(f"Values:\n{actions_unsqueezed}\n")

# Step 2: gather to extract specific Q-values
print("-" * 60)
print("STEP 2: q_values.gather(1, actions.unsqueeze(1))")
print("-" * 60)
gathered = q_values.gather(1, actions_unsqueezed)
print(f"Shape: {gathered.shape}")
print(f"Values:\n{gathered}")
print("\nWhat happened:")
print(f"  Experience 0: q_values[0, {actions[0]}] = {q_values[0, actions[0]]}")
print(f"  Experience 1: q_values[1, {actions[1]}] = {q_values[1, actions[1]]}")
print(f"  Experience 2: q_values[2, {actions[2]}] = {q_values[2, actions[2]]}\n")

# Step 3: squeeze to remove extra dimension
print("-" * 60)
print("STEP 3: .squeeze(1)")
print("-" * 60)
pred_q_values = gathered.squeeze(1)
print(f"Shape: {pred_q_values.shape}")
print(f"Values: {pred_q_values}\n")

# Complete operation in one line
print("="*60)
print("COMPLETE OPERATION (one line):")
print("="*60)
result = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
print(f"pred_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)")
print(f"Result: {result}")
print(f"Shape: {result.shape}\n")

# Verify it matches manual indexing
print("="*60)
print("VERIFICATION: Compare with manual indexing")
print("="*60)
manual_result = torch.tensor([
    q_values[0, actions[0]],
    q_values[1, actions[1]],
    q_values[2, actions[2]]
])
print(f"Manual indexing: {manual_result}")
print(f"Gather method:   {result}")
print(f"Are they equal? {torch.equal(manual_result, result)}")


EXAMPLE: Extracting Q(s,a) for specific actions

q_values shape: torch.Size([3, 4])
q_values:
tensor([[ 0.5000,  1.2000, -0.3000,  0.8000],
        [ 0.1000,  0.7000,  1.5000,  0.2000],
        [ 0.9000,  0.3000,  0.6000,  1.1000]])

actions shape: torch.Size([3])
actions: tensor([1, 2, 3])

------------------------------------------------------------
STEP 1: actions.unsqueeze(1)
------------------------------------------------------------
Shape: torch.Size([3, 1])
Values:
tensor([[1],
        [2],
        [3]])

------------------------------------------------------------
STEP 2: q_values.gather(1, actions.unsqueeze(1))
------------------------------------------------------------
Shape: torch.Size([3, 1])
Values:
tensor([[1.2000],
        [1.5000],
        [1.1000]])

What happened:
  Experience 0: q_values[0, 1] = 1.2000000476837158
  Experience 1: q_values[1, 2] = 1.5
  Experience 2: q_values[2, 3] = 1.100000023841858

------------------------------------------------------------
STE

: 