In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Define the LongTermMemory module
class LongTermMemory(nn.Module):
    def __init__(self, input_dim, memory_dim, learn_rate=0.1):
        super(LongTermMemory, self).__init__()
        self.memory = nn.Parameter(torch.zeros(memory_dim, input_dim))
        self.learn_rate = learn_rate

    def forward(self, x):
        # x: [batch, seq_length, input_dim]
        mem_exp = self.memory.unsqueeze(0).unsqueeze(0)  # [1, 1, memory_dim, input_dim]
        x_exp = x.unsqueeze(2)  # [batch, seq_length, 1, input_dim]
        diff = x_exp - mem_exp  # [batch, seq_length, memory_dim, input_dim]
        surprise = torch.norm(diff, dim=-1, keepdim=True)  # [batch, seq_length, memory_dim, 1]
        update = self.learn_rate * surprise * diff  # [batch, seq_length, memory_dim, input_dim]
        aggregated_update = update.mean(dim=(0,1))  # [memory_dim, input_dim]
        self.memory.data += aggregated_update
        return self.memory

# Define the Titans model
class Titans(nn.Module):
    def __init__(self, input_dim, hidden_dim, memory_dim):
        super(Titans, self).__init__()
        self.memory_module = LongTermMemory(input_dim, memory_dim)
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, batch_first=True)
        self.fc = nn.Linear(input_dim, hidden_dim)
    
    def forward(self, x):
        # x: [batch, seq_length, input_dim]
        _ = self.memory_module(x)  # update memory
        mem = self.memory_module.memory.unsqueeze(0).expand(x.size(0), -1, -1)  # [batch, memory_dim, input_dim]
        attn_out, _ = self.attention(query=x, key=mem, value=mem)  # [batch, seq_length, input_dim]
        return F.relu(self.fc(attn_out))  # [batch, seq_length, hidden_dim]

def generate_sequence(seq_length=50, input_dim=128, batch_size=1):
    # Returns a tensor of shape [batch, seq_length, input_dim]
    return torch.randn(batch_size, seq_length, input_dim)

def run_model():
    model = Titans(input_dim=128, hidden_dim=256, memory_dim=10)
    input_seq = generate_sequence()
    output = model(input_seq)
    return model, output

def visualize_memory(model):
    memory_updates = model.memory_module.memory.detach().cpu().numpy()
    plt.imshow(memory_updates, aspect='auto', cmap='viridis')
    plt.colorbar(label='Memory Activation')
    plt.title('Titans Memory Module Activation')
    plt.xlabel('Feature Dimensions')
    plt.ylabel('Memory Slots')
    plt.show()


In [None]:
model, output = run_model()
visualize_memory(model)
print('Output shape:', output.shape)