In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define constants
INPUT_SIZE = 1     # Size of input data (1 for binary input)
MEMORY_SIZE = 128  # Number of memory locations
MEMORY_DIM = 20    # Dimensionality of each memory slot
CONTROLLER_HIDDEN_SIZE = 100  # Size of controller hidden layer
SEQ_LEN = 10       # Length of the input sequence

# Controller network
class Controller(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Controller, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h, _ = self.rnn(x)
        return self.fc(h[:, -1, :])

# Memory module
class Memory(nn.Module):
    def __init__(self, memory_size, memory_dim):
        super(Memory, self).__init__()
        self.memory = torch.randn(memory_size, memory_dim) * 0.01

    def read(self, address):
        return torch.matmul(address.unsqueeze(0), self.memory).squeeze(0)

    def write(self, address, erase_vector, add_vector):
        address = address.view(-1, 1)
        erase_matrix = address * erase_vector.unsqueeze(0)
        add_matrix = address * add_vector.unsqueeze(0)
        self.memory = self.memory * (1 - erase_matrix) + add_matrix

# Read-Write head
class Head(nn.Module):
    def __init__(self, memory_size, memory_dim):
        super(Head, self).__init__()
        self.memory_size = memory_size
        self.memory_dim = memory_dim
        self.addressing = nn.Linear(CONTROLLER_HIDDEN_SIZE, memory_size)
        self.erase = nn.Linear(CONTROLLER_HIDDEN_SIZE, memory_dim)
        self.add = nn.Linear(CONTROLLER_HIDDEN_SIZE, memory_dim)

    def forward(self, control_vector, memory):
        address_weights = torch.softmax(self.addressing(control_vector), dim=-1)
        erase_vector = torch.sigmoid(self.erase(control_vector))
        add_vector = torch.tanh(self.add(control_vector))
        memory.write(address_weights, erase_vector, add_vector)
        read_data = memory.read(address_weights)
        return read_data

# Neural Turing Machine model
class NTM(nn.Module):
    def __init__(self, input_size, memory_size, memory_dim, controller_hidden_size):
        super(NTM, self).__init__()
        self.controller = Controller(input_size + memory_dim, controller_hidden_size, controller_hidden_size)
        self.memory = Memory(memory_size, memory_dim)
        self.head = Head(memory_size, memory_dim)
        self.fc = nn.Linear(controller_hidden_size + memory_dim, input_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        memory_output = torch.zeros(batch_size, MEMORY_DIM).detach()  # Detach to avoid gradient issues
        outputs = []

        for t in range(seq_len):
            controller_input = torch.cat([x[:, t, :], memory_output], dim=-1)
            control_vector = self.controller(controller_input.unsqueeze(1))
            memory_output = self.head(control_vector, self.memory)
            memory_output = memory_output.detach()  # Detach here to avoid retaining computation graph
            output = self.fc(torch.cat([control_vector, memory_output], dim=-1))
            outputs.append(output)

        return torch.stack(outputs, dim=1)

# Generate input and target sequences for the copying task
def generate_copy_task_data(seq_len, batch_size=1):
    # Create a random binary sequence
    input_seq = torch.randint(0, 2, (batch_size, seq_len, INPUT_SIZE)).float()
    # The target is the same as the input
    target_seq = input_seq.clone()
    return input_seq, target_seq

# Training example
def train_ntm():
    ntm = NTM(INPUT_SIZE, MEMORY_SIZE, MEMORY_DIM, CONTROLLER_HIDDEN_SIZE)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(ntm.parameters(), lr=0.001)

    epochs = 100
    for epoch in range(epochs):
        input_seq, target_seq = generate_copy_task_data(SEQ_LEN)
        optimizer.zero_grad()  # Reset gradients at the start of each epoch
        output_seq = ntm(input_seq)  # Forward pass
        loss = criterion(output_seq, target_seq)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update model parameters

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
            print("Input sequence:")
            print(input_seq.squeeze().numpy())
            print("Expected Output sequence:")
            print(target_seq.squeeze().numpy())
            print("NTM Output sequence:")
            print(output_seq.detach().squeeze().numpy())
            print("=" * 50)

    print("Training complete.")
    return ntm, input_seq, target_seq

# Run training and test the NTM
ntm, input_seq, target_seq = train_ntm()
output_seq = ntm(input_seq).detach()

# Show final results
def display_final_results(input_seq, target_seq, output_seq):
    print("\nFinal Input sequence:")
    print(input_seq.squeeze().numpy())
    print("\nFinal Expected Output sequence:")
    print(target_seq.squeeze().numpy())
    print("\nFinal NTM Output sequence:")
    print(output_seq.detach().squeeze().numpy())
    print(torch.absolute(torch.round(output_seq)).detach().squeeze().numpy())

display_final_results(input_seq, target_seq, output_seq)

Epoch 0, Loss: 0.5868085622787476
Input sequence:
[1. 1. 0. 0. 0. 1. 1. 0. 0. 1.]
Expected Output sequence:
[1. 1. 0. 0. 0. 1. 1. 0. 0. 1.]
NTM Output sequence:
[-0.08042132 -0.08032616 -0.08483903 -0.08474657 -0.08465458 -0.07995401
 -0.07985984 -0.08437518 -0.08428452 -0.07958482]
Epoch 10, Loss: 0.5584716200828552
Input sequence:
[1. 1. 0. 0. 1. 1. 1. 1. 1. 1.]
Expected Output sequence:
[1. 1. 0. 0. 1. 1. 1. 1. 1. 1.]
NTM Output sequence:
[0.15709254 0.16711842 0.11566921 0.11582641 0.16758935 0.16773868
 0.16788743 0.16803558 0.16818315 0.16833013]
Epoch 20, Loss: 0.21968059241771698
Input sequence:
[1. 1. 1. 1. 1. 1. 0. 0. 1. 0.]
Expected Output sequence:
[1. 1. 1. 1. 1. 1. 0. 0. 1. 0.]
NTM Output sequence:
[0.44373828 0.5064663  0.5067223  0.5069742  0.5072251  0.50747496
 0.37834144 0.37858915 0.5082485  0.37908646]
Epoch 30, Loss: 0.26068374514579773
Input sequence:
[0. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
Expected Output sequence:
[0. 0. 1. 0. 0. 1. 0. 0. 0. 0.]
NTM Output sequence:
[0