In [9]:
import torch
import torch.nn as nn
import numpy as np
import glob
import string
import csv
import os
from matplotlib import pyplot as plt
import sys

# various helper functions
from torch_name_classifier_helpers import readLines
from torch_name_classifier_helpers import categoryFromOutput
from torch_name_classifier_helpers import textToTensor

# declare RNN
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_layer, hidden_layer):
        combined_layer = torch.cat((input_layer,hidden_layer), 1)
        hidden_layer = self.i2h(combined_layer)
        output_layer = self.i2o(combined_layer)
        output_layer = self.softmax(output_layer)
        return output_layer, hidden_layer
    
    def initHidden(self):
        return torch.zeros(1,self.hidden_size)

def predict(the_rnn, line_tensor):
    hidden_layer = the_rnn.initHidden()

    for i in range(list(line_tensor.size())[0]):
        output_layer,hidden_layer = the_rnn(line_tensor[i], hidden_layer)

    return output_layer
        

def main():
    # declare regex for files containing names
    fnames = 'data/names/*.txt'

    # assemble sequence of valid ASCII characters
    # that can occur in a name
    all_letters = string.ascii_letters + " .,;'"
    n_letters = len(all_letters)

    # Build the category_lines dictionary, a list of names per language
    category_lines = {}
    all_categories = []

    nfiles = 0
    for filename in glob.glob(fnames):
        # basename of file is the lanquage
        category = os.path.splitext(os.path.basename(filename))[0]
        # add category (i.e. language) to list
        all_categories.append(category)
        # add names to dictionary, indexed by language
        lines = readLines(filename, all_letters)
        category_lines[category] = lines
        nfiles += 1
    if(nfiles == 0):
        print("No files found for regular expression ("+fnames+")")
        sys.exit(-1)
        
    # count number of languages (i.e. classes)
    n_categories = len(all_categories)

    # write categories to csv file
    with open('all_categories.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerows([all_categories])

    # 4. create instance of the RNN
    n_input_neurons = n_letters
    n_hidden_neurons = 256
    n_output_neurons = n_categories
    MyRNN = RNN(n_input_neurons, n_hidden_neurons, n_output_neurons)
    
    # 5. load checkpoint, if available
    le='mnist_names_model.pkl'
    if(os.path.isfile(checkpoint_file)==True):
        print("Resuming from checkpoint ("+checkpoint_file+")")
        MyRNN.load_state_dict(torch.load(checkpoint_file))
        MyRNN.eval()
    
    names = ["Popescu","Fernandez","Velenzuela","Lovecraft","Chambers","Davies","Paltrowski","Sargiannis","Ovechkin","Fapp"]
    for name in names:
        line_tensor = textToTensor(name, all_letters)
        output = predict(MyRNN, line_tensor)
        guess, guess_idx = categoryFromOutput(output, all_categories)
        print("name = {:s}, origin = {:s}".format(name, guess))

main()



Resuming from checkpoint (mnist_names_model.pkl)
name = Popescu, origin = Czech
name = Fernandez, origin = Portuguese
name = Velenzuela, origin = German
name = Lovecraft, origin = French
name = Chambers, origin = English
name = Davies, origin = Portuguese
name = Paltrowski, origin = Polish
name = Sargiannis, origin = Greek
name = Ovechkin, origin = Russian
name = Fapp, origin = German
