In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

**HyperParameters**

In [None]:
train_batch_size = 128
max_len = 9

**Num_Sequence**

In [None]:
class Num_sequence:
  UNK_TAG = 'UNK'
  PAD_TAG = 'PAD'
  SOS_TAG = 'SOS' # start of sequence
  EOS_TAG = 'EOS' # end of sequence

  UNK = 0
  PAD = 1
  SOS = 2
  EOS = 3

  def __init__(self) -> None:
    self.dict = {
        self.PAD_TAG : self.PAD,
        self.UNK_TAG : self.UNK,
        self.SOS_TAG : self.SOS,
        self.EOS_TAG : self.EOS,
         }

    for i in range(10):
      self.dict[str(i)] = len(self.dict)
    
    self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))
  
  def transform(self, sentence, max_len, add_eos=False):
    '''sentence 2 number
    add_eos: True: sentence length = max_len + 1
    add_eos: False: sentence length = max_len
    
    '''
    if len(sentence) > max_len: # cut if sentence > max_len
      sentence = sentence[:max_len]

    sentence_len = len(sentence)  # must calculate lenth of sentence previously
      
    if add_eos:
      sentence = sentence + [self.EOS_TAG]

    if sentence_len < max_len: # add padding if sentence < max_len
      sentence = sentence + [self.PAD_TAG]*(max_len-sentence_len)


    result = [self.dict.get(i, self.UNK) for i in sentence]

    return result
  
  def inverse_transform(self, indices):
    '''seq 2 sentence'''
    [self.inverse_dict.get(i, self.UNK_TAG) for i in indices]

In [None]:
num_sequence = Num_sequence()

In [None]:
num_sequence.dict

{'0': 4,
 '1': 5,
 '2': 6,
 '3': 7,
 '4': 8,
 '5': 9,
 '6': 10,
 '7': 11,
 '8': 12,
 '9': 13,
 'EOS': 3,
 'PAD': 1,
 'SOS': 2,
 'UNK': 0}

**Dataset**  
Prepare dataset and dataloader

In [None]:
class NumDataset(Dataset):
  def __init__(self):
    # generate random number with numpy
    self.data = np.random.randint(0, 1e8, size=[500000])
  
  def __getitem__(self, index):
    input = list(str(self.data[index]))
    label = input + ['0']
    input_length = len(input)
    label_length = len(label)
    return input, label, input_length, label_length
  

  def __len__(self):
    return len(self.data)

In [None]:
def collate_fn(batch):
  '''
  :param batch: [(input, label, input_length, label_length), (input, label, input_length, label_length)]
  :return:
  '''

  batch = sorted(batch, key=lambda x: x[3], reverse=True) # big -> small
  
  input, target, input_length, target_length = zip(*batch)

  input = [num_sequence.transform(i, max_len=max_len) for i in input]
  target = [num_sequence.transform(i, max_len=max_len+1) for i in target]

  input = torch.LongTensor(input)
  target = torch.LongTensor(target)
  input_length = torch.LongTensor(input_length)
  target_length = torch.LongTensor(target_length)


  return input, target, input_length, target_length

In [None]:
data_set = NumDataset()
train_data_loader = DataLoader(data_set, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
for input, target, input_length, target_length in train_data_loader:
  print(input)
  print('*'*10)
  print(target)
  print('*'*10)
  print(input_length)
  print('*'*10)
  print(target_length)
  break

tensor([[ 6,  6,  7,  ..., 10,  9,  1],
        [ 6,  4,  7,  ...,  6,  8,  1],
        [ 5, 13,  6,  ...,  6, 13,  1],
        ...,
        [ 5, 13,  7,  ...,  9,  1,  1],
        [ 6,  8,  6,  ..., 13,  1,  1],
        [13,  5,  7,  ...,  1,  1,  1]])
**********
tensor([[ 6,  6,  7,  ...,  9,  4,  1],
        [ 6,  4,  7,  ...,  8,  4,  1],
        [ 5, 13,  6,  ..., 13,  4,  1],
        ...,
        [ 5, 13,  7,  ...,  4,  1,  1],
        [ 6,  8,  6,  ...,  4,  1,  1],
        [13,  5,  7,  ...,  1,  1,  1]])
**********
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 5])
**********
tensor([9, 9, 9, 9, 9, 9, 