In [1]:
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import sys
import string
import torch as t
import torch.nn as nn
from torch.autograd import Variable as V
import torch.optim as optim

t.manual_seed(777)   # reproducibility

<torch._C.Generator at 0x7f3f2c725e30>

#### 1. input data

In [2]:
#            0,   1,   2,   3,   4
idx2char = ['h', 'i', 'e', 'l', 'o']

# Teach hihell -> ihello
x_data = [0, 1, 0, 2, 3, 3]    # hihell
one_hot_lookup = [[1, 0, 0, 0, 0],  # 0
                  [0, 1, 0, 0, 0],  # 1
                  [0, 0, 1, 0, 0],  # 2
                  [0, 0, 0, 1, 0],  # 3
                  [0, 0, 0, 0, 1]]  # 4
y_data = [1, 0, 2, 3, 3, 4]    # ihello
x_one_hot = [one_hot_lookup[x] for x in x_data]

# As we have one batch of samples, we will change them to variables only once
inputs = V(t.Tensor(x_one_hot))
labels = V(t.LongTensor(y_data))

print("inputs size: ", inputs.size())
print("labels size: ", labels.size())


inputs size:  torch.Size([6, 5])
labels size:  torch.Size([6])


#### 2. Define model

In [3]:
num_classes = 5
input_size = 5   # one-hot size
hidden_size = 5  # output from the RNN. 5 to directly predict one-hot
batch_size = 1   # one sentence
sequence_length = 1 # One by one
num_layers = 1      # one-layer rnn

In [4]:
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.rnn = nn.RNN(input_size=input_size, 
                          hidden_size=hidden_size,
                          num_layers=num_layers,
                          batch_first=True)
    
    def forward(self, x, hidden):
        
        # Reshape input(batch first)
        x = x.view(batch_size, sequence_length, input_size)
        
        # Propagate input through RNN
        # Input: (batch, seq_len, input_size)
        # hidden: (batch, num_layers * num_directions, hidden_size)
        out, hidden = self.rnn(x, hidden)
        return out.view(-1, num_classes), hidden
    
    def init_hidden(self):
        # Initialize hidden and cell states
        # (batch, num_layers * num_directions, hidden_size) for batch_first=True
        return V(t.zeros(batch_size, num_layers, hidden_size)).cuda()

#### 3. Train model

In [5]:
# Instantiate RNN model
model = Model()
print(model)
if t.cuda.is_available():
    model.cuda()

# Set loss and optimizer function
# CrossEntropyLoss = LogSoftmax + NLLLoss
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
for epoch in range(100):
    optimizer.zero_grad()
    loss = 0
    hidden = model.init_hidden()
    
    sys.stdout.write("predicted string: ")
    
    # Propagate one word for every step
    for input, label in zip(inputs, labels):
        if t.cuda.is_available():
            input = input.cuda()
            label = label.cuda()
        output, hidden = model(input, hidden)
        val, idx = output.max(1)
        sys.stdout.write(idx2char[idx.data[0]])
        loss += criterion(output, label)
    print("Epoch: %d, Loss: %1.3f" % (epoch + 1, loss.data[0]))
    
    loss.backward()
    optimizer.step()
    
print("Learning finished!!!")


Model(
  (rnn): RNN(5, 5, batch_first=True)
)
predicted string: llllllEpoch: 1, Loss: 10.155
predicted string: llllllEpoch: 2, Loss: 9.995
predicted string: llllllEpoch: 3, Loss: 9.843
predicted string: llllllEpoch: 4, Loss: 9.702
predicted string: llllllEpoch: 5, Loss: 9.571
predicted string: llllllEpoch: 6, Loss: 9.452
predicted string: llllllEpoch: 7, Loss: 9.342
predicted string: llllllEpoch: 8, Loss: 9.238
predicted string: llllllEpoch: 9, Loss: 9.136
predicted string: llllllEpoch: 10, Loss: 9.033
predicted string: llllllEpoch: 11, Loss: 8.928
predicted string: llllllEpoch: 12, Loss: 8.822
predicted string: llllllEpoch: 13, Loss: 8.715
predicted string: llllllEpoch: 14, Loss: 8.607
predicted string: llllllEpoch: 15, Loss: 8.498
predicted string: llllllEpoch: 16, Loss: 8.389
predicted string: llllllEpoch: 17, Loss: 8.281
predicted string: llllllEpoch: 18, Loss: 8.172
predicted string: llllllEpoch: 19, Loss: 8.063
predicted string: llllllEpoch: 20, Loss: 7.954
predicted string: llll