In [226]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
#import jax.tools.colab_tpu
#jax.tools.colab_tpu.setup_tpu()
from jax.scipy.signal import correlate

from functools import partial
#!pip install einops
from einops import rearrange, reduce, repeat
import matplotlib.pyplot as plt 
import numpy as onp
import jax.numpy as np
from jax.lax import scan
from jax import grad, jit, vmap, value_and_grad, lax
from jax import random
from jax.ops import index, index_add, index_update
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

In [272]:
class LSNN():
    def __init__(self, w_inp, w_rec, w_out, w_fb):
        self.w_inp = w_inp
        self.w_rec = w_rec
        self.w_out = w_out
        self.w_fb  = w_fb
        self.alpha = 0.95
        self.thr   = 0.1
        self.kappa = 0.95
        self.gamma = 0.95
        self.f0    = 12
        self.n_inp = 80
        self.n_rec = 100
        self.n_t   = 1000
        self.reg   = 12

        # Pre-comp
        self.alpha_conv = np.array([self.alpha ** (self.n_t - i - 1) for i in range(self.n_t)]).astype(float) # 1, 1, n_t
        self.kappa_conv = np.array([self.kappa ** (self.n_t - i - 1) for i in range(self.n_t)]).astype(float) # 1, 1, n_t

    @partial(jit, static_argnums=(0,))
    def calc_inp_trace(self, x, h):
        ''' Optimal implementation of e-trace per inp
            correlate in JAX/scipy is conv1d in PyTorch/TF.
            See https://discuss.pytorch.org/t/numpy-convolve-and-conv1d-in-pytorch/12172/4
        '''
        trace_in = repeat(vmap(correlate, in_axes=(0, None))(x, self.alpha_conv)[:,0:self.n_t], 'i t -> r i t', r=self.n_rec) # in, t
        trace_in = np.einsum('tr,rit->rit', h, trace_in)  # n_r, inp_dim, n_t
        trace_in = vmap(correlate, in_axes=(0, None))(trace_in.reshape(self.n_inp*self.n_rec, self.n_t), self.kappa_conv)[:,0:self.n_t].reshape(self.n_rec, self.n_inp, self.n_t)
        return trace_in

    @partial(jit, static_argnums=(0,))
    def calc_rec_trace(self, z, h):
        trace_rec = repeat(vmap(correlate, in_axes=(0, None))(z.T, self.alpha_conv)[:,0:self.n_t], 'i t -> r i t', r=self.n_rec) # in, t
        trace_rec = np.einsum('tr,rit->rit', h, trace_rec) # n_r, inp_dim, n_t
        trace_rec = vmap(correlate, in_axes=(0, None))(trace_rec.reshape(self.n_rec*self.n_rec, self.n_t), self.kappa_conv)[:,0:self.n_t].reshape(self.n_rec, self.n_rec, self.n_t)
        return trace_rec

    @partial(jit, static_argnums=(0,))
    def calc_out_trace(self, z):
        trace_out = vmap(correlate, in_axes=(0, None))(z.T, self.kappa_conv)[:,0:self.n_t]
        return trace_out

    @partial(jit, static_argnums=(0,))
    def calc_fr(self, z):
        fr = np.sum(z, axis=(0)) / (self.n_t * 1e-3) 
        reg_term = fr - self.f0
        return reg_term
    
    @partial(jit, static_argnums=(0,))
    def pseudo_der(self, v):
        return self.gamma * np.maximum(np.zeros_like(v), 1 - np.abs((v-self.thr)/self.thr))
    
    @partial(jit, static_argnums=(0,))
    def forward(self, x):

        # Reset diagonal
        rec_weight = (- self.w_rec * (np.eye(100) - 1)).T
        inp_weight = self.w_inp.T
        out_weight = self.w_out.T
        
        def f(carry, x):
            v_curr   = carry[0]
            z_curr   = carry[1]
            vo_curr  = carry[2]
            not_init = carry[3]

            v_next  = (self.alpha * v_curr + np.matmul(z_curr, rec_weight) + np.matmul(x, inp_weight) - z_curr * self.thr) * not_init
            z_next  = (v_next > self.thr).astype(np.float32) * not_init
            vo_next = (self.kappa * vo_curr + np.matmul(z_next, out_weight) ) * not_init

            carry = [v_next, z_next, vo_next,  True]
            y = [v_next, z_next, vo_next]
            return carry, y

        _, (v,z,vo) = scan(f, [np.zeros((100)), np.zeros((100)), np.zeros((1)), False], x.T)

        # Pseudo-derivative
        h = self.pseudo_der(v) # nt, nb, nrec

        # E-trace calculation
        trace_inp = self.calc_inp_trace(x, h)
        trace_rec = self.calc_rec_trace(z, h)
        trace_out = self.calc_out_trace(z)
        
        # Calc firing rate
        reg_term = self.calc_fr(z)
        return vo, trace_inp, trace_rec, trace_out, reg_term

    def acc_gradient(self, err, trace_in, trace_rec, trace_out, reg_term):
        L_loss = np.einsum('t o, o r -> r t', err, self.w_fb)
        L_reg  = repeat(reg_term, 'r -> r (t) ', t=self.n_t)

        L = L_loss + self.reg * L_reg

        w_in_grad  +=  np.clamp(-100, np.sum(np.einsum('xt,xyt->xyt', L, trace_in), axis=2), 100)
        w_rec_grad +=  np.clamp(-100, np.sum(np.einsum('xt,xyt->xyt', L, trace_rec), axis=2), 100)
        w_out_grad +=  np.clamp(-100, np.einsum('tbo,brt->or', err, trace_out), 100)
        return w_in_grad, w_rec_grad, w_out_grad

In [273]:
w_inp = random.normal(key, (100,80))
w_rec = random.normal(key, (100,100))
w_out = random.normal(key, (1,100))
w_fb  = random.normal(key, (1,100))

l = LSNN(w_inp, w_rec, w_out, w_fb)

x = (random.normal(key, (80,1000))>0)

In [274]:
import math
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

class Sinusoids(Dataset):
    def __init__(self, seq_length=1000, num_samples=2, num_inputs=80, input_freq=50):
        self.seq_length   = seq_length
        self.num_inputs   = num_inputs
        self.num_samples  = num_samples
        self.freq_list    = torch.tensor([1, 2, 3, 5]) # Hz
        self.dt           = 1e-3 # s
        self.t            = torch.arange(0, 1, self.dt) # s
        self.inp_freq     = input_freq

        # Random input
        self.x = (torch.rand(self.num_samples, self.num_inputs, self.seq_length) < self.dt * self.inp_freq).float()

        # Randomized output amplitude and phase
        amplitude_list = torch.FloatTensor(self.num_samples, len(self.freq_list)).uniform_(0.5, 2)
        phase_list = torch.FloatTensor(self.num_samples, len(self.freq_list)).uniform_(0, 2 * math.pi)

        # Normalized sum of sinusoids
        self.y = torch.zeros(self.num_samples, self.seq_length)
        for i in range(self.num_samples):
          summed_sinusoid = sum([amplitude_list[i, ix] * torch.sin(2*math.pi*f*self.t + phase_list[i, ix]) for ix, f in enumerate(self.freq_list)])
          self.y[i, :] = summed_sinusoid/torch.max(torch.abs(summed_sinusoid))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.x[idx], self.y[idx]



sinusoid_dataset = Sinusoids(seq_length=1000, num_samples=5, num_inputs=80, input_freq=50)

# Parameters
train_percentage = 50
batch_size = 1

train_size = int(len(sinusoid_dataset) * train_percentage/100)
train_set, _ = random_split(sinusoid_dataset, [train_size, len(sinusoid_dataset)-train_size])

train_data = DataLoader(train_set, batch_size, shuffle=True)

In [275]:
epochs = 100

In [276]:
import numpy as onp
import time 
w_inp = random.normal(key, (100,80)) * 0.1
w_rec = random.normal(key, (100,100)) * 0.1
w_out = random.normal(key, (1,100)) * 0.1
w_fb  = random.normal(key, (1,100)) * 0.1

lsnn =  LSNN(w_inp, w_rec, w_out, w_fb)

mse_loss = nn.MSELoss()

for epoch in range(epochs):
    for _, (x, y) in enumerate(train_data):
        y = y.permute(1,0).unsqueeze(-1)
        x = np.array(onp.array(x.squeeze(0)))        
        y = np.array(onp.array(y))
        #t=time.time()
        yhat, trace_inp, trace_rec, trace_out, reg_term = lsnn.forward(x)
        #print(time.time()-t)
        #return trace_inp, trace_rec, trace_out, reg_term
        #trace_inp, trace_rec, trace_out, reg_term = lsnn.calculate_traces(x, h, z)
        lsnn.acc_gradient(yhat-y, trace_inp, trace_rec, trace_out, reg_term)
    if epoch%10 == 0:
        print(epoch)
    # Report


ValueError: Einstein sum subscript 't' does not contain the correct number of indices for operand 0.

In [293]:
from jax.scipy.signal import correlate

n_t = 1000
n_rec = 100 
n_b = 1
n_inp = 80

x = np.ones((80,1000))
h = np.ones((1000,100))

alpha_conv = np.array([0.95 ** (n_t - i - 1) for i in range(n_t)]) # 1, 1, n_t
kappa_conv = np.array([0.80 ** (n_t - i - 1) for i in range(n_t)]) # 1, 1, n_t
kappa_conv = index_update(kappa_conv, index[412],12.)
           
trace_in = repeat(vmap(correlate, in_axes=(0, None))(x, alpha_conv)[:,0:n_t], 'i t -> r i t', r=n_rec) # in, t
trace_in = np.einsum('tr,rit->rit', h, trace_in)  # n_r, inp_dim, n_t
trace_in = vmap(correlate, in_axes=(0, None))(trace_in.reshape(n_inp*n_rec, n_t), kappa_conv)[:,0:n_t].reshape(n_rec, n_inp, n_t)

%timeit vmap(correlate, in_axes=(0, None))(trace_in.reshape(n_inp*n_rec, n_t), kappa_conv)[:,0:n_t].reshape(n_rec, n_inp, n_t)

#plt.plot(trace_in[0,0,:])

17.9 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [285]:
import torch
import torch.nn.functional as F

x2 = torch.ones((1,80,1000))
h2 = torch.ones((n_t, n_b, n_rec))
alpha_conv2 =  torch.tensor([0.95 ** (n_t - i - 1) for i in range(n_t)]).float().view(1, 1, -1) # 1, 1, n_t
kappa_conv2 =  torch.tensor([0.80 ** (n_t - i - 1) for i in range(n_t)]).float().view(1, 1, -1) # 1, 1, n_t
kappa_conv2[0,0,412] = 12.

trace_in2 = F.conv1d(x2, alpha_conv2.expand( n_inp, -1, -1), padding= n_t, groups= n_inp)[:, :, 1: n_t + 1].unsqueeze(1).expand(-1,  n_rec, -1, -1) #n_b, n_rec, inp_dim, n_t
trace_in2 = torch.einsum('tbr,brit->brit', h2,  trace_in2)  # n_b, n_r, inp_dim, n_t
trace_in2 = F.conv1d( trace_in2.reshape( n_b,  n_inp *  n_rec,  n_t),kappa_conv2.expand( n_inp *  n_rec, -1, -1), padding= n_t, groups= n_inp *  n_rec)[:, :, 1: n_t + 1].reshape( n_b,  n_rec,  n_inp,  n_t)

%timeit F.conv1d( trace_in2.reshape( n_b,  n_inp *  n_rec,  n_t),kappa_conv2.expand( n_inp *  n_rec, -1, -1), padding= n_t, groups= n_inp *  n_rec)[:, :, 1: n_t + 1].reshape( n_b,  n_rec,  n_inp,  n_t)

#plt.plot(trace_in2[0,0,0,:])

1.93 s ± 96.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


(DeviceArray([[[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               ...,
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21]],
 
              [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
                1.7082366e-21, 2.1352967e-21, 2.5261975e-21],
               [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
               

In [120]:
np.exp(-1e-3 / 30e-3)

DeviceArray(0.96721613, dtype=float32)