In [1]:
import torch

import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')

In [2]:
def sigmoid(x):
    s = 1 / (1 + torch.exp(-x))
    return s
def weight_init(net):
    if isinstance(net, nn.Conv1d):
        nn.init.kaiming_uniform_(net.weight, mode='fan_in', nonlinearity='leaky_relu')

class ReservoirNet(nn.Module):
    def __init__(self, inSize, resSize, a):
        super(ReservoirNet, self).__init__()
        self.inSize = inSize
        self.resSize = resSize
        self.a = a
        self.Win = (torch.rand([self.resSize, 1 + self.inSize]) - 0.5) * 2.4
        self.W = (torch.rand(self.resSize, self.resSize) - 0.5)
        self.Win[abs(self.Win) > 0.6] = 0
        self.rhoW = max(abs(torch.linalg.eig(self.W)[0]))
        self.W *= 1.25 / self.rhoW
        self.reg = 1e-12
        self.one = torch.ones([1, 1])


    def RCPred(self, Wout, RCin):
        T = RCin.size(0)
        X = torch.zeros([1 + self.inSize + self.resSize, T])
        x = torch.zeros((self.resSize, 1))
        for t in range(RCin.size(0)):
            u = RCin[t:t + 1, :].T
            x = (1 - self.a) * x + self.a * sigmoid(torch.matmul(self.Win, torch.vstack((self.one, u))) + torch.matmul(self.W, x))
            X[:, t] = torch.vstack((self.one, u, x))[:, 0]

        pred = Wout @ X
        return pred


    def forward(self, data, labels):
        self.U = data
        self.Yt = labels
        self.T = labels.size(0)
        self.X = torch.zeros([1 + self.inSize + self.resSize, self.T])
        self.x = torch.zeros((self.resSize, 1))

        for t in range(self.U.size(0)):
            self.u = self.U[t:t + 1, :].T
            self.x = (1 - self.a) * self.x + self.a * sigmoid(
                torch.matmul(self.Win, torch.vstack((self.one, self.u))) + torch.matmul(self.W, self.x))
            self.X[:, t] = torch.vstack((self.one, self.u, self.x))[:, 0]

        self.Wout = torch.matmul(torch.matmul(self.Yt.T, self.X.T),
                                 torch.linalg.inv(
                                     torch.matmul(self.X, self.X.T) + self.reg * torch.eye(1 + self.inSize + self.resSize)))

        return self.Wout


class ConvNet(nn.Module):
    def __init__(self, inSize, filters,kernel_size=3):
        super(ConvNet, self).__init__()
        self.inSize = inSize
        self.filter1, self.filter2, self.filter3, self.filter4 = filters
        self.k=kernel_size

        self.layer1 = nn.Sequential(nn.Conv1d(2, self.filter1, kernel_size=self.k, bias=False),
                                    nn.PReLU(),
                                    nn.MaxPool1d(3, stride=1))

        self.layer2 = nn.Sequential(nn.Conv1d(self.filter1, self.filter2, kernel_size=self.k, bias=False),
                                    nn.PReLU(),
                                    nn.AvgPool1d(3, stride=1))

        self.layer3 = nn.Sequential(nn.Conv1d(self.filter2, self.filter3, kernel_size=self.k, bias=False),
                                    nn.PReLU(),
                                    nn.AvgPool1d(3, stride=2))

        self.layer4 = nn.Sequential(nn.Conv1d(self.filter3, self.filter4, kernel_size=self.k, bias=False),
                                    nn.PReLU())

    def forward_once(self, x):
        x = x.to(torch.float32)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

    def forward(self, input):
        ConvNetemb = self.forward_once(input)
        out  = ConvNetemb.view((ConvNetemb.shape[0],int(self.inSize)))

        return out

In [4]:
inSize = 4096
Win_net = ConvNet(inSize,[4,8,16,32])
Win_net.apply(weight_init)
x=torch.rand(512, 2, 272)
Win_net=Win_net.eval()
x=Win_net(x)

In [26]:
x[1,1]

tensor([-1.6124, -1.7749, -2.0579, -2.1327, -1.9999, -1.7906, -1.6315, -1.5286,
        -1.4008, -1.5771, -2.0293, -2.1057, -1.9073, -1.9054, -1.7923, -1.5839,
        -1.4060, -1.3411, -1.5463, -1.8125, -1.9610, -1.9915, -2.0922, -2.1072,
        -1.7374, -1.4982, -1.6652, -1.6949, -1.5828, -1.4917, -1.3611, -1.2176,
        -1.0786, -0.9373, -1.0824, -1.5156, -1.9107, -1.9566, -1.9245, -1.8493,
        -1.4848, -1.1928, -1.4472, -1.7742, -1.8519, -1.8886, -1.8564, -1.7596,
        -1.8152, -1.8930, -1.7366, -1.5446, -1.7070, -1.9440, -1.9394, -1.7257,
        -1.6307, -1.6631, -1.5741, -1.5366, -1.7615, -1.9319, -1.7608, -1.5288,
        -1.5126, -1.6639, -1.5470, -1.3340, -1.6021, -1.9232, -1.8519, -1.6863,
        -1.8354, -1.9819, -1.9949, -1.9938, -1.9733, -1.7432, -1.5178, -1.6967,
        -1.9960, -1.9663, -1.6199, -1.2945, -1.2026, -1.2204, -1.1856, -1.1178,
        -1.0387, -1.1667, -1.4382, -1.7191, -1.8202, -1.7145, -1.5945, -1.4249,
        -1.4844, -1.6555, -1.6285, -1.65

In [10]:
class ReservoirModel(nn.Module):
    def __init__(self,inSize,filters,resSize,a,vocab_size=16,embedding_size=16):
        super(ReservoirModel, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size)
        self.Win_net = ConvNet(inSize,filters)
        self.Win_net.apply(weight_init)
        self.Win_net=self.Win_net.eval()
        self.Wout_net = ReservoirNet(inSize,resSize,a)
        self.softmax = nn.Softmax(1)
    #512,1,272
    def forward(self, seq,sig,label):
        seq_emb = self.embed(seq.long())
        seq_emb = seq_emb.reshape(seq_emb.shape[0], 1, -1)
        sig = sig.reshape(sig.shape[0], 1, -1)
        x = torch.cat((seq_emb,sig), dim=1).to(torch.float32)
        x = self.Win_net(x)
        x = self.Wout_net(x,label)

In [None]:
inSize = 4096
resSize = 2000
filter=[4,8,16,32]
a = 0.3
x=torch.rand(512, 1, 17)
y=torch.rand(512, 1, 272)
label=torch.rand(512,1)
net=ReservoirModel(inSize,filter,resSize,a)
z=net(x,y,label)
