In [1]:
# data: https://download.pytorch.org/tutorial/data.zip
import io
import os
import unicodedata
import string
import glob

import torch
import random

# alphabet small + capital letters + " .,;'"
ALL_LETTERS = string.ascii_letters + " .,;'"
N_LETTERS = len(ALL_LETTERS)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in ALL_LETTERS
    )

def load_data():
    # Build the category_lines dictionary, a list of names per language
    category_lines = {}
    all_categories = []
    
    def find_files(path):
        return glob.glob(path)
    
    # Read a file and split into lines
    def read_lines(filename):
        lines = io.open(filename, encoding='utf-8').read().strip().split('\n')
        return [unicode_to_ascii(line) for line in lines]
    
    for filename in find_files('data/names/*.txt'):
        category = os.path.splitext(os.path.basename(filename))[0]
        all_categories.append(category)
        
        lines = read_lines(filename)
        category_lines[category] = lines
        
    return category_lines, all_categories



# Find letter index from all_letters, e.g. "a" = 0
def letter_to_index(letter):
    return ALL_LETTERS.find(letter)

# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letter_to_tensor(letter):
    tensor = torch.zeros(1, N_LETTERS)
    tensor[0][letter_to_index(letter)] = 1
    return tensor

# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def line_to_tensor(line):
    tensor = torch.zeros(len(line), 1, N_LETTERS)
    for i, letter in enumerate(line):
        tensor[i][0][letter_to_index(letter)] = 1
    return tensor


def random_training_example(category_lines, all_categories):
    
    def random_choice(a):
        random_idx = random.randint(0, len(a) - 1)
        return a[random_idx]
    
    category = random_choice(all_categories)
    line = random_choice(category_lines[category])
    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
    line_tensor = line_to_tensor(line)
    return category, line, category_tensor, line_tensor



if __name__ == '__main__':
    print(ALL_LETTERS)
    print(unicode_to_ascii('Ślusàrski'))
    
    category_lines, all_categories = load_data()
    print(category_lines['Italian'][:5])
    
    print(letter_to_tensor('J')) # [1, 57]
    print(line_to_tensor('Jones')) # [5, 1, 57]

abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
Slusarski
['Abandonato', 'Abatangelo', 'Abatantuono', 'Abate', 'Abategiovanni']
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [2]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

In [3]:
class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.i2h = nn.Linear(input_size + hidden_size,hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size,output_size)
        self.softmax = nn.LogSoftmax(dim=1) # 1,57
        
    def forward(self,input_tensor,hidden_tensor):
        combined = torch.cat((input_tensor,hidden_tensor),1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output,hidden
    
    def init_hidden(self):
        return torch.zeros(1,self.hidden_size)

In [4]:
category_lines,all_categories = load_data()

In [5]:
n_categories = len(all_categories)

In [6]:
N_LETTERS

57

In [7]:
n_hidden = 128

In [8]:
rnn = RNN(N_LETTERS,n_hidden,n_categories)

In [9]:
input_tensor = letter_to_tensor('A')

In [10]:
hidden_tensor = rnn.init_hidden()

In [11]:
output,next_hidden = rnn(input_tensor,hidden_tensor)

In [12]:
output.shape

torch.Size([1, 18])

In [13]:
next_hidden.shape

torch.Size([1, 128])

In [14]:
input_tensor = line_to_tensor('Ranuga')
hidden_tensor = rnn.init_hidden()
output,next_hidden = rnn(input_tensor[0],hidden_tensor)

In [15]:
def category_from_output(output):
    category_idx = torch.argmax(output).item()
    return all_categories[category_idx]

In [16]:
category_from_output(output)

'Italian'

In [17]:
criterion = nn.NLLLoss()

In [18]:
lr = 0.005

In [19]:
optimizer = torch.optim.SGD(rnn.parameters(),lr=lr)

In [20]:
epochs = 100

In [21]:
def train(line_tensor,category_tensor):
    hidden = rnn.init_hidden()
    for i in range(line_tensor.size()[0]):
        output,hidden = rnn(line_tensor[i],hidden)
    loss = criterion(output,category_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return output,loss.item()

In [22]:
current_loss = 0
all_losses = []

In [23]:
plot_steps,print_steps = 1000,5000

In [24]:
n_iters = 1000000

In [26]:
from tqdm import tqdm

In [None]:
for i in tqdm(range(n_iters)):
    category,line,category_tensor,line_tensor = random_training_example(category_lines,all_categories)
    output,loss = train(line_tensor,category_tensor)
    current_loss += loss
    if (i+1) % plot_steps ==0:
        all_losses.append(current_loss/plot_steps)
        current_loss = 0
    if (i+1) % print_steps == 0:
        guess = category_from_output(output)
        correct = 'CORRECT' if guess == category else f'WRONG (Guess : {guess} | Acctual : {category})'
        print(f'{i}/{n_iters} | {loss} {line} | {guess} {category} | {correct}')

  1%|          | 5152/1000000 [00:04<15:28, 1071.69it/s]

4999/1000000 | 2.704515218734741 Medeiros | Greek Portuguese | WRONG (Guess : Greek | Acctual : Portuguese)


  1%|          | 10145/1000000 [00:09<15:06, 1091.72it/s]

9999/1000000 | 0.8876245617866516 Panoulias | Greek Greek | CORRECT


  2%|▏         | 15178/1000000 [00:14<15:07, 1085.42it/s]

14999/1000000 | 2.4915952682495117 Kozumplikova | French Czech | WRONG (Guess : French | Acctual : Czech)


  2%|▏         | 20187/1000000 [00:18<15:05, 1081.80it/s]

19999/1000000 | 0.7686255574226379 Sheng | Chinese Chinese | CORRECT


  3%|▎         | 25197/1000000 [00:23<14:47, 1098.06it/s]

24999/1000000 | 1.4492430686950684 Adibekoff | Russian Russian | CORRECT


  3%|▎         | 30200/1000000 [00:27<14:34, 1108.88it/s]

29999/1000000 | 1.3047791719436646 Niall | Irish Irish | CORRECT


  4%|▎         | 35213/1000000 [00:32<15:01, 1070.28it/s]

34999/1000000 | 1.2435686588287354 Achteren | Dutch Dutch | CORRECT


  4%|▍         | 40169/1000000 [00:37<14:42, 1087.89it/s]

39999/1000000 | 1.6188135147094727 Cockle | English English | CORRECT


  5%|▍         | 45121/1000000 [00:41<14:46, 1076.65it/s]

44999/1000000 | 1.3436455726623535 Sebastiani | Japanese Italian | WRONG (Guess : Japanese | Acctual : Italian)


  5%|▌         | 50123/1000000 [00:46<14:35, 1085.19it/s]

49999/1000000 | 1.2966217994689941 Shalhoub | Arabic Arabic | CORRECT


  6%|▌         | 55185/1000000 [00:51<14:29, 1086.02it/s]

54999/1000000 | 4.061731338500977 Maly | Irish Czech | WRONG (Guess : Irish | Acctual : Czech)


  6%|▌         | 60166/1000000 [00:55<14:21, 1090.91it/s]

59999/1000000 | 2.2500391006469727 Arendonk | Italian Dutch | WRONG (Guess : Italian | Acctual : Dutch)


  7%|▋         | 65162/1000000 [01:00<14:16, 1090.93it/s]

64999/1000000 | 1.6926058530807495 Romeijn | Dutch Dutch | CORRECT


  7%|▋         | 70149/1000000 [01:04<14:09, 1095.09it/s]

69999/1000000 | 5.029425621032715 Ruzzier | German Italian | WRONG (Guess : German | Acctual : Italian)


  8%|▊         | 75116/1000000 [01:09<13:58, 1102.40it/s]

74999/1000000 | 0.5628539323806763 Azarola | Spanish Spanish | CORRECT


  8%|▊         | 80134/1000000 [01:13<14:09, 1083.46it/s]

79999/1000000 | 0.7785527110099792 Tamura | Japanese Japanese | CORRECT


  9%|▊         | 85192/1000000 [01:18<13:33, 1125.02it/s]

84999/1000000 | 0.23916968703269958 Iseya | Japanese Japanese | CORRECT


  9%|▉         | 90151/1000000 [01:22<13:46, 1100.91it/s]

89999/1000000 | 1.7265822887420654 Morgenstern | Dutch German | WRONG (Guess : Dutch | Acctual : German)


 10%|▉         | 95117/1000000 [01:27<13:30, 1116.89it/s]

94999/1000000 | 1.760408878326416 Gilder | Scottish English | WRONG (Guess : Scottish | Acctual : English)


 10%|█         | 100181/1000000 [01:31<13:09, 1139.34it/s]

99999/1000000 | 2.114992141723633 Dinko | Russian Czech | WRONG (Guess : Russian | Acctual : Czech)


 11%|█         | 105194/1000000 [01:36<13:49, 1078.53it/s]

104999/1000000 | 0.5543521046638489 Souza | Portuguese Portuguese | CORRECT


 11%|█         | 110132/1000000 [01:41<13:20, 1110.98it/s]

109999/1000000 | 2.6491169929504395 Roig | Korean Spanish | WRONG (Guess : Korean | Acctual : Spanish)


 12%|█▏        | 115101/1000000 [01:45<13:51, 1064.77it/s]

114999/1000000 | 4.23534631729126 Sano | Chinese Japanese | WRONG (Guess : Chinese | Acctual : Japanese)


 12%|█▏        | 120162/1000000 [01:50<14:19, 1023.84it/s]

119999/1000000 | 0.20072714984416962 Gomolka | Polish Polish | CORRECT


 13%|█▎        | 125109/1000000 [01:55<13:37, 1069.73it/s]

124999/1000000 | 0.019312094897031784 Sienkiewicz | Polish Polish | CORRECT


 13%|█▎        | 130201/1000000 [01:59<13:27, 1076.93it/s]

129999/1000000 | 0.7672692537307739 Dobrushin | Russian Russian | CORRECT


 14%|█▎        | 135170/1000000 [02:04<13:50, 1041.92it/s]

134999/1000000 | 1.7088594436645508 Korycan | English Czech | WRONG (Guess : English | Acctual : Czech)


 14%|█▍        | 140147/1000000 [02:09<13:44, 1043.32it/s]

139999/1000000 | 0.09278237074613571 Yeon | Korean Korean | CORRECT


 15%|█▍        | 145193/1000000 [02:14<13:08, 1084.44it/s]

144999/1000000 | 3.5566458702087402 Mullins | English French | WRONG (Guess : English | Acctual : French)


 15%|█▌        | 150120/1000000 [02:18<13:06, 1080.01it/s]

149999/1000000 | 1.3960727453231812 Bonaventura | Italian Spanish | WRONG (Guess : Italian | Acctual : Spanish)


 16%|█▌        | 155125/1000000 [02:26<19:54, 707.37it/s] 

154999/1000000 | 1.2059718370437622 Chong | Vietnamese Korean | WRONG (Guess : Vietnamese | Acctual : Korean)


 16%|█▌        | 160154/1000000 [02:32<14:00, 999.80it/s] 

159999/1000000 | 0.0005926521262153983 Stavropoulos | Greek Greek | CORRECT


 17%|█▋        | 165055/1000000 [02:37<16:58, 819.47it/s] 

164999/1000000 | 0.9164418578147888 Rowley | English English | CORRECT


 17%|█▋        | 170128/1000000 [02:43<14:12, 973.61it/s] 

169999/1000000 | 1.5025854110717773 Li | Vietnamese Korean | WRONG (Guess : Vietnamese | Acctual : Korean)


 18%|█▊        | 175152/1000000 [02:48<12:58, 1059.36it/s]

174999/1000000 | 2.4967734813690186 Botros | Portuguese Arabic | WRONG (Guess : Portuguese | Acctual : Arabic)


 18%|█▊        | 180144/1000000 [02:53<12:30, 1093.13it/s]

179999/1000000 | 0.0009098681039176881 Yoshizawa | Japanese Japanese | CORRECT


 19%|█▊        | 185141/1000000 [02:57<12:35, 1079.23it/s]

184999/1000000 | 0.12674649059772491 Moraitopoulos | Greek Greek | CORRECT


 19%|█▉        | 190194/1000000 [03:02<12:05, 1116.65it/s]

189999/1000000 | 0.6340600848197937 Mackay | Scottish Scottish | CORRECT


 20%|█▉        | 195160/1000000 [03:06<12:03, 1111.76it/s]

194999/1000000 | 0.29696041345596313 Ra | Korean Korean | CORRECT


 20%|██        | 200177/1000000 [03:11<12:09, 1096.76it/s]

199999/1000000 | 0.3747742772102356 Petimezas | Greek Greek | CORRECT


 21%|██        | 205191/1000000 [03:15<12:04, 1097.70it/s]

204999/1000000 | 4.0457844734191895 Day | Vietnamese English | WRONG (Guess : Vietnamese | Acctual : English)


 21%|██        | 210112/1000000 [03:20<12:04, 1090.53it/s]

209999/1000000 | 1.9477243423461914 Foong | German Chinese | WRONG (Guess : German | Acctual : Chinese)


 22%|██▏       | 215173/1000000 [03:25<11:58, 1092.89it/s]

214999/1000000 | 0.1851138472557068 Petrakis | Greek Greek | CORRECT


 22%|██▏       | 220222/1000000 [03:29<11:49, 1099.25it/s]

219999/1000000 | 0.9622392058372498 Henriques | Portuguese Portuguese | CORRECT


 23%|██▎       | 225129/1000000 [03:34<11:47, 1095.95it/s]

224999/1000000 | 2.095142364501953 Simon | Dutch Irish | WRONG (Guess : Dutch | Acctual : Irish)


 23%|██▎       | 230219/1000000 [03:39<11:39, 1100.08it/s]

229999/1000000 | 1.0537118911743164 Suchanka | Japanese Czech | WRONG (Guess : Japanese | Acctual : Czech)


 24%|██▎       | 235119/1000000 [03:43<11:57, 1065.36it/s]

234999/1000000 | 0.07350566238164902 Narvaez | Spanish Spanish | CORRECT


 24%|██▍       | 240190/1000000 [03:48<12:08, 1042.77it/s]

239999/1000000 | 0.8959550261497498 Goodridge | English English | CORRECT


 25%|██▍       | 245127/1000000 [03:53<11:34, 1086.48it/s]

244999/1000000 | 0.05467815324664116 Delacroix | French French | CORRECT


 25%|██▌       | 250172/1000000 [03:57<11:37, 1075.40it/s]

249999/1000000 | 0.2541266977787018 Doan | Vietnamese Vietnamese | CORRECT


 26%|██▌       | 255131/1000000 [04:02<11:35, 1071.04it/s]

254999/1000000 | 1.5777606964111328 Ling | Vietnamese Chinese | WRONG (Guess : Vietnamese | Acctual : Chinese)


 26%|██▌       | 260211/1000000 [04:07<11:32, 1068.73it/s]

259999/1000000 | 0.3632030785083771 Sugimura | Japanese Japanese | CORRECT


 27%|██▋       | 265121/1000000 [04:11<11:42, 1046.29it/s]

264999/1000000 | 0.1451493203639984 Jang | Korean Korean | CORRECT


 27%|██▋       | 270165/1000000 [04:16<11:50, 1026.64it/s]

269999/1000000 | 0.04359056428074837 Capitani | Italian Italian | CORRECT


 28%|██▊       | 275139/1000000 [04:21<11:15, 1073.22it/s]

274999/1000000 | 0.3309023380279541 Christodoulou | Greek Greek | CORRECT


 28%|██▊       | 280159/1000000 [04:26<11:24, 1052.22it/s]

279999/1000000 | 0.43636268377304077 Higoshi | Japanese Japanese | CORRECT


 29%|██▊       | 285175/1000000 [04:30<10:47, 1103.15it/s]

284999/1000000 | 0.22340157628059387 Taguchi | Japanese Japanese | CORRECT


 29%|██▉       | 290135/1000000 [04:35<10:49, 1093.15it/s]

289999/1000000 | 1.1387531757354736 Kramer | Dutch German | WRONG (Guess : Dutch | Acctual : German)


 30%|██▉       | 295182/1000000 [04:40<10:58, 1070.85it/s]

294999/1000000 | 3.1621508598327637 Fuse | German Japanese | WRONG (Guess : German | Acctual : Japanese)


 30%|██▉       | 299108/1000000 [04:44<11:18, 1032.77it/s]

In [None]:
plt.figure()
plt.plot(all_losses)
plt.show()

In [None]:
def predict(input_line):
    print(f'\n {input_line}')
    with torch.no_grad():
        line_tensor = line_to_tensor(input_line)
        hidden = rnn.init_hidden()
        for i in range(line_tensor.size()[0]):
            output,hidden = rnn(input_tensor[i],hidden)
        guess = category_from_output(output)
        print(guess)

In [None]:
while True:
    sentence = input('Input a Name : ')
    if sentenc == 'quit':
        break
    predict(sentence)