In [None]:
entity_colour_map = {2: 'green', 3: 'grey', 4: 'grey', 5: 'red', 6: 'brown', 7: 'black'}

entity_colour_map = {2: (0.13, 0.33, 0.16), 3: (0.61, 0.65, 0.62), 4: (0.61, 0.65, 0.62),
                             5: (0.78, 0.16, 0.12), 6: (0.46, 0.31, 0.04), 7: (0, 0, 0)}

In [None]:
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

class KG:
    def __init__(self, environment, vision_range, completion=1.0):
        self.environment = environment
        self.vision_range = vision_range
        self.completion = completion
        self.player_pos = (environment.player.grid_x, environment.player.grid_y)
        self.terrain_array = environment.terrain_index_grid
        self.entity_array = environment.entity_index_grid
        self.terrain_node_type = 0
        self.entity_node_type = 1
        self.max_terrain_nodes = self.terrain_array.size
        self.terrain_pos_list = []
        self.initialize_graph()

    def initialize_graph(self):
        """ Initializes the graph with nodes for player and initial terrain. """
        player_features = self.get_node_features(self.player_pos, self.entity_node_type)
        terrain_features = self.get_node_features(self.player_pos, self.terrain_node_type)
        self.graph = Data(x=torch.cat([player_features, terrain_features], dim=0),
                          edge_index=torch.tensor([[0, 1], [1, 0]], dtype=torch.long))
        self.add_terrain_to_graph()
        self.visualize_graph()

    def get_node_features(self, coords, node_type):
        """ Returns the tensor of node features based on the coordinates and node type. """
        x, y = coords
        array = self.terrain_array if node_type == self.terrain_node_type else self.entity_array
        return torch.tensor([[node_type, array[y, x], x, y]], dtype=torch.float)

    def add_terrain_to_graph(self):
        """ Adds terrain nodes to the graph based on the vision range and completion. """
        limit = min(int(self.completion * self.max_terrain_nodes), len(self.terrain_array))
        for i in range(limit):
            position = (i % self.environment.width, i // self.environment.width)
            self.add_terrain_node(position)
            self.terrain_pos_list.append(position)

    def add_terrain_node(self, position):
        """ Adds a terrain node if it does not exist. """
        if position not in self.terrain_pos_list:
            features = self.get_node_features(position, self.terrain_node_type)
            self.graph.x = torch.cat([self.graph.x, features], dim=0)
            self.terrain_pos_list.append(position)
            self.create_terrain_edges(len(self.graph.x) - 1, position)

    def create_terrain_edges(self, node_idx, coords):
        """ Creates edges for a terrain node to its adjacent nodes. """
        x, y = coords
        neighbors = self.get_neighbors(x, y)
        for neighbor in neighbors:
            if neighbor in self.terrain_pos_list:
                neighbor_idx = self.terrain_pos_list.index(neighbor)
                self.create_edge(node_idx, neighbor_idx)

    def get_neighbors(self, x, y):
        """ Returns coordinates of neighbors for given coordinates. """
        directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
        return [(x + dx, y + dy) for dx, dy in directions if self.environment.within_bounds(x + dx, y + dy)]

    def create_edge(self, node1_idx, node2_idx):
        """ Adds a bidirectional edge between two nodes. """
        edge = torch.tensor([[node1_idx, node2_idx], [node2_idx, node1_idx]], dtype=torch.long).view(2, -1)
        self.graph.edge_index = torch.cat([self.graph.edge_index, edge], dim=1)

    def visualize_graph(self):
        """ Visualizes the graph in 3D. """
        G = to_networkx(self.graph, to_undirected=True)
        pos = {node: (data[2].item(), data[3].item(), 0 if data[0] == self.terrain_node_type else 1) for node, data in enumerate(self.graph.x)}
        colors = ['blue' if data[0] == self.terrain_node_type else 'red' for data in self.graph.x]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        for node, color in zip(pos, colors):
            ax.scatter(*pos[node], color=color, s=50)

        for edge in G.edges:
            points = np.array([pos[edge[0]], pos[edge[1]]])
            ax.plot(points[:, 0], points[:, 1], points[:, 2], 'gray')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        plt.title('Knowledge Graph Visualization')
        plt.show()

    def print_graph_connections(self):
        """ Prints the connections between nodes in the graph. """
        for u, v in self.graph.edge_index.t():
            print(f"Node {u.item()} -> Node {v.item()}")

        print(f"Node features:\n{self.graph.x}")
