<a href="https://colab.research.google.com/github/Leo-Lifeblood/Projects/blob/main/Example_Neural_Computer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Node(nn.Module):
    def __init__(self, output_size):
      super(Node, self).__init__()
      self.key = nn.Parameter(torch.randn(1, output_size))
      self.norm = nn.LazyBatchNorm1d()
      self.fc = nn.LazyLinear(output_size*4)
      self.fc2 = nn.LazyLinear(output_size)
      self.terminator = nn.LazyLinear(1)
      self.terminator_activation = nn.Sigmoid()
      self.activation = nn.SiLU()

    def forward(self, x):
      x2 = self.norm(x)
      x2 = self.fc(x2)
      x_int = self.activation(x2)
      x2 = self.fc2(x_int)
      x2 = self.activation(x2)
      term = self.terminator(x_int)
      term = self.terminator_activation(term)

      return x+x2, term



In [None]:
class Node(nn.Module):
    def __init__(self, output_size):
      super(Node, self).__init__()
      self.key = nn.Parameter(torch.randn(1, output_size))
      self.norm = nn.LazyBatchNorm1d()
      self.fc = nn.LazyLinear(output_size*4)
      self.fc2 = nn.LazyLinear(output_size)
      self.terminator = nn.LazyLinear(1)
      self.terminator_activation = nn.Sigmoid()
      self.activation = nn.SiLU()

    def forward(self, x):
      x2 = self.norm(x)
      x2 = self.fc(x2)
      x_int = self.activation(x2)
      x2 = self.fc2(x_int)
      x2 = self.activation(x2)
      term = self.terminator(x_int)
      term = self.terminator_activation(term)

      return x+x2, term

class NeuralNetwork(nn.Module):
    def __init__(self, num_options=6, hidden_dim=64):
        super(NeuralNetwork, self).__init__()
        self.state = nn.LazyLinear(hidden_dim)
        self.action = nn.LazyLinear(hidden_dim)
        self.layer_reservoir = nn.ModuleList([Node(hidden_dim) for _ in range(num_options)])
        self.num_options = num_options
        self.hidden_dim = hidden_dim

        self.extender1 = nn.LazyLinear(hidden_dim)
        self.extender2 = nn.LazyLinear(hidden_dim)


    def forward(self, x):
        x_state = self.state(x)
        x_action = self.action(x)

        terminator = torch.ones(x.shape[0],1)

        heads = torch.concat([i.key for i in self.layer_reservoir], dim=0)

        end = 0

        step_count = 0

        while (end <= 3) and (step_count < 15):
          queries = x_action#F.normalize(x_action)

          keys = heads.unsqueeze(0)


          scores = torch.matmul(queries, keys.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.num_options, dtype=torch.float32))
          scores = torch.softmax(scores, dim=-1).squeeze(1).transpose(0, 1)

          predictions = [i(x_state) for i in self.layer_reservoir]

          terminators = torch.stack([i[1] for i in predictions], dim=1)
          predictions = torch.stack([i[0] for i in predictions], dim=1)


          prediction = torch.sum(predictions * scores.transpose(-2,-1), dim=1)

          term = (terminators * scores.transpose(-2,-1)).squeeze().sum(dim=-1, keepdim=True)

          terminator *= term

          x_state = x_state + self.extender1(prediction)
          x_action = x_action + self.extender2(prediction)

          step_count += 1

          if terminator.mean() > 0.1:
            end +=1

        return x_state

    def logging_forward(self, x):
        x_state = self.state(x)
        x_action = self.action(x)

        terminator = torch.ones(x.shape[0],1)

        heads = torch.concat([i.key for i in self.layer_reservoir], dim=0)

        end = 0

        step_count = 0

        while (end <= 3) and (step_count < 15):
          queries = x_action#F.normalize(x_action)

          keys = heads.unsqueeze(0)


          scores = torch.matmul(queries, keys.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.num_options, dtype=torch.float32))
          scores = torch.softmax(scores, dim=-1).squeeze(1).transpose(0, 1)

          print(scores)

          predictions = [i(x_state) for i in self.layer_reservoir]

          terminators = torch.stack([i[1] for i in predictions], dim=1)
          predictions = torch.stack([i[0] for i in predictions], dim=1)


          prediction = torch.sum(predictions * scores.transpose(-2,-1), dim=1)

          term = (terminators * scores.transpose(-2,-1)).squeeze().sum(dim=-1, keepdim=True)

          terminator *= term

          x_state = x_state + self.extender1(prediction)
          x_action = x_action + self.extender2(prediction)

          step_count += 1

          if terminator.mean() > 0.1:
            end +=1

        return x_state





In [None]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7db4a3ca4e50>

In [None]:
example = NeuralNetwork()

In [None]:
optimizer = optim.Adam(example.parameters(), lr=0.001)

In [None]:
example_data = torch.randn(2, 64)

In [None]:
example.eval()
example.logging_forward(example_data)
example.train()

tensor([[[0.0031, 0.0046, 0.5644, 0.0080, 0.3446, 0.0753]],

        [[0.0501, 0.5371, 0.2608, 0.0573, 0.0613, 0.0332]]],
       grad_fn=<TransposeBackward0>)
tensor([[[0.0048, 0.0156, 0.3785, 0.0404, 0.4549, 0.1058]],

        [[0.2106, 0.6071, 0.0364, 0.0682, 0.0724, 0.0053]]],
       grad_fn=<TransposeBackward0>)
tensor([[[0.0046, 0.0795, 0.1308, 0.4704, 0.2859, 0.0289]],

        [[0.5558, 0.3604, 0.0049, 0.0435, 0.0342, 0.0012]]],
       grad_fn=<TransposeBackward0>)
tensor([[[1.7476e-04, 3.5297e-02, 1.9736e-03, 9.5894e-01, 3.5257e-03,
          9.0016e-05]],

        [[9.0791e-01, 7.2898e-02, 4.7926e-04, 1.1984e-02, 6.4962e-03,
          2.3455e-04]]], grad_fn=<TransposeBackward0>)
tensor([[[4.0653e-07, 3.2399e-03, 4.1318e-06, 9.9675e-01, 8.9683e-07,
          4.1396e-09]],

        [[9.9350e-01, 4.3677e-03, 3.2237e-05, 1.4163e-03, 6.6211e-04,
          2.6612e-05]]], grad_fn=<TransposeBackward0>)
tensor([[[2.2851e-11, 1.4482e-04, 3.7385e-09, 9.9986e-01, 3.3350e-12,
          2.9

NeuralNetwork(
  (state): Linear(in_features=64, out_features=64, bias=True)
  (action): Linear(in_features=64, out_features=64, bias=True)
  (layer_reservoir): ModuleList(
    (0-5): 6 x Node(
      (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (fc): Linear(in_features=64, out_features=256, bias=True)
      (fc2): Linear(in_features=256, out_features=64, bias=True)
      (terminator): Linear(in_features=256, out_features=1, bias=True)
      (terminator_activation): Sigmoid()
      (activation): SiLU()
    )
  )
  (extender1): Linear(in_features=64, out_features=64, bias=True)
  (extender2): Linear(in_features=64, out_features=64, bias=True)
)

In [None]:
num_examples = 2048
batch_size = 256

example_x = torch.randn(num_examples, 64)
example_y = torch.randn(num_examples, 64)

for epoch in range(10):
  for batch in range(0, num_examples, batch_size):
    batch_x = example_x[batch:batch+batch_size]
    batch_y = example_y[batch:batch+batch_size]
    optimizer.zero_grad()
    output = example(batch_x)
    loss = F.mse_loss(output, batch_y)
    loss.backward()
    optimizer.step()
    print(loss.item())

111.86737823486328
109.95809936523438
110.28087615966797
108.80140686035156
112.84201049804688
103.87139892578125
108.24845886230469
107.38261413574219
100.46147155761719
98.75789642333984
99.04219818115234
98.05561065673828
101.63528442382812
93.55934143066406
97.55122375488281
96.94429779052734
90.48837280273438
89.01153564453125
89.25438690185547
88.72948455810547
91.90441131591797
84.5938491821289
88.26659393310547
87.88243865966797
81.86488342285156
80.5920181274414
80.76351928710938
80.64463806152344
83.4783706665039
76.81646728515625
80.2093505859375
80.04308319091797
74.39612579345703
73.30374908447266
73.3805160522461
73.6016616821289
76.15754699707031
70.02249145507812
73.16389465332031
73.23617553710938
67.90100860595703
66.95006561279297
66.9102554321289
67.4134521484375
69.7148208618164
64.03864288330078
66.95362854003906
67.2740249633789
62.18050765991211
61.365596771240234
61.17934036254883
61.921512603759766
64.00006103515625
58.7158203125
61.44645309448242
61.991462707

In [None]:
example.eval()
example.logging_forward(example_data)
example.train()

tensor([[[8.1711e-04, 3.5442e-03, 9.2644e-01, 7.8048e-03, 3.7411e-02,
          2.3983e-02]],

        [[8.2024e-03, 4.8116e-01, 4.2774e-01, 4.0761e-02, 2.3652e-02,
          1.8481e-02]]], grad_fn=<TransposeBackward0>)
tensor([[[8.0713e-04, 2.6317e-02, 9.0784e-01, 4.2102e-02, 9.5266e-03,
          1.3407e-02]],

        [[1.4378e-02, 8.4126e-01, 9.4974e-02, 3.5612e-02, 1.1143e-02,
          2.6330e-03]]], grad_fn=<TransposeBackward0>)
tensor([[[2.5253e-04, 2.1797e-01, 4.9586e-01, 2.8407e-01, 6.7387e-04,
          1.1764e-03]],

        [[1.6084e-02, 9.3073e-01, 3.3705e-02, 1.6895e-02, 1.9873e-03,
          6.0294e-04]]], grad_fn=<TransposeBackward0>)
tensor([[[8.2765e-06, 3.7612e-01, 5.4226e-02, 5.6963e-01, 3.1949e-06,
          5.2411e-06]],

        [[2.0846e-02, 9.3585e-01, 3.4825e-02, 7.7912e-03, 3.3936e-04,
          3.5005e-04]]], grad_fn=<TransposeBackward0>)
tensor([[[7.8269e-08, 2.2780e-01, 3.6182e-03, 7.6859e-01, 1.6927e-09,
          2.6556e-09]],

        [[3.8427e-02, 8.4

NeuralNetwork(
  (state): Linear(in_features=64, out_features=64, bias=True)
  (action): Linear(in_features=64, out_features=64, bias=True)
  (layer_reservoir): ModuleList(
    (0-5): 6 x Node(
      (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (fc): Linear(in_features=64, out_features=256, bias=True)
      (fc2): Linear(in_features=256, out_features=64, bias=True)
      (terminator): Linear(in_features=256, out_features=1, bias=True)
      (terminator_activation): Sigmoid()
      (activation): SiLU()
    )
  )
  (extender1): Linear(in_features=64, out_features=64, bias=True)
  (extender2): Linear(in_features=64, out_features=64, bias=True)
)