<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/AGI_Prototype_Implementation_(PyTorch).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DifferentiableMemory(nn.Module):
    """
    A simple Neural Turing Machine–style memory module with
    content-based read/write heads.
    """

    def __init__(self,
                 memory_size: int,
                 vector_dim: int,
                 controller_dim: int,
                 num_read_heads: int = 1,
                 num_write_heads: int = 1):
        super().__init__()
        self.memory_size = memory_size
        self.vector_dim = vector_dim
        self.controller_dim = controller_dim
        self.num_read_heads = num_read_heads
        self.num_write_heads = num_write_heads

        # Initialize memory to zeros (stateful, updated in forward)
        self.register_buffer('memory',
                             torch.zeros(memory_size, vector_dim))

        # Read‐key projection: controller_state → (num_read_heads × vector_dim)
        self.read_key_proj = nn.Linear(controller_dim,
                                       num_read_heads * vector_dim)

        # Write‐key projection: controller_state → (num_write_heads × vector_dim)
        self.write_key_proj = nn.Linear(controller_dim,
                                        num_write_heads * vector_dim)

        # Erase/Add/Gate projections for writes
        self.erase_proj = nn.Linear(controller_dim,
                                    num_write_heads * vector_dim)
        self.add_proj = nn.Linear(controller_dim,
                                  num_write_heads * vector_dim)
        self.write_gate_proj = nn.Linear(controller_dim,
                                         num_write_heads)

    def reset_memory(self):
        """Zero out memory before each new sequence (if desired)."""
        self.memory.zero_()

    def content_addressing(self, keys: torch.Tensor) -> torch.Tensor:
        """
        Given keys of shape (heads, D), returns a weight matrix
        (heads × memory_size) via cosine‐similarity + softmax.
        """
        # Normalize along dim
        mem_norm = F.normalize(self.memory, dim=1)          # [M × D]
        key_norm = F.normalize(keys, dim=1)                 # [H × D]

        # similarity[h, i] = key_norm[h] ⋅ mem_norm[i]
        similarity = torch.matmul(key_norm, mem_norm.t())   # [H × M]
        weights = F.softmax(similarity, dim=1)              # [H × M]
        return weights

    def read(self, controller_state: torch.Tensor) -> torch.Tensor:
        """
        controller_state: [hidden_dim]
        returns:
          read_vector: [heads × D] flattened → [heads*D]
        """
        # 1) compute keys
        rk = self.read_key_proj(controller_state)                      # [H*D]
        rk = rk.view(self.num_read_heads, self.vector_dim)            # [H × D]

        # 2) content addressing
        w = self.content_addressing(rk)                               # [H × M]

        # 3) weighted sum
        read_vecs = torch.matmul(w, self.memory)                      # [H × D]
        return read_vecs.view(-1), w                                  # ([H*D], [H × M])

    def write(self, controller_state: torch.Tensor):
        """
        controller_state: [hidden_dim]
        Updates self.memory in place via:
          m ← m * (1 – g * wᵀ erase) + g * wᵀ add
        where w ∈ ℝ^{H×M}, erase/add ∈ ℝ^{H×D}, g ∈ ℝ^{H}
        """
        # 1) compute write keys & weights
        wk = self.write_key_proj(controller_state)                    # [H*D]
        wk = wk.view(self.num_write_heads, self.vector_dim)          # [H × D]
        w = self.content_addressing(wk)                               # [H × M]

        # 2) gates & erase/add
        erase = torch.sigmoid(self.erase_proj(controller_state))      # [H*D]
        erase = erase.view(self.num_write_heads, self.vector_dim)    # [H × D]

        add = torch.tanh(self.add_proj(controller_state))             # [H*D]
        add = add.view(self.num_write_heads, self.vector_dim)        # [H × D]

        g = torch.sigmoid(self.write_gate_proj(controller_state))     # [H]

        # 3) update memory
        # For each head h: m ← m * (1 - g_h * w_h.unsqueeze(1) * erase_h.unsqueeze(0))
        #                  + g_h * w_h.unsqueeze(1) * add_h.unsqueeze(0)
        M = self.memory.clone()
        for h in range(self.num_write_heads):
            w_h = w[h].unsqueeze(1)       # [M × 1]
            e_h = erase[h].unsqueeze(0)   # [1 × D]
            a_h = add[h].unsqueeze(0)     # [1 × D]
            gate = g[h]

            M = M * (1 - gate * (w_h * e_h)) + gate * (w_h * a_h)

        # overwrite memory
        self.memory.copy_(M)


class AGI_Core(nn.Module):
    """
    An LSTM‐based controller that reads from and writes to
    a DifferentiableMemory at each time step.
    """

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 output_dim: int,
                 memory_size: int,
                 vector_dim: int,
                 num_read_heads: int = 1,
                 num_write_heads: int = 1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_read_heads = num_read_heads
        self.vector_dim = vector_dim

        # Memory module
        self.memory = DifferentiableMemory(
            memory_size,
            vector_dim,
            controller_dim=hidden_dim,
            num_read_heads=num_read_heads,
            num_write_heads=num_write_heads,
        )

        # Controller: LSTM expects [input + read_vec] per time step
        self.lstm = nn.LSTM(input_dim + num_read_heads * vector_dim,
                            hidden_dim,
                            batch_first=True)

        # Final output head
        self.fc = nn.Linear(hidden_dim + num_read_heads * vector_dim,
                            output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch=1, seq_len, input_dim]
        returns:
          outputs: [1, seq_len, output_dim]
        """
        batch_size, seq_len, _ = x.size()
        assert batch_size == 1, "Current impl assumes batch size = 1"

        # reset memory & controller states
        self.memory.reset_memory()
        h = torch.zeros(1, 1, self.hidden_dim, device=x.device)
        c = torch.zeros(1, 1, self.hidden_dim, device=x.device)

        # initial read vector = zeros
        read_vec = torch.zeros(1, self.num_read_heads * self.vector_dim,
                               device=x.device)

        outputs = []
        for t in range(seq_len):
            xi = x[:, t, :]                                      # [1 × input_dim]
            inp = torch.cat([xi, read_vec], dim=-1).unsqueeze(1) # [1,1,input+read]

            lstm_out, (h, c) = self.lstm(inp, (h, c))            # lstm_out: [1,1,hidden]
            controller_state = h.squeeze(0).squeeze(0)           # [hidden_dim]

            # write then read
            self.memory.write(controller_state)
            read_vec, _ = self.memory.read(controller_state)     # ([1×read_dim], [H×M])
            read_vec = read_vec.unsqueeze(0)                     # [1 × read_dim]

            # final output uses both controller_state and read_vec
            fc_in = torch.cat([controller_state.unsqueeze(0), read_vec],
                              dim=-1)                            # [1, hidden+read]
            out = self.fc(fc_in).unsqueeze(1)                    # [1,1,output_dim]
            outputs.append(out)

        return torch.cat(outputs, dim=1)


if __name__ == "__main__":
    # Example usage
    input_dim = 10
    hidden_dim = 20
    output_dim = 5

    memory_size = 100
    vector_dim = 16
    num_read_heads = 1
    num_write_heads = 1

    agi = AGI_Core(input_dim,
                   hidden_dim,
                   output_dim,
                   memory_size,
                   vector_dim,
                   num_read_heads,
                   num_write_heads)

    # Dummy input: batch=1, seq_len=3
    sample = torch.randn(1, 3, input_dim)
    out = agi(sample)

    print("AGI Output Shape:", out.shape)  # → [1, 3, 5]
    print(out)