In [59]:
import torch

torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f100e5208f0>

In [60]:
class Layer:
    def __init__(
        self,
        num_in: int,
        num_out: int,
        threshold: float = 1,
        mem_decay: float = 0.99,
        ls_decay:float = 0,
        lr: float = 1e-5,
        device: str = "cuda",
    ) -> None:
        self.num_in = num_in
        self.num_out = num_out
        
        self.mem_decay = mem_decay
        self.ls_decay = ls_decay
        self.lr = lr
        
        self.w = torch.randn(num_in, num_out).to(device)
        self.mem = torch.zeros(num_out).to(self.w)
        self.threshold = torch.ones(num_out).to(self.w) * threshold
        
        self.ls = torch.zeros_like(self.mem)
        
        self.in_trace = torch.zeros(num_in).to(self.w)
    
    def update_ls(
        self,
        learning_signal: torch.Tensor,
    ) -> None:
        assert learning_signal.size() == (self.num_out,), f"Learning Signal update: expected size ({self.num_out},), got {learning_signal.size()} instead"
        learning_signal = learning_signal.to(self.ls)
        
        self.ls *= self.ls_decay
        self.ls += learning_signal
    
    def backward(
        self,
    ) -> torch.Tensor:
        ls = self.ls.unsqueeze(0)  # [1, out]
        ls = torch.broadcast_to(ls, (self.num_in, self.num_out))  # [in, out]
        
        input_tensor = self.in_trace.unsqueeze(-1)  # [in, 1]
        input_tensor = torch.broadcast_to(input_tensor, (self.num_in, self.num_out))  # [in, out]
        
        delta_w = ls * input_tensor  # d_ls/d_w = input
        delta_input = ls * self.w  # d_ls/d_in = weight
        
        self.w += delta_w * self.lr
        
        passed_ls = delta_input.sum(dim=-1)
        
        return passed_ls
    
    def forward(
        self,
        in_spikes: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert in_spikes.size() == (self.num_in,), f"Forward: expected size ({self.num_in},), got {in_spikes.size()} instead"
        in_spikes = in_spikes.to(self.w)
        
        back_pass_ls = self.backward()
        
        self.in_trace *= self.mem_decay
        self.in_trace += in_spikes
        
        current = in_spikes @ self.w
        self.mem *= self.mem_decay
        self.mem += current
        
        out_spikes = (self.mem >= self.threshold).float()
        
        self.mem -= out_spikes * self.threshold
        
        return out_spikes, back_pass_ls

n_in, n_out = 5, 2
test_layer = Layer(n_in, n_out)
in_spikes = torch.rand(n_in).round()
print(f"Incoming Spikes: {in_spikes}")
out_spikes, learning_signals = test_layer.forward(in_spikes)
print(f"Out Spikes: {out_spikes}")
print(f"Learning Signal: {learning_signals}")

Incoming Spikes: tensor([1., 1., 0., 1., 0.])
Out Spikes: tensor([0., 1.], device='cuda:0')
Learning Signal: tensor([0., 0., 0., 0., 0.], device='cuda:0')


In [None]:
class Network:
    def __init__(
        self,
    ) -> None:
        self.layers = [Layer(100,100) for _ in range(3)]
        self.layers[-1] = Layer(100,10)
    
    def forward(
        self,
        spike_input: torch.Tensor,
        expected_output: torch.Tensor,
    ) -> torch.Tensor:
        spike_input = spike_input.to(self.layers[0].w)
        expected_output = expected_output.to(self.layers[0].w)
        
        o, ls = self.layers[0].forward(spike_input)
        
        for i in range(1, len(self.layers)):
            o, ls = self.layers[i].forward(o)
            self.layers[i-1].update_ls(ls)
        
        ls = o - expected_output
        self.layers[-1].update_ls(ls)
        
        return ls

    def get_learning_signals(
        self,
    ) -> list[torch.Tensor]:
        ls = []        
        for layer in self.layers:
            ls.append(layer.ls)
        return ls

bpot = Network()

start_l1_weight = bpot.layers[0].w.detach().clone()

input_spikes = torch.rand(100).round()
expected_output = torch.rand(10).round()

num_timesteps = 10
for i in range(num_timesteps):
    ls = bpot.forward(input_spikes, expected_output)
    print(f"Timestep {i} - LS for last layer: {ls}, loss={ls.sum()}")


signals = bpot.get_learning_signals()
for index, signal in enumerate(signals):
    print(f"Signal for layer {index}: {signal[:5]}, loss={signal.sum()}")

end_l1_weight = bpot.layers[0].w.detach().clone()

print(f"Difference for training: {(start_l1_weight-end_l1_weight).sum()}")

Timestep 0 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0., -1.], device='cuda:0'), loss=2.0
Timestep 1 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.], device='cuda:0'), loss=3.0
Timestep 2 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.], device='cuda:0'), loss=3.0
Timestep 3 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0., -1.], device='cuda:0'), loss=2.0
Timestep 4 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.], device='cuda:0'), loss=3.0
Timestep 5 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.], device='cuda:0'), loss=3.0
Timestep 6 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1., -1.], device='cuda:0'), loss=3.0
Timestep 7 - LS for last layer: tensor([ 1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0., -1.], device='cuda:0'), loss=2.0
Timestep 8 - LS for last layer: tensor([ 1.,  0.,  0.,  