In [None]:
# Memory-Augmented Neural Network Demo
This notebook demonstrates the functionality of a memory-augmented neural network that can store and retrieve information from a memory bank, implement attention-based memory access, calculate confidence scores for retrievals, and demonstrate learning over sequential data.


In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("C:/Users/sophy/graphfusion-challenge")
from model.memory import MemoryBank, ConfidenceScoring
from data.data_generator import generate_sequence_data

In [2]:
# Hyperparameters
seq_length = 10        # Length of sequences
feature_dim = 8        # Dimension of features
batch_size = 16        # Number of sequences in each batch
memory_size = 32       # Number of memory slots
learning_rate = 0.001  # Learning rate
num_epochs = 5         # Number of training epochs


In [3]:
# Generate sequential data
input_sequences, target_sequences = generate_sequence_data(seq_length, feature_dim, batch_size)

print("Input Sequences Shape:", input_sequences.shape)
print("Target Sequences Shape:", target_sequences.shape)


Input Sequences Shape: torch.Size([16, 10, 8])
Target Sequences Shape: torch.Size([16, 10, 8])


In [4]:
# Initialize Memory Bank and Confidence Scoring
memory_bank = MemoryBank(memory_size=memory_size, feature_dim=feature_dim)
confidence_scorer = ConfidenceScoring(feature_dim)
# After initializing MemoryBank and ConfidenceScoring
optimizer = torch.optim.Adam(
    list(memory_bank.parameters()) + list(confidence_scorer.parameters()),
    lr=learning_rate
)



In [5]:
# Define a simple Mean Squared Error loss function
loss_fn = torch.nn.MSELoss()

# Training loop
for epoch in range(num_epochs):
    for i in range(batch_size):
        # Extract the appropriate input sequence for the current index
        input_sequence = input_sequences[i]  # Shape (seq_length, feature_dim)
        
        # Reshape input_sequence to match the expected input data shape (1, feature_dim)
        input_sequence = input_sequence[-1].unsqueeze(0)  # Use the last timestep, shape (1, feature_dim)
        
        # Write to memory
        write_weights = torch.ones(1, memory_size) / memory_size  # Shape (1, memory_size)
        memory_bank.write(input_sequence, write_weights)
        
        # Retrieve data from memory
        retrieved_memory, _ = memory_bank.read(input_sequence)
        
        # Calculate loss
        loss = loss_fn(retrieved_memory, target_sequences[i])
        
        # Backpropagation (assuming optimizer is defined elsewhere)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/5, Loss: 0.1684
Epoch 2/5, Loss: 0.0838
Epoch 3/5, Loss: 0.1475
Epoch 4/5, Loss: 0.3403
Epoch 5/5, Loss: 0.6492


In [6]:
# Example of storing and retrieving data
test_input = input_sequences[0]  # Use the first input sequence as a test

# Print shape for debugging
print("Original Test Input Shape:", test_input.shape)

# Ensure write_weights has the correct shape for batch processing
write_weights = (torch.ones(test_input.size(0), memory_size) / memory_size)  # Shape (10, memory_size)

# Write to memory
memory_bank.write(test_input, write_weights)  # Ensure test_input is 2D (10, feature_dim)
retrieved_memory, _ = memory_bank.read(test_input)  # Ensure test_input is 2D (10, feature_dim)

print("Test Input:", test_input)
print("Retrieved Memory:", retrieved_memory)


Original Test Input Shape: torch.Size([10, 8])
Test Input: tensor([[0.9660, 0.0498, 0.4853, 0.7602, 0.0199, 0.6989, 0.8146, 0.5031],
        [0.3014, 0.4729, 0.9222, 0.6524, 0.1770, 0.6331, 0.4730, 0.0851],
        [0.9899, 0.7812, 0.8622, 0.6188, 0.1208, 0.7936, 0.4908, 0.0949],
        [0.6297, 0.8746, 0.2662, 0.6311, 0.3014, 0.3592, 0.0244, 0.9470],
        [0.3352, 0.1314, 0.9045, 0.8858, 0.6027, 0.6463, 0.4954, 0.1524],
        [0.0314, 0.3477, 0.9629, 0.2460, 0.9476, 0.9162, 0.0715, 0.1833],
        [0.8146, 0.8403, 0.0434, 0.5029, 0.0247, 0.8999, 0.8866, 0.9748],
        [0.7497, 0.3686, 0.7897, 0.7560, 0.1014, 0.5805, 0.9760, 0.0107],
        [0.3933, 0.4537, 0.3019, 0.4786, 0.5196, 0.3242, 0.0339, 0.5363],
        [0.9830, 0.9865, 0.0206, 0.5675, 0.4898, 0.0313, 0.6223, 0.2567]])
Retrieved Memory: tensor([[1.8219, 1.2503, 1.0000, 1.6300, 1.3497, 1.2919, 1.2620, 1.5660],
        [1.8219, 1.2503, 1.0000, 1.6300, 1.3497, 1.2919, 1.2620, 1.5660],
        [1.8219, 1.2503, 1.0000, 1