In [3]:
import torch

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f123b1ffb10>

In [None]:
class Layer:
    def __init__(
        self,
        num_in: int,
        num_out: int,
        threshold: float = 1,
        mem_decay: float = 0.99,
        ls_decay:float = 0,
        device: str = "cuda",
    ) -> None:
        self.num_in = num_in
        self.num_out = num_out
        
        self.mem_decay = mem_decay
        self.ls_decay = ls_decay
        
        self.w = torch.randn(num_in, num_out).to(device)
        self.mem = torch.zeros(num_out).to(self.w)
        self.threshold = torch.ones(num_out).to(self.w) * threshold
        
        self.ls = torch.zeros_like(self.mem)
        
        self.in_trace = torch.zeros(num_in).to(self.w)
    
    def update_ls(
        self,
        learning_signal: torch.Tensor,
    ) -> None:
        assert learning_signal.size() == (self.num_out,), f"Learning Signal update: expected size ({self.num_out},), got {learning_signal.size()} instead"
        learning_signal = learning_signal.to(self.ls)
        
        self.ls *= self.ls_decay
        self.ls += learning_signal
    
    def backward(
        self,
    ) -> torch.Tensor:
        ls = self.ls.unsqueeze(0)
        ls = torch.broadcast_to(ls, (self.num_in, self.num_out))
        
        input_tensor = self.in_trace.unsqueeze(-1)
        input_tensor = torch.broadcast_to(input_tensor, (self.num_in, self.num_out))
        
        delta_w = ls * input_tensor
        delta_input = ls * self.w # [in, out], has to be turned into size [in]
        
        passed_ls = delta_input.sum(dim=-1)
        
        return passed_ls
    
    def forward(
        self,
        in_spikes: torch.Tensor,
    ) -> torch.Tensor:
        assert in_spikes.size() == (self.num_in,), f"Forward: expected size ({self.num_in},), got {in_spikes.size()} instead"
        in_spikes = in_spikes.to(self.w)
        
        self.in_trace *= self.mem_decay
        self.in_trace += in_spikes
        
        current = in_spikes @ self.w
        self.mem *= self.mem_decay
        self.mem += current
        
        out_spikes = (self.mem >= self.threshold).float()
        
        self.mem -= out_spikes * self.threshold
        
        return out_spikes

n_in, n_out = 10, 10
test_layer = Layer(n_in, n_out)
rand = torch.rand(n_in).round()
out = test_layer.forward(rand)
print(out)

tensor([0., 1., 1., 0., 1., 0., 0., 1., 1., 1.], device='cuda:0')
