In [None]:


    def remove_entity_node_sub(self, position):
        print(f"[remove_entity_node()] Removing entity node at position {position} and idx {self.idx_manager.get_idx(position, self.entity_z_level)}")
        if not self.idx_manager.verify_node_exists(position, self.entity_z_level):
            print(f"No entity node at position {position} with z_level {self.entity_z_level}")
            return

        entity_idx = self.idx_manager.get_idx(position, self.entity_z_level)

        # Create a mask for all nodes except the one to be removed
        node_mask = torch.arange(self.graph.num_nodes) != entity_idx

        # Use subgraph to extract the subgraph without the specified node
        new_edge_index, new_edge_attr = subgraph(node_mask, self.graph.edge_index, edge_attr=self.graph.edge_attr, relabel_nodes=True)

        # Update the graph data
        self.graph.edge_index = new_edge_index
        self.graph.edge_attr = new_edge_attr
        self.graph.x = self.graph.x[node_mask]  # Adjust the node features as well

        # Update the idx manager and any other necessary indices
        self.idx_manager.remove_idx(position, self.entity_z_level)

In [None]:

    def remove_entity_node_0(self, position):
        if not self.idx_manager.verify_node_exists(position, self.entity_z_level):
            print(f"No entity node at position {position} with z_level {self.entity_z_level}")
            return

        entity_idx = self.idx_manager.get_idx(position, self.entity_z_level)

        # Remove the node from the node features array
        self.graph.x = torch.cat([self.graph.x[:entity_idx], self.graph.x[entity_idx+1:]], dim=0)
        
        # Adjust the indices in the edge index tensor before removing edges to maintain consistency
        adjustment_mask = self.graph.edge_index > entity_idx
        self.graph.edge_index[adjustment_mask] -= 1
        
        # Create mask to find all edges connected to the removed node and remove these edges
        mask = (self.graph.edge_index == entity_idx).any(dim=0)
        self.graph.edge_index = self.graph.edge_index[:, ~mask]
        self.graph.edge_attr = self.graph.edge_attr[~mask]

        # Remove the index from the IDX manager
        self.idx_manager.remove_idx(position, self.entity_z_level)
        self.verify_terrain_node_connections(position)

In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

class KG:
    def __init__(self, environment, vision_range, completion=1.0):
        self.environment = environment
        self.terrain_array = environment.terrain_index_grid
        self.entity_array = environment.entity_index_grid
        
        self.vision_range = vision_range
        self.player_pos = (self.environment.player.grid_x, self.environment.player.grid_y)
        print(f"Player position: {self.player_pos}")
        self.terrain_pos_list = [self.player_pos]  # Start by keeping track of the initial terrain position
        self.terrain_pos_idx_dict = {self.player_pos: 0} # Dictionary to map terrain positions to their node indices
        self.max_terrain_nodes = self.terrain_array.size
        self.distance = self.get_distance(completion) # Graph distance
        self.terrain_node_type = 0
        self.entity_node_type = 1
        # Initialize nodes with initial features including positions
        # Assuming: [Node Type, Entity/Terrain Type, X Pos, Y Pos, Additional Feature 1 (e.g., Zero-padding), Additional Feature 2 (e.g., Zero-padding)]
        player_node_features = self.get_node_features(self.player_pos, self.entity_node_type)  # Player node
        terrain_node_features = self.get_node_features(self.player_pos, self.terrain_node_type)  # Initial terrain node
        # Initialize edges between the player and the initial terrain
        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
        
        # Create the graph with initial nodes and edges
        self.graph = Data(x=torch.cat([player_node_features, terrain_node_features], dim=0),
                          edge_index=edge_index)
        
        # print(self.environment.terrain_colour_map)
        self.add_terrain_to_graph()
        # self.print_graph()        
        self.visualise_graph()

    def get_node_features(self, coor, node_type):
        x, y = coor
        if node_type == self.terrain_node_type:
            return torch.tensor([[node_type, self.terrain_array[y, x], x, y]], dtype=torch.float)
        elif node_type == self.entity_node_type:
            assert self.entity_array[y, x] > 0, f"Invalid entity type at position ({x}, {y}), type: {self.entity_array[y, x], self.entity_array}"
            return torch.tensor([[node_type, self.entity_array[y, x], x, y]], dtype=torch.float)
        else:
            raise ValueError(f"Invalid node type: {node_type}")
        
    def add_terrain_node(self, position, wait=False):
        if position not in self.terrain_pos_list:
            new_terrain_features = self.get_node_features(position, self.terrain_node_type)
            self.graph.x = torch.cat([self.graph.x, new_terrain_features], dim=0)
            new_node_index = len(self.graph.x) - 1

            self.terrain_pos_list.append(position)

            if self.environment.entity_index_grid[position[1], position[0]] != 0:
                self.add_entity_node(new_node_index, position)

            if wait == False:
                self.create_terrain_edges(new_node_index, position)
            else:
                return new_node_index, position
    
    def create_edge(self, node1, node2):
        new_edge = torch.tensor([[node1, node2], [node2, node1]], dtype=torch.long).view(2, -1)
        self.graph.edge_index = torch.cat([self.graph.edge_index, new_edge], dim=1)
    
    def create_terrain_edges(self, terrain_idx, coor):
        x, y = coor
        neighbours = self.environment.get_neighbours(x, y)
        for neighbour in neighbours:
            if neighbour in self.terrain_pos_list:
                neighbour_idx = self.terrain_pos_list.index(neighbour)
                self.create_edge(terrain_idx, neighbour_idx)

    def add_entity_node(self, terrain_idx, position):
        entity_features = self.get_node_features(position, self.entity_node_type)
        self.graph.x = torch.cat([self.graph.x, entity_features], dim=0)
        entity_idx = len(self.graph.x) - 1

        # Format the new edges and concatenate to the graph
        new_edges = torch.tensor([[entity_idx, terrain_idx], [terrain_idx, entity_idx], [entity_idx, 0], [0, entity_idx]], dtype=torch.long).view(2, -1)
        self.graph.edge_index = torch.cat([self.graph.edge_index, new_edges], dim=1)

        # Debugging: Print connections of the entity node
        print(f"Entity Node {entity_idx} connections: Terrain {terrain_idx}, Player 0")

    def add_entity_nodes(self):
        for y in range(self.entity_array.shape[0]):
            for x in range(self.entity_array.shape[1]):
                if (x, y) not in self.terrain_pos_list:
                    continue
                if self.entity_array[y, x] > 0:
                    entity_features = self.get_node_features((x, y), self.entity_node_type)
                    self.graph.x = torch.cat([self.graph.x, entity_features], dim=0)
                    entity_idx = len(self.graph.x) - 1

                    # Format the new edges and concatenate to the graph
                    new_edges = torch.tensor([[entity_idx, 0], [0, entity_idx]], dtype=torch.long).view(2, -1)
                    self.graph.edge_index = torch.cat([self.graph.edge_index, new_edges], dim=1)

                    # Debugging: Print connections of the entity node
                    print(f"Entity Node {entity_idx} connections: Player 0")
    
    def get_distance(self, completion):
        """Calculates the effective distance for subgraph extraction."""
        completion = min(completion, 1)
        return max(int(completion * self.terrain_array.shape[0]), self.vision_range)
        
    def finalize_graph_edges(self, t_nodes_to_calculate_edges):
        for node_index, position in t_nodes_to_calculate_edges:
            self.create_terrain_edges(node_index, position)

    def add_terrain_to_graph(self):
        player_x, player_y = self.player_pos
        terrain_nodes_to_calculate_edges = []

        for y in range(player_y - self.distance, player_y + self.distance + 1):
            for x in range(player_x - self.distance, player_x + self.distance + 1):
                if not self.environment.within_bounds(x, y):
                    continue
                if (x, y) not in self.terrain_pos_list:
                    node_index, _ = self.add_terrain_node((x, y), wait=True)
                    terrain_nodes_to_calculate_edges.append((node_index, (x, y)))

        # Calculate edges for all deferred nodes
        self.finalize_graph_edges(terrain_nodes_to_calculate_edges)

        # Print connections for verification
        self.print_graph_connections()

    def visualise_graph(self, node_size=100, edge_color="tab:gray", show_ticks=True):
        # These colors are in RGB format, normalized to [0, 1] --> green, grey twice, red, brown, black
        entity_colour_map = {2: (0.13, 0.33, 0.16), 3: (0.61, 0.65, 0.62),4: (0.61, 0.65, 0.62),
                             5: (0.78, 0.16, 0.12), 6: (0.46, 0.31, 0.04), 7: (0, 0, 0)}
        
        # Convert to undirected graph for visualization
        G = to_networkx(self.graph, to_undirected=True)
        
        # Use a 2D spring layout, as z-coordinates are manually assigned
        pos = nx.spring_layout(G, seed=42)  # 2D layout
        node_colors = []

        for node in G.nodes():
            node_type = self.graph.x[node][0].item()
            # The node features are [Node Type, Entity/Terrain Type, X Pos, Y Pos] and the z-coordinate is the node type
            x, y, z = self.graph.x[node][2].item(), self.graph.x[node][3].item(), self.graph.x[node][0].item()

            # Assign z-coordinate based on node type
            z = 0 if node_type == self.terrain_node_type else 1
            pos[node] = (x, y, z)  # Update position to include z-coordinate
            
            # Set node color based on node type
            if node_type == self.terrain_node_type:
                terrain_type = int(self.graph.x[node][1].item())
                color = self.environment.terrain_colour_map.get(terrain_type, (255, 0, 0))
                node_colors.append([color[0] / 255.0, color[1] / 255.0, color[2] / 255.0])
            elif node_type == self.entity_node_type:
                entity_type = int(self.graph.x[node][1].item())
                print(f"Entity type: {entity_type}")
                # color = entity_colour_map.get(entity_type)  # Default to grey if type not found
                color = entity_colour_map[self.graph.x[node][1]]
                if color is None:
                    color = entity_colour_map[3]
                    print(f"Entity type {entity_type} not found in colour map. now it is {color}")
                    
                node_colors.append(color)  # Directly use the color name

            
        print(f"Player node at position ({pos[0][0]}, {pos[0][1]}, {pos[0][2]})")
        assert (pos[0][0], pos[0][1]) == self.player_pos and pos[0][2] == 1, "Player position does not match the graph position"
        # Create a 3D plot
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        node_xyz = np.array([pos[v] for v in sorted(G)])
        edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])
        try:
            node_colors = np.array(node_colors)
        except ValueError:
            print("Printing node colors")
            print(node_colors)
        # Scatter plot for nodes
        ax.scatter(*node_xyz.T, s=node_size, color=node_colors, edgecolor='w', depthshade=True)
        # Draw edges
        for vizedge in edge_xyz:
            ax.plot(*vizedge.T, color=edge_color)

        # Configure axis visibility and ticks
        if show_ticks:
            # Set tick labels based on the data range
            ax.set_xticks(np.linspace(min(pos[n][0] for n in G.nodes()), max(pos[n][0] for n in G.nodes()), num=5))
            ax.set_yticks(np.linspace(min(pos[n][1] for n in G.nodes()), max(pos[n][1] for n in G.nodes()), num=5))
            ax.set_zticks([0, 1])  # Only two levels: 0 for terrain, 1 for entities
        else:
            ax.grid(False)
            ax.xaxis.set_ticks([])
            ax.yaxis.set_ticks([])
            ax.zaxis.set_ticks([])

        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        plt.title("3D Graph Visualization")
        plt.show()

    def print_graph_connections(self):
        edge_index_np = self.graph.edge_index.numpy()  # Convert edge_index to a numpy array for easier handling
        for i in range(0, edge_index_np.shape[1], 2):  # Step by 2 to handle bi-directional edges
            print(f"Node {edge_index_np[0, i]} -> Node {edge_index_np[1, i]}")

        # get the node features
        print(self.graph.x)

In [None]:
entity_colour_map = {2: 'green', 3: 'grey', 4: 'grey', 5: 'red', 6: 'brown', 7: 'black'}

entity_colour_map = {2: (0.13, 0.33, 0.16), 3: (0.61, 0.65, 0.62), 4: (0.61, 0.65, 0.62),
                             5: (0.78, 0.16, 0.12), 6: (0.46, 0.31, 0.04), 7: (0, 0, 0)}

In [None]:
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

class KG:
    def __init__(self, environment, vision_range, completion=1.0):
        self.environment = environment
        self.vision_range = vision_range
        self.completion = completion
        self.player_pos = (environment.player.grid_x, environment.player.grid_y)
        self.terrain_array = environment.terrain_index_grid
        self.entity_array = environment.entity_index_grid
        self.terrain_node_type = 0
        self.entity_node_type = 1
        self.max_terrain_nodes = self.terrain_array.size
        self.terrain_pos_list = []
        self.initialize_graph()

    def initialize_graph(self):
        """ Initializes the graph with nodes for player and initial terrain. """
        player_features = self.get_node_features(self.player_pos, self.entity_node_type)
        terrain_features = self.get_node_features(self.player_pos, self.terrain_node_type)
        self.graph = Data(x=torch.cat([player_features, terrain_features], dim=0),
                          edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long))
        self.add_terrain_to_graph()
        self.visualize_graph()

    def get_node_features(self, coords, node_type):
        """ Returns the tensor of node features based on the coordinates and node type. """
        x, y = coords
        array = self.terrain_array if node_type == self.terrain_node_type else self.entity_array
        return torch.tensor([[node_type, array[y, x], x, y]], dtype=torch.float)

    def add_terrain_to_graph(self):
        """ Adds terrain nodes to the graph based on the vision range and completion. """
        limit = min(int(self.completion * self.max_terrain_nodes), len(self.terrain_array))
        for i in range(limit):
            position = (i % self.environment.width, i // self.environment.width)
            self.add_terrain_node(position)
            self.terrain_pos_list.append(position)

    def add_terrain_node(self, position):
        """ Adds a terrain node if it does not exist. """
        if position not in self.terrain_pos_list:
            features = self.get_node_features(position, self.terrain_node_type)
            self.graph.x = torch.cat([self.graph.x, features], dim=0)
            self.terrain_pos_list.append(position)
            self.create_terrain_edges(len(self.graph.x) - 1, position)

    def create_terrain_edges(self, node_idx, coords):
        """ Creates edges for a terrain node to its adjacent nodes. """
        x, y = coords
        neighbors = self.get_neighbors(x, y)
        for neighbor in neighbors:
            if neighbor in self.terrain_pos_list:
                neighbor_idx = self.terrain_pos_list.index(neighbor)
                self.create_edge(node_idx, neighbor_idx)

    def get_neighbors(self, x, y):
        """ Returns coordinates of neighbors for given coordinates. """
        directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
        return [(x + dx, y + dy) for dx, dy in directions if self.environment.within_bounds(x + dx, y + dy)]

    def create_edge(self, node1_idx, node2_idx):
        """ Adds a bidirectional edge between two nodes. """
        edge = torch.tensor([[node1_idx, node2_idx], [node2_idx, node1_idx]], dtype=torch.long).view(2, -1)
        self.graph.edge_index = torch.cat([self.graph.edge_index, edge], dim=1)

    def visualize_graph(self):
        """ Visualizes the graph in 3D. """
        G = to_networkx(self.graph, to_undirected=True)
        pos = {node: (data[2].item(), data[3].item(), 0 if data[0] == self.terrain_node_type else 1) for node, data in enumerate(self.graph.x)}
        colors = ['blue' if data[0] == self.terrain_node_type else 'red' for data in self.graph.x]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        for node, color in zip(pos, colors):
            ax.scatter(*pos[node], color=color, s=50)

        for edge in G.edges:
            points = np.array([pos[edge[0]], pos[edge[1]]])
            ax.plot(points[:, 0], points[:, 1], points[:, 2], 'gray')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        plt.title('Knowledge Graph Visualization')
        plt.show()

    def print_graph_connections(self):
        """ Prints the connections between nodes in the graph. """
        for u, v in self.graph.edge_index.t():
            print(f"Node {u.item()} -> Node {v.item()}")

        print(f"Node features:\n{self.graph.x}")
