In [98]:
from io import open
import glob
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

def findFiles(path): return glob.glob(path)
import unicodedata
import string

all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
print(all_letters)
data_path = '../../data/names/*.txt'

vocabs = ""
categories = []
lang_lines = {}
for filename in glob.glob(data_path):
    category = os.path.splitext(os.path.basename(filename))[0]
    categories.append(category)
    data = open(filename, encoding='utf-8').read().strip()
    vocabs += data
    lines = vocabs.split('\n')
    lang_lines[category] = lines

n_category = len(categories)
print(n_category)
    
itoa = dict(enumerate(sorted(list(set(vocabs)))))
atoi = { i: k for k, i in itoa.items()}

def to_char(idxs):
    return [itoa[idx] for idx in idxs]

def to_index(name):
    return [atoi[ch] for ch in name]
    
n_vocabs = len(itoa)
print(atoi, len(itoa), n_letters)

category_lines = {}
for category, lines in lang_lines.items():
    category_lines[category] = [torch.tensor(to_index(line)) for line in lines]

abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
18
{'\n': 0, ' ': 1, "'": 2, ',': 3, '-': 4, '/': 5, '1': 6, ':': 7, 'A': 8, 'B': 9, 'C': 10, 'D': 11, 'E': 12, 'F': 13, 'G': 14, 'H': 15, 'I': 16, 'J': 17, 'K': 18, 'L': 19, 'M': 20, 'N': 21, 'O': 22, 'P': 23, 'Q': 24, 'R': 25, 'S': 26, 'T': 27, 'U': 28, 'V': 29, 'W': 30, 'X': 31, 'Y': 32, 'Z': 33, 'a': 34, 'b': 35, 'c': 36, 'd': 37, 'e': 38, 'f': 39, 'g': 40, 'h': 41, 'i': 42, 'j': 43, 'k': 44, 'l': 45, 'm': 46, 'n': 47, 'o': 48, 'p': 49, 'q': 50, 'r': 51, 's': 52, 't': 53, 'u': 54, 'v': 55, 'w': 56, 'x': 57, 'y': 58, 'z': 59, '\xa0': 60, 'Á': 61, 'É': 62, 'ß': 63, 'à': 64, 'á': 65, 'ã': 66, 'ä': 67, 'ç': 68, 'è': 69, 'é': 70, 'ê': 71, 'ì': 72, 'í': 73, 'ñ': 74, 'ò': 75, 'ó': 76, 'õ': 77, 'ö': 78, 'ù': 79, 'ú': 80, 'ü': 81, 'ą': 82, 'ł': 83, 'ń': 84, 'Ś': 85, 'Ż': 86, 'ż': 87} 88 57


In [74]:
F.one_hot(category_lines['Russian'][0], num_classes=n_vocabs).view(-1, 1, n_vocabs).shape

torch.Size([6, 1, 88])

In [99]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        
    def forward(self, input, hidden, targets=None):
        hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
        output = self.h2o(hidden)
        output = F.softmax(output, dim=1)
        if targets is None:
            return (output, hidden), None
        loss = F.cross_entropy(output, targets)
        return (output, hidden), loss

    def _init_hidden(self):
        return torch.zeros(1, self.hidden_size)

n_hidden = 128
rnn = RNN(n_vocabs, n_hidden, n_category)

In [113]:
# forward
hidden = torch.zeros(1, n_hidden)
input = torch.zeros(6, 88)
input[:, 0] = 1
(output, next_hidden), p = rnn(input, hidden)
print(output.shape, next_hidden.shape, p)

torch.Size([6, 18]) torch.Size([6, 128]) None


In [118]:
def fetch_label(output):
    top_n, top_i = output.topk(1)
    label_i = top_i[0].item()
    return categories[label_i], label_i

fetch_label(output)

('Arabic', 14)