#### 1. Import and clean data

In [10]:
from io import open
import glob
import os

files_path = '../data/names/*.txt'
def find_files(path): return(glob.glob(path))

print(find_files(files_path))


['../data/names\\Arabic.txt', '../data/names\\Chinese.txt', '../data/names\\Czech.txt', '../data/names\\Dutch.txt', '../data/names\\English.txt', '../data/names\\French.txt', '../data/names\\German.txt', '../data/names\\Greek.txt', '../data/names\\Irish.txt', '../data/names\\Italian.txt', '../data/names\\Japanese.txt', '../data/names\\Korean.txt', '../data/names\\Polish.txt', '../data/names\\Portuguese.txt', '../data/names\\Russian.txt', '../data/names\\Scottish.txt', '../data/names\\Spanish.txt', '../data/names\\Vietnamese.txt']


In [40]:
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

#Turn unicode strings to ASCII
def unicode_to_ascii(s):
	return ''.join(
		char for char in unicodedata.normalize('NFD', s)
		if unicodedata.category(char) != 'Mn' and char in all_letters
	)

print(unicode_to_ascii('Ślusàrski'))

Slusarski


In [41]:
category_lines = {}
all_categories = []

# read a file and split it into lines
def read_lines(file_name):
	lines = open(file_name, encoding='utf-8').read().strip().split('\n')
	return[unicode_to_ascii(line) for line in lines]

for file_name in find_files(files_path):
	category = os.path.splitext(os.path.basename(file_name))[0]
	all_categories.append(category)
	lines = read_lines(file_name)
	category_lines[category] = lines

n_categories = len(all_categories)

print(category_lines['Arabic'][:5])

['Khoury', 'Nahas', 'Daher', 'Gerges', 'Nazari']


#### 2. Convert words to tensors

In [51]:
import torch
def letter_index(letter):
	return all_letters.find(letter)

def letter_to_tensor(letter):
	tensor = torch.zeros(1, n_letters)
	tensor[0][letter_index(letter)] = 1
	return tensor

def word_to_tensor(word):
	tensor = torch.zeros(len(word), 1, n_letters)
	for li, letter in enumerate(word):
		tensor[li][0][letter_index(letter)] = 1
	return tensor

print(letter_index('K'))
print(letter_to_tensor('a'))
print(word_to_tensor('abc').size())

36
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])
torch.Size([3, 1, 57])


#### 3. RNN Network design

In [58]:
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        
        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
    
n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)
print(summary(rnn))

Layer (type:depth-idx)                   Param #
RNN                                      --
├─Linear: 1-1                            7,424
├─Linear: 1-2                            16,512
├─Linear: 1-3                            2,322
├─LogSoftmax: 1-4                        --
Total params: 26,258
Trainable params: 26,258
Non-trainable params: 0


In [68]:
input = word_to_tensor('Attah')
hidden = torch.zeros(1, n_hidden)

output, next_hidden = rnn(input[0], hidden)
print(output)

tensor([[ 0.1447, -0.1512, -0.0462,  0.0378, -0.0272,  0.0865, -0.1072,  0.0457,
          0.1263, -0.0115, -0.1232,  0.0299, -0.0021,  0.1349, -0.1821, -0.0454,
          0.0132,  0.1349,  0.1013, -0.2199, -0.2411,  0.0691, -0.0732, -0.0198,
         -0.0259,  0.0060,  0.0868,  0.0205,  0.0364, -0.0270, -0.1076, -0.1061,
          0.2526,  0.0429, -0.0818, -0.0126, -0.1232,  0.0953, -0.0693,  0.0078,
         -0.0837, -0.0430,  0.1228,  0.0030, -0.0246, -0.0420,  0.1188, -0.3048,
          0.1657,  0.0510, -0.0451,  0.1143, -0.1227, -0.0630,  0.0284, -0.1008,
          0.0566, -0.1749,  0.0127, -0.0451,  0.0439, -0.0319, -0.1743, -0.0397,
          0.0778, -0.1822,  0.0654, -0.0158, -0.0979, -0.0595, -0.0740, -0.2435,
          0.0022, -0.0104,  0.0071,  0.1444,  0.1011,  0.0457, -0.0607, -0.0658,
         -0.0062, -0.0786,  0.0271,  0.1777, -0.0894,  0.1612,  0.0932, -0.0561,
         -0.1343,  0.2906,  0.2672,  0.0531,  0.0243,  0.2434,  0.0251,  0.0076,
         -0.2629,  0.2194,  