In [7]:
import torch.nn as nn
import torch

input = torch.randn(3,22,15000)
input.shape

torch.Size([3, 22, 15000])

In [8]:
nn.Conv1d(in_channels=22, out_channels=33, kernel_size=2, stride=2, padding=0)

Conv1d(22, 33, kernel_size=(2,), stride=(2,))

In [22]:
class Block(nn.Module):
    def __init__(self,inplace):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv1d(in_channels=inplace, out_channels=32, kernel_size=8, stride=2, padding=3)
        self.relu = nn.ReLU()

    def forward(self,x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x = torch.cat([x1,x2,x3], dim = 1)
        return x

In [23]:
block = Block(22)
out1 = block(input)
out1.shape

torch.Size([3, 96, 7500])

In [24]:
block = Block(96)
out2 = block(out1)
out2.shape

torch.Size([3, 96, 3750])

In [25]:
block = Block(96)
out3 = block(out2)
out3.shape

torch.Size([3, 96, 1875])

In [27]:
gru = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
x=out3.permute(0,2,1)
print(x.shape)

output, hn = gru(x)
print(output.shape, hn.shape)

torch.Size([3, 1875, 96])
torch.Size([3, 1875, 32]) torch.Size([1, 3, 32])


In [29]:
gru1 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
gru2 = nn.GRU(input_size=32, hidden_size=32,batch_first=True)
gru_out1,_ = gru1(x)
gru_out2,_ = gru2(gru_out1)
gru_out = torch.cat([gru_out1,gru_out2], dim=2)
gru_out.shape

torch.Size([3, 1875, 64])

In [31]:
gru3 = nn.GRU(input_size=64, hidden_size=32,batch_first=True)
gru_out3,_ = gru3(gru_out)
gru_out3.shape


torch.Size([3, 1875, 32])

In [32]:
gru_out = torch.cat([gru_out1,gru_out2, gru_out3], dim=2)
gru_out.shape

torch.Size([3, 1875, 96])

In [35]:
gru4 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
gru_out4,_ = gru4(gru_out)
gru_out4.shape

torch.Size([3, 1875, 32])

In [36]:
linear = nn.Linear(1875,1)
linear_out = linear(gru_out.permute(0,2,1))
linear_out.shape

torch.Size([3, 96, 1])

In [37]:
gru4 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
gru_out4,_ = gru4(linear_out.permute(0,2,1))
gru_out4.shape

torch.Size([3, 1, 32])

In [52]:
class ChronoNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = Block(22)
        self.block2 = Block(96)
        self.block3 = Block(96)

        self.gru1 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
        self.gru2 = nn.GRU(input_size=32, hidden_size=32,batch_first=True)
        self.gru3 = nn.GRU(input_size=64, hidden_size=32,batch_first=True)
        self.gru4 = nn.GRU(input_size=96, hidden_size=32,batch_first=True)
        
        self.gru_linear = nn.Linear(1875,1)
        self.flattern = nn.Flatten()
        self.fcl = nn.Linear(32,1)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        x = x.permute(0,2,1)

        gru_out1,_ = self.gru1(x)
        gru_out2,_= self.gru2(gru_out1)
        gru_out = torch.cat([gru_out1, gru_out2], dim=2)

        gru_out3,_ = self.gru3(gru_out)
        gru_out = torch.cat([gru_out1,gru_out2, gru_out3], dim=2)
        linear_out = self.relu(self.gru_linear(gru_out.permute(0,2,1)))

        gru_out4,_ = self.gru4(linear_out.permute(0,2,1))
        
        x = self.flattern(gru_out4)
        x = self.fcl(x)

        return x

In [53]:
model = ChronoNet()
out = model(input)