In [9]:
import gymnasium as gym
import numpy as np
import networkx as nx
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import torch
import time

# Check if MPS is available
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
device = torch.device("mps" if mps_available else "cpu")
print(f"Using device: {device}")

class GraphEnv(gym.Env):
    def __init__(self, num_nodes, max_steps=1000):
        super(GraphEnv, self).__init__()
        self.num_nodes = num_nodes
        self.max_steps = max_steps
        self.current_step = 0
        self.graph = nx.Graph()
        self.action_space = gym.spaces.MultiDiscrete([num_nodes, num_nodes])
        self.observation_space = gym.spaces.Box(low=0, high=1, 
                                                shape=(num_nodes*num_nodes,), 
                                                dtype=np.int32)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.graph = nx.Graph()
        self.graph.add_nodes_from(range(self.num_nodes))
        self.current_step = 0
        return self._get_observation(), {}

    def step(self, action):
        self.current_step += 1
        node1, node2 = action
        if node1 != node2:
            if self.graph.has_edge(node1, node2):
                self.graph.remove_edge(node1, node2)
            else:
                self.graph.add_edge(node1, node2)
        
        obs = self._get_observation()
        reward = self._calculate_reward()
        terminated = self._is_terminated()
        truncated = self.current_step >= self.max_steps
        return obs, reward, terminated, truncated, {}

    def _get_observation(self):
        return nx.to_numpy_array(self.graph).flatten()

    def _calculate_reward(self):
        start_time = time.time()
        mis = self.approximate_mis()
        mis_time = time.time() - start_time
        
        # Reward longer MIS computation times
        reward = mis_time * 1000  # Scale up the time for better gradients
        
        # Penalize for very sparse or very dense graphs
        edge_density = self.graph.number_of_edges() / (self.num_nodes * (self.num_nodes - 1) / 2)
        density_penalty = -abs(edge_density - 0.5) * 10  # Favor graphs with ~50% edge density
        
        return reward + density_penalty

    def _is_terminated(self):
        return self.current_step >= self.max_steps

    def approximate_mis(self):
        return nx.maximal_independent_set(self.graph)

def create_env(num_nodes):
    return lambda: GraphEnv(num_nodes)

def train_agent(env, total_timesteps=100000):
    model = PPO("MlpPolicy", env, verbose=1, 
                learning_rate=0.0003, 
                n_steps=2048, 
                batch_size=64, 
                n_epochs=10, 
                gamma=0.99, 
                gae_lambda=0.95, 
                clip_range=0.2, 
                ent_coef=0.01,
                device=device)
    
    model.learn(total_timesteps=total_timesteps)
    return model

def generate_graph_with_visualization(model, env, delay=0.5):
    obs, _ = env.reset()
    done = False
    step = 0
    
    plt.figure(figsize=(10, 8))
    
    while not done:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        
        graph = env.envs[0].graph
        mis = nx.maximal_independent_set(graph)
        
        # Visualize the current state of the graph
        plt.clf()
        pos = nx.spring_layout(graph)
        nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=700)
        nx.draw_networkx_nodes(graph, pos, nodelist=mis, node_color='red', node_size=800)
        
        edge_labels = {(u, v): f'{u}-{v}' for u, v in graph.edges()}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
        
        plt.title(f"Graph Generation Step {step}\nNodes: {graph.number_of_nodes()}, Edges: {graph.number_of_edges()}, MIS size: {len(mis)}")
        
        clear_output(wait=True)
        display(plt.gcf())
        
        plt.pause(delay)
        step += 1
    
    plt.close()
    return graph

def analyze_graph(graph):
    mis = nx.maximal_independent_set(graph)
    print(f"Nodes: {graph.number_of_nodes()}")
    print(f"Edges: {graph.number_of_edges()}")
    print(f"MIS size: {len(mis)}")
    print(f"Average clustering coefficient: {nx.average_clustering(graph)}")
    print(f"Is connected: {nx.is_connected(graph)}")

# Main execution
if __name__ == "__main__":
    NUM_NODES = 30
    env = DummyVecEnv([create_env(NUM_NODES)])
    
    print("Training agent...")
    model = train_agent(env, total_timesteps=200000)
    
    print("Generating graph with visualization...")
    final_graph = generate_graph_with_visualization(model, env)
    
    print("Analyzing generated graph:")
    analyze_graph(final_graph)

Using device: mps
Training agent...
Using mps device
-----------------------------
| time/              |      |
|    fps             | 142  |
|    iterations      | 1    |
|    time_elapsed    | 14   |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 121         |
|    iterations           | 2           |
|    time_elapsed         | 33          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010183027 |
|    clip_fraction        | 0.0918      |
|    clip_range           | 0.2         |
|    entropy_loss         | -6.8        |
|    explained_variance   | 0.00621     |
|    learning_rate        | 0.0003      |
|    loss                 | 281         |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0349     |
|    value_loss           | 518         |
-----------------------

ValueError: not enough values to unpack (expected 2, got 1)