In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch
import torch.nn as nn

class CustomGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomGRUCell, self).__init__()
        self.hidden_size = hidden_size
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, x, h):
        combined = torch.cat((x, h), 1)
        z_t = torch.sigmoid(self.W_z(combined))
        r_t = torch.sigmoid(self.W_r(combined))
        combined_reset = torch.cat((x, r_t * h), 1)
        h_tilde = torch.tanh(self.W_h(combined_reset))
        h_t = (1 - z_t) * h + z_t * h_tilde
        return h_t, z_t

class CustomGRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomGRU, self).__init__()
        self.hidden_size = hidden_size
        self.gru_cell = CustomGRUCell(input_size, hidden_size)

    def forward(self, x, h):
        seq_len, batch_size, _ = x.size()
        outputs = []
        update_gates = []
        for t in range(seq_len):
            h, z_t = self.gru_cell(x[t], h)
            outputs.append(h)
            update_gates.append(z_t)
        outputs = torch.stack(outputs, dim=0)
        update_gates = torch.stack(update_gates, dim=0)
        return outputs, h, update_gates


class GRUQNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUQNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.gru = CustomGRU(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, h):
        out, h, update_gates = self.gru(x, h)
        q_values = self.fc(out)
        return q_values, h, update_gates

def generate_target_masks(batch_size, sequence_length, hidden_size, input_seq):
    target_masks = torch.zeros(batch_size, sequence_length, hidden_size)
    for i in range(batch_size):
        for t in range(sequence_length):
            digit = input_seq[i, t].argmax().item()
            target_masks[i, t, digit*10:(digit+1)*10] = 1
    return target_masks

def compute_update_gate_loss(update_gate, target_mask):
    return ((target_mask * (1 - update_gate)**2) + ((1 - target_mask) * update_gate**2)).sum()

# Hyperparameters
input_size = 10  # One-hot encoded digits
hidden_size = 100
output_size = 10  # Q-values for 10 possible actions
lambda_update = 0.1

# Initialize network and optimizer
net = GRUQNetwork(input_size, hidden_size, output_size)
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Dummy data
batch_size = 32
sequence_length = 5
x = torch.zeros(sequence_length, batch_size, input_size)  # One-hot encoded sequences
for i in range(batch_size):
    for t in range(sequence_length):
        x[t, i, torch.randint(0, input_size, (1,)).item()] = 1  # Random one-hot digits
target_q_values = torch.randn(sequence_length, batch_size, output_size)  # Dummy target Q-values

# Training loop
for epoch in range(100):
    h = torch.zeros(batch_size, hidden_size)  # Initial hidden state
    optimizer.zero_grad()

    q_values, h, update_gates = net(x, h)

    # Compute Q-learning loss (dummy example)
    q_loss = ((q_values - target_q_values)**2).mean()

    # Compute target masks for update gate loss
    target_masks = generate_target_masks(batch_size, sequence_length, hidden_size, x.permute(1, 0, 2))

    # Compute update gate loss
    update_loss = compute_update_gate_loss(update_gates.permute(1, 0, 2), target_masks)

    # Total loss
    total_loss = q_loss + lambda_update * update_loss

    # Backward pass and optimize
    total_loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {total_loss.item()}')


Epoch 1, Loss: 400.2486877441406
Epoch 2, Loss: 398.46087646484375
Epoch 3, Loss: 396.6715393066406
Epoch 4, Loss: 394.8627014160156
Epoch 5, Loss: 393.01690673828125
Epoch 6, Loss: 391.1184997558594
Epoch 7, Loss: 389.1527099609375
Epoch 8, Loss: 387.1055908203125
Epoch 9, Loss: 384.9633483886719
Epoch 10, Loss: 382.71209716796875
Epoch 11, Loss: 380.3382263183594
Epoch 12, Loss: 377.8282165527344
Epoch 13, Loss: 375.16912841796875
Epoch 14, Loss: 372.3486328125
Epoch 15, Loss: 369.35546875
Epoch 16, Loss: 366.17901611328125
Epoch 17, Loss: 362.8099365234375
Epoch 18, Loss: 359.2397155761719
Epoch 19, Loss: 355.461181640625
Epoch 20, Loss: 351.4690246582031
Epoch 21, Loss: 347.2602844238281
Epoch 22, Loss: 342.8346252441406
Epoch 23, Loss: 338.1951904296875
Epoch 24, Loss: 333.3492126464844
Epoch 25, Loss: 328.3089599609375
Epoch 26, Loss: 323.091796875
Epoch 27, Loss: 317.72149658203125
Epoch 28, Loss: 312.2276916503906
Epoch 29, Loss: 306.646484375
Epoch 30, Loss: 301.019287109375
E