In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [51]:
"""# Building Encoder"""

class Encoder(nn.Module):
  def __init__(self,input_dim,embed_dim,hidden_dim,segment_dim,n_layers,dropout,segment_threshold,device):
    super().__init__()
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.n_layers = n_layers
    self.segment_threshold = segment_threshold
    self.segment_dim = segment_dim
    self.device = device
    
    self.embedding = nn.Embedding(input_dim,embed_dim)
    self.rnn = nn.GRU(embed_dim,hidden_dim,n_layers,dropout=dropout,bidirectional=True)

    self.segmentRnn = nn.GRU(hidden_dim*2,segment_dim,n_layers,dropout=dropout)
    # self.fc = nn.Linear(hidden_dim*2,hidden_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self,input):

    #input = [src len, batch size]
    embedded = self.dropout(self.embedding(input))
    #embedded = [src len, batch size, emb dim]

    outputs, hidden = self.rnn(embedded)
    #outputs = [src len, batch size, hid dim * num directions]
    #hidden = [n layers * num directions, batch size, hid dim]
        
    segment_encoding, hidden = self.segment_rnn(outputs)
    #segment_encoding = [src len* (src len+1)/2, batch size, segment_dim*num_directions]
    #hidden = [n layers * num_directions, batch size, hid dim]

    # hidden = torch.tanh(self.fc(torch.cat((hidden[-2],hidden[-1]),dim=1)))

    return segment_encoding,hidden

  def segment_rnn(self,outputs):
    N = outputs.shape[0]
    print('outputs :', outputs.shape)
    batch_size = outputs.shape[1]
    dp_forward = torch.zeros(N, N, batch_size, self.segment_dim).to(self.device)
    dp_backward = torch.zeros(N, N, batch_size, self.segment_dim).to(self.device)
    print('dp_forward :', dp_forward.shape)
    print('dp_backward :', dp_backward.shape)

    for i in range(N):
      hidden_forward = torch.randn(self.n_layers, batch_size, self.hidden_dim).to(self.device)
      print('hidden_forward :', hidden_forward.shape)
      print('f condition :', i,',', min(N, i + self.segment_threshold))
      for j in range(i, min(N, i + self.segment_threshold)):
        
        # outputs[j] = [batch size, hidden_dim* num_direction]
        next_input = outputs[j].unsqueeze(0)
        print('f next_input :', next_input.shape)
        # next_input = [1, batch size, hidden_dim* num_direction]
        
        out, hidden_forward = self.segmentRnn(next_input,hidden_forward)
        print('f out :', out.shape)
        #out = [1, batch size, segment_dim]
        #hidden_forward = [n layers , batch size, hid dim]
        print('f out squeeze:', out.squeeze(0).shape)

        dp_forward[i][j] = out.squeeze(0)

    for i in range(N):
      hidden_backward = torch.randn(self.n_layers, batch_size, self.hidden_dim).to(self.device)
      print('hidden_backward :', hidden_backward.shape)
      print('b condition :',i,',', max(-1, i - self.segment_threshold))
      for j in range(i, max(-1, i - self.segment_threshold), -1):

        # outputs[j] = [batch size, hidden_dim* num_direction]
        next_input = outputs[j].unsqueeze(0)
        # next_input = [1, batch size, hidden_dim* num_direction]
        
        out, hidden_backward = self.segmentRnn(next_input,hidden_backward)
        #out = [1, batch size, segment_dim]
        #hidden_backward = [n layers , batch size, hid dim]
        
        dp_backward[j][i] = out.squeeze(0)
    
    dp = torch.cat((dp_forward,dp_backward),dim=3)
    dp = dp[torch.triu(torch.ones(N, N)) == 1]
    return dp,torch.cat((hidden_forward,hidden_backward),dim=2)

In [52]:
input_dim = 64
embed_dim = 128
hidden_dim = 128
segment_dim = 128
n_layers = 2
dropout = 0.4
segment_threshold = 4
device = 'cpu'

In [53]:
enc = Encoder(input_dim, embed_dim, hidden_dim, segment_dim, n_layers, dropout, segment_threshold, device)

In [54]:
enc

Encoder(
  (embedding): Embedding(64, 128)
  (rnn): GRU(128, 128, num_layers=2, dropout=0.4, bidirectional=True)
  (segmentRnn): GRU(256, 128, num_layers=2, dropout=0.4)
  (dropout): Dropout(p=0.4, inplace=False)
)

In [55]:
inputs = np.random.randint(50, size=(4,96))
inputs = torch.from_numpy(inputs)

In [56]:
enc(inputs)

outputs : torch.Size([4, 96, 256])
dp_forward : torch.Size([4, 4, 96, 128])
dp_backward : torch.Size([4, 4, 96, 128])
hidden_forward : torch.Size([2, 96, 128])
f condition : 0 , 4
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
hidden_forward : torch.Size([2, 96, 128])
f condition : 1 , 4
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size([96, 128])
f next_input : torch.Size([1, 96, 256])
f out : torch.Size([1, 96, 128])
f out squeeze: torch.Size

(tensor([[[ 1.6903, -1.0416,  0.1952,  ..., -0.2684,  0.8140,  0.0920],
          [-0.1308, -0.1469,  0.3786,  ...,  0.1485, -0.4135, -0.0331],
          [-0.3406, -0.1709, -0.5119,  ...,  0.1385, -0.2987,  0.0240],
          ...,
          [-0.1609,  0.6451,  0.2526,  ..., -0.2893, -0.2642, -0.0646],
          [-0.5268, -0.5851, -0.6335,  ...,  1.7021,  0.0222, -0.8598],
          [ 0.9743, -0.1840, -0.6940,  ...,  0.3823,  1.1701, -0.5797]],
 
         [[ 1.1594, -0.5827,  0.0364,  ..., -0.3593, -0.0909, -0.3255],
          [-0.0824,  0.0204,  0.2786,  ..., -0.2978, -0.4635, -0.4056],
          [-0.1682, -0.0154, -0.0692,  ..., -0.2043, -0.5647,  0.1498],
          ...,
          [-0.0562,  0.2657,  0.1794,  ..., -0.0520, -0.4269,  0.3653],
          [-0.3715, -0.1650, -0.5096,  ..., -0.4506,  0.1396, -0.0742],
          [ 0.7220, -0.2693, -0.3038,  ..., -0.2139, -0.8202, -0.0259]],
 
         [[ 0.5950, -0.3648,  0.0765,  ...,  0.0560, -0.3390, -0.1510],
          [ 0.0426,  0.0567,