In [74]:
!pip install gymnasium networkx torch-geometric



In [75]:
import torch
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

class GraphColoringEnv(gym.Env):
    def __init__(self, num_colors=3):
        super(GraphColoringEnv, self).__init__()
        self.num_colors = num_colors
        self.edge_index = self._create_edge_index()
        self.n_nodes = int(torch.max(self.edge_index)) + 1

        # Action space: (node_id, color_id)
        self.action_space = spaces.MultiDiscrete([self.n_nodes, self.num_colors])

        # Observation: PyG Data object with x=[node_color] and edge_index
        self.observation_space = None  # not used in PyG; handled by downstream model

        self.state = None
        self.reset()

    def _create_edge_index(self):
        # Define a fixed graph: (0-1-2 triangle), 1-3-4 chain
        edges = [
            [0, 1], [1, 0],
            [0, 2], [2, 0],
            [1, 2], [2, 1],
            [1, 3], [3, 1],
            [3, 4], [4, 3]
        ]
        return torch.tensor(edges, dtype=torch.long).t().contiguous()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # -1 for all uncolored nodes
        self.state = torch.full((self.n_nodes, 1), -1, dtype=torch.long)
        return self._get_obs(), {}

    def _get_obs(self):
        return Data(x=self.state.clone(), edge_index=self.edge_index)

    def step(self, action):
        node, color = action
        node = int(node)
        color = int(color)

        reward = 0
        done = False

        if self.state[node].item() != -1:
            print("Node already colored")
            reward = -1  # Node already colored
        else:
            # Check for conflict with neighbors
            neighbors = self.edge_index[1][self.edge_index[0] == node]
            neighbor_colors = self.state[neighbors]
            if (neighbor_colors == color).any():
                reward = -1
            else:
                self.state[node] = color
                reward = 1

        # Check if done (all nodes colored)
        done = (self.state != -1).all().item()
        if done and reward > 0:
            reward += 10  # Bonus for successful coloring

        return self._get_obs(), reward, done, False, {}

    def render(self):
        print("Node Colors:", self.state.view(-1).tolist())


In [76]:
env = GraphColoringEnv(num_colors=3)
obs, _ = env.reset()

done = False
total_reward = 0

while not done:
    env.render()
    # Choose a random uncolored node and a random color
    uncolored = (obs.x.view(-1) == -1).nonzero(as_tuple=True)[0]
    if len(uncolored) == 0:
        break
    node = uncolored[0].item()
    color = np.random.randint(0, env.num_colors)
    obs, reward, done, _, _ = env.step((node, color))
    total_reward += reward

print("\nFinal Colors:", obs.x.view(-1).tolist())
print("Total Reward:", total_reward)


Node Colors: [-1, -1, -1, -1, -1]
Node Colors: [1, -1, -1, -1, -1]
Node Colors: [1, -1, -1, -1, -1]
Node Colors: [1, 0, -1, -1, -1]
Node Colors: [1, 0, -1, -1, -1]
Node Colors: [1, 0, 2, -1, -1]
Node Colors: [1, 0, 2, 2, -1]
Node Colors: [1, 0, 2, 2, -1]

Final Colors: [1, 0, 2, 2, 0]
Total Reward: 12


In [77]:

class GraphColorQNet(torch.nn.Module):
    def __init__(self, num_of_features, hidden_dim=64, color_embedding_dim=64, num_colors=3):
        super().__init__()
        self.gnn1 = GCNConv(num_of_features, hidden_dim)
        self.gnn2 = GCNConv(hidden_dim, hidden_dim)

        self.color_embedding = torch.nn.Embedding(num_colors, color_embedding_dim)

        self.q_mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim + color_embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1)
        )



    def forward(self, data):
        # data.x shape: [num_nodes, num_of_features (only colour for now)]
        x = data.x.float()  # make sure it's float
        edge_index = data.edge_index

        h = self.gnn1(x, edge_index).relu()
        h = self.gnn2(h, edge_index).relu()
        return h  # node embeddings

    def get_q_values(self, node_embs, node_ids, color_ids):
        node_vecs = node_embs[node_ids]
        color_embeddings = self.color_embedding(color_ids)

        combined = torch.cat([node_vecs, color_embeddings], dim=-1)
        q_values = self.q_mlp(combined)
        return q_values





In [78]:
import torch
import random

def select_action(model, data, num_colors, epsilon=0.1):
    color_column = data.x[:, 0]
    uncolored_nodes = torch.where(color_column == -1)[0]

    if not len(uncolored_nodes):
      print("GRAPH COLORED!")
      return None

    colors = torch.arange(num_colors, device=data.x.device)
    node_ids, color_ids = torch.meshgrid(uncolored_nodes, colors, indexing='ij')
    node_ids = node_ids.flatten()
    color_ids = color_ids.flatten()

    # Step 3: Predict Q-values
    with torch.no_grad():
        node_embs = model(data)
        q_values = model.get_q_values(node_embs, node_ids, color_ids)

    # Step 4: Epsilon-greedy selection
    if random.random() < epsilon:
        # Randomly explore
        idx = random.randint(0, len(q_values) - 1)
    else:
        # Exploit: pick best Q-value
        idx = torch.argmax(q_values).item()

    selected_node = node_ids[idx].item()
    selected_color = color_ids[idx].item()
    selected_q = q_values[idx].item()

    return selected_node, selected_color, selected_q


In [79]:
env = GraphColoringEnv(num_colors=3)
data, _ = env.reset()  # PyG Data object

model = GraphColorQNet(num_of_features=1, hidden_dim=64, color_embedding_dim=64, num_colors=3)

# Step 3: Get node embeddings
node_embs = model(data)  # shape: [num_nodes, hidden_dim]

node_ids = torch.tensor([0, 1, 2])     # Which nodes to color
color_ids = torch.tensor([0, 1, 2])    # Try different colors
q_values = model.get_q_values(node_embs, node_ids, color_ids)

print("Node IDs:   ", node_ids.tolist())
print("Color IDs:  ", color_ids.tolist())
print("Q-values:   ", q_values.tolist())

obs, _ = env.reset()
action = select_action(model, obs, num_colors=3, epsilon=1)

if action:
    node, color, q = action
    print(f"Selected action: Color node {node} with color {color} (Q={q:.3f})")
else:
    print("No valid actions left.")


Node IDs:    [0, 1, 2]
Color IDs:   [0, 1, 2]
Q-values:    [[-0.04053190350532532], [-0.015579842031002045], [0.22975106537342072]]
Selected action: Color node 3 with color 1 (Q=-0.016)


### One episode

In [80]:
# Create environment and reset
env = GraphColoringEnv(num_colors=3)
data, _ = env.reset()

# Create model (fresh/random for now)
model = GraphColorQNet(num_of_features=1, hidden_dim=64, color_embedding_dim=64, num_colors=3)

# Episode parameters
epsilon = 0.5
done = False
total_reward = 0
step_count = 0

print("Starting Episode...\n")

while not done:
    env.render()

    # Select an action
    action = select_action(model, data, num_colors=env.num_colors, epsilon=epsilon)

    if action is None:
        print("❗ No valid actions left.")
        break

    node, color, q_val = action
    print(f"-- Step {step_count}: Color node {node} with color {color} (Q={q_val:.3f})")

    # Apply action in environment
    obs, reward, done, _, _ = env.step((node, color))
    data = obs

    total_reward += reward
    step_count += 1
    print(f"--- Reward: {reward}\n")

print("Episode finished.")
print(f"Total reward: {total_reward}")
print(f"Steps taken: {step_count}")
print(f"Solution:")
env.render()


Starting Episode...

Node Colors: [-1, -1, -1, -1, -1]
-- Step 0: Color node 4 with color 2 (Q=0.122)
--- Reward: 1

Node Colors: [-1, -1, -1, -1, 2]
-- Step 1: Color node 0 with color 2 (Q=0.122)
--- Reward: 1

Node Colors: [2, -1, -1, -1, 2]
-- Step 2: Color node 3 with color 0 (Q=0.003)
--- Reward: 1

Node Colors: [2, -1, -1, 0, 2]
-- Step 3: Color node 1 with color 2 (Q=0.129)
--- Reward: -1

Node Colors: [2, -1, -1, 0, 2]
-- Step 4: Color node 2 with color 0 (Q=-0.004)
--- Reward: 1

Node Colors: [2, -1, 0, 0, 2]
-- Step 5: Color node 1 with color 0 (Q=0.008)
--- Reward: -1

Node Colors: [2, -1, 0, 0, 2]
-- Step 6: Color node 1 with color 0 (Q=0.008)
--- Reward: -1

Node Colors: [2, -1, 0, 0, 2]
-- Step 7: Color node 1 with color 1 (Q=-0.001)
--- Reward: 11

Episode finished.
Total reward: 12
Steps taken: 8
Solution:
Node Colors: [2, 1, 0, 0, 2]


# Replay buffer + training

In [81]:
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, node_id, color_id, reward, next_state, done):
        self.buffer.append((state, node_id, color_id, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, node_ids, color_ids, rewards, next_states, dones = zip(*batch)
        return list(states), torch.tensor(node_ids), torch.tensor(color_ids), \
               torch.tensor(rewards, dtype=torch.float32), list(next_states), torch.tensor(dones, dtype=torch.bool)

    def __len__(self):
        return len(self.buffer)



def train_step(model, target_model, buffer, optimizer, batch_size, gamma, num_colors):
    if len(buffer) < batch_size:
        return None  # not enough data

    states, node_ids, color_ids, rewards, next_states, dones = buffer.sample(batch_size)

    # Get Q(s, a)
    q_vals = []
    for i in range(batch_size):
        node_embs = model(states[i])
        q_val = model.get_q_values(node_embs, node_ids[i].unsqueeze(0), color_ids[i].unsqueeze(0))
        q_vals.append(q_val)
    q_vals = torch.stack(q_vals).squeeze()

    # Compute target Q-values
    target_qs = []
    for i in range(batch_size):
        if dones[i]:
            target_q = rewards[i]
        else:
            next_node_embs = target_model(next_states[i])
            color_space = torch.arange(num_colors)
            next_node_ids, next_color_ids = torch.meshgrid(
                torch.where(next_states[i].x[:, 0] == -1)[0],
                color_space,
                indexing='ij'
            )
            next_node_ids = next_node_ids.flatten()
            next_color_ids = next_color_ids.flatten()
            if len(next_node_ids) == 0:
                max_q = 0.0
            else:
                q_vals_next = target_model.get_q_values(next_node_embs, next_node_ids, next_color_ids)
                max_q = torch.max(q_vals_next)
            target_q = rewards[i] + gamma * max_q
        target_qs.append(target_q)
    target_qs = torch.stack(target_qs)

    # Loss & backprop
    loss = torch.nn.functional.mse_loss(q_vals, target_qs)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


### Hyperparameters

In [82]:
num_episodes = 300
epsilon_start = 1.0
epsilon_end = 0.05
epsilon_decay = 0.999
batch_size = 32
gamma = 0.99
learning_rate = 1e-3
target_update_freq = 10

In [83]:
env = GraphColoringEnv(num_colors=3)
model = GraphColorQNet(num_of_features=1, hidden_dim=64, color_embedding_dim=64, num_colors=3)
target_model = GraphColorQNet(num_of_features=1, hidden_dim=64, color_embedding_dim=64, num_colors=3)
target_model.load_state_dict(model.state_dict())  # start with same weights

buffer = ReplayBuffer(capacity=10000)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [88]:
epsilon = epsilon_start

for episode in range(1, num_episodes + 1):
    data, _ = env.reset()
    done = False
    total_reward = 0
    step_count = 0

    while not done:
        # Select action using current model
        action = select_action(model, data, num_colors=env.num_colors, epsilon=epsilon)

        if action is None:
            break

        node, color, _ = action
        print("Env: ")
        env.render()
        print(f"-- Selected node {node} with color {color}")

        # Save old state
        state = data

        # Step in env
        next_data, reward, done, _, _ = env.step((node, color))
        print(f"- Reward for it: {reward}")
        total_reward += reward

        # Save to buffer
        buffer.add(state, node, color, reward, next_data, done)

        # Update model
        loss = train_step(model, target_model, buffer, optimizer, batch_size, gamma, num_colors=env.num_colors)

        # Move to new state
        data = next_data
        step_count += 1

    # Epsilon decay
    epsilon = max(epsilon_end, epsilon * epsilon_decay)

    # Update target model
    if episode % target_update_freq == 0:
        target_model.load_state_dict(model.state_dict())

    # Logging
    print(f"🟢 Episode {episode:3d} | Steps: {step_count:2d} | Total Reward: {total_reward:3.0f} | Epsilon: {epsilon:.3f}")


Env: 
Node Colors: [-1, -1, -1, -1, -1]
-- Selected node 3 with color 2
- Reward for it: 1
Env: 
Node Colors: [-1, -1, -1, 2, -1]
-- Selected node 2 with color 1
- Reward for it: 1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 0 with color 1
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 1 with color 1
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 4 with color 2
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 1 with color 1
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 0 with color 1
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 1 with color 2
- Reward for it: -1
Env: 
Node Colors: [-1, -1, 1, 2, -1]
-- Selected node 4 with color 1
- Reward for it: 1
Env: 
Node Colors: [-1, -1, 1, 2, 1]
-- Selected node 0 with color 0
- Reward for it: 1
Env: 
Node Colors: [0, -1, 1, 2, 1]
-- Selected node 1 with color 1
- Reward for it: -1
Env: 
Node Col

KeyboardInterrupt: 