In [None]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque, namedtuple
import random
from itertools import combinations
import time
import copy
import os

# =========================================
# GRAPH UTILITY FUNCTIONS
# =========================================

def create_graph(n):
    """Create a random graph with n nodes."""
    G = nx.fast_gnp_random_graph(n, 0.5)
    return G

def create_state_pairs(n):
    """Create a pair of initial and target graphs."""
    G1 = create_graph(n)
    G2 = create_graph(n)

    # Ensure G2 is not empty
    while G2.number_of_edges() == 0:
        G2 = create_graph(n)

    return [G1, G2]

def local_complementation(G, node):
    """Apply local complementation on a given node."""
    G_copy = G.copy()
    neighbors = list(G_copy.neighbors(node))

    for i in range(len(neighbors)):
        for j in range(i+1, len(neighbors)):
            ni, nj = neighbors[i], neighbors[j]
            if G_copy.has_edge(ni, nj):
                G_copy.remove_edge(ni, nj)
            else:
                G_copy.add_edge(ni, nj)

    return G_copy

def calculate_cost(operations):
    """Calculate cost of operations (edge ops cost 10, local ops cost 1)."""
    edge_ops = sum(1 for op, _ in operations if op == "edge")
    local_ops = sum(1 for op, _ in operations if op == "local")
    return edge_ops * 10 + local_ops

def apply_operation(graph, op_type, param):
    """Apply a single operation to a graph."""
    G_copy = graph.copy()
    if op_type == "edge":
        i, j = param
        if G_copy.has_edge(i, j):
            G_copy.remove_edge(i, j)
        else:
            G_copy.add_edge(i, j)
    else:  # op_type == "local"
        G_copy = local_complementation(G_copy, param)
    return G_copy

def graph_difference(G1, G2):
    """Calculate difference between graphs."""
    adj1 = nx.to_numpy_array(G1)
    adj2 = nx.to_numpy_array(G2)
    return np.sum(np.abs(adj1 - adj2))

def estimate_min_solution_cost(initial_graph, target_graph):
    """
    Estimate the theoretical minimum solution cost based on graph properties.
    """
    # Calculate the graph difference (number of different edges)
    diff = graph_difference(initial_graph, target_graph)

    # In the best case, we might be able to change multiple edges with
    # a single local complementation (LC), which costs 1
    # Assume optimistically that an LC can change up to 3 edges
    edges_via_lc = diff // 3
    direct_edge_ops = diff - (edges_via_lc * 3)

    # Calculate estimated minimum cost
    # Each LC costs 1, each edge operation costs 10
    min_cost = edges_via_lc + (direct_edge_ops * 10)

    # Put a reasonable lower bound based on empirical evidence
    return max(min_cost, 3)  # At least 3 as a safe lower bound

# =========================================
# GRAPH FEATURE EXTRACTION
# =========================================

def graph_to_features(G, target_G=None, max_nodes=25):
    """Convert a NetworkX graph to a feature matrix and adjacency matrix."""
    n = G.number_of_nodes()

    # Create adjacency matrix
    adj_matrix = nx.to_numpy_array(G)

    # Create feature matrix with basic node features
    # Features: degree, clustering coefficient
    features = np.zeros((n, 3))
    for i in range(n):
        features[i, 0] = G.degree(i) / n  # Normalized degree
        features[i, 1] = nx.clustering(G, i) if G.degree(i) > 1 else 0  # Clustering coefficient

        # If target graph is provided, add a feature indicating if node has different connections
        if target_G:
            target_adj = nx.to_numpy_array(target_G)
            features[i, 2] = np.sum(np.abs(adj_matrix[i] - target_adj[i])) / n

    # Create target mask if needed (for larger graphs, could be used to focus on specific nodes)
    target_mask = np.ones(n)

    # Pad to max_nodes for consistent tensor sizes
    if n < max_nodes:
        adj_pad = np.zeros((max_nodes, max_nodes))
        adj_pad[:n, :n] = adj_matrix

        feat_pad = np.zeros((max_nodes, features.shape[1]))
        feat_pad[:n] = features

        mask_pad = np.zeros(max_nodes)
        mask_pad[:n] = target_mask

        return feat_pad, adj_pad, mask_pad, n

    return features, adj_matrix, target_mask, n

# =========================================
# NEURAL NETWORK
# =========================================

class GraphTransformNet(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, max_nodes=25):
        super(GraphTransformNet, self).__init__()

        self.max_nodes = max_nodes

        # Graph embedding layers
        self.node_embed = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # For processing adjacency information
        self.graph_conv1 = nn.Linear(hidden_dim, hidden_dim)
        self.graph_conv2 = nn.Linear(hidden_dim, hidden_dim)

        # Combine current and target graph information
        self.combine = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Action-specific heads
        self.local_head = nn.Linear(hidden_dim, 1)  # Score for local complementation on each node
        self.edge_head = nn.Linear(hidden_dim * 2, 1)  # Score for toggling each edge

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, node_features, adjacency, target_features, target_adjacency, mask, num_nodes):
        """Forward pass to compute action probabilities and state value."""
        batch_size = node_features.size(0)

        # Initial node embeddings
        node_embeddings = self.node_embed(node_features)
        target_embeddings = self.node_embed(target_features)

        # Simple graph convolution operations
        # Multiply node features by adjacency matrix to aggregate neighborhood info
        neighbor_features = torch.bmm(adjacency, node_embeddings)
        node_embeddings = F.relu(self.graph_conv1(node_embeddings + neighbor_features))

        neighbor_features = torch.bmm(adjacency, node_embeddings)
        node_embeddings = F.relu(self.graph_conv2(node_embeddings + neighbor_features))

        # Same for target graph
        target_neighbor_features = torch.bmm(target_adjacency, target_embeddings)
        target_embeddings = F.relu(self.graph_conv1(target_embeddings + target_neighbor_features))

        target_neighbor_features = torch.bmm(target_adjacency, target_embeddings)
        target_embeddings = F.relu(self.graph_conv2(target_embeddings + target_neighbor_features))

        # Global graph representation (simple mean pooling)
        graph_embed = torch.sum(node_embeddings * mask.unsqueeze(-1), dim=1) / torch.sum(mask, dim=1, keepdim=True)
        target_graph_embed = torch.sum(target_embeddings * mask.unsqueeze(-1), dim=1) / torch.sum(mask, dim=1, keepdim=True)

        # Combine current and target information
        combined = torch.cat([graph_embed, target_graph_embed], dim=1)

        # Broadcast combined info back to nodes
        combined_node = self.combine(torch.cat([
            node_embeddings,
            target_embeddings
        ], dim=2))

        # Get local complementation logits for each node
        local_logits = self.local_head(combined_node).squeeze(-1)

        # Create edge representation for all possible edges
        edge_logits = torch.zeros((batch_size, self.max_nodes, self.max_nodes)).to(node_features.device)

        # Calculate edge logits
        for i in range(self.max_nodes):
            for j in range(i+1, self.max_nodes):
                if i < num_nodes and j < num_nodes:
                    # Concatenate node embeddings for the edge
                    edge_embed = torch.cat([
                        combined_node[:, i],
                        combined_node[:, j]
                    ], dim=1)

                    # Get edge toggle logit
                    edge_logit = self.edge_head(edge_embed).squeeze(-1)
                    edge_logits[:, i, j] = edge_logit
                    edge_logits[:, j, i] = edge_logit  # Symmetric

        # Mask out invalid nodes/edges
        mask_2d = mask.unsqueeze(2) * mask.unsqueeze(1)
        local_logits = local_logits * mask
        edge_logits = edge_logits * mask_2d

        # Get state value
        value = self.value_head(combined)

        return local_logits, edge_logits, value

# =========================================
# ENVIRONMENT
# =========================================

class GraphTransformationEnv:
    def __init__(self, initial_graph, target_graph):
        self.initial_graph = initial_graph
        self.target_graph = target_graph
        self.current_graph = initial_graph.copy()
        self.n_nodes = initial_graph.number_of_nodes()
        self.operations_history = []
        self.initial_difference = graph_difference(initial_graph, target_graph)
        self.step_count = 0
        self.max_steps = self.n_nodes * 3  # Allow more steps for larger graphs

    def reset(self):
        self.current_graph = self.initial_graph.copy()
        self.operations_history = []
        self.step_count = 0
        return self._get_state()

    def _get_state(self):
        # Return current graph and target graph
        return self.current_graph, self.target_graph

    def step(self, action):
        """
        Take a step in the environment by applying the given action.
        Enhanced with sophisticated rewards to guide toward optimal solutions.
        """
        op_type, param = action

        # Store the previous graph state to compare changes
        prev_graph = self.current_graph.copy()

        # Apply the operation
        new_graph = apply_operation(self.current_graph, op_type, param)

        # Record the operation
        self.operations_history.append((op_type, param))

        # Calculate reward based on graph similarity improvement
        prev_diff = graph_difference(self.current_graph, self.target_graph)
        new_diff = graph_difference(new_graph, self.target_graph)

        # Enhanced reward for moving closer to target graph
        diff_improvement = prev_diff - new_diff
        reward = diff_improvement * 5  # Stronger gradient

        # Operation cost penalties (standard)
        if op_type == "edge":
            reward -= 10  # High cost for edge operations
        else:
            reward -= 1   # Low cost for local complementation

        # Check if the action is a local complementation that efficiently changes multiple edges
        if op_type == "local":
            node = param
            affected_edges = self._count_affected_edges_by_LC(prev_graph, new_graph)
            if affected_edges > 2:  # If LC affected multiple edges efficiently
                # Bonus reward proportional to effectiveness
                reward += affected_edges * 2
                # Extra bonus if the changes move us toward the target
                if diff_improvement > 0:
                    reward += affected_edges * 3

        # Update current graph
        self.current_graph = new_graph
        self.step_count += 1

        # Check if target reached (exact match)
        target_reached = (new_diff == 0)

        # Check if we're very close to target (for partial credit)
        very_close = (new_diff <= 2)  # Within 1-2 edges of target

        # Step limit reached
        step_limit_reached = (self.step_count >= self.max_steps)

        # Determine if episode is done
        done = target_reached or step_limit_reached

        # Reward bonuses for completion
        if target_reached:
            # Base completion bonus
            reward += 100

            # Reward inversely proportional to solution cost
            solution_cost = calculate_cost(self.operations_history)
            reward += 500 / (solution_cost + 1)  # +1 to avoid division by zero

            # Bonus for solutions with fewer operations (parsimony)
            reward += 20 / (len(self.operations_history) + 1)

            # Special bonus for finding very efficient solutions
            if solution_cost < 15:  # Empirically determined threshold for "very good" solutions
                reward += 100
            elif solution_cost < 25:  # Good solutions
                reward += 50
        elif very_close:
            reward += 20   # Partial bonus for getting very close

        # Add penalties for long episodes without success
        if step_limit_reached and not target_reached:
            reward -= 30   # Penalty for failing to reach target within steps

        return self._get_state(), reward, done, {}

    def _count_affected_edges_by_LC(self, prev_graph, new_graph):
        """
        Count how many edges were affected by a local complementation operation.
        """
        prev_edges = set(prev_graph.edges())
        new_edges = set(new_graph.edges())

        # Edges that were added or removed
        added_edges = new_edges - prev_edges
        removed_edges = prev_edges - new_edges

        # Total number of affected edges
        return len(added_edges) + len(removed_edges)

    def get_valid_actions(self):
        valid_actions = []

        # Local complementation on each node
        for node in range(self.n_nodes):
            # Skip isolated nodes (no effect)
            if self.current_graph.degree(node) > 0:
                valid_actions.append(("local", node))

        # Edge toggle for each pair
        for i, j in combinations(range(self.n_nodes), 2):
            valid_actions.append(("edge", (i, j)))

        return valid_actions

# =========================================
# EXPERIENCE REPLAY BUFFER
# =========================================

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
        self.Experience = namedtuple('Experience',
                                    ['state', 'action', 'reward', 'next_state', 'done'])

    def add(self, state, action, reward, next_state, done):
        experience = self.Experience(state, action, reward, next_state, done)
        self.buffer.append(experience)

    def sample(self, batch_size):
        experiences = random.sample(self.buffer, k=min(batch_size, len(self.buffer)))
        return experiences

    def __len__(self):
        return len(self.buffer)

# =========================================
# RL AGENT
# =========================================

class GraphTransformAgent:
    def __init__(self, input_dim=3, hidden_dim=64, max_nodes=25, lr=0.001, gamma=0.99, device='cpu'):
        self.device = device
        self.max_nodes = max_nodes
        self.model = GraphTransformNet(input_dim, hidden_dim, max_nodes).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.batch_size = 16
        self.memory = ReplayBuffer(capacity=50000)

    def select_action(self, env, current_graph, target_graph):
        valid_actions = env.get_valid_actions()

        # Random action with probability epsilon
        if random.random() < self.epsilon:
            return random.choice(valid_actions)

        self.model.eval()
        with torch.no_grad():
            # Convert graphs to tensors
            features, adjacency, mask, n_nodes = graph_to_features(current_graph, target_graph, self.max_nodes)
            target_features, target_adjacency, _, _ = graph_to_features(target_graph, current_graph, self.max_nodes)

            # Add batch dimension and convert to tensors
            features = torch.FloatTensor(features).unsqueeze(0).to(self.device)
            adjacency = torch.FloatTensor(adjacency).unsqueeze(0).to(self.device)
            target_features = torch.FloatTensor(target_features).unsqueeze(0).to(self.device)
            target_adjacency = torch.FloatTensor(target_adjacency).unsqueeze(0).to(self.device)
            mask = torch.FloatTensor(mask).unsqueeze(0).to(self.device)

            # Get action scores
            local_logits, edge_logits, _ = self.model(features, adjacency, target_features, target_adjacency, mask, n_nodes)

            # Score each valid action
            action_scores = {}
            for action in valid_actions:
                if action[0] == "local":
                    node = action[1]
                    score = local_logits[0, node].item()
                else:  # "edge"
                    i, j = action[1]
                    score = edge_logits[0, i, j].item()

                action_scores[action] = score

            # Select action with highest score
            best_action = max(action_scores.items(), key=lambda x: x[1])[0]
            return best_action

    def _prepare_batch(self, experiences):
        """Convert batch of experiences to tensors for training."""
        # Initialize tensors
        batch_size = len(experiences)
        features_batch = []
        adjacency_batch = []
        mask_batch = []
        target_features_batch = []
        target_adjacency_batch = []
        actions_batch = []
        rewards = []
        next_features_batch = []
        next_adjacency_batch = []
        next_mask_batch = []
        next_target_features_batch = []
        next_target_adjacency_batch = []
        dones = []
        n_nodes_batch = []

        # Process each experience
        for exp in experiences:
            current_graph, target_graph = exp.state
            next_current_graph, next_target_graph = exp.next_state

            # Current state
            features, adjacency, mask, n_nodes = graph_to_features(current_graph, target_graph, self.max_nodes)
            target_features, target_adjacency, _, _ = graph_to_features(target_graph, current_graph, self.max_nodes)

            # Next state
            next_features, next_adjacency, next_mask, _ = graph_to_features(next_current_graph, next_target_graph, self.max_nodes)
            next_target_features, next_target_adjacency, _, _ = graph_to_features(next_target_graph, next_current_graph, self.max_nodes)

            # Add to batches
            features_batch.append(features)
            adjacency_batch.append(adjacency)
            mask_batch.append(mask)
            target_features_batch.append(target_features)
            target_adjacency_batch.append(target_adjacency)

            next_features_batch.append(next_features)
            next_adjacency_batch.append(next_adjacency)
            next_mask_batch.append(next_mask)
            next_target_features_batch.append(next_target_features)
            next_target_adjacency_batch.append(next_target_adjacency)

            actions_batch.append(exp.action)
            rewards.append(exp.reward)
            dones.append(float(exp.done))
            n_nodes_batch.append(n_nodes)

        # Convert to tensors
        features_tensor = torch.FloatTensor(np.array(features_batch)).to(self.device)
        adjacency_tensor = torch.FloatTensor(np.array(adjacency_batch)).to(self.device)
        mask_tensor = torch.FloatTensor(np.array(mask_batch)).to(self.device)
        target_features_tensor = torch.FloatTensor(np.array(target_features_batch)).to(self.device)
        target_adjacency_tensor = torch.FloatTensor(np.array(target_adjacency_batch)).to(self.device)

        next_features_tensor = torch.FloatTensor(np.array(next_features_batch)).to(self.device)
        next_adjacency_tensor = torch.FloatTensor(np.array(next_adjacency_batch)).to(self.device)
        next_mask_tensor = torch.FloatTensor(np.array(next_mask_batch)).to(self.device)
        next_target_features_tensor = torch.FloatTensor(np.array(next_target_features_batch)).to(self.device)
        next_target_adjacency_tensor = torch.FloatTensor(np.array(next_target_adjacency_batch)).to(self.device)

        rewards_tensor = torch.FloatTensor(rewards).to(self.device)
        dones_tensor = torch.FloatTensor(dones).to(self.device)

        return {
            'current_state': (features_tensor, adjacency_tensor, target_features_tensor, target_adjacency_tensor, mask_tensor, n_nodes_batch),
            'actions': actions_batch,
            'rewards': rewards_tensor,
            'next_state': (next_features_tensor, next_adjacency_tensor, next_target_features_tensor, next_target_adjacency_tensor, next_mask_tensor, n_nodes_batch),
            'dones': dones_tensor
        }

    def _learn(self):
        """Update model based on experiences."""
        self.model.train()

        # Sample experiences
        experiences = self.memory.sample(self.batch_size)
        batch = self._prepare_batch(experiences)

        # Unpack batch
        features, adjacency, target_features, target_adjacency, mask, n_nodes = batch['current_state']
        actions = batch['actions']
        rewards = batch['rewards']
        next_features, next_adjacency, next_target_features, next_target_adjacency, next_mask, _ = batch['next_state']
        dones = batch['dones']

        # Get predictions for current state
        local_logits, edge_logits, values = self.model(
            features, adjacency, target_features, target_adjacency, mask, n_nodes[0])

        # Get next state values
        with torch.no_grad():
            _, _, next_values = self.model(
                next_features, next_adjacency, next_target_features, next_target_adjacency, next_mask, n_nodes[0])
            # Make sure rewards and dones have the right shape for broadcasting
            rewards = rewards.unsqueeze(1)
            dones = dones.unsqueeze(1)
            target_values = rewards + (1 - dones) * self.gamma * next_values

        # Value loss (critic)
        value_loss = F.mse_loss(values, target_values)

        # Action loss (actor)
        policy_loss = torch.zeros(1, device=self.device)

        for i, (action_type, param) in enumerate(actions):
            # Convert advantage to scalar with .item()
            advantage = (target_values[i] - values[i]).detach().item()

            if action_type == "local":
                # Local complementation
                node = param
                policy_loss -= local_logits[i, node] * advantage
            else:  # "edge"
                # Edge toggle
                src, dst = param
                policy_loss -= edge_logits[i, src, dst] * advantage

        # Total loss
        loss = value_loss + policy_loss

        # Update model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    # Fast learning method placeholder - if you have an optimized version
    def _learn_fast(self):
        return self._learn()

    def train(self, env, num_episodes=1000, max_time=3600):
        """Train the agent."""
        start_time = time.time()
        rewards_history = []
        best_solution = None
        best_solution_cost = float('inf')

        for episode in range(num_episodes):
            if time.time() - start_time > max_time:
                print(f"Time limit reached after {episode} episodes.")
                break

            # Reset environment
            state = env.reset()
            episode_reward = 0
            done = False

            while not done:
                # Select action
                action = self.select_action(env, *state)

                # Take action
                next_state, reward, done, _ = env.step(action)

                # Store experience
                self.memory.add(state, action, reward, next_state, done)

                # Learn from experiences
                if len(self.memory) > self.batch_size:
                    self._learn()

                state = next_state
                episode_reward += reward

            # Decay exploration rate
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

            # Record metrics
            rewards_history.append(episode_reward)

            # Check if this solution is the best so far
            if graph_difference(env.current_graph, env.target_graph) == 0:
                solution_cost = calculate_cost(env.operations_history)
                if solution_cost < best_solution_cost:
                    best_solution = env.operations_history.copy()
                    best_solution_cost = solution_cost
                    print(f"Episode {episode}: Found better solution with cost {best_solution_cost}")

            # Logging
            if episode % 10 == 0:
                avg_reward = sum(rewards_history[-10:]) / min(10, len(rewards_history[-10:]))
                print(f"Episode {episode}: Avg reward = {avg_reward:.2f}, Epsilon = {self.epsilon:.2f}")
                if best_solution:
                    print(f"Best solution so far: cost {best_solution_cost}, {len(best_solution)} operations")

        return best_solution, rewards_history

# =========================================
# VISUALIZATION
# =========================================

def visualize_transformation(initial_graph, operations, max_steps_to_show=12):
    """Visualize the graph transformation process step by step."""
    current = initial_graph.copy()

    # Limit the number of steps to show
    operations_to_show = operations[:max_steps_to_show]
    n_steps = len(operations_to_show)

    # Make sure we don't create an empty grid
    if n_steps == 0:
        # Just show the initial graph if no operations
        plt.figure(figsize=(6, 6))
        pos = nx.spring_layout(current, seed=42)
        nx.draw(current, pos, with_labels=True, node_color='lightblue',
                node_size=500, font_weight='bold')
        plt.title("Graph (No Operations)")
        plt.tight_layout()
        plt.show()
        return current

    # Calculate an appropriate grid size for subplots
    if n_steps <= 2:
        rows, cols = 1, n_steps + 1  # One row: [initial] [op1] [op2]
    elif n_steps <= 5:
        rows, cols = 2, 3  # 2x3 grid (6 cells)
    elif n_steps <= 8:
        rows, cols = 3, 3  # 3x3 grid (9 cells)
    else:
        rows, cols = 4, 3  # 4x3 grid (12 cells)

    # Create the figure
    plt.figure(figsize=(cols*4, rows*4))

    # Show initial graph
    plt.subplot(rows, cols, 1)
    pos = nx.spring_layout(current, seed=42)  # Consistent layout
    nx.draw(current, pos, with_labels=True, node_color='lightblue',
            node_size=500, font_weight='bold')
    plt.title("Initial Graph")

    # Apply and show each operation
    for i, (op_type, param) in enumerate(operations_to_show):
        # Check if we've run out of subplot space
        if i+2 > rows*cols:
            print(f"Warning: Too many operations to show. Showing only the first {i} operations.")
            break

        # Apply operation
        if op_type == "edge":
            i_node, j_node = param
            if current.has_edge(i_node, j_node):
                current.remove_edge(i_node, j_node)
                op_text = f"Remove edge ({i_node}, {j_node})"
            else:
                current.add_edge(i_node, j_node)
                op_text = f"Add edge ({i_node}, {j_node})"
        else:  # op_type == "local"
            node = param
            current = local_complementation(current, node)
            op_text = f"Local comp. on node {node}"

        plt.subplot(rows, cols, i+2)
        nx.draw(current, pos, with_labels=True, node_color='lightgreen',
                node_size=500, font_weight='bold')
        plt.title(f"Step {i+1}: {op_text}")

    # If we didn't show all operations, continue applying them without visualization
    for op_type, param in operations[max_steps_to_show:]:
        if op_type == "edge":
            i, j = param
            if current.has_edge(i, j):
                current.remove_edge(i, j)
            else:
                current.add_edge(i, j)
        else:  # op_type == "local"
            node = param
            current = local_complementation(current, node)

    plt.tight_layout()
    plt.show()

    return current

# =========================================
# TRAINING FUNCTION
# =========================================

def train_agent(n_nodes=4, num_graphs=20, max_episodes=2000, target_success_rate=0.9):
    """
    Train the RL agent on graph pairs with enhancements for finding optimal solutions.
    """
    print(f"Training on {num_graphs} randomly generated {n_nodes}-node graph pairs")

    # Generate more diverse graph pairs for training
    graph_pairs = []

    # Add predefined graph examples
    G9 = nx.Graph()
    G9.add_nodes_from(range(6))
    G10 = nx.Graph()
    G10.add_nodes_from(range(6))
    G10.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (3, 5), (2, 4)])
    graph_pairs.append((G9, G10))

    G7 = nx.Graph()
    G7.add_nodes_from(range(5))
    G7.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
    G8 = nx.Graph()
    G8.add_nodes_from(range(5))
    G8.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4), (1, 4), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)])
    graph_pairs.append((G7, G8))

    G1 = nx.Graph()
    G1.add_nodes_from(range(5))
    G2 = nx.cycle_graph(5)
    graph_pairs.append((G1, G2))

    G3 = nx.Graph()
    G3.add_nodes_from(range(8))
    G4 = nx.cubical_graph(create_using=nx.Graph)
    graph_pairs.append((G3, G4))

    G5 = nx.Graph()
    G5.add_nodes_from(range(6))
    G6 = nx.Graph()
    G6.add_nodes_from(range(6))
    G6.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4), (3, 5), (1, 5)])
    graph_pairs.append((G5, G6))

    # Create agent
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    agent = GraphTransformAgent(max_nodes=n_nodes, device=device)

    # Adjust agent parameters for better exploration
    agent.epsilon = 1.0        # Start fully exploratory
    agent.epsilon_decay = 0.998  # Slower decay for more exploration
    agent.epsilon_min = 0.05   # Slightly higher minimum exploration

    # Track training statistics
    all_training_stats = []
    overall_success_count = 0

    # Train on all graph pairs
    for i, (initial_graph, target_graph) in enumerate(graph_pairs):
        print(f"\nTraining on graph pair {i+1}/{len(graph_pairs)}...")

        # Estimate theoretical minimum solution cost
        est_min_cost = estimate_min_solution_cost(initial_graph, target_graph)
        print(f"Estimated minimum solution cost: {est_min_cost}")

        # Visualize the training pair
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        pos = nx.spring_layout(initial_graph, seed=42)
        nx.draw(initial_graph, pos, with_labels=True, node_color='lightblue',
                node_size=500, font_weight='bold')
        plt.title("Initial Graph")

        plt.subplot(1, 2, 2)
        nx.draw(target_graph, pos, with_labels=True, node_color='lightgreen',
                node_size=500, font_weight='bold')
        plt.title("Target Graph")

        plt.tight_layout()
        plt.show()

        env = GraphTransformationEnv(initial_graph, target_graph)

        # Track best solution for this graph pair
        best_solution = None
        best_solution_cost = float('inf')
        success_found = False
        consecutive_successes = 0  # Track consecutive successes for more reliable stopping

        # Track different solution paths found
        solution_archive = []  # List of (operations, cost) tuples

        # Variables to control exploration phases
        episodes_since_improvement = 0
        exploration_phase = False

        # Train until max episodes or success
        for episode in range(max_episodes):
            # Reset environment
            state = env.reset()
            episode_reward = 0
            done = False

            # Temporarily increase exploration if we're in exploration phase
            if exploration_phase:
                original_epsilon = agent.epsilon
                agent.epsilon = 0.8  # High exploration

            # Run episode
            while not done:
                action = agent.select_action(env, *state)
                next_state, reward, done, _ = env.step(action)
                agent.memory.add(state, action, reward, next_state, done)

                if len(agent.memory) > agent.batch_size:
                    agent._learn_fast()

                state = next_state
                episode_reward += reward

            # Reset epsilon if we were in exploration phase
            if exploration_phase:
                agent.epsilon = original_epsilon
                exploration_phase = False

            # Standard epsilon decay
            if episode > 1000 and success_found:
                agent.epsilon = max(0.01, agent.epsilon * 0.995)  # Lower floor, faster decay
            else:
                agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)

            # Check if this episode found a solution
            if graph_difference(env.current_graph, target_graph) == 0:
                solution_cost = calculate_cost(env.operations_history)
                success_found = True
                consecutive_successes += 1

                # Record different solution paths
                solution_path = tuple((op, param if not isinstance(param, tuple) else tuple(param))
                                      for op, param in env.operations_history)

                # Check if this is a new solution path
                if all(solution_path != path for path, _ in solution_archive):
                    solution_archive.append((solution_path, solution_cost))
                    print(f"Episode {episode}: Found new solution path with cost {solution_cost}")

                # Check if this is a better solution
                if solution_cost < best_solution_cost:
                    best_solution = env.operations_history.copy()
                    best_solution_cost = solution_cost
                    print(f"Episode {episode}: Found better solution with cost {best_solution_cost}")
                    episodes_since_improvement = 0
                else:
                    episodes_since_improvement += 1
            else:
                consecutive_successes = 0
                episodes_since_improvement += 1

            # Log progress
            if episode % 50 == 0:
                print(f"Episode {episode}: Reward = {episode_reward:.2f}, Epsilon = {agent.epsilon:.3f}")
                if best_solution:
                    print(f"Best solution so far: cost {best_solution_cost}, {len(best_solution)} operations")

            # Periodically enter exploration phase to find diverse solutions
            if success_found and episode > 200 and episode % 100 == 0:
                print(f"Entering exploration phase at episode {episode}")
                exploration_phase = True

            # If we've found a solution but it might not be optimal, try harder
            if success_found and best_solution_cost > est_min_cost * 1.2 and episodes_since_improvement > 200:
                print(f"Solution cost ({best_solution_cost}) still above estimated minimum ({est_min_cost})")
                print("Temporarily increasing exploration to find better solutions...")

                # Save current epsilon
                saved_epsilon = agent.epsilon
                # Temporarily increase exploration
                agent.epsilon = 0.9

                # Run a set of focused exploration episodes
                for _ in range(30):
                    explore_state = env.reset()
                    explore_done = False

                    while not explore_done:
                        explore_action = agent.select_action(env, *explore_state)
                        explore_state, _, explore_done, _ = env.step(explore_action)

                    # Check if we found a better solution
                    if graph_difference(env.current_graph, target_graph) == 0:
                        explore_cost = calculate_cost(env.operations_history)
                        if explore_cost < best_solution_cost:
                            best_solution = env.operations_history.copy()
                            best_solution_cost = explore_cost
                            print(f"Exploration found better solution with cost {best_solution_cost}")
                            episodes_since_improvement = 0

                # Restore epsilon
                agent.epsilon = saved_epsilon

            # Early stopping with verification
            if success_found and episode > 200:
                # Try 5 verification runs with epsilon=0 (no exploration)
                verification_successes = 0
                original_epsilon = agent.epsilon
                agent.epsilon = 0

                for _ in range(5):
                    verify_state = env.reset()
                    verify_done = False
                    while not verify_done:
                        verify_action = agent.select_action(env, *verify_state)
                        verify_state, _, verify_done, _ = env.step(verify_action)

                    if graph_difference(env.current_graph, target_graph) == 0:
                        verification_successes += 1

                agent.epsilon = original_epsilon

                # Early stopping conditions:
                # 1. If we can reliably solve it with a cost close to theoretical minimum
                if verification_successes >= 4 and best_solution_cost <= est_min_cost * 1.2:
                    print(f"Solution verified in {verification_successes}/5 attempts!")
                    print(f"Cost ({best_solution_cost}) is within 20% of estimated minimum ({est_min_cost})")
                    break

                # 2. If we've had 3 consecutive successes and no improvement for a long time
                if consecutive_successes >= 3 and episodes_since_improvement > 500:
                    print(f"Early stopping after {consecutive_successes} consecutive successes")
                    print(f"No improvement for {episodes_since_improvement} episodes")
                    break

        # Record training stats for this graph pair
        if best_solution:
            overall_success_count += 1
            all_training_stats.append({
                'graph_pair': i,
                'success': True,
                'episodes': episode + 1,
                'best_cost': best_solution_cost,
                'solution_length': len(best_solution),
                'est_min_cost': est_min_cost,
                'solution_archive': solution_archive
            })

            # Visualize the solution
            print("\nFinal solution:")
            visualize_transformation(initial_graph, best_solution)

            # Show all discovered solutions
            print(f"\nDiscovered {len(solution_archive)} different solution paths:")
            solution_archive.sort(key=lambda x: x[1])  # Sort by cost
            for j, (solution_path, solution_cost) in enumerate(solution_archive[:5]):  # Show top 5
                print(f"Solution {j+1}: Cost {solution_cost}")
                ops = []
                for op_type, param in solution_path:
                    if op_type == "local":
                        ops.append(f"LC on node {param}")
                    else:
                        i, j = param
                        ops.append(f"{'Add' if not initial_graph.has_edge(i, j) else 'Remove'} edge ({i},{j})")
                print(" → ".join(ops))

        else:
            all_training_stats.append({
                'graph_pair': i,
                'success': False,
                'episodes': episode + 1,
                'best_cost': float('inf'),
                'solution_length': 0,
                'est_min_cost': est_min_cost
            })
            print("No solution found for this graph pair.")

    # Print overall training results
    success_rate = overall_success_count / len(graph_pairs)
    print(f"\nTraining completed:")
    print(f"Success rate: {success_rate:.2f} ({overall_success_count}/{len(graph_pairs)} graph pairs)")

    if overall_success_count > 0:
        avg_cost = sum(stat['best_cost'] for stat in all_training_stats if stat['success']) / overall_success_count
        print(f"Average solution cost: {avg_cost:.2f}")

        # Calculate optimality ratio
        optimality_stats = [(stat['best_cost'] / stat['est_min_cost'])
                           for stat in all_training_stats if stat['success']]
        avg_optimality = sum(optimality_stats) / len(optimality_stats)
        print(f"Average solution optimality: {avg_optimality:.2f}x estimated minimum")

    # Save the trained model
    model_dir = 'models'
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    model_path = os.path.join(model_dir, f'graph_transform_model_{n_nodes}nodes.pt')
    torch.save({
        'model_state_dict': agent.model.state_dict(),
        'n_nodes': n_nodes,
        'training_stats': all_training_stats,
        'success_rate': success_rate
    }, model_path)

    print(f"Model saved to {model_path}")

    return agent

# =========================================
# NOTEBOOK-COMPATIBLE MAIN FUNCTION
# =========================================

def run_notebook_training(n_nodes=8, num_graphs=5, max_episodes=2000):
    """Run training directly from a notebook cell."""
    trained_agent = train_agent(n_nodes=n_nodes, num_graphs=num_graphs, max_episodes=max_episodes)
    return trained_agent

def run_notebook_testing(model_path, initial_graph=None, target_graph=None):
    """Run testing on specific graphs from a notebook cell."""
    # Load the model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load(model_path, map_location=device)

    agent = GraphTransformAgent(max_nodes=checkpoint['n_nodes'], device=device)
    agent.model.load_state_dict(checkpoint['model_state_dict'])
    agent.epsilon = 0  # No exploration during testing

    # Create graphs if not provided
    if initial_graph is None or target_graph is None:
        initial_graph, target_graph = create_state_pairs(checkpoint['n_nodes'])

    # Visualize graphs
    plt.figure(figsize=(10, 5))
    pos = nx.spring_layout(initial_graph, seed=42)

    plt.subplot(1, 2, 1)
    nx.draw(initial_graph, pos, with_labels=True, node_color='lightblue',
            node_size=500, font_weight='bold')
    plt.title("Initial Graph")

    plt.subplot(1, 2, 2)
    nx.draw(target_graph, pos, with_labels=True, node_color='lightcoral',
            node_size=500, font_weight='bold')
    plt.title("Target Graph")

    plt.tight_layout()
    plt.show()

    # Run agent on graphs
    env = GraphTransformationEnv(initial_graph, target_graph)
    state = env.reset()
    done = False
    steps = 0
    max_steps = initial_graph.number_of_nodes() * 3

    print("Finding solution...")
    while not done and steps < max_steps:
        action = agent.select_action(env, *state)
        state, _, done, _ = env.step(action)
        steps += 1

    # Check result
    if graph_difference(env.current_graph, target_graph) == 0:
        print("✅ Solution found!")
        cost = calculate_cost(env.operations_history)
        print(f"Solution cost: {cost}")
        print(f"Number of steps: {steps}")

        # Display the solution
        visualize_transformation(initial_graph, env.operations_history)
        return env.operations_history
    else:
        print("❌ Failed to find solution")
        print(f"Remaining difference: {graph_difference(env.current_graph, target_graph)}")
        return None

# For running directly
if __name__ == "__main__":
    # Check if running in notebook environment
    try:
        import IPython
        is_notebook = True
    except ImportError:
        is_notebook = False

    if is_notebook:
        print("Running in notebook environment")
        # Call run_notebook_training() or run_notebook_testing() directly in cells
    else:
        # Command-line execution
        import argparse
        parser = argparse.ArgumentParser(description='Graph State Preparation using RL')
        parser.add_argument('--mode', choices=['train', 'test'], default='train',
                            help='Program mode: train or test')
        parser.add_argument('--nodes', type=int, default=8,
                            help='Number of nodes in training graphs')
        parser.add_argument('--episodes', type=int, default=2000,
                            help='Number of training episodes')
        parser.add_argument('--model', type=str, default='models/graph_transform_model.pt',
                            help='Path for saving/loading the model')

        args = parser.parse_args()

        if args.mode == 'train':
            train_agent(n_nodes=args.nodes, num_graphs=5, max_episodes=args.episodes)
        else:
            # This is a placeholder - proper testing would require more code for graph input
            print("Test mode requires notebook environment or custom graphs.")

Running in notebook environment
