<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Reinforcement_Learning_with_Graph_Neural_Networks_(GNNs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import torch.nn.functional as F

class GraphPolicyNetwork(nn.Module):
    def __init__(self, node_features, hidden_dim, output_dim):
        super(GraphPolicyNetwork, self).__init__()
        self.gnn1 = nn.Linear(node_features, hidden_dim)
        self.gnn2 = nn.Linear(hidden_dim, output_dim)
        self.policy_head = nn.Linear(output_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.gnn1(x))
        x = F.relu(self.gnn2(x))
        policy = self.policy_head(x).mean(dim=0)  # Aggregate node-level outputs
        return policy

def train_graph_rl(env, model, optimizer, num_episodes=1000, epsilon=0.1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for episode in range(num_episodes):
        state = env.reset().to(device)
        done = False
        total_reward = 0

        while not done:
            policy = model(state)
            if random.random() < epsilon:
                action = random.randint(0, policy.size(-1) - 1)  # Exploration
            else:
                action = torch.argmax(policy).item()  # Exploitation

            next_state, reward, done, _ = env.step(action)
            next_state = next_state.to(device)
            total_reward += reward

            # Policy Gradient Loss
            log_prob = torch.log(F.softmax(policy, dim=-1)[action])
            loss = -log_prob * reward

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            state = next_state

        print(f"Episode {episode + 1}, Total Reward: {total_reward}")

class DummyEnv:
    def reset(self):
        return torch.rand((10, 5))  # Simulated graph with 10 nodes and 5 features per node

    def step(self, action):
        return torch.rand((10, 5)), random.uniform(0, 1), True, {}  # Next state, reward, and termination flag

env = DummyEnv()
model = GraphPolicyNetwork(node_features=5, hidden_dim=16, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_graph_rl(env, model, optimizer)