In [None]:
import _pickle as cPickle
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
with open('JSB Chorales.pickle', 'rb') as p:
    data = cPickle.load(p)

In [None]:
def convert_binary(piano_keys):
  piano_binary = torch.zeros(88)
  
  for key in piano_keys:
    piano_binary[key-21] = 1
  
  return piano_binary

In [None]:
class JSB_Chorales(Dataset):
    def __init__(self, pickle_file_path='filepath', type_data='string'):
      
      with open(pickle_file_path, 'rb') as p:
        data = cPickle.load(p, encoding="latin1")

      self.sequences = data[type_data]
      self.lengths = [len(sentence) for sentence in data[type_data]]
        
    def __len__(self):
      return len(self.sequences)

    def __getitem__(self, index):
      longest_sent = max(self.lengths)
      seq = self.sequences[index]
      inp = torch.zeros(longest_sent,88)
      i = 0
      for time in seq:
        inp[i,:] = convert_binary(time)
        i = i + 1
      
      label = torch.zeros(longest_sent,88)
      label[:longest_sent-1,:] = inp[1:,:]

      return inp, label, self.lengths[index]

In [None]:
train_dataset = JSB_Chorales(pickle_file_path='JSB Chorales.pickle',type_data='train')
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size = len(data['train']), 
                                            shuffle=True, 
                                            num_workers=1)

valid_dataset = JSB_Chorales(pickle_file_path='JSB Chorales.pickle',type_data='valid')
valid_dataloader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size = len(data['valid']), 
                                            shuffle=False, 
                                            num_workers=1)

test_dataset = JSB_Chorales(pickle_file_path='JSB Chorales.pickle',type_data='test')
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size = len(data['test']), 
                                            shuffle=False, 
                                            num_workers=1)

# **Using LSTMCell**

In [None]:
lstm = nn.LSTMCell(88, 400)
linear = nn.Linear(400,88)

In [None]:
loss_function = nn.BCEWithLogitsLoss()
optimizer_lstm = optim.Adam(lstm.parameters())
optimizer_linear = optim.Adam(linear.parameters())


In [None]:
def train_model(num_epochs):

  Best_loss = 1000.00  #set to a very large value
  for epoch in range(num_epochs):
    print(f'Epoch {epoch}/{num_epochs}')
    print('-' * 10)
    lstm.train()
    linear.train()
    running_loss_train = 0.0
    running_loss_valid = 0.0

    for inputs, labels, lengths in train_dataloader:
      
      loss = 0.0 

      inputs = inputs.permute(1,0,2)
      labels = labels.permute(1,0,2)
      optimizer_lstm.zero_grad()
      optimizer_linear.zero_grad()
      
      hx = torch.randn(229, 400) # (batch, hidden_size)
      cx = torch.randn(229, 400)
      output_lstm = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lstm(inputs[i], (hx, cx))
          output_lstm.append(hx)

      output_lstm = torch.stack(output_lstm, dim=0)
      output = linear(output_lstm)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      loss.backward()
      optimizer_lstm.step()
      optimizer_linear.step()

      running_loss_train += loss.item() * inputs.size(0)
      epoch_loss_train = running_loss_train / len(train_dataset)
      
      loss_values_train.append(epoch_loss_train)
    print(f'Train Loss: {epoch_loss_train:.8f}')

    #Validation starts here (for each epoch)
    lstm.eval()
    linear.eval()


    for inputs, labels, lengths in valid_dataloader:
      
      loss = 0.0 
      inputs = inputs.reshape(144,76,88)  #(time_steps, batch, hidden_size)
      labels = labels.reshape(144,76,88)
      
      hx = torch.randn(76, 400) # (batch, hidden_size)
      cx = torch.randn(76, 400)
      output_lstm = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lstm(inputs[i], (hx, cx))
          output_lstm.append(hx)

      output_lstm = torch.stack(output_lstm, dim=0)
      output = linear(output_lstm)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      running_loss_valid += loss.item() * inputs.size(0)
      epoch_loss_valid = running_loss_valid / len(valid_dataset)
      
      loss_values_valid.append(epoch_loss_valid)
      break
    print(f'Valid Loss: {epoch_loss_valid:.8f}')

    if(epoch_loss_valid<Best_loss):
      Best_loss = epoch_loss_valid
      torch.save(lstm.state_dict(), 'lstm_best'+'.pt')
      torch.save(linear.state_dict(), 'linear_lstm_best'+'.pt')


   
loss_values_train = []
loss_values_valid = []   
model_ft = train_model(num_epochs=400)

In [None]:
def test_model():
    lstm.eval()
    linear.eval()

    for inputs, labels, lengths in test_dataloader:
      
      loss = 0.0 
      inputs = inputs.reshape(160,77,88)  #(time_steps, batch, hidden_size)
      labels = labels.reshape(160,77,88)
      
      hx = torch.zeros(77, 400) # (batch, hidden_size)
      cx = torch.zeros(77, 400)
      output_lstm = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lstm(inputs[i], (hx, cx))
          output_lstm.append(hx)

      output_lstm = torch.stack(output_lstm, dim=0)
      output = linear(output_lstm)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      running_loss_test = loss.item() * inputs.size(0)
      epoch_loss_test = running_loss_test / len(test_dataset)
    print(f'Test Loss: {epoch_loss_test:.8f}')

In [None]:
lstm.load_state_dict(torch.load("lstm_best.pt"))
linear.load_state_dict(torch.load("linear_lstm_best.pt"))
test_model()

# **Using LMUCell**

In [None]:
!pip install nengolib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from nengolib.signal import Identity,cont2discrete
from nengolib.synapses import LegendreDelay
import numpy as np
from scipy.special import comb


class LMU(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, theta, matrix_type='pl',discretizer = 'zoh',nonlinearity='sigmoid', A_learnable = False, B_learnable = False):
        super(LMU, self).__init__()

        ### SIZE
        self.k = input_size
        self.n = hidden_size
        self.d = memory_size

        ### PARAMETERS
        self.Wx = nn.Parameter(torch.Tensor(self.n,self.k))
        self.Wh = nn.Parameter(torch.Tensor(self.n,self.n))
        self.Wm = nn.Parameter(torch.Tensor(self.n,self.d))
        self.ex = nn.Parameter(torch.Tensor(1,self.k))
        self.eh = nn.Parameter(torch.Tensor(1,self.n))
        self.em = nn.Parameter(torch.Tensor(1,self.d))


        if matrix_type=='pl':   #For Legendre Memory Unit
            order=self.d
            Q = np.arange(order, dtype=np.float64)
            R = (2 * Q + 1)[:, None] / theta
            j, i = np.meshgrid(Q, Q)
            A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
            B = (-1.0) ** Q[:, None] * R
            C = np.ones((1, order))
            D = np.zeros((1,))
            self._ss = cont2discrete((A, B, C, D), dt=0.01, method=discretizer)
            self._A = self._ss.A
            self._B = self._ss.B
        elif matrix_type=='p':  #For Pade Memory Unit
            order=self.d
            Q=np.arange(order,dtype=np.float64)
            V=(order+Q+1)*(order-Q)/(Q+1)/theta
            A=np.zeros([order,order],dtype=np.float64)
            B=np.zeros([order,1],dtype=np.float64)
            A[0,:]=-V[0]
            A[1:order,0:order-1]=np.diag(V[1:order])
            B[0]=V[0]
            C = np.ones((1, order))
            D = np.zeros((1,))
            self._ss = cont2discrete((A, B, C, D), dt=0.01, method=discretizer)
            self._A = self._ss.A
            self._B = self._ss.B
        elif matrix_type=='pb':  #For Bernstein Memory Unit
            order=self.d
            Q = np.arange(order, dtype=np.float64)
            R = (2 * Q + 1)[:, None] / theta
            j, i = np.meshgrid(Q, Q)
            A_leg = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
            B_leg = (-1.0) ** Q[:, None] * R
            C = np.ones((1, order))
            D = np.zeros((1,))
            M=np.zeros([order,order],dtype=np.float64)
            M_inv=np.zeros([order,order],dtype=np.float64)
            n=order-1 #degree of polynomial
            for j in range(0,n+1):
              for k in range(0,n+1):
                ll=max(0,j+k-n)
                ul=min(j,k)+1
                sum=0.0
                for i in range(ll,ul):
                  sum=sum+((-1.0)**(k+i))*(comb(k,i)**2)*comb(n-k,j-i)
                M[j,k]=sum/comb(n,j)

                sum=0.0
                for i in range(0,j+1):
                  sum=sum+(-1.0)**(j+i)*comb(j,i)**2/comb(n+j,k+i)
                M_inv[j,k]=(2*j+1)/(n+j+1)*comb(n,k)*sum

            M=10*np.tanh(M/10)
            M_inv=10*np.tanh(M_inv/10)

            A_1=np.matmul(M,A_leg)
            A=np.matmul(A_1,M_inv)
            B=np.matmul(M,B_leg)

            self._ss = cont2discrete((A, B, C, D), dt=0.01, method=discretizer)
            self._A = self._ss.A
            self._B = self._ss.B

        ### NON-LINEARITY
        self.nl = nonlinearity
        if self.nl == 'sigmoid':
            self.act = nn.Sigmoid()
        elif self.nl == 'tanh':
            self.act = nn.Tanh()
        else:
            self.act = nn.ReLU()

        ### INITIALIZATION
        torch.nn.init.xavier_normal_(self.Wm)    ##### FIGURE THIS OUT!!
        torch.nn.init.xavier_normal_(self.Wx)
        torch.nn.init.xavier_normal_(self.Wh)
        torch.nn.init.zeros_(self.em)
        torch.nn.init.uniform_(self.ex, -np.sqrt(3/self.d), np.sqrt(3/self.d))
        torch.nn.init.uniform_(self.eh, -np.sqrt(3/self.d), np.sqrt(3/self.d))


        #### TRIAL
        self.register_buffer('AT', torch.Tensor(self._A))
        self.register_buffer('BT', torch.Tensor(self._B))
        if A_learnable:
            self.AT = nn.Parameter(self.AT)
        if B_learnable:
            self.BT = nn.Parameter(self.BT)


    def forward(self,x,hm):

        h,m = hm
        u = F.linear(x,self.ex)+F.linear(h,self.eh)+F.linear(m,self.em)
        new_m = F.linear(m,self.AT) + F.linear(u,self.BT)
        new_h = self.act(F.linear(x,self.Wx)+F.linear(h,self.Wh)+F.linear(new_m,self.Wm))

        return new_h,new_m


In [None]:
class ASSVMU(nn.Module):
    '''
    Also known as Linear LMU
    '''
    def __init__(self, input_size, hidden_size, memory_size, theta, name='garbage', discretizer = 'zoh',nonlinearity='sigmoid', 
                        A_learnable = False, B_learnable = False, activate=False):
        super(ASSVMU, self).__init__()
        
        ### SIZE
        self.k = input_size
        self.n = hidden_size
        self.d = memory_size
        

        ### PARAMETERS
        self.Wx = nn.Parameter(torch.Tensor(self.n,self.k))
        self.Wh = nn.Parameter(torch.Tensor(self.n,self.n))
        self.Wm = nn.Parameter(torch.Tensor(self.n,self.d))
        self.ex = nn.Parameter(torch.Tensor(1,self.k))
        self.eh = nn.Parameter(torch.Tensor(1,self.n))
        self.em = nn.Parameter(torch.Tensor(1,self.d))

        ### A,B MATRIX ----- FIX??
        order=self.d
        Q = np.arange(order, dtype=np.float64)
        R = (2 * Q + 1)[:, None] / theta
        j, i = np.meshgrid(Q, Q)
        A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
        B = (-1.0) ** Q[:, None] * R
        C = np.ones((1, order))
        D = np.zeros((1,))
        self._ss = cont2discrete((A, B, C, D), dt=0.01, method=discretizer)
        self._A = self._ss.A
        self._B = self._ss.B

        ### NON-LINEARITY
        self.nl = nonlinearity
        if self.nl == 'sigmoid':
            self.act = nn.Sigmoid()
        elif self.nl == 'tanh':
            self.act = nn.Tanh()
        else:
            self.act = nn.ReLU()

        ### NN
        self.fc = nn.Linear(self.n,self.n)

        if activate:
            self.nn_act = self.act
        else:
            self.nn_act = nn.LeakyReLU(1.0) #Identity Function

        ### INITIALIZATION
        torch.nn.init.xavier_normal_(self.Wm)    ##### FIGURE THIS OUT!!
        torch.nn.init.xavier_normal_(self.Wx)
        torch.nn.init.xavier_normal_(self.Wh)
        torch.nn.init.zeros_(self.em)
        torch.nn.init.uniform_(self.ex, -np.sqrt(3/self.d), np.sqrt(3/self.d))
        torch.nn.init.uniform_(self.eh, -np.sqrt(3/self.d), np.sqrt(3/self.d))


        #### TRIAL
        self.register_buffer('AT', torch.Tensor(self._A))
        self.register_buffer('BT', torch.Tensor(self._B))
        if A_learnable:
            self.AT = nn.Parameter(self.AT)
        if B_learnable:
            self.BT = nn.Parameter(self.BT)

    def forward(self,x,hm):

        h,m = hm 
        u = F.linear(x,self.ex)+F.linear(h,self.eh)+F.linear(m,self.em)
        new_m = F.linear(m,self.AT) + F.linear(u,self.BT)
        new_h = self.act(F.linear(x,self.Wx)+F.linear(h,self.Wh)+F.linear(new_m,self.Wm))
        new_h = self.nn_act(self.fc(new_h))
        return new_h,new_m

In [None]:
lmu = ASSVMU(input_size=88,hidden_size=400,memory_size=44,theta=5)
linear = nn.Linear(400,88)

In [None]:
loss_function = nn.BCEWithLogitsLoss()
optimizer_lmu = optim.Adam(lmu.parameters())
optimizer_linear = optim.Adam(linear.parameters())

In [None]:
def train_model(num_epochs): 

  Best_loss = 1000.00  #set to a very large value
  for epoch in range(num_epochs):
    print(f'Epoch {epoch}/{num_epochs}')
    print('-' * 10)
    lmu.train()
    running_loss_train = 0.0
    running_loss_valid = 0.0

    for inputs, labels, lengths in train_dataloader:
      
      loss = 0.0 

      inputs = inputs.permute(1,0,2)
      labels = labels.permute(1,0,2)

      optimizer_lmu.zero_grad()
      optimizer_linear.zero_grad()
      
      hx = torch.zeros(229, 400) # (batch, hidden_size)
      cx = torch.zeros(229, 44)
      output_lmu = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lmu(inputs[i], (hx,cx))
          output_lmu.append(hx)

      output_lmu = torch.stack(output_lmu, dim=0)
      output = linear(output_lmu)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      loss.backward()
      optimizer_lmu.step()
      optimizer_linear.step()
      running_loss_train += loss.item() * inputs.size(0)
      epoch_loss_train = running_loss_train / len(train_dataset)

      lmu_loss_values_train.append(epoch_loss_train)
    print(f'Train Loss: {epoch_loss_train:.8f}')

    lmu.eval()
    linear.eval()


    for inputs, labels, lengths in valid_dataloader:
      
      loss = 0.0 

      inputs = inputs.permute(1,0,2)
      labels = labels.permute(1,0,2)
      
      hx = torch.zeros(76, 400) # (batch, hidden_size)
      cx = torch.zeros(76, 44)
      output_lmu = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lmu(inputs[i], (hx,cx))
          output_lmu.append(hx)

      output_lmu = torch.stack(output_lmu, dim=0)
      output = linear(output_lmu)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      running_loss_valid += loss.item() * inputs.size(0)
      epoch_loss_valid = running_loss_valid / len(valid_dataset)
      lmu_loss_values_valid.append(epoch_loss_valid)
      break
    print(f'Valid Loss: {epoch_loss_valid:.8f}')
    
    if(epoch_loss_valid<Best_loss):
      Best_loss = epoch_loss_valid
      torch.save(lmu.state_dict(), 'lmu_best'+'.pt')
      torch.save(linear.state_dict(), 'linear_lmu_best'+'.pt')		
  
lmu_loss_values_train = []
lmu_loss_values_valid = []   
model_ft = train_model(num_epochs=1000)  

In [None]:
def test_model():
    lmu.eval()
    linear.eval()

    for inputs, labels, lengths in test_dataloader:
      
      loss = 0.0 

      inputs = inputs.reshape(160,77,88)  #(time_steps, batch, hidden_size)
      labels = labels.reshape(160,77,88)
      
      hx = torch.zeros(77, 400) # (batch, hidden_size)
      cx = torch.zeros(77, 44)
      output_lmu = []
      
      for i in range(inputs.size()[0]):
          hx, cx = lmu(inputs[i], (hx,cx))
          output_lmu.append(hx)

      output_lmu = torch.stack(output_lmu, dim=0)
      output = linear(output_lmu)
            
      for i in range(inputs.size()[1]):
        x = lengths[i]
        loss += loss_function(output[:x,i,:],labels[:x,i,:])

      running_loss_test = loss.item() * inputs.size(0)
      epoch_loss_test = running_loss_test / len(test_dataset)
      break
    print(f'Test Loss: {epoch_loss_test:.8f}')

In [None]:
lmu = ASSVMU(input_size=88,hidden_size=400,memory_size=44,theta=5)
linear = nn.Linear(400,88)

lmu.load_state_dict(torch.load("lmu_best.pt"))
linear.load_state_dict(torch.load("linear_lmu_best.pt"))

test_model()