In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

**HyperParameters**

In [None]:
train_batch_size = 128
max_len = 9

embedding_dim = 100 # word embedding
# GRU
num_layers = 1
hidden_size = 64

**Num_Sequence**

In [None]:
class Num_sequence:
  UNK_TAG = 'UNK'
  PAD_TAG = 'PAD'
  SOS_TAG = 'SOS' # start of sequence
  EOS_TAG = 'EOS' # end of sequence

  UNK = 0
  PAD = 1
  SOS = 2
  EOS = 3

  def __init__(self) -> None:
    self.dict = {
        self.PAD_TAG : self.PAD,
        self.UNK_TAG : self.UNK,
        self.SOS_TAG : self.SOS,
        self.EOS_TAG : self.EOS,
         }

    for i in range(10):
      self.dict[str(i)] = len(self.dict)
    
    self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))
  
  def transform(self, sentence, max_len, add_eos=False):
    '''sentence 2 number
    add_eos: True: sentence length = max_len + 1
    add_eos: False: sentence length = max_len
    
    '''
    if len(sentence) > max_len: # cut if sentence > max_len
      sentence = sentence[:max_len]

    sentence_len = len(sentence)  # must calculate lenth of sentence previously
      
    if add_eos:
      sentence = sentence + [self.EOS_TAG]

    if sentence_len < max_len: # add padding if sentence < max_len
      sentence = sentence + [self.PAD_TAG]*(max_len-sentence_len)


    result = [self.dict.get(i, self.UNK) for i in sentence]

    return result
  
  def inverse_transform(self, indices):
    '''seq 2 sentence'''
    [self.inverse_dict.get(i, self.UNK_TAG) for i in indices]
  
  def __len__(self):
    return len(self.dict)

In [None]:
num_sequence = Num_sequence()

**Dataset**  
Prepare dataset and dataloader

1. In targets of the samples, EOS and SOS are needed to label the start and the end of the network.  
2. Add EOS in the target and transform.  

In [None]:
class NumDataset(Dataset):
  def __init__(self):
    # generate random number with numpy
    self.data = np.random.randint(0, 1e8, size=[500000])
  
  def __getitem__(self, index):
    input = list(str(self.data[index]))
    label = input + ['0']
    input_length = len(input)
    label_length = len(label)
    return input, label, input_length, label_length
  

  def __len__(self):
    return len(self.data)

In [None]:
def collate_fn(batch):
  '''
  :param batch: [(input, label, input_length, label_length), (input, label, input_length, label_length)]
  :return:
  '''

  batch = sorted(batch, key=lambda x: x[3], reverse=True) # big -> small
  
  input, target, input_length, target_length = zip(*batch)

  input = [num_sequence.transform(i, max_len=max_len) for i in input]
  target = [num_sequence.transform(i, max_len=max_len+1) for i in target]

  input = torch.LongTensor(input)
  target = torch.LongTensor(target)
  input_length = torch.LongTensor(input_length)
  target_length = torch.LongTensor(target_length)


  return input, target, input_length, target_length

In [None]:
data_set = NumDataset()
train_data_loader = DataLoader(data_set, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)

**Encoder**

Before using GRU, there are two API for accelerating the calculation.  
1. pad_packed_sequence(out, batct_first, padding_value) *unpack*
2. pack_padded_sequence(embedded, real_length, batch_first) *pack*
3. Before using the two API, sort the batch in descending order.

In [None]:
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

In [None]:
class Encoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.embedding = nn.Embedding(num_embeddings=len(num_sequence), embedding_dim=embedding_dim, padding_idx=num_sequence.PAD)
    self.gru = nn.GRU(input_size=embedding_dim, num_layers=num_layers, hidden_size=hidden_size, batch_first=True)
  
  def forward(self, input, input_length):
    '''
    :param input: [batch_size, max_len]
    :return 
    '''
    embeded = self.embedding(input) # [batch_size, max_len, embedding_dim]

    # pack to accelerate calculation
    embeded = pack_padded_sequence(embeded, input_length, batch_first=True)

    output, hidden = self.gru(embeded)

    # unpack
    output, output_length = pad_packed_sequence(output, batch_first=True)

    # hidden: [1*1, batch_size, hidden_size]
    # output: [batch_size, seq_len, hidden_size]
    return output, hidden, output_length


In [None]:
encoder = Encoder()
print(encoder)
for input, target, input_length, target_length in train_data_loader:
  out, hidden, output_length = encoder(input, input_length)
  print(out.size())
  print(hidden.size())
  print(output_length)
  break

Encoder(
  (embedding): Embedding(14, 100, padding_idx=1)
  (gru): GRU(100, 64, batch_first=True)
)
torch.Size([128, 8, 64])
torch.Size([1, 128, 64])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7])


**Decoder**

1. The out put of the encoder is a classification problem. We choose the output with a highest probability. 
2. The output of the decoder is [batch_size, max_len, vocab_size].
3. Loss function: Cross Entropy

In [None]:
import torch.nn.functional as F

In [None]:
class Decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.embedding = nn.Embedding(num_embeddings=len(num_sequence), embedding_dim=embedding_dim,padding_idx=num_sequence.PAD)
    self.gru = nn.GRU(input_size=embedding_dim,
                      hidden_size=hidden_size,
                      num_layers=num_layers,
                      batch_first=True)
    self.fc = nn.Linear(hidden_size, len(num_sequence))
  
  def forward(self, target, encoder_hidden):
    # 1. Get output from encoder, pass it into the hidden_state of decoder for the fitst time
    decoder_hidden = encoder_hidden
    # 2. Prepare the input for decoder for the first time, SOS with size of [batch_size, 1]
    batch_size = target.size(0)
    decoder_input = torch.LongTensor(torch.ones([batch_size, 1], dtype=torch.int64))*num_sequence.SOS
    # 3. Calculate at the first time stamp, get output and hidden_state

    # 4. Calculate the next output according to previous output
    # 5. Put previous hidden_state and output as current hidden_state and input
    # 6. Recurrsion step 4 and step 5
    for i in range(max_len + 2):
      decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)

      value, index = torch.topk(decoder_output_t, 1)
      decoder_input = index
  
  def forward_step(self, decoder_input, decoder_hidden):
    '''
    calculate output at each time stamp
    :param decoder_input: [batch_size, 1]
    :param decoder_hidden: [1, batch_size, hidden_size]
    :return:    
    '''
    decoder_input_embedded = self.embedding(decoder_input)  # [batch_size, 1, embedding_dim]

    # out: [batch_size, 1, hidden_size] It is 1 because at the first point seq_len=1
    # decoder_hidden: [1, batch_size, hidden_size]
    out, decoder_hidden = self.gru(decoder_input_embedded)

    out = out.squeeze(1) # [batch_size, hidden_size]
    out = self.fc(out)  # [batch_size, vocab_size]
    output = F.log_softmax(out, dim=-1) # [batch_size, vocab_size]

    return output, decoder_hidden
    

In [None]:
encoder = Encoder()
decoder = Decoder()
print(encoder)
print(decoder)
for input, target, input_length, target_length in train_data_loader:
  out, encoder_hidden, _ = encoder(input, input_length)
  decoder(target, encoder_hidden)
  print(out.size())
  print(hidden.size())
  print(output_length)
  break

Encoder(
  (embedding): Embedding(14, 100, padding_idx=1)
  (gru): GRU(100, 64, batch_first=True)
)
Decoder(
  (embedding): Embedding(14, 100, padding_idx=1)
  (gru): GRU(100, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=14, bias=True)
)
torch.Size([128, 8, 64])
torch.Size([1, 128, 64])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7])
