In [3]:
import torch 
import torch.nn as nn
import numpy as np
from torch.distributions.categorical import Categorical


In [4]:
n = 128
m = 64
k = 3
o = 3

gamma = 0.1
tau = 0.01

In [5]:
C = nn.Parameter(torch.Tensor(n * n))
D = nn.Parameter(torch.Tensor(n * n, m))
F = nn.Parameter(torch.Tensor(m))
H = nn.Parameter(torch.Tensor(m, n * n))

W_in_1 = nn.Parameter(torch.Tensor(n, k))
W_in_2 = nn.Parameter(torch.Tensor(m, k))

with torch.no_grad():
    C.normal_(std = 1. / np.sqrt(n * n))
    D.normal_(std = 1. / np.sqrt(n * n))
    F.normal_(std = 1. / np.sqrt(m))
    H.normal_(std = 1. / np.sqrt(m))
    W_in_1.uniform_(-1. / np.sqrt(k), 1. / np.sqrt(k))
    W_in_2.uniform_(-1. / np.sqrt(k), 1. / np.sqrt(k))

In [29]:
import torch.nn as nn

class Model(nn.Module):
    def __init__(self, neur_cnt, astr_cnt, in_size, out_size, gamma=0.1, tau=0.01):
        super().__init__()
        self.n = neur_cnt
        self.m = astr_cnt
        self.k = in_size
        self.o = out_size
        
        self.gamma = gamma
        self.tau = tau

        self.C = nn.Parameter(torch.Tensor(self.n * self.n))
        self.D = nn.Parameter(torch.Tensor(self.n * self.n, self.m))
        self.F = nn.Parameter(torch.Tensor(self.m))
        self.H = nn.Parameter(torch.Tensor(self.m, self.n * self.n))

        self.W_in_1 = nn.Parameter(torch.Tensor(self.n, self.k))
        self.W_in_2 = nn.Parameter(torch.Tensor(self.m, self.k))


        with torch.no_grad():
            self.C.normal_(std = 1. / np.sqrt(self.n * self.n))
            self.D.normal_(std = 1. / np.sqrt(self.n * self.n))
            self.F.normal_(std = 1. / np.sqrt(self.m))
            self.H.normal_(std = 1. / np.sqrt(self.m))
            self.W_in_1.uniform_(-1. / np.sqrt(self.k), 1. / np.sqrt(self.k))
            self.W_in_2.uniform_(-1. / np.sqrt(self.k), 1. / np.sqrt(self.k))
        

    def phi(self, x):
        return torch.sigmoid(x)

    def Phi(self, x):
        return (self.phi(x) @ self.phi(x).reshape(1, -1)).reshape(-1, 1)

    def psi(self, z):
        return torch.tanh(z)


    def forward(self, I, hidden=None):

        if hidden is None:
            x, W, z =  (torch.zeros(self.n, 1),
                        torch.zeros(self.n, self.n),
                        torch.zeros(self.m, 1))
        else :
            x, W, z = hidden
        
        x = (1 - self.gamma) * x + self.gamma * W @ self.phi(x) + (self.W_in_1 @ I)
        W = (1. - self.gamma) * W + self.gamma * (torch.diag(self.C) @ self.Phi(x) + self.D @ self.psi(z)).reshape(self.n, self.n)
        z = (1. - self.gamma * self.tau) * z + self.gamma * self.tau * (torch.diag(self.F) @ self.psi(z) + self.H @ self.Phi(x) + self.W_in_2 @ I)

        hidden = (x, W, z)
        
        return x, hidden
    


M = Model(neur_cnt=128, astr_cnt=64, in_size=2, out_size=3)

I = torch.ones((2, 1))
state = None
M.forward(I, state)

(tensor([[ 0.4506],
         [-0.3785],
         [-1.2354],
         [ 0.0955],
         [-0.6702],
         [-0.3804],
         [ 1.1943],
         [ 0.8052],
         [-0.3001],
         [-0.0360],
         [-0.1137],
         [-0.1782],
         [ 1.1300],
         [-0.6214],
         [ 0.2712],
         [-0.2148],
         [-0.3060],
         [-0.2194],
         [ 0.5025],
         [-0.1934],
         [ 0.9057],
         [-0.3288],
         [-0.6845],
         [-0.1833],
         [-0.5975],
         [ 0.3016],
         [-0.8432],
         [ 0.3582],
         [ 0.5525],
         [ 0.0683],
         [ 0.0277],
         [ 0.6777],
         [-0.3001],
         [-1.1370],
         [ 1.0836],
         [-0.5642],
         [-0.8647],
         [ 1.0996],
         [-0.3495],
         [ 1.0084],
         [ 1.2731],
         [-0.6271],
         [-1.0080],
         [ 0.3298],
         [ 1.0204],
         [-0.4690],
         [ 0.4809],
         [-0.2664],
         [ 0.7098],
         [ 0.9914],
