# Neural Graph Memory: Full Demo (Colab Ready)
This notebook demonstrates:
- Text & image memory insertion
- Temporal & multimodal graph structure
- Associative retrieval
- Graph traversal
- Persistence (save/load)
- Visualization

In [None]:
# Uncomment if needed
# !pip install torch torchvision transformers sentence-transformers networkx matplotlib Pillow

In [None]:
import torch
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import requests
from io import BytesIO
import numpy as np
import pickle

In [None]:
class MemoryNode:
    def __init__(self, id, embedding, modality, metadata=None):
        self.id = id
        self.embedding = embedding
        self.modality = modality
        self.metadata = metadata or {}

class NeuralGraphMemory:
    def __init__(self):
        self.graph = nx.Graph()
        self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
        self.image_encoder = models.resnet18(pretrained=True)
        self.image_encoder.fc = torch.nn.Identity()
        self.image_encoder.eval()
        self.transform = T.Compose([
            T.Resize(256), T.CenterCrop(224),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
        ])

    def add_text_node(self, text):
        embedding = self.text_encoder.encode(text)
        return self._add_node(embedding, 'text', {"text": text})

    def add_image_node(self, img_url):
        image = Image.open(BytesIO(requests.get(img_url).content)).convert("RGB")
        tensor = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            embedding = self.image_encoder(tensor).squeeze().numpy()
        return self._add_node(embedding, 'image', {"url": img_url})

    def _add_node(self, embedding, modality, metadata):
        node_id = len(self.graph)
        node = MemoryNode(node_id, embedding, modality, metadata)
        self.graph.add_node(node_id, data=node)
        if node_id > 0:
            self.graph.add_edge(node_id - 1, node_id, relation="temporal")
        return node_id

    def retrieve_nearest(self, query, top_k=3):
        query_emb = self.text_encoder.encode(query)
        similarities = []
        for nid, data in self.graph.nodes(data="data"):
            sim = np.dot(data.embedding, query_emb) / (
                np.linalg.norm(data.embedding) * np.linalg.norm(query_emb))
            similarities.append((nid, sim))
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]

    def traverse_temporal(self, start_id=0):
        return list(nx.dfs_preorder_nodes(self.graph, source=start_id))

    def visualize(self):
        labels = nx.get_node_attributes(self.graph, "data")
        pos = nx.spring_layout(self.graph)
        plt.figure(figsize=(10, 6))
        nx.draw(self.graph, pos, with_labels=True, node_color='lightblue')
        node_texts = {k: (v.metadata.get("text") or "img")[:15] for k, v in labels.items()}
        nx.draw_networkx_labels(self.graph, pos, labels=node_texts, font_size=8)
        plt.title("🧠 Neural Graph Memory")
        plt.show()

    def save(self, path='ngm.pkl'):
        with open(path, 'wb') as f:
            pickle.dump(self.graph, f)

    def load(self, path='ngm.pkl'):
        with open(path, 'rb') as f:
            self.graph = pickle.load(f)

## 🔧 Memory Construction

In [None]:
# Using a placeholder image of a neural graph for multimodal input
from PIL import Image
import requests
from io import BytesIO

image_url = 'https://raw.githubusercontent.com/StuckInTheNet/neural-graph-memory/main/assets/sample_graph_image.png'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
image.show()

In [None]:
ngm.visualize()

## 🔍 Retrieval

In [None]:
query = "What happened with the dog?"
nearest = ngm.retrieve_nearest(query)
print(f"Query: {query}\n\nTop Matches:")
for nid, score in nearest:
    node = ngm.graph.nodes[nid]["data"]
    print(f" - Node {nid} ({node.modality}): {node.metadata} [Score: {score:.2f}]")

## 🔄 Temporal Traversal

In [None]:
order = ngm.traverse_temporal()
print("Traversal order:")
for nid in order:
    node = ngm.graph.nodes[nid]["data"]
    print(f" - Node {nid}: {node.metadata}")

## 💾 Persistence

In [None]:
ngm.save("ngm_demo.pkl")
print("✅ Saved memory graph")

In [None]:
ngm2 = NeuralGraphMemory()
ngm2.load("ngm_demo.pkl")
print("✅ Loaded memory graph with", len(ngm2.graph), "nodes")