<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Differentiable_Neural_Computers_(DNC).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class DNC(nn.Module):
    def __init__(self, input_dim, hidden_dim, memory_dim, memory_size, output_dim, device='cuda'):
        super(DNC, self).__init__()
        self.controller = nn.LSTM(input_dim + memory_dim, hidden_dim)
        self.memory = torch.zeros(memory_size, memory_dim).to(device)
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.device = device

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        hidden = torch.zeros(1, batch_size, self.controller.hidden_size).to(self.device)
        cell = torch.zeros(1, batch_size, self.controller.hidden_size).to(self.device)
        outputs = []

        for t in range(seq_len):
            memory_mean = self.memory.mean(dim=0).unsqueeze(0).expand(batch_size, -1)
            input_and_memory = torch.cat([x[:, t], memory_mean], dim=1).unsqueeze(0)
            output, (hidden, cell) = self.controller(input_and_memory, (hidden, cell))

            # Ensure output size matches memory size and dimensions
            output_trimmed = output.squeeze(0)[:, :self.memory_dim]
            new_memory = torch.cat([output_trimmed, self.memory[:-1, :]], dim=0).detach()  # Detach to avoid in-place operation
            self.memory = new_memory

            outputs.append(self.output_layer(output.squeeze(0)))

        return torch.stack(outputs, dim=1)

# Define model parameters
input_dim = 10
hidden_dim = 128
memory_dim = 32
memory_size = 50
output_dim = 10

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
dnc = DNC(input_dim, hidden_dim, memory_dim, memory_size, output_dim, device).to(device)
optimizer = optim.Adam(dnc.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Dummy data for training
x = torch.rand(32, 20, input_dim).to(device)
y = torch.rand(32, 20, output_dim).to(device)

# Training loop
for epoch in range(50):
    dnc.train()
    optimizer.zero_grad()
    output = dnc(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")