# Neural Graph Memory Demo

This notebook demonstrates the capabilities of **Neural Graph Memory (NGM)** — a biologically-inspired graph-based memory architecture designed for long-term, multimodal retrieval in AI agents.

The notebook includes:
- Image and text embedding
- Graph-based memory construction
- Traversal and retrieval demos
- Multimodal query-response examples

All supporting assets are pulled automatically from the public GitHub repository.


In [None]:

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO


In [None]:

# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [None]:

# Load and preprocess the image
image_url = "https://raw.githubusercontent.com/StuckInTheNet/neural-graph-memory-Work-In-Progress-/main/assets/architecture.png"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")


In [None]:

# Extract CLIP embedding
with torch.no_grad():
    image_features = clip_model.get_image_features(**inputs)

# Normalize the embedding
image_embedding = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
image_embedding_np = image_embedding.squeeze().numpy()


In [None]:

# Add embedding to memory graph
ngm.add_node(
    node_id="visual_architecture_memory",
    modality="image",
    content="Neural Graph Memory architecture diagram",
    embedding=image_embedding_np.tolist()
)


In [None]:
# Uncomment if needed
# !pip install torch torchvision transformers sentence-transformers networkx matplotlib Pillow

In [None]:
import torch
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import requests
from io import BytesIO
import numpy as np
import pickle

In [None]:
class MemoryNode:
    def __init__(self, id, embedding, modality, metadata=None):
        self.id = id
        self.embedding = embedding
        self.modality = modality
        self.metadata = metadata or {}

class NeuralGraphMemory:
    def __init__(self):
        self.graph = nx.Graph()
        self.text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
        self.image_encoder = models.resnet18(pretrained=True)
        self.image_encoder.fc = torch.nn.Identity()
        self.image_encoder.eval()
        self.transform = T.Compose([
            T.Resize(256), T.CenterCrop(224),
            T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
        ])

    def add_text_node(self, text):
        embedding = self.text_encoder.encode(text)
        return self._add_node(embedding, 'text', {"text": text})

    def add_image_node(self, img_url):
        image = Image.open(BytesIO(requests.get(img_url).content)).convert("RGB")
        tensor = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            embedding = self.image_encoder(tensor).squeeze().numpy()
        return self._add_node(embedding, 'image', {"url": img_url})

    def _add_node(self, embedding, modality, metadata):
        node_id = len(self.graph)
        node = MemoryNode(node_id, embedding, modality, metadata)
        self.graph.add_node(node_id, data=node)
        if node_id > 0:
            self.graph.add_edge(node_id - 1, node_id, relation="temporal")
        return node_id

    def retrieve_nearest(self, query, top_k=3):
        query_emb = self.text_encoder.encode(query)
        similarities = []
        for nid, data in self.graph.nodes(data="data"):
            sim = np.dot(data.embedding, query_emb) / (
                np.linalg.norm(data.embedding) * np.linalg.norm(query_emb))
            similarities.append((nid, sim))
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:top_k]

    def traverse_temporal(self, start_id=0):
        return list(nx.dfs_preorder_nodes(self.graph, source=start_id))

    def visualize(self):
        labels = nx.get_node_attributes(self.graph, "data")
        pos = nx.spring_layout(self.graph)
        plt.figure(figsize=(10, 6))
        nx.draw(self.graph, pos, with_labels=True, node_color='lightblue')
        node_texts = {k: (v.metadata.get("text") or "img")[:15] for k, v in labels.items()}
        nx.draw_networkx_labels(self.graph, pos, labels=node_texts, font_size=8)
        plt.title("🧠 Neural Graph Memory")
        plt.show()

    def save(self, path='ngm.pkl'):
        with open(path, 'wb') as f:
            pickle.dump(self.graph, f)

    def load(self, path='ngm.pkl'):
        with open(path, 'rb') as f:
            self.graph = pickle.load(f)

## 🔧 Memory Construction

### Adding a Visual Memory Node
We now demonstrate how an AI agent might ingest a visual memory (e.g., a memory architecture diagram) into the graph memory system using CLIP embeddings.

In [None]:
from PIL import Image
import requests
from io import BytesIO

# Load your architecture image as an example of multimodal input
image_url = 'https://raw.githubusercontent.com/StuckInTheNet/neural-graph-memory/main/assets/architecture.png'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
image.show()

In [None]:
ngm.visualize()

## 🔍 Retrieval

In [None]:
query = "What happened with the dog?"
nearest = ngm.retrieve_nearest(query)
print(f"Query: {query}\n\nTop Matches:")
for nid, score in nearest:
    node = ngm.graph.nodes[nid]["data"]
    print(f" - Node {nid} ({node.modality}): {node.metadata} [Score: {score:.2f}]")

## 🔄 Temporal Traversal

In [None]:
order = ngm.traverse_temporal()
print("Traversal order:")
for nid in order:
    node = ngm.graph.nodes[nid]["data"]
    print(f" - Node {nid}: {node.metadata}")

## 💾 Persistence

In [None]:
ngm.save("ngm_demo.pkl")
print("✅ Saved memory graph")

In [None]:
ngm2 = NeuralGraphMemory()
ngm2.load("ngm_demo.pkl")
print("✅ Loaded memory graph with", len(ngm2.graph), "nodes")

## 🧠 Multimodal Memory Demonstrations
This section simulates how an AI agent stores, embeds, and retrieves rich multimodal memories.
Each example below demonstrates a different modality or modality combination processed into graph memory.

### 🐕 Image + Caption

In [None]:

# Load and display multimodal memory image
image_url = "https://upload.wikimedia.org/wikipedia/commons/9/91/Golden_Retriever_Carlos_%2810581910556%29.jpg"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image.show()

# Corresponding caption for visual memory
caption = "Saw a golden retriever chasing a ball in Central Park."
print(f"Caption: {caption}")


In [None]:

# Preprocess and embed image
inputs_image = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
    image_features = clip_model.get_image_features(**inputs_image)
image_embedding = image_features / image_features.norm(p=2, dim=-1, keepdim=True)

# Preprocess and embed caption
inputs_text = clip_processor(text=[caption], return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = clip_model.get_text_features(**inputs_text)
text_embedding = text_features / text_features.norm(p=2, dim=-1, keepdim=True)

# Fuse embeddings
fused_embedding = (image_embedding + text_embedding) / 2
fused_embedding = fused_embedding.squeeze().numpy()


In [None]:

# Store the multimodal memory node
ngm.add_node(
    node_id="memory_golden_retriever_central_park",
    modality="image+text",
    content=caption,
    embedding=fused_embedding.tolist()
)


### 🧠 Text + Audio (Simulated Voice Note)

In [None]:
# Simulated transcription of a voice note
voice_note = 'Remember to review the memory architecture notes after today’s meeting.'
print(f"Transcribed voice note: {voice_note}")

# Embed and normalize
inputs_text = clip_processor(text=[voice_note], return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = clip_model.get_text_features(**inputs_text)
text_embedding = text_features / text_features.norm(p=2, dim=-1, keepdim=True)

ngm.add_node(
    node_id="memory_audio_text_meeting_review",
    modality="text",
    content=voice_note,
    embedding=text_embedding.squeeze().tolist()
)

### 🗽 Image + Timestamp + Context

In [None]:
image_url = 'https://upload.wikimedia.org/wikipedia/commons/6/63/Times_Square%2C_New_York_City_%28HDR%29.jpg'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
image.show()

contextual_note = 'Visited Times Square, lots of people, bright screens — 7:45pm, Jan 15'
print(f"Note: {contextual_note}")

# Get embeddings
inputs_image = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
    image_features = clip_model.get_image_features(**inputs_image)

inputs_text = clip_processor(text=[contextual_note], return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = clip_model.get_text_features(**inputs_text)

# Normalize & fuse
image_embedding = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
text_embedding = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
fused_embedding = (image_embedding + text_embedding) / 2

ngm.add_node(
    node_id="memory_times_square_evening",
    modality="image+text",
    content=contextual_note,
    embedding=fused_embedding.squeeze().tolist()
)

### 🔁 Diagram + Interpretation

In [None]:
image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/Transformer.png/600px-Transformer.png'
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert('RGB')
image.show()

interpretation = 'This shows the self-attention mechanism used in GPT-based models.'
print(f"Interpretation: {interpretation}")

# Embed both
inputs_image = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
    image_features = clip_model.get_image_features(**inputs_image)

inputs_text = clip_processor(text=[interpretation], return_tensors="pt", padding=True)
with torch.no_grad():
    text_features = clip_model.get_text_features(**inputs_text)

image_embedding = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
text_embedding = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
fused_embedding = (image_embedding + text_embedding) / 2

ngm.add_node(
    node_id="memory_transformer_diagram",
    modality="image+text",
    content=interpretation,
    embedding=fused_embedding.squeeze().tolist()
)