In [3]:
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import time
import math
import numpy as np

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as V
import torch.utils.data as Data
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from pyfile.text_loader import TextDataset
import pyfile.seq2seq_models as sm
from pyfile.seq2seq_models import str2tensor, EOS_token, SOS_token

In [8]:
HIDDEN_SIZE = 100
N_LAYERS = 1
BATCH_SIZE = 1
N_EPOCH = 100
N_CHARS = 128 # ASCII

print("EOS_token: ", EOS_token)   # 1
print("SOS_token: ", SOS_token)   # SOS_token = chd(0)


EOS_token:  1
SOS_token:   


In [9]:
# Train for a given src and target
def train(src, target):
    src_var = str2tensor(src)
    target_var = str2tensor(target, eos=True)    # Add the EOS token
    
    encoder_hidden = encoder.init_hidden()
    encoder_outputs, encoder_hidden = encoder(src_var, encoder_hidden)
    
    hidden = encoder_hidden
    loss = 0
    
    for c in range(len(target_var)):
        # First, we feed SOS
        # others, use teacher forcing
        token = target_var[c - 1] if c else str2tensor(SOS_token)
        output, hidden = decoder(token, hidden)
        loss += criterion(output, target[c])
    
    eocoder.zero_grad()
    decoder.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.data[0] / len(target_var)


In [10]:
# Simple test to show how our network works
def test():
    encoder_hidden = encoder.init_hidden()
    word_input = str2tensor('hello')
    encoder_outputs, encoder_hidden = encoder(word_input, encoder_hidden)
    print("encoder outputs: ", encoder_outputs)
    
    decoder_hidden = encoder_hidden
    
    word_target = str2tensor('pytorch')
    for c in range(len(word_target)):
        decoder_output, decoder_hidden = decoder(word_target[c], decoder_hidden)
        print("decoder output size: ", decoder_output.size())
        print("decoder hidden size: ", decoder_hidden.size())



In [11]:
# Translate the given input
def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature=0.9):
    input_var = str2tensor(enc_input)
    
    encoder_hidden = encoder.init_hidden()
    encoder_outputs, encoder_hidden = encoder(input_var, encoder_hidden)
    
    hidden = encoder_hidden
    
    predicted = ''
    dec_input = str2tensor(SOS_token)
    for c in range(predict_len):
        output, hidden = decoder(dec_input, hidden)
        
        # Sample from the network as a multi nominal distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = t.multinomial(output_dist, 1)[0]
        
        # Stop at the EOS
        if top_i is EOS_token:
            break
        
        predicted_char = chr(top_i)
        predicted += predicted_char
        
        dec_input = str2tensor(predicted_char)
    
    return enc_input, predicted

In [None]:
# main

encoder = sm.EncoderRNN(N_CHARS, HIDDEN_SIZE, N_LAYERS)
decoder = sm.DecoderRNN(HIDDEN_SIZE, N_CHARS, N_LAYERS)

if t.cuda.is_available():
    encoder.cuda()
    decoder.cuda()
print("encoder: ", encoder)
print("decoder: ", decoder)
test()


# Optimizer and Loss
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = optim.Adam(params, lr=0.001)
criterion = nn.CrossEntropyLoss()

train_dataloader = Data.DataLoader(dataset=TextDataset(),
                                   batch_size=BATCH_SIZE,
                                   shuffle=True,
                                   num_workers=2)
print("Training for %d epochs..." % N_EPOCH)

for epoch in range(1, N_EPOCH + 1):
    for i, (srcs, targets) in enumerate(train_dataloader):
        train_loss = train(srcs[0], targets[0])      # Batch is 1
        
        if i % 100 == 0:
            print("Epoch: (%d %d%%) Loss: %.4f" % 
                  (epoch, epoch / N_EPOCH * 100, train_loss))
            print(translate(srcs[0]), '\n')
            print(translate(), '\n')